rig/
json_utils.rs

1use serde::de::{self, Deserializer, SeqAccess, Visitor};
2use serde::Deserialize;
3use std::convert::Infallible;
4use std::fmt;
5use std::marker::PhantomData;
6use std::str::FromStr;
7
8pub fn merge(a: serde_json::Value, b: serde_json::Value) -> serde_json::Value {
9    match (a, b) {
10        (serde_json::Value::Object(mut a_map), serde_json::Value::Object(b_map)) => {
11            b_map.into_iter().for_each(|(key, value)| {
12                a_map.insert(key, value);
13            });
14            serde_json::Value::Object(a_map)
15        }
16        (a, _) => a,
17    }
18}
19
20pub fn merge_inplace(a: &mut serde_json::Value, b: serde_json::Value) {
21    if let (serde_json::Value::Object(a_map), serde_json::Value::Object(b_map)) = (a, b) {
22        b_map.into_iter().for_each(|(key, value)| {
23            a_map.insert(key, value);
24        });
25    }
26}
27
28/// This module is helpful in cases where raw json objects are serialized and deserialized as
29///  strings such as `"{\"key\": \"value\"}"`. This might seem odd but it's actually how some
30///  some providers such as OpenAI return function arguments (for some reason).
31pub mod stringified_json {
32    use serde::{self, Deserialize, Deserializer, Serializer};
33
34    pub fn serialize<S>(value: &serde_json::Value, serializer: S) -> Result<S::Ok, S::Error>
35    where
36        S: Serializer,
37    {
38        let s = value.to_string();
39        serializer.serialize_str(&s)
40    }
41
42    pub fn deserialize<'de, D>(deserializer: D) -> Result<serde_json::Value, D::Error>
43    where
44        D: Deserializer<'de>,
45    {
46        let s = String::deserialize(deserializer)?;
47        serde_json::from_str(&s).map_err(serde::de::Error::custom)
48    }
49}
50
51pub fn string_or_vec<'de, T, D>(deserializer: D) -> Result<Vec<T>, D::Error>
52where
53    T: Deserialize<'de> + FromStr<Err = Infallible>,
54    D: Deserializer<'de>,
55{
56    struct StringOrVec<T>(PhantomData<fn() -> T>);
57
58    impl<'de, T> Visitor<'de> for StringOrVec<T>
59    where
60        T: Deserialize<'de> + FromStr<Err = Infallible>,
61    {
62        type Value = Vec<T>;
63
64        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
65            formatter.write_str("a string, sequence, or null")
66        }
67
68        fn visit_str<E>(self, value: &str) -> Result<Vec<T>, E>
69        where
70            E: de::Error,
71        {
72            let item = FromStr::from_str(value).map_err(de::Error::custom)?;
73            Ok(vec![item])
74        }
75
76        fn visit_seq<A>(self, seq: A) -> Result<Vec<T>, A::Error>
77        where
78            A: SeqAccess<'de>,
79        {
80            Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
81        }
82
83        fn visit_none<E>(self) -> Result<Vec<T>, E>
84        where
85            E: de::Error,
86        {
87            Ok(vec![])
88        }
89
90        fn visit_unit<E>(self) -> Result<Vec<T>, E>
91        where
92            E: de::Error,
93        {
94            Ok(vec![])
95        }
96    }
97
98    deserializer.deserialize_any(StringOrVec(PhantomData))
99}
100
101pub fn null_or_vec<'de, T, D>(deserializer: D) -> Result<Vec<T>, D::Error>
102where
103    T: Deserialize<'de>,
104    D: Deserializer<'de>,
105{
106    struct NullOrVec<T>(PhantomData<fn() -> T>);
107
108    impl<'de, T> Visitor<'de> for NullOrVec<T>
109    where
110        T: Deserialize<'de>,
111    {
112        type Value = Vec<T>;
113
114        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
115            formatter.write_str("a sequence or null")
116        }
117
118        fn visit_seq<A>(self, seq: A) -> Result<Vec<T>, A::Error>
119        where
120            A: SeqAccess<'de>,
121        {
122            Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
123        }
124
125        fn visit_none<E>(self) -> Result<Vec<T>, E>
126        where
127            E: de::Error,
128        {
129            Ok(vec![])
130        }
131
132        fn visit_unit<E>(self) -> Result<Vec<T>, E>
133        where
134            E: de::Error,
135        {
136            Ok(vec![])
137        }
138    }
139
140    deserializer.deserialize_any(NullOrVec(PhantomData))
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use serde::{Deserialize, Serialize};
147
148    #[derive(Serialize, Deserialize, Debug, PartialEq)]
149    struct Dummy {
150        #[serde(with = "stringified_json")]
151        data: serde_json::Value,
152    }
153
154    #[test]
155    fn test_merge() {
156        let a = serde_json::json!({"key1": "value1"});
157        let b = serde_json::json!({"key2": "value2"});
158        let result = merge(a, b);
159        let expected = serde_json::json!({"key1": "value1", "key2": "value2"});
160        assert_eq!(result, expected);
161    }
162
163    #[test]
164    fn test_merge_inplace() {
165        let mut a = serde_json::json!({"key1": "value1"});
166        let b = serde_json::json!({"key2": "value2"});
167        merge_inplace(&mut a, b);
168        let expected = serde_json::json!({"key1": "value1", "key2": "value2"});
169        assert_eq!(a, expected);
170    }
171
172    #[test]
173    fn test_stringified_json_serialize() {
174        let dummy = Dummy {
175            data: serde_json::json!({"key": "value"}),
176        };
177        let serialized = serde_json::to_string(&dummy).unwrap();
178        let expected = r#"{"data":"{\"key\":\"value\"}"}"#;
179        assert_eq!(serialized, expected);
180    }
181
182    #[test]
183    fn test_stringified_json_deserialize() {
184        let json_str = r#"{"data":"{\"key\":\"value\"}"}"#;
185        let dummy: Dummy = serde_json::from_str(json_str).unwrap();
186        let expected = Dummy {
187            data: serde_json::json!({"key": "value"}),
188        };
189        assert_eq!(dummy, expected);
190    }
191}