agent_client_protocol_schema/
version.rs

1use derive_more::{Display, From};
2use schemars::JsonSchema;
3use serde::Serialize;
4
5/// Protocol version identifier.
6///
7/// This version is only bumped for breaking changes.
8/// Non-breaking changes should be introduced via capabilities.
9#[derive(Debug, Clone, Serialize, JsonSchema, PartialEq, Eq, PartialOrd, Ord, From, Display)]
10pub struct ProtocolVersion(u16);
11
12impl ProtocolVersion {
13    /// Version `0` of the protocol.
14    ///
15    /// This was a pre-release version that shouldn't be used in production.
16    /// It is used as a fallback for any request whose version cannot be parsed
17    /// as a valid version, and should likely be treated as unsupported.
18    pub const V0: Self = Self(0);
19    /// Version `1` of the protocol.
20    ///
21    /// <https://agentclientprotocol.com/protocol/overview>
22    pub const V1: Self = Self(1);
23    /// The latest supported version of the protocol.
24    ///
25    /// Currently, this is version `1`.
26    pub const LATEST: Self = Self::V1;
27
28    #[cfg(test)]
29    #[must_use]
30    pub const fn new(version: u16) -> Self {
31        Self(version)
32    }
33}
34
35use serde::{Deserialize, Deserializer};
36
37impl<'de> Deserialize<'de> for ProtocolVersion {
38    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
39    where
40        D: Deserializer<'de>,
41    {
42        use serde::de::{self, Visitor};
43        use std::fmt;
44
45        struct ProtocolVersionVisitor;
46
47        impl Visitor<'_> for ProtocolVersionVisitor {
48            type Value = ProtocolVersion;
49
50            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
51                formatter.write_str("a protocol version number or string")
52            }
53
54            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
55            where
56                E: de::Error,
57            {
58                match u16::try_from(value) {
59                    Ok(value) => Ok(ProtocolVersion(value)),
60                    Err(_) => Err(E::custom(format!("protocol version {value} is too large"))),
61                }
62            }
63
64            fn visit_str<E>(self, _value: &str) -> Result<Self::Value, E>
65            where
66                E: de::Error,
67            {
68                // Old versions used strings, we consider all of those version 0
69                Ok(ProtocolVersion::V0)
70            }
71
72            fn visit_string<E>(self, _value: String) -> Result<Self::Value, E>
73            where
74                E: de::Error,
75            {
76                // Old versions used strings, we consider all of those version 0
77                Ok(ProtocolVersion::V0)
78            }
79        }
80
81        deserializer.deserialize_any(ProtocolVersionVisitor)
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88
89    #[test]
90    fn test_deserialize_u64() {
91        let json = "1";
92        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
93        assert_eq!(version, ProtocolVersion::new(1));
94    }
95
96    #[test]
97    fn test_deserialize_string() {
98        let json = "\"1.0.0\"";
99        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
100        assert_eq!(version, ProtocolVersion::new(0));
101    }
102
103    #[test]
104    fn test_deserialize_large_number() {
105        let json = "100000";
106        let result: Result<ProtocolVersion, _> = serde_json::from_str(json);
107        assert!(result.is_err());
108    }
109
110    #[test]
111    fn test_deserialize_zero() {
112        let json = "0";
113        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
114        assert_eq!(version, ProtocolVersion::new(0));
115    }
116
117    #[test]
118    fn test_deserialize_max_u16() {
119        let json = "65535";
120        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
121        assert_eq!(version, ProtocolVersion::new(65535));
122    }
123}