cedar_policy_core/
jsonvalue.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! This module provides general-purpose JSON utilities not specific to Cedar.
18
19use std::marker::PhantomData;
20
21use linked_hash_map::LinkedHashMap;
22use serde::de::{MapAccess, SeqAccess, Visitor};
23use serde::{Deserialize, Deserializer, Serialize};
24use std::fmt;
25
26/// Wrapper around `serde_json::Value`, with a different `Deserialize`
27/// implementation, such that duplicate keys in JSON objects (maps/records) are
28/// not allowed (result in an error).
29//
30// CAUTION: this type is publicly exported in `cedar-policy`.
31// Don't make fields `pub`, don't make breaking changes, and use caution
32// when adding public methods.
33#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
34pub struct JsonValueWithNoDuplicateKeys(serde_json::Value);
35
36impl std::ops::Deref for JsonValueWithNoDuplicateKeys {
37    type Target = serde_json::Value;
38    fn deref(&self) -> &Self::Target {
39        &self.0
40    }
41}
42
43// this implementation heavily borrows from the `Deserialize` implementation
44// for `serde_json::Value`
45impl<'de> Deserialize<'de> for JsonValueWithNoDuplicateKeys {
46    fn deserialize<D>(deserializer: D) -> Result<JsonValueWithNoDuplicateKeys, D::Error>
47    where
48        D: serde::Deserializer<'de>,
49    {
50        struct ValueVisitor;
51
52        impl<'de> Visitor<'de> for ValueVisitor {
53            type Value = JsonValueWithNoDuplicateKeys;
54
55            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56                formatter.write_str("any valid JSON value")
57            }
58
59            fn visit_bool<E>(self, value: bool) -> Result<JsonValueWithNoDuplicateKeys, E> {
60                Ok(JsonValueWithNoDuplicateKeys(serde_json::Value::Bool(value)))
61            }
62
63            fn visit_i64<E>(self, value: i64) -> Result<JsonValueWithNoDuplicateKeys, E> {
64                Ok(JsonValueWithNoDuplicateKeys(serde_json::Value::Number(
65                    value.into(),
66                )))
67            }
68
69            fn visit_u64<E>(self, value: u64) -> Result<JsonValueWithNoDuplicateKeys, E> {
70                Ok(JsonValueWithNoDuplicateKeys(serde_json::Value::Number(
71                    value.into(),
72                )))
73            }
74
75            fn visit_f64<E>(self, value: f64) -> Result<JsonValueWithNoDuplicateKeys, E> {
76                Ok(JsonValueWithNoDuplicateKeys(
77                    serde_json::Number::from_f64(value)
78                        .map_or(serde_json::Value::Null, serde_json::Value::Number),
79                ))
80            }
81
82            fn visit_str<E>(self, value: &str) -> Result<JsonValueWithNoDuplicateKeys, E>
83            where
84                E: serde::de::Error,
85            {
86                self.visit_string(String::from(value))
87            }
88
89            fn visit_string<E>(self, value: String) -> Result<JsonValueWithNoDuplicateKeys, E> {
90                Ok(JsonValueWithNoDuplicateKeys(serde_json::Value::String(
91                    value,
92                )))
93            }
94
95            fn visit_none<E>(self) -> Result<JsonValueWithNoDuplicateKeys, E> {
96                Ok(JsonValueWithNoDuplicateKeys(serde_json::Value::Null))
97            }
98
99            fn visit_some<D>(
100                self,
101                deserializer: D,
102            ) -> Result<JsonValueWithNoDuplicateKeys, D::Error>
103            where
104                D: serde::Deserializer<'de>,
105            {
106                Deserialize::deserialize(deserializer)
107            }
108
109            fn visit_unit<E>(self) -> Result<JsonValueWithNoDuplicateKeys, E> {
110                Ok(JsonValueWithNoDuplicateKeys(serde_json::Value::Null))
111            }
112
113            fn visit_seq<A>(self, mut access: A) -> Result<JsonValueWithNoDuplicateKeys, A::Error>
114            where
115                A: SeqAccess<'de>,
116            {
117                let mut vec: Vec<serde_json::Value> = Vec::new();
118
119                while let Some(elem) = access.next_element::<JsonValueWithNoDuplicateKeys>()? {
120                    vec.push(elem.0);
121                }
122
123                Ok(JsonValueWithNoDuplicateKeys(serde_json::Value::Array(vec)))
124            }
125
126            fn visit_map<A>(self, mut access: A) -> Result<JsonValueWithNoDuplicateKeys, A::Error>
127            where
128                A: MapAccess<'de>,
129            {
130                let mut map: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
131
132                while let Some((k, v)) =
133                    access.next_entry::<String, JsonValueWithNoDuplicateKeys>()?
134                {
135                    match map.entry(k) {
136                        serde_json::map::Entry::Vacant(ventry) => {
137                            ventry.insert(v.0);
138                        }
139                        serde_json::map::Entry::Occupied(oentry) => {
140                            return Err(serde::de::Error::custom(format!(
141                                "the key `{}` occurs two or more times in the same JSON object",
142                                oentry.key()
143                            )));
144                        }
145                    }
146                }
147
148                Ok(JsonValueWithNoDuplicateKeys(serde_json::Value::Object(map)))
149            }
150        }
151
152        deserializer.deserialize_any(ValueVisitor)
153    }
154}
155
156impl std::str::FromStr for JsonValueWithNoDuplicateKeys {
157    type Err = serde_json::Error;
158    fn from_str(s: &str) -> Result<Self, Self::Err> {
159        serde_json::from_str(s)
160    }
161}
162
163impl From<serde_json::Value> for JsonValueWithNoDuplicateKeys {
164    fn from(value: serde_json::Value) -> Self {
165        // the `serde_json::Value` representation cannot represent duplicate keys, so we can just wrap.
166        // If there were any duplicate keys, they're already gone as a result of creating the `serde_json::Value`.
167        Self(value)
168    }
169}
170
171impl From<JsonValueWithNoDuplicateKeys> for serde_json::Value {
172    fn from(value: JsonValueWithNoDuplicateKeys) -> Self {
173        value.0
174    }
175}
176
177struct LinkedHashMapVisitor<K, V> {
178    marker: PhantomData<fn() -> LinkedHashMap<K, V>>,
179}
180
181impl<K, V> LinkedHashMapVisitor<K, V> {
182    fn new() -> Self {
183        LinkedHashMapVisitor {
184            marker: PhantomData,
185        }
186    }
187}
188
189impl<'de, K, V> Visitor<'de> for LinkedHashMapVisitor<K, V>
190where
191    K: serde::Deserialize<'de> + std::hash::Hash + Eq,
192    V: serde::Deserialize<'de>,
193{
194    type Value = LinkedHashMap<K, V>;
195
196    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
197        formatter.write_str("a linked hash map")
198    }
199
200    fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
201    where
202        M: MapAccess<'de>,
203    {
204        let mut map = LinkedHashMap::new();
205
206        while let Some((key, value)) = access.next_entry()? {
207            if map.contains_key(&key) {
208                return Err(serde::de::Error::custom(
209                    "invalid entry: found duplicate key",
210                ));
211            }
212            map.insert(key, value);
213        }
214
215        Ok(map)
216    }
217}
218
219pub(crate) fn deserialize_linked_hash_map_no_duplicates<'de, D, K, V>(
220    deserializer: D,
221) -> Result<LinkedHashMap<K, V>, D::Error>
222where
223    D: Deserializer<'de>,
224    K: serde::Deserialize<'de> + std::hash::Hash + Eq,
225    V: serde::Deserialize<'de>,
226{
227    deserializer.deserialize_map(LinkedHashMapVisitor::new())
228}