agent_client_protocol_schema/
version.rs

1use derive_more::{Display, From};
2use schemars::JsonSchema;
3use serde::Serialize;
4
5pub const V0: ProtocolVersion = ProtocolVersion(0);
6pub const V1: ProtocolVersion = ProtocolVersion(1);
7pub const VERSION: ProtocolVersion = V1;
8
9/// Protocol version identifier.
10///
11/// This version is only bumped for breaking changes.
12/// Non-breaking changes should be introduced via capabilities.
13#[derive(Debug, Clone, Serialize, JsonSchema, PartialEq, Eq, PartialOrd, Ord, From, Display)]
14pub struct ProtocolVersion(u16);
15
16impl ProtocolVersion {
17    #[cfg(test)]
18    #[must_use]
19    pub const fn new(version: u16) -> Self {
20        Self(version)
21    }
22}
23
24use serde::{Deserialize, Deserializer};
25
26impl<'de> Deserialize<'de> for ProtocolVersion {
27    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
28    where
29        D: Deserializer<'de>,
30    {
31        use serde::de::{self, Visitor};
32        use std::fmt;
33
34        struct ProtocolVersionVisitor;
35
36        impl Visitor<'_> for ProtocolVersionVisitor {
37            type Value = ProtocolVersion;
38
39            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
40                formatter.write_str("a protocol version number or string")
41            }
42
43            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
44            where
45                E: de::Error,
46            {
47                match u16::try_from(value) {
48                    Ok(value) => Ok(ProtocolVersion(value)),
49                    Err(_) => Err(E::custom(format!("protocol version {value} is too large"))),
50                }
51            }
52
53            fn visit_str<E>(self, _value: &str) -> Result<Self::Value, E>
54            where
55                E: de::Error,
56            {
57                // Old versions used strings, we consider all of those version 0
58                Ok(ProtocolVersion(0))
59            }
60
61            fn visit_string<E>(self, _value: String) -> Result<Self::Value, E>
62            where
63                E: de::Error,
64            {
65                // Old versions used strings, we consider all of those version 0
66                Ok(ProtocolVersion(0))
67            }
68        }
69
70        deserializer.deserialize_any(ProtocolVersionVisitor)
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77
78    #[test]
79    fn test_deserialize_u64() {
80        let json = "1";
81        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
82        assert_eq!(version, ProtocolVersion::new(1));
83    }
84
85    #[test]
86    fn test_deserialize_string() {
87        let json = "\"1.0.0\"";
88        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
89        assert_eq!(version, ProtocolVersion::new(0));
90    }
91
92    #[test]
93    fn test_deserialize_large_number() {
94        let json = "100000";
95        let result: Result<ProtocolVersion, _> = serde_json::from_str(json);
96        assert!(result.is_err());
97    }
98
99    #[test]
100    fn test_deserialize_zero() {
101        let json = "0";
102        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
103        assert_eq!(version, ProtocolVersion::new(0));
104    }
105
106    #[test]
107    fn test_deserialize_max_u16() {
108        let json = "65535";
109        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
110        assert_eq!(version, ProtocolVersion::new(65535));
111    }
112}