1use serde::{Deserialize, Serialize};
2use smol_str::SmolStr;
3use std::fmt;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum ProtocolVersion {
8 ZeroZeroThree,
9
10 Unknown(SmolStr),
11}
12
13impl fmt::Display for ProtocolVersion {
14 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15 match self {
16 Self::ZeroZeroThree => f.write_str("0.03"),
17 Self::Unknown(v) => f.write_str(v.as_str()),
18 }
19 }
20}
21
22impl Serialize for ProtocolVersion {
23 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
24 where
25 S: serde::Serializer,
26 {
27 match self {
28 Self::ZeroZeroThree => "0.03".serialize(serializer),
29 Self::Unknown(v) => v.serialize(serializer),
30 }
31 }
32}
33
34struct ProtocolVisitor;
35
36impl<'de> serde::de::Visitor<'de> for ProtocolVisitor {
37 type Value = ProtocolVersion;
38
39 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
40 formatter.write_str("protocol version")
41 }
42
43 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
44 where
45 E: serde::de::Error,
46 {
47 match v {
48 "0.03" => Ok(ProtocolVersion::ZeroZeroThree),
49 _ => Ok(ProtocolVersion::Unknown(v.into())),
50 }
51 }
52}
53
54impl<'de> Deserialize<'de> for ProtocolVersion {
55 #[inline]
56 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
57 where
58 D: serde::Deserializer<'de>,
59 {
60 deserializer.deserialize_str(ProtocolVisitor)
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67 use serde_test::{assert_tokens, Token};
68
69 #[test]
70 fn protocol_version_zerozerothree_serde() {
71 let version = ProtocolVersion::ZeroZeroThree;
72
73 assert_tokens(&version, &[Token::Str("0.03")])
74 }
75
76 #[test]
77 fn protocol_version_other_serde() {
78 let version = ProtocolVersion::Unknown("other".into());
79
80 assert_tokens(&version, &[Token::Str("other")])
81 }
82}