ear/
extension.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use std::collections::{BTreeMap, HashSet};
4use std::sync::{Arc, Mutex, RwLock};
5
6use lazy_static::lazy_static;
7use serde::de::Error as _;
8
9use crate::appraisal::Appraisal;
10use crate::ear::Ear;
11use crate::error::Error;
12use crate::raw::{RawValue, RawValueKind};
13
14#[derive(Debug, Clone)]
15struct ExtensionEntry {
16    pub kind: RawValueKind,
17    pub value: RawValue,
18}
19
20impl ExtensionEntry {
21    pub fn new(kind: RawValueKind) -> ExtensionEntry {
22        ExtensionEntry {
23            kind,
24            value: RawValue::Null,
25        }
26    }
27}
28
29#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
30enum CollectedKey {
31    Key(i32),
32    Name(String),
33}
34
35#[derive(Debug)]
36pub struct Extensions {
37    by_key: BTreeMap<i32, Arc<RwLock<ExtensionEntry>>>,
38    by_name: BTreeMap<String, Arc<RwLock<ExtensionEntry>>>,
39    collected: BTreeMap<CollectedKey, RawValue>,
40}
41
42impl Default for Extensions {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl<'de> Extensions {
49    pub fn new() -> Extensions {
50        Extensions {
51            by_key: BTreeMap::new(),
52            by_name: BTreeMap::new(),
53            collected: BTreeMap::new(),
54        }
55    }
56
57    pub fn register(&mut self, name: &str, key: i32, kind: RawValueKind) -> Result<(), Error> {
58        if self.by_name.contains_key(name) {
59            return Err(Error::ExtensionError(
60                format!("name {name} already registered").to_string(),
61            ));
62        }
63
64        if self.by_key.contains_key(&key) {
65            return Err(Error::ExtensionError(
66                format!("key {key} already registered").to_string(),
67            ));
68        }
69
70        let entry = Arc::new(RwLock::new(ExtensionEntry::new(kind)));
71
72        // Check whether any of the values we previously collected match the key or name for
73        // this entry. If so, add the value to the entry, ensuring it is the right kind.
74        // Note: while it is theoretically possible for the collected HashMap to contain both,
75        // the key and the name, in practice that won't happen because:
76        // - collection only happens during deserialization
77        // - a new Extensions is created as part of each deserialization
78        // - depending on deserializer.is_human_reaadable, we'd be dealing only with keys or only
79        //   with names
80        let collected = self
81            .collected
82            .get(&CollectedKey::Key(key))
83            .or(self.collected.get(&CollectedKey::Name(name.to_string())));
84        match collected {
85            Some(v) => {
86                let entry_kind = &entry.read().unwrap().kind.clone();
87
88                if v.is(entry_kind) {
89                    entry.write().unwrap().value = v.clone();
90                    Ok(())
91                } else if v.can_convert(entry_kind) {
92                    entry.write().unwrap().value = v.convert(entry_kind)?;
93                    Ok(())
94                } else {
95                    Err(Error::ExtensionError(
96                        format!(
97                            "kind mismatch: value is {vk:?}, but want {ek:?}",
98                            vk = v.kind(),
99                            ek = entry.read().unwrap().kind
100                        )
101                        .to_string(),
102                    ))
103                }
104            }
105            None => Ok(()),
106        }?;
107
108        self.by_key.insert(key, Arc::clone(&entry));
109        self.by_name.insert(name.to_string(), Arc::clone(&entry));
110
111        Ok(())
112    }
113
114    pub fn have_key(&self, key: &i32) -> bool {
115        self.by_key.contains_key(key)
116    }
117
118    pub fn have_name(&self, name: &str) -> bool {
119        self.by_name.contains_key(name)
120    }
121
122    pub fn get_by_key(&self, key: &i32) -> Option<RawValue> {
123        self.by_key
124            .get(key)
125            .map(|entry| entry.read().unwrap().value.clone())
126    }
127
128    pub fn get_by_name(&self, name: &str) -> Option<RawValue> {
129        self.by_name
130            .get(name)
131            .map(|entry| entry.read().unwrap().value.clone())
132    }
133
134    pub fn get_kind_by_key(&self, key: &i32) -> RawValueKind {
135        match self.by_key.get(key) {
136            Some(entry) => entry.read().unwrap().kind.clone(),
137            None => RawValueKind::Null,
138        }
139    }
140
141    pub fn get_kind_by_name(&self, name: &str) -> RawValueKind {
142        match self.by_name.get(name) {
143            Some(entry) => entry.read().unwrap().kind.clone(),
144            None => RawValueKind::Null,
145        }
146    }
147
148    pub fn set_by_key(&mut self, key: i32, value: RawValue) -> Result<(), Error> {
149        let entry = self.by_key.get(&key).ok_or(Error::ExtensionError(
150            format!("{key} not registered").to_string(),
151        ))?;
152
153        if !value.is(&entry.read().unwrap().kind) {
154            return Err(Error::ExtensionError(format!(
155                "kind mismatch: value is {vk:?}, but want {ek:?}",
156                vk = value.kind(),
157                ek = entry.read().unwrap().kind
158            )));
159        }
160
161        entry.write().unwrap().value = value;
162
163        Ok(())
164    }
165
166    pub fn set_by_name(&mut self, name: &str, value: RawValue) -> Result<(), Error> {
167        let entry = self.by_name.get_mut(name).ok_or(Error::ExtensionError(
168            format!("{name} not registered").to_string(),
169        ))?;
170
171        if !value.is(&entry.read().unwrap().kind) {
172            return Err(Error::ExtensionError(format!(
173                "kind mismatch: value is {vk:?}, but want {ek:?}",
174                vk = value.kind(),
175                ek = entry.read().unwrap().kind
176            )));
177        }
178
179        entry.write().unwrap().value = value;
180
181        Ok(())
182    }
183
184    pub(crate) fn visit_map_entry_by_name<A>(
185        &mut self,
186        name: &str,
187        mut map: A,
188    ) -> Result<(), A::Error>
189    where
190        A: serde::de::MapAccess<'de>,
191    {
192        if !self.have_name(name) {
193            self.collected.insert(
194                CollectedKey::Name(name.to_string()),
195                map.next_value::<RawValue>()?,
196            );
197            return Ok(());
198        }
199
200        let value = map.next_value::<RawValue>()?;
201
202        self.set_by_name(name, value).map_err(A::Error::custom)?;
203
204        Ok(())
205    }
206
207    pub(crate) fn visit_map_entry_by_key<A>(&mut self, key: i32, mut map: A) -> Result<(), A::Error>
208    where
209        A: serde::de::MapAccess<'de>,
210    {
211        if !self.have_key(&key) {
212            self.collected
213                .insert(CollectedKey::Key(key), map.next_value::<RawValue>()?);
214            return Ok(());
215        }
216
217        let value = map.next_value::<RawValue>()?;
218
219        self.set_by_key(key, value).map_err(A::Error::custom)?;
220
221        Ok(())
222    }
223
224    pub(crate) fn serialize_to_map_by_name<M>(&self, map: &mut M) -> Result<(), M::Error>
225    where
226        M: serde::ser::SerializeMap,
227    {
228        for (name, val) in &self.by_name {
229            if val.read().unwrap().value.is(&RawValueKind::Null) {
230                continue;
231            }
232
233            map.serialize_entry(&name, &val.read().unwrap().value)?;
234        }
235
236        Ok(())
237    }
238
239    pub(crate) fn serialize_to_map_by_key<M>(&self, map: &mut M) -> Result<(), M::Error>
240    where
241        M: serde::ser::SerializeMap,
242    {
243        for (key, val) in &self.by_key {
244            if val.read().unwrap().value.is(&RawValueKind::Null) {
245                continue;
246            }
247
248            map.serialize_entry(&key, &val.read().unwrap().value)?;
249        }
250
251        Ok(())
252    }
253}
254
255impl PartialEq for Extensions {
256    fn eq(&self, other: &Self) -> bool {
257        for (name, val) in &self.by_name {
258            match other.get_by_name(name) {
259                Some(other_val) => {
260                    if val.read().unwrap().value != other_val {
261                        return false;
262                    }
263                }
264                None => return false,
265            }
266        }
267
268        for (key, val) in &self.by_key {
269            match other.get_by_key(key) {
270                Some(other_val) => {
271                    if val.read().unwrap().value != other_val {
272                        return false;
273                    }
274                }
275                None => return false,
276            }
277        }
278
279        true
280    }
281}
282
283#[derive(Debug, Clone)]
284struct RegisterEntry {
285    pub name: String,
286    pub key: i32,
287    pub kind: RawValueKind,
288}
289
290#[derive(Debug, Clone)]
291struct Register {
292    pub entries: Vec<RegisterEntry>,
293    names: HashSet<String>,
294    keys: HashSet<i32>,
295}
296
297impl Register {
298    pub fn new() -> Self {
299        Register {
300            entries: Vec::new(),
301            names: HashSet::new(),
302            keys: HashSet::new(),
303        }
304    }
305
306    pub fn register(&mut self, name: &str, key: i32, kind: RawValueKind) -> Result<(), Error> {
307        match self.names.get(name) {
308            Some(_) => Err(Error::ExtensionError(
309                format!("name {name} already registered").to_string(),
310            )),
311            None => Ok(()),
312        }?;
313
314        match self.keys.get(&key) {
315            Some(_) => Err(Error::ExtensionError(
316                format!("key {key} already registered").to_string(),
317            )),
318            None => Ok(()),
319        }?;
320
321        self.entries.push(RegisterEntry {
322            name: name.to_string(),
323            key,
324            kind,
325        });
326
327        Ok(())
328    }
329}
330
331impl IntoIterator for Register {
332    type Item = RegisterEntry;
333    type IntoIter = <Vec<RegisterEntry> as IntoIterator>::IntoIter;
334
335    fn into_iter(self) -> Self::IntoIter {
336        self.entries.into_iter()
337    }
338}
339
340#[derive(Debug, Clone)]
341pub struct Profile {
342    id: String,
343    ear: Register,
344    appraisal: Register,
345}
346
347impl Profile {
348    pub fn new(id: &str) -> Self {
349        Profile {
350            id: id.to_string(),
351            ear: Register::new(),
352            appraisal: Register::new(),
353        }
354    }
355
356    pub fn register_ear_extension(
357        &mut self,
358        name: &str,
359        key: i32,
360        kind: RawValueKind,
361    ) -> Result<(), Error> {
362        self.ear.register(name, key, kind)
363    }
364
365    pub fn register_appraisal_extension(
366        &mut self,
367        name: &str,
368        key: i32,
369        kind: RawValueKind,
370    ) -> Result<(), Error> {
371        self.appraisal.register(name, key, kind)
372    }
373
374    pub fn populate_ear_extensions(&self, ear: &mut Ear) -> Result<(), Error> {
375        if self.id != ear.profile {
376            return Err(Error::ProfileError(format!(
377                "ID mismatch: wanted {wid}, but got {gid}",
378                wid = self.id,
379                gid = ear.profile,
380            )));
381        }
382
383        for entry in self.ear.clone() {
384            ear.extensions
385                .register(&entry.name, entry.key, entry.kind)?
386        }
387
388        for (_, appraisal) in ear.submods.iter_mut() {
389            for entry in self.appraisal.clone() {
390                appraisal
391                    .extensions
392                    .register(&entry.name, entry.key, entry.kind)?
393            }
394        }
395
396        Ok(())
397    }
398
399    pub fn populate_appraisal_extensions(&self, appraisal: &mut Appraisal) -> Result<(), Error> {
400        for entry in self.appraisal.clone() {
401            appraisal
402                .extensions
403                .register(&entry.name, entry.key, entry.kind)?
404        }
405
406        Ok(())
407    }
408}
409
410lazy_static! {
411    static ref PROFILE_REGISTER: Mutex<BTreeMap<String, Profile>> = Mutex::new(BTreeMap::new());
412}
413
414pub fn register_profile(profile: &Profile) -> Result<(), Error> {
415    let mut register = PROFILE_REGISTER.lock().unwrap();
416
417    match register.get(&profile.id) {
418        Some(_) => Err(Error::ProfileError(format!(
419            "{id} already registered",
420            id = profile.id
421        ))),
422        None => {
423            register.insert(profile.id.clone(), profile.clone());
424            Ok(())
425        }
426    }?;
427
428    Ok(())
429}
430
431pub fn get_profile(id: &str) -> Option<Profile> {
432    let register = PROFILE_REGISTER.lock().unwrap();
433    register.get(id).cloned()
434}
435
436#[cfg(test)]
437mod test {
438    use super::*;
439    use crate::base64::Bytes;
440    use crate::error::Error;
441
442    use std::str;
443    use std::thread;
444
445    use serde::ser::SerializeMap;
446    use serde::ser::Serializer;
447
448    #[test]
449    fn crud() {
450        let mut exts = Extensions::new();
451        exts.register("foo", 1, RawValueKind::String).unwrap();
452
453        let res = exts.register("foo", 2, RawValueKind::String);
454        assert!(matches!(res, Err(Error::ExtensionError(t))
455                if t == "name foo already registered"));
456
457        let res = exts.register("bad", 1, RawValueKind::String);
458        assert!(matches!(res, Err(Error::ExtensionError(t))
459                if t == "key 1 already registered"));
460
461        assert_eq!(exts.get_kind_by_key(&1), RawValueKind::String);
462        assert_eq!(exts.get_kind_by_name("foo"), RawValueKind::String);
463
464        assert!(exts.have_name("foo"));
465        assert!(exts.have_key(&1));
466        assert!(!exts.have_name("bad"));
467        assert!(!exts.have_key(&-1));
468
469        exts.set_by_key(1, RawValue::String("bar".to_string()))
470            .unwrap();
471        match exts.get_by_name("foo").unwrap() {
472            RawValue::String(s) => assert_eq!(s, "bar"),
473            v => panic!("unexpected value: {v:?}"),
474        }
475
476        exts.set_by_name("foo", RawValue::String("buzz".to_string()))
477            .unwrap();
478        match exts.get_by_key(&1).unwrap() {
479            RawValue::String(s) => assert_eq!(s, "buzz"),
480            v => panic!("unexpected value: {v:?}"),
481        }
482
483        let res = exts.set_by_name("bad", RawValue::String("bar".to_string()));
484        assert!(matches!(res, Err(Error::ExtensionError(t)) if t == "bad not registered"));
485
486        let res = exts.set_by_key(-1, RawValue::String("bar".to_string()));
487        assert!(matches!(res, Err(Error::ExtensionError(t)) if t == "-1 not registered"));
488
489        let res = exts.set_by_name("foo", RawValue::Integer(42));
490        assert!(matches!(res, Err(Error::ExtensionError(t))
491                if t == "kind mismatch: value is Integer, but want String"));
492
493        let res = exts.set_by_key(1, RawValue::Bool(true));
494        assert!(matches!(res, Err(Error::ExtensionError(t))
495                if t == "kind mismatch: value is Bool, but want String"));
496    }
497
498    #[test]
499    fn serde() {
500        let mut exts = Extensions::new();
501        exts.register("foo", 1, RawValueKind::String).unwrap();
502        exts.set_by_name("foo", RawValue::String("bar".to_string()))
503            .unwrap();
504
505        let mut v = Vec::new();
506        let mut s = serde_json::Serializer::new(&mut v);
507        let mut map = s.serialize_map(None).unwrap();
508
509        exts.serialize_to_map_by_name(&mut map).unwrap();
510
511        map.end().unwrap();
512
513        let out = str::from_utf8(&v).unwrap();
514        assert_eq!(out, r#"{"foo":"bar"}"#);
515    }
516
517    #[test]
518    fn value_convert() {
519        let v = RawValue::String("3q2-7w".to_string());
520        let res = v.convert(&RawValueKind::Bytes).unwrap();
521
522        if let RawValue::Bytes(bs) = res {
523            let expected: [u8; 4] = [0xde, 0xad, 0xbe, 0xef];
524            assert_eq!(bs, Bytes::from(&expected[..]));
525        } else {
526            panic!("wrong variant: {res:?}");
527        }
528    }
529
530    #[test]
531    fn test_send() {
532        let mut exts = Extensions::new();
533        exts.register("foo", 1, RawValueKind::String).unwrap();
534        exts.set_by_name("foo", RawValue::String("test".to_string()))
535            .unwrap();
536
537        let handle = thread::spawn(move || {
538            let val = match exts.get_by_name("foo").unwrap() {
539                RawValue::String(v) => v,
540                _ => panic!(),
541            };
542
543            assert_eq!(&val, "test");
544        });
545
546        handle.join().unwrap();
547    }
548}