openapi_schema/extension/
extension.rs

1use std::collections::BTreeMap;
2use std::fmt;
3
4use serde::de::{MapAccess, Visitor};
5use serde::ser::SerializeMap;
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7
8/// Contains openapi specification extensions
9/// see https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.1.md#specificationExtensions
10#[derive(Debug, Eq, PartialEq, Clone)]
11pub struct Extensions(BTreeMap<String, serde_json::Value>);
12
13impl Extensions {
14    fn add(&mut self, ext_id: String, value: serde_json::Value) {
15        self.0.insert(ext_id, value);
16    }
17
18    /// Fetch extension by name
19    pub fn get(&self, ext_id: &str) -> Option<&serde_json::Value> {
20        self.0.get(ext_id)
21    }
22
23    /// A reference to all the captured extensions
24    pub fn all(&self) -> &BTreeMap<String, serde_json::Value> {
25        &self.0
26    }
27}
28
29impl Default for Extensions {
30    fn default() -> Self {
31        Self(BTreeMap::new())
32    }
33}
34
35impl<'de> Deserialize<'de> for Extensions {
36    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
37    where
38        D: Deserializer<'de>,
39    {
40        struct ExtensionsVisitor;
41        impl<'de> Visitor<'de> for ExtensionsVisitor {
42            type Value = Extensions;
43
44            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
45                formatter.write_str("struct Extensions")
46            }
47
48            fn visit_map<V>(self, mut map: V) -> Result<Extensions, V::Error>
49            where
50                V: MapAccess<'de>,
51            {
52                let mut extensions = Extensions::default();
53                while let Some(key) = map.next_key::<String>()? {
54                    if key.starts_with("x-") {
55                        extensions.add(key, map.next_value()?);
56                    }
57                }
58                Ok(extensions)
59            }
60        }
61        deserializer.deserialize_map(ExtensionsVisitor)
62    }
63}
64
65impl Serialize for Extensions {
66    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
67    where
68        S: Serializer,
69    {
70        let mut map = serializer.serialize_map(Some(self.0.len()))?;
71        for (k, v) in self.0.clone() {
72            map.serialize_entry(&k, &v)?;
73        }
74        map.end()
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use serde_json::Value;
81    use serde_test::{assert_tokens, Token};
82
83    use crate::extension::Extensions;
84
85    #[test]
86    fn test_serde_extensions() {
87        let mut extensions = Extensions::default();
88        extensions.add(String::from("x-test"), Value::from("val"));
89        assert_tokens(
90            &extensions,
91            &[
92                Token::Map { len: Some(1) },
93                Token::String("x-test"),
94                Token::String("val"),
95                Token::MapEnd,
96            ],
97        )
98    }
99
100    #[test]
101    fn test_get_extension() {
102        let value = Value::from("val");
103
104        let mut extensions = Extensions::default();
105        extensions.add(String::from("x-test"), value.clone());
106
107        assert_eq!(extensions.get("x-test"), Some(&value));
108    }
109
110    #[test]
111    fn test_all_extensions() {
112        let value = Value::from("val");
113
114        let mut extensions = Extensions::default();
115        extensions.add(String::from("x-test"), value.clone());
116
117        assert_eq!(
118            extensions.all().get_key_value("x-test"),
119            Some((&"x-test".to_string(), &value))
120        );
121    }
122}