1use std::collections::{BTreeMap, HashSet};
4use std::sync::{Arc, Mutex, RwLock};
5
6use lazy_static::lazy_static;
7use serde::de::Error as _;
8
9use crate::appraisal::Appraisal;
10use crate::ear::Ear;
11use crate::error::Error;
12use crate::raw::{RawValue, RawValueKind};
13
14#[derive(Debug, Clone)]
15struct ExtensionEntry {
16 pub kind: RawValueKind,
17 pub value: RawValue,
18}
19
20impl ExtensionEntry {
21 pub fn new(kind: RawValueKind) -> ExtensionEntry {
22 ExtensionEntry {
23 kind,
24 value: RawValue::Null,
25 }
26 }
27}
28
29#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
30enum CollectedKey {
31 Key(i32),
32 Name(String),
33}
34
35#[derive(Debug)]
36pub struct Extensions {
37 by_key: BTreeMap<i32, Arc<RwLock<ExtensionEntry>>>,
38 by_name: BTreeMap<String, Arc<RwLock<ExtensionEntry>>>,
39 collected: BTreeMap<CollectedKey, RawValue>,
40}
41
42impl Default for Extensions {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl<'de> Extensions {
49 pub fn new() -> Extensions {
50 Extensions {
51 by_key: BTreeMap::new(),
52 by_name: BTreeMap::new(),
53 collected: BTreeMap::new(),
54 }
55 }
56
57 pub fn register(&mut self, name: &str, key: i32, kind: RawValueKind) -> Result<(), Error> {
58 if self.by_name.contains_key(name) {
59 return Err(Error::ExtensionError(
60 format!("name {name} already registered").to_string(),
61 ));
62 }
63
64 if self.by_key.contains_key(&key) {
65 return Err(Error::ExtensionError(
66 format!("key {key} already registered").to_string(),
67 ));
68 }
69
70 let entry = Arc::new(RwLock::new(ExtensionEntry::new(kind)));
71
72 let collected = self
81 .collected
82 .get(&CollectedKey::Key(key))
83 .or(self.collected.get(&CollectedKey::Name(name.to_string())));
84 match collected {
85 Some(v) => {
86 let entry_kind = &entry.read().unwrap().kind.clone();
87
88 if v.is(entry_kind) {
89 entry.write().unwrap().value = v.clone();
90 Ok(())
91 } else if v.can_convert(entry_kind) {
92 entry.write().unwrap().value = v.convert(entry_kind)?;
93 Ok(())
94 } else {
95 Err(Error::ExtensionError(
96 format!(
97 "kind mismatch: value is {vk:?}, but want {ek:?}",
98 vk = v.kind(),
99 ek = entry.read().unwrap().kind
100 )
101 .to_string(),
102 ))
103 }
104 }
105 None => Ok(()),
106 }?;
107
108 self.by_key.insert(key, Arc::clone(&entry));
109 self.by_name.insert(name.to_string(), Arc::clone(&entry));
110
111 Ok(())
112 }
113
114 pub fn have_key(&self, key: &i32) -> bool {
115 self.by_key.contains_key(key)
116 }
117
118 pub fn have_name(&self, name: &str) -> bool {
119 self.by_name.contains_key(name)
120 }
121
122 pub fn get_by_key(&self, key: &i32) -> Option<RawValue> {
123 self.by_key
124 .get(key)
125 .map(|entry| entry.read().unwrap().value.clone())
126 }
127
128 pub fn get_by_name(&self, name: &str) -> Option<RawValue> {
129 self.by_name
130 .get(name)
131 .map(|entry| entry.read().unwrap().value.clone())
132 }
133
134 pub fn get_kind_by_key(&self, key: &i32) -> RawValueKind {
135 match self.by_key.get(key) {
136 Some(entry) => entry.read().unwrap().kind.clone(),
137 None => RawValueKind::Null,
138 }
139 }
140
141 pub fn get_kind_by_name(&self, name: &str) -> RawValueKind {
142 match self.by_name.get(name) {
143 Some(entry) => entry.read().unwrap().kind.clone(),
144 None => RawValueKind::Null,
145 }
146 }
147
148 pub fn set_by_key(&mut self, key: i32, value: RawValue) -> Result<(), Error> {
149 let entry = self.by_key.get(&key).ok_or(Error::ExtensionError(
150 format!("{key} not registered").to_string(),
151 ))?;
152
153 if !value.is(&entry.read().unwrap().kind) {
154 return Err(Error::ExtensionError(format!(
155 "kind mismatch: value is {vk:?}, but want {ek:?}",
156 vk = value.kind(),
157 ek = entry.read().unwrap().kind
158 )));
159 }
160
161 entry.write().unwrap().value = value;
162
163 Ok(())
164 }
165
166 pub fn set_by_name(&mut self, name: &str, value: RawValue) -> Result<(), Error> {
167 let entry = self.by_name.get_mut(name).ok_or(Error::ExtensionError(
168 format!("{name} not registered").to_string(),
169 ))?;
170
171 if !value.is(&entry.read().unwrap().kind) {
172 return Err(Error::ExtensionError(format!(
173 "kind mismatch: value is {vk:?}, but want {ek:?}",
174 vk = value.kind(),
175 ek = entry.read().unwrap().kind
176 )));
177 }
178
179 entry.write().unwrap().value = value;
180
181 Ok(())
182 }
183
184 pub(crate) fn visit_map_entry_by_name<A>(
185 &mut self,
186 name: &str,
187 mut map: A,
188 ) -> Result<(), A::Error>
189 where
190 A: serde::de::MapAccess<'de>,
191 {
192 if !self.have_name(name) {
193 self.collected.insert(
194 CollectedKey::Name(name.to_string()),
195 map.next_value::<RawValue>()?,
196 );
197 return Ok(());
198 }
199
200 let value = map.next_value::<RawValue>()?;
201
202 self.set_by_name(name, value).map_err(A::Error::custom)?;
203
204 Ok(())
205 }
206
207 pub(crate) fn visit_map_entry_by_key<A>(&mut self, key: i32, mut map: A) -> Result<(), A::Error>
208 where
209 A: serde::de::MapAccess<'de>,
210 {
211 if !self.have_key(&key) {
212 self.collected
213 .insert(CollectedKey::Key(key), map.next_value::<RawValue>()?);
214 return Ok(());
215 }
216
217 let value = map.next_value::<RawValue>()?;
218
219 self.set_by_key(key, value).map_err(A::Error::custom)?;
220
221 Ok(())
222 }
223
224 pub(crate) fn serialize_to_map_by_name<M>(&self, map: &mut M) -> Result<(), M::Error>
225 where
226 M: serde::ser::SerializeMap,
227 {
228 for (name, val) in &self.by_name {
229 if val.read().unwrap().value.is(&RawValueKind::Null) {
230 continue;
231 }
232
233 map.serialize_entry(&name, &val.read().unwrap().value)?;
234 }
235
236 Ok(())
237 }
238
239 pub(crate) fn serialize_to_map_by_key<M>(&self, map: &mut M) -> Result<(), M::Error>
240 where
241 M: serde::ser::SerializeMap,
242 {
243 for (key, val) in &self.by_key {
244 if val.read().unwrap().value.is(&RawValueKind::Null) {
245 continue;
246 }
247
248 map.serialize_entry(&key, &val.read().unwrap().value)?;
249 }
250
251 Ok(())
252 }
253}
254
255impl PartialEq for Extensions {
256 fn eq(&self, other: &Self) -> bool {
257 for (name, val) in &self.by_name {
258 match other.get_by_name(name) {
259 Some(other_val) => {
260 if val.read().unwrap().value != other_val {
261 return false;
262 }
263 }
264 None => return false,
265 }
266 }
267
268 for (key, val) in &self.by_key {
269 match other.get_by_key(key) {
270 Some(other_val) => {
271 if val.read().unwrap().value != other_val {
272 return false;
273 }
274 }
275 None => return false,
276 }
277 }
278
279 true
280 }
281}
282
283#[derive(Debug, Clone)]
284struct RegisterEntry {
285 pub name: String,
286 pub key: i32,
287 pub kind: RawValueKind,
288}
289
290#[derive(Debug, Clone)]
291struct Register {
292 pub entries: Vec<RegisterEntry>,
293 names: HashSet<String>,
294 keys: HashSet<i32>,
295}
296
297impl Register {
298 pub fn new() -> Self {
299 Register {
300 entries: Vec::new(),
301 names: HashSet::new(),
302 keys: HashSet::new(),
303 }
304 }
305
306 pub fn register(&mut self, name: &str, key: i32, kind: RawValueKind) -> Result<(), Error> {
307 match self.names.get(name) {
308 Some(_) => Err(Error::ExtensionError(
309 format!("name {name} already registered").to_string(),
310 )),
311 None => Ok(()),
312 }?;
313
314 match self.keys.get(&key) {
315 Some(_) => Err(Error::ExtensionError(
316 format!("key {key} already registered").to_string(),
317 )),
318 None => Ok(()),
319 }?;
320
321 self.entries.push(RegisterEntry {
322 name: name.to_string(),
323 key,
324 kind,
325 });
326
327 Ok(())
328 }
329}
330
331impl IntoIterator for Register {
332 type Item = RegisterEntry;
333 type IntoIter = <Vec<RegisterEntry> as IntoIterator>::IntoIter;
334
335 fn into_iter(self) -> Self::IntoIter {
336 self.entries.into_iter()
337 }
338}
339
340#[derive(Debug, Clone)]
341pub struct Profile {
342 id: String,
343 ear: Register,
344 appraisal: Register,
345}
346
347impl Profile {
348 pub fn new(id: &str) -> Self {
349 Profile {
350 id: id.to_string(),
351 ear: Register::new(),
352 appraisal: Register::new(),
353 }
354 }
355
356 pub fn register_ear_extension(
357 &mut self,
358 name: &str,
359 key: i32,
360 kind: RawValueKind,
361 ) -> Result<(), Error> {
362 self.ear.register(name, key, kind)
363 }
364
365 pub fn register_appraisal_extension(
366 &mut self,
367 name: &str,
368 key: i32,
369 kind: RawValueKind,
370 ) -> Result<(), Error> {
371 self.appraisal.register(name, key, kind)
372 }
373
374 pub fn populate_ear_extensions(&self, ear: &mut Ear) -> Result<(), Error> {
375 if self.id != ear.profile {
376 return Err(Error::ProfileError(format!(
377 "ID mismatch: wanted {wid}, but got {gid}",
378 wid = self.id,
379 gid = ear.profile,
380 )));
381 }
382
383 for entry in self.ear.clone() {
384 ear.extensions
385 .register(&entry.name, entry.key, entry.kind)?
386 }
387
388 for (_, appraisal) in ear.submods.iter_mut() {
389 for entry in self.appraisal.clone() {
390 appraisal
391 .extensions
392 .register(&entry.name, entry.key, entry.kind)?
393 }
394 }
395
396 Ok(())
397 }
398
399 pub fn populate_appraisal_extensions(&self, appraisal: &mut Appraisal) -> Result<(), Error> {
400 for entry in self.appraisal.clone() {
401 appraisal
402 .extensions
403 .register(&entry.name, entry.key, entry.kind)?
404 }
405
406 Ok(())
407 }
408}
409
410lazy_static! {
411 static ref PROFILE_REGISTER: Mutex<BTreeMap<String, Profile>> = Mutex::new(BTreeMap::new());
412}
413
414pub fn register_profile(profile: &Profile) -> Result<(), Error> {
415 let mut register = PROFILE_REGISTER.lock().unwrap();
416
417 match register.get(&profile.id) {
418 Some(_) => Err(Error::ProfileError(format!(
419 "{id} already registered",
420 id = profile.id
421 ))),
422 None => {
423 register.insert(profile.id.clone(), profile.clone());
424 Ok(())
425 }
426 }?;
427
428 Ok(())
429}
430
431pub fn get_profile(id: &str) -> Option<Profile> {
432 let register = PROFILE_REGISTER.lock().unwrap();
433 register.get(id).cloned()
434}
435
436#[cfg(test)]
437mod test {
438 use super::*;
439 use crate::base64::Bytes;
440 use crate::error::Error;
441
442 use std::str;
443 use std::thread;
444
445 use serde::ser::SerializeMap;
446 use serde::ser::Serializer;
447
448 #[test]
449 fn crud() {
450 let mut exts = Extensions::new();
451 exts.register("foo", 1, RawValueKind::String).unwrap();
452
453 let res = exts.register("foo", 2, RawValueKind::String);
454 assert!(matches!(res, Err(Error::ExtensionError(t))
455 if t == "name foo already registered"));
456
457 let res = exts.register("bad", 1, RawValueKind::String);
458 assert!(matches!(res, Err(Error::ExtensionError(t))
459 if t == "key 1 already registered"));
460
461 assert_eq!(exts.get_kind_by_key(&1), RawValueKind::String);
462 assert_eq!(exts.get_kind_by_name("foo"), RawValueKind::String);
463
464 assert!(exts.have_name("foo"));
465 assert!(exts.have_key(&1));
466 assert!(!exts.have_name("bad"));
467 assert!(!exts.have_key(&-1));
468
469 exts.set_by_key(1, RawValue::String("bar".to_string()))
470 .unwrap();
471 match exts.get_by_name("foo").unwrap() {
472 RawValue::String(s) => assert_eq!(s, "bar"),
473 v => panic!("unexpected value: {v:?}"),
474 }
475
476 exts.set_by_name("foo", RawValue::String("buzz".to_string()))
477 .unwrap();
478 match exts.get_by_key(&1).unwrap() {
479 RawValue::String(s) => assert_eq!(s, "buzz"),
480 v => panic!("unexpected value: {v:?}"),
481 }
482
483 let res = exts.set_by_name("bad", RawValue::String("bar".to_string()));
484 assert!(matches!(res, Err(Error::ExtensionError(t)) if t == "bad not registered"));
485
486 let res = exts.set_by_key(-1, RawValue::String("bar".to_string()));
487 assert!(matches!(res, Err(Error::ExtensionError(t)) if t == "-1 not registered"));
488
489 let res = exts.set_by_name("foo", RawValue::Integer(42));
490 assert!(matches!(res, Err(Error::ExtensionError(t))
491 if t == "kind mismatch: value is Integer, but want String"));
492
493 let res = exts.set_by_key(1, RawValue::Bool(true));
494 assert!(matches!(res, Err(Error::ExtensionError(t))
495 if t == "kind mismatch: value is Bool, but want String"));
496 }
497
498 #[test]
499 fn serde() {
500 let mut exts = Extensions::new();
501 exts.register("foo", 1, RawValueKind::String).unwrap();
502 exts.set_by_name("foo", RawValue::String("bar".to_string()))
503 .unwrap();
504
505 let mut v = Vec::new();
506 let mut s = serde_json::Serializer::new(&mut v);
507 let mut map = s.serialize_map(None).unwrap();
508
509 exts.serialize_to_map_by_name(&mut map).unwrap();
510
511 map.end().unwrap();
512
513 let out = str::from_utf8(&v).unwrap();
514 assert_eq!(out, r#"{"foo":"bar"}"#);
515 }
516
517 #[test]
518 fn value_convert() {
519 let v = RawValue::String("3q2-7w".to_string());
520 let res = v.convert(&RawValueKind::Bytes).unwrap();
521
522 if let RawValue::Bytes(bs) = res {
523 let expected: [u8; 4] = [0xde, 0xad, 0xbe, 0xef];
524 assert_eq!(bs, Bytes::from(&expected[..]));
525 } else {
526 panic!("wrong variant: {res:?}");
527 }
528 }
529
530 #[test]
531 fn test_send() {
532 let mut exts = Extensions::new();
533 exts.register("foo", 1, RawValueKind::String).unwrap();
534 exts.set_by_name("foo", RawValue::String("test".to_string()))
535 .unwrap();
536
537 let handle = thread::spawn(move || {
538 let val = match exts.get_by_name("foo").unwrap() {
539 RawValue::String(v) => v,
540 _ => panic!(),
541 };
542
543 assert_eq!(&val, "test");
544 });
545
546 handle.join().unwrap();
547 }
548}