agent_client_protocol_schema/
version.rs1use derive_more::{Display, From};
2use schemars::JsonSchema;
3use serde::Serialize;
4
5#[derive(
10 Debug, Clone, Copy, Serialize, JsonSchema, PartialEq, Eq, PartialOrd, Ord, From, Display,
11)]
12pub struct ProtocolVersion(u16);
13
14impl ProtocolVersion {
15 pub const V0: Self = Self(0);
21 pub const V1: Self = Self(1);
25 #[cfg(feature = "unstable_protocol_v2")]
32 pub const V2: Self = Self(2);
33 pub const LATEST: Self = Self::V1;
39
40 #[must_use]
42 pub const fn as_u16(self) -> u16 {
43 self.0
44 }
45
46 #[cfg(test)]
47 #[must_use]
48 const fn new(version: u16) -> Self {
49 Self(version)
50 }
51}
52
53use serde::{Deserialize, Deserializer};
54
55impl<'de> Deserialize<'de> for ProtocolVersion {
56 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
57 where
58 D: Deserializer<'de>,
59 {
60 use serde::de::{self, Visitor};
61 use std::fmt;
62
63 struct ProtocolVersionVisitor;
64
65 impl Visitor<'_> for ProtocolVersionVisitor {
66 type Value = ProtocolVersion;
67
68 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
69 formatter.write_str("a protocol version number or string")
70 }
71
72 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
73 where
74 E: de::Error,
75 {
76 match u16::try_from(value) {
77 Ok(value) => Ok(ProtocolVersion(value)),
78 Err(_) => Err(E::custom(format!("protocol version {value} is too large"))),
79 }
80 }
81
82 fn visit_str<E>(self, _value: &str) -> Result<Self::Value, E>
83 where
84 E: de::Error,
85 {
86 Ok(ProtocolVersion::V0)
88 }
89
90 fn visit_string<E>(self, _value: String) -> Result<Self::Value, E>
91 where
92 E: de::Error,
93 {
94 Ok(ProtocolVersion::V0)
96 }
97 }
98
99 deserializer.deserialize_any(ProtocolVersionVisitor)
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn test_deserialize_u64() {
109 let json = "1";
110 let version: ProtocolVersion = serde_json::from_str(json).unwrap();
111 assert_eq!(version, ProtocolVersion::new(1));
112 }
113
114 #[test]
115 fn test_deserialize_string() {
116 let json = "\"1.0.0\"";
117 let version: ProtocolVersion = serde_json::from_str(json).unwrap();
118 assert_eq!(version, ProtocolVersion::new(0));
119 }
120
121 #[test]
122 fn test_deserialize_large_number() {
123 let json = "100000";
124 let result: Result<ProtocolVersion, _> = serde_json::from_str(json);
125 assert!(result.is_err());
126 }
127
128 #[test]
129 fn test_deserialize_zero() {
130 let json = "0";
131 let version: ProtocolVersion = serde_json::from_str(json).unwrap();
132 assert_eq!(version, ProtocolVersion::new(0));
133 }
134
135 #[test]
136 fn test_deserialize_max_u16() {
137 let json = "65535";
138 let version: ProtocolVersion = serde_json::from_str(json).unwrap();
139 assert_eq!(version, ProtocolVersion::new(65535));
140 }
141
142 #[test]
143 fn test_as_u16() {
144 assert_eq!(ProtocolVersion::V0.as_u16(), 0);
145 assert_eq!(ProtocolVersion::V1.as_u16(), 1);
146 assert_eq!(ProtocolVersion::LATEST.as_u16(), 1);
147
148 #[cfg(feature = "unstable_protocol_v2")]
149 assert_eq!(ProtocolVersion::V2.as_u16(), 2);
150
151 assert_eq!(ProtocolVersion::new(65535).as_u16(), 65535);
152 }
153}