agent_client_protocol_schema/
version.rs

1use schemars::JsonSchema;
2use serde::Serialize;
3
4pub const V0: ProtocolVersion = ProtocolVersion(0);
5pub const V1: ProtocolVersion = ProtocolVersion(1);
6pub const VERSION: ProtocolVersion = V1;
7
8/// Protocol version identifier.
9///
10/// This version is only bumped for breaking changes.
11/// Non-breaking changes should be introduced via capabilities.
12#[derive(Default, Debug, Clone, Serialize, JsonSchema, PartialEq, Eq, PartialOrd, Ord)]
13pub struct ProtocolVersion(u16);
14
15impl ProtocolVersion {
16    #[cfg(test)]
17    #[must_use]
18    pub const fn new(version: u16) -> Self {
19        Self(version)
20    }
21}
22
23use serde::{Deserialize, Deserializer};
24
25impl<'de> Deserialize<'de> for ProtocolVersion {
26    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
27    where
28        D: Deserializer<'de>,
29    {
30        use serde::de::{self, Visitor};
31        use std::fmt;
32
33        struct ProtocolVersionVisitor;
34
35        impl Visitor<'_> for ProtocolVersionVisitor {
36            type Value = ProtocolVersion;
37
38            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
39                formatter.write_str("a protocol version number or string")
40            }
41
42            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
43            where
44                E: de::Error,
45            {
46                match u16::try_from(value) {
47                    Ok(value) => Ok(ProtocolVersion(value)),
48                    Err(_) => Err(E::custom(format!("protocol version {value} is too large"))),
49                }
50            }
51
52            fn visit_str<E>(self, _value: &str) -> Result<Self::Value, E>
53            where
54                E: de::Error,
55            {
56                // Old versions used strings, we consider all of those version 0
57                Ok(ProtocolVersion(0))
58            }
59
60            fn visit_string<E>(self, _value: String) -> Result<Self::Value, E>
61            where
62                E: de::Error,
63            {
64                // Old versions used strings, we consider all of those version 0
65                Ok(ProtocolVersion(0))
66            }
67        }
68
69        deserializer.deserialize_any(ProtocolVersionVisitor)
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76
77    #[test]
78    fn test_deserialize_u64() {
79        let json = "1";
80        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
81        assert_eq!(version, ProtocolVersion::new(1));
82    }
83
84    #[test]
85    fn test_deserialize_string() {
86        let json = "\"1.0.0\"";
87        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
88        assert_eq!(version, ProtocolVersion::new(0));
89    }
90
91    #[test]
92    fn test_deserialize_large_number() {
93        let json = "100000";
94        let result: Result<ProtocolVersion, _> = serde_json::from_str(json);
95        assert!(result.is_err());
96    }
97
98    #[test]
99    fn test_deserialize_zero() {
100        let json = "0";
101        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
102        assert_eq!(version, ProtocolVersion::new(0));
103    }
104
105    #[test]
106    fn test_deserialize_max_u16() {
107        let json = "65535";
108        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
109        assert_eq!(version, ProtocolVersion::new(65535));
110    }
111}