Skip to main content

agent_client_protocol_schema/
version.rs

1use derive_more::{Display, From};
2use schemars::JsonSchema;
3use serde::Serialize;
4
5/// Protocol version identifier.
6///
7/// This version is only bumped for breaking changes.
8/// Non-breaking changes should be introduced via capabilities.
9#[derive(
10    Debug, Clone, Copy, Serialize, JsonSchema, PartialEq, Eq, PartialOrd, Ord, From, Display,
11)]
12pub struct ProtocolVersion(u16);
13
14impl ProtocolVersion {
15    /// Version `0` of the protocol.
16    ///
17    /// This was a pre-release version that shouldn't be used in production.
18    /// It is used as a fallback for any request whose version cannot be parsed
19    /// as a valid version, and should likely be treated as unsupported.
20    pub const V0: Self = Self(0);
21    /// Version `1` of the protocol.
22    ///
23    /// <https://agentclientprotocol.com/protocol/overview>
24    pub const V1: Self = Self(1);
25    /// Version `2` of the protocol.
26    ///
27    /// This is an unstable draft used for protocol iteration. It is only
28    /// available when the `unstable_protocol_v2` feature is enabled and is
29    /// **not** advertised by [`ProtocolVersion::LATEST`] yet — callers must
30    /// opt into V2 explicitly.
31    #[cfg(feature = "unstable_protocol_v2")]
32    pub const V2: Self = Self(2);
33    /// The latest stable supported version of the protocol.
34    ///
35    /// Currently this is version `1`. Enabling the `unstable_protocol_v2`
36    /// feature exposes `ProtocolVersion::V2` but does **not** change the
37    /// value of `LATEST` — v2 will only become the latest once it stabilizes.
38    pub const LATEST: Self = Self::V1;
39
40    /// Returns the numeric protocol version.
41    #[must_use]
42    pub const fn as_u16(self) -> u16 {
43        self.0
44    }
45
46    #[cfg(test)]
47    #[must_use]
48    const fn new(version: u16) -> Self {
49        Self(version)
50    }
51}
52
53use serde::{Deserialize, Deserializer};
54
55impl<'de> Deserialize<'de> for ProtocolVersion {
56    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
57    where
58        D: Deserializer<'de>,
59    {
60        use serde::de::{self, Visitor};
61        use std::fmt;
62
63        struct ProtocolVersionVisitor;
64
65        impl Visitor<'_> for ProtocolVersionVisitor {
66            type Value = ProtocolVersion;
67
68            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
69                formatter.write_str("a protocol version number or string")
70            }
71
72            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
73            where
74                E: de::Error,
75            {
76                match u16::try_from(value) {
77                    Ok(value) => Ok(ProtocolVersion(value)),
78                    Err(_) => Err(E::custom(format!("protocol version {value} is too large"))),
79                }
80            }
81
82            fn visit_str<E>(self, _value: &str) -> Result<Self::Value, E>
83            where
84                E: de::Error,
85            {
86                // Old versions used strings, we consider all of those version 0
87                Ok(ProtocolVersion::V0)
88            }
89
90            fn visit_string<E>(self, _value: String) -> Result<Self::Value, E>
91            where
92                E: de::Error,
93            {
94                // Old versions used strings, we consider all of those version 0
95                Ok(ProtocolVersion::V0)
96            }
97        }
98
99        deserializer.deserialize_any(ProtocolVersionVisitor)
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_deserialize_u64() {
109        let json = "1";
110        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
111        assert_eq!(version, ProtocolVersion::new(1));
112    }
113
114    #[test]
115    fn test_deserialize_string() {
116        let json = "\"1.0.0\"";
117        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
118        assert_eq!(version, ProtocolVersion::new(0));
119    }
120
121    #[test]
122    fn test_deserialize_large_number() {
123        let json = "100000";
124        let result: Result<ProtocolVersion, _> = serde_json::from_str(json);
125        assert!(result.is_err());
126    }
127
128    #[test]
129    fn test_deserialize_zero() {
130        let json = "0";
131        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
132        assert_eq!(version, ProtocolVersion::new(0));
133    }
134
135    #[test]
136    fn test_deserialize_max_u16() {
137        let json = "65535";
138        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
139        assert_eq!(version, ProtocolVersion::new(65535));
140    }
141
142    #[test]
143    fn test_as_u16() {
144        assert_eq!(ProtocolVersion::V0.as_u16(), 0);
145        assert_eq!(ProtocolVersion::V1.as_u16(), 1);
146        assert_eq!(ProtocolVersion::LATEST.as_u16(), 1);
147
148        #[cfg(feature = "unstable_protocol_v2")]
149        assert_eq!(ProtocolVersion::V2.as_u16(), 2);
150
151        assert_eq!(ProtocolVersion::new(65535).as_u16(), 65535);
152    }
153}