agent_client_protocol_schema/
version.rs

1use derive_more::{Display, From};
2use schemars::JsonSchema;
3use serde::Serialize;
4
5pub const V0: ProtocolVersion = ProtocolVersion(0);
6pub const V1: ProtocolVersion = ProtocolVersion(1);
7pub const VERSION: ProtocolVersion = V1;
8
9/// Protocol version identifier.
10///
11/// This version is only bumped for breaking changes.
12/// Non-breaking changes should be introduced via capabilities.
13#[derive(
14    Default, Debug, Clone, Serialize, JsonSchema, PartialEq, Eq, PartialOrd, Ord, From, Display,
15)]
16pub struct ProtocolVersion(u16);
17
18impl ProtocolVersion {
19    #[cfg(test)]
20    #[must_use]
21    pub const fn new(version: u16) -> Self {
22        Self(version)
23    }
24}
25
26use serde::{Deserialize, Deserializer};
27
28impl<'de> Deserialize<'de> for ProtocolVersion {
29    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
30    where
31        D: Deserializer<'de>,
32    {
33        use serde::de::{self, Visitor};
34        use std::fmt;
35
36        struct ProtocolVersionVisitor;
37
38        impl Visitor<'_> for ProtocolVersionVisitor {
39            type Value = ProtocolVersion;
40
41            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
42                formatter.write_str("a protocol version number or string")
43            }
44
45            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
46            where
47                E: de::Error,
48            {
49                match u16::try_from(value) {
50                    Ok(value) => Ok(ProtocolVersion(value)),
51                    Err(_) => Err(E::custom(format!("protocol version {value} is too large"))),
52                }
53            }
54
55            fn visit_str<E>(self, _value: &str) -> Result<Self::Value, E>
56            where
57                E: de::Error,
58            {
59                // Old versions used strings, we consider all of those version 0
60                Ok(ProtocolVersion(0))
61            }
62
63            fn visit_string<E>(self, _value: String) -> Result<Self::Value, E>
64            where
65                E: de::Error,
66            {
67                // Old versions used strings, we consider all of those version 0
68                Ok(ProtocolVersion(0))
69            }
70        }
71
72        deserializer.deserialize_any(ProtocolVersionVisitor)
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn test_deserialize_u64() {
82        let json = "1";
83        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
84        assert_eq!(version, ProtocolVersion::new(1));
85    }
86
87    #[test]
88    fn test_deserialize_string() {
89        let json = "\"1.0.0\"";
90        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
91        assert_eq!(version, ProtocolVersion::new(0));
92    }
93
94    #[test]
95    fn test_deserialize_large_number() {
96        let json = "100000";
97        let result: Result<ProtocolVersion, _> = serde_json::from_str(json);
98        assert!(result.is_err());
99    }
100
101    #[test]
102    fn test_deserialize_zero() {
103        let json = "0";
104        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
105        assert_eq!(version, ProtocolVersion::new(0));
106    }
107
108    #[test]
109    fn test_deserialize_max_u16() {
110        let json = "65535";
111        let version: ProtocolVersion = serde_json::from_str(json).unwrap();
112        assert_eq!(version, ProtocolVersion::new(65535));
113    }
114}