1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::{Arc, Mutex, OnceLock};
4
5use crate::error::ContextError;
6use crate::value::ContextValue;
7
8type DeserializeFn =
10 Box<dyn Fn(&[u8]) -> Result<Box<dyn ContextValue>, ContextError> + Send + Sync>;
11
12type SerializeFn = Arc<dyn Fn(&dyn ContextValue) -> Result<Vec<u8>, ContextError> + Send + Sync>;
14
15pub(crate) type RegistryMap = HashMap<&'static str, Registration>;
16
17pub(crate) struct Registration {
19 pub key: &'static str,
20 pub type_id: TypeId,
21 pub key_version: u32,
23 pub deserializers: HashMap<u32, DeserializeFn>,
25 pub type_name: &'static str,
26 pub local_only: bool,
28 pub serialize_fn: Option<SerializeFn>,
30 pub cached: bool,
35 pub metadata: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
38}
39
40static FROZEN: OnceLock<RegistryMap> = OnceLock::new();
51
52static BUILD: std::sync::LazyLock<Mutex<Option<RegistryMap>>> =
55 std::sync::LazyLock::new(|| Mutex::new(Some(HashMap::new())));
56static EMPTY_MAP: std::sync::LazyLock<RegistryMap> = std::sync::LazyLock::new(HashMap::new);
57
58fn lock_build() -> std::sync::MutexGuard<'static, Option<RegistryMap>> {
59 BUILD
60 .lock()
61 .unwrap_or_else(|poisoned| poisoned.into_inner())
62}
63
64pub(crate) struct Registry<'a> {
66 map: &'a RegistryMap,
67}
68
69impl<'a> Registry<'a> {
70 pub(crate) fn new(map: &'a RegistryMap) -> Self {
71 Self { map }
72 }
73
74 pub(crate) fn empty() -> Registry<'static> {
75 Registry { map: &EMPTY_MAP }
76 }
77
78 pub(crate) fn with_registration<R>(
79 &self,
80 key: &str,
81 f: impl FnOnce(&Registration) -> R,
82 ) -> Option<R> {
83 self.map.get(key).map(f)
84 }
85
86 pub(crate) fn get_serialization_info(&self, key: &str) -> Option<SerializationInfo> {
87 self.map.get(key).map(|r| SerializationInfo {
88 key_version: r.key_version,
89 serialize_fn: r.serialize_fn.clone(),
90 })
91 }
92
93 pub(crate) fn cached_keys(&self) -> Vec<&'static str> {
94 self.map
95 .iter()
96 .filter(|(_, r)| r.cached)
97 .map(|(&k, _)| k)
98 .collect()
99 }
100
101 pub(crate) fn is_local_key(&self, key: &str) -> bool {
102 self.map.get(key).is_some_and(|r| r.local_only)
103 }
104
105 pub(crate) fn is_valid_value(&self, key: &str, value: &dyn ContextValue) -> bool {
106 self.map
107 .get(key)
108 .is_some_and(|r| r.type_id == value.as_any().type_id())
109 }
110
111 pub(crate) fn with_metadata<M: 'static, R>(
112 &self,
113 key: &str,
114 f: impl FnOnce(&M) -> R,
115 ) -> Option<R> {
116 self.with_registration(key, |r| {
117 r.metadata
118 .get(&TypeId::of::<M>())
119 .and_then(|boxed| boxed.downcast_ref::<M>())
120 .map(f)
121 })
122 .flatten()
123 }
124
125 pub(crate) fn keys_with_metadata<M: 'static, R>(
126 &self,
127 f: impl Fn(&'static str, &M) -> R,
128 ) -> Vec<R> {
129 self.map
130 .iter()
131 .filter_map(|(&key, reg)| {
132 reg.metadata
133 .get(&TypeId::of::<M>())
134 .and_then(|boxed| boxed.downcast_ref::<M>())
135 .map(|meta| f(key, meta))
136 })
137 .collect()
138 }
139}
140
141pub(crate) fn with_global_registry<R>(f: impl FnOnce(&Registry<'_>) -> R) -> R {
144 if let Some(frozen) = FROZEN.get() {
145 return f(&Registry::new(frozen));
146 }
147
148 let guard = lock_build();
149 match guard.as_ref() {
150 Some(map) => f(&Registry::new(map)),
151 None => f(&Registry::empty()),
152 }
153}
154
155#[allow(clippy::type_complexity)]
173pub struct RegistrationOptions<T: 'static> {
174 version: u32,
175 local_only: bool,
176 cached: bool,
177 encode: Option<Box<dyn Fn(&T) -> Result<Vec<u8>, String> + Send + Sync>>,
178 decode: Option<Box<dyn Fn(&[u8]) -> Result<T, String> + Send + Sync>>,
179 metadata: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
180}
181
182impl<T: 'static> RegistrationOptions<T> {
183 fn new() -> Self {
184 Self {
185 version: 1,
186 local_only: false,
187 cached: false,
188 encode: None,
189 decode: None,
190 metadata: HashMap::new(),
191 }
192 }
193
194 pub fn version(mut self, v: u32) -> Self {
196 self.version = v;
197 self
198 }
199
200 pub fn local_only(mut self) -> Self {
203 self.local_only = true;
204 self
205 }
206
207 pub fn cached(mut self) -> Self {
212 self.cached = true;
213 self
214 }
215
216 pub fn codec(
219 mut self,
220 encode: impl Fn(&T) -> Result<Vec<u8>, String> + Send + Sync + 'static,
221 decode: impl Fn(&[u8]) -> Result<T, String> + Send + Sync + 'static,
222 ) -> Self {
223 self.encode = Some(Box::new(encode));
224 self.decode = Some(Box::new(decode));
225 self
226 }
227
228 pub fn with_metadata<M: Any + Send + Sync + 'static>(mut self, value: M) -> Self {
242 self.metadata.insert(TypeId::of::<M>(), Box::new(value));
243 self
244 }
245}
246
247fn do_register_with<T>(
252 registry: &mut RegistryMap,
253 key: &'static str,
254 configure: impl FnOnce(RegistrationOptions<T>) -> RegistrationOptions<T>,
255) -> Result<(), ContextError>
256where
257 T: Clone + Default + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
258{
259 let opts = configure(RegistrationOptions::new());
260
261 if opts.local_only {
262 if opts.encode.is_some() || opts.decode.is_some() {
263 return Err(ContextError::SerializationFailed(
264 "local_only and codec are mutually exclusive: \
265 local-only entries are excluded from serialization"
266 .into(),
267 ));
268 }
269 if opts.version != 1 {
270 return Err(ContextError::SerializationFailed(
271 "local_only and version are mutually exclusive: \
272 local-only entries have no wire format"
273 .into(),
274 ));
275 }
276 }
277
278 let tid = TypeId::of::<T>();
279
280 if let Some(existing) = registry.get(key) {
281 if existing.type_id == tid {
282 return Ok(()); }
284 return Err(ContextError::AlreadyRegistered(key.to_string()));
285 }
286
287 let mut deserializers: HashMap<u32, DeserializeFn> = HashMap::new();
288
289 if !opts.local_only {
290 if let Some(decode) = opts.decode {
291 deserializers.insert(
292 opts.version,
293 Box::new(
294 move |bytes: &[u8]| -> Result<Box<dyn ContextValue>, ContextError> {
295 decode(bytes)
296 .map(|v| Box::new(v) as Box<dyn ContextValue>)
297 .map_err(ContextError::DeserializationFailed)
298 },
299 ),
300 );
301 } else {
302 deserializers.insert(
303 opts.version,
304 Box::new(
305 |bytes: &[u8]| -> Result<Box<dyn ContextValue>, ContextError> {
306 bincode::deserialize::<T>(bytes)
307 .map(|v| Box::new(v) as Box<dyn ContextValue>)
308 .map_err(|e| ContextError::DeserializationFailed(e.to_string()))
309 },
310 ),
311 );
312 }
313 }
314
315 let serialize_fn = opts.encode.map(|encode| -> SerializeFn {
316 Arc::new(move |val: &dyn ContextValue| {
317 let typed = val.as_any().downcast_ref::<T>().ok_or_else(|| {
318 ContextError::SerializationFailed(
319 "type mismatch during custom serialization".into(),
320 )
321 })?;
322 encode(typed).map_err(ContextError::SerializationFailed)
323 })
324 });
325
326 registry.insert(
327 key,
328 Registration {
329 key,
330 type_id: tid,
331 key_version: opts.version,
332 deserializers,
333 type_name: std::any::type_name::<T>(),
334 local_only: opts.local_only,
335 serialize_fn,
336 cached: opts.cached,
337 metadata: opts.metadata,
338 },
339 );
340 Ok(())
341}
342
343fn do_register_migration<TOld, TCurrent>(
344 registry: &mut RegistryMap,
345 key: &'static str,
346 old_version: u32,
347 migrate: impl Fn(TOld) -> TCurrent + Send + Sync + 'static,
348) -> Result<(), ContextError>
349where
350 TOld: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
351 TCurrent: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
352{
353 let reg = registry
354 .get_mut(key)
355 .ok_or_else(|| ContextError::NotRegistered(key.to_string()))?;
356
357 if reg.type_id != TypeId::of::<TCurrent>() {
358 return Err(ContextError::TypeMismatch(
359 key.to_string(),
360 reg.type_name.to_string(),
361 std::any::type_name::<TCurrent>().to_string(),
362 ));
363 }
364
365 if reg.local_only {
366 return Err(ContextError::SerializationFailed(format!(
367 "cannot register migration for local-only key '{}'",
368 key
369 )));
370 }
371
372 if old_version == reg.key_version {
373 return Err(ContextError::DeserializationFailed(format!(
374 "cannot register migration for key '{}' at current version {} \
375 (would overwrite the native deserializer)",
376 key, old_version
377 )));
378 }
379
380 reg.deserializers.insert(
381 old_version,
382 Box::new(
383 move |bytes: &[u8]| -> Result<Box<dyn ContextValue>, ContextError> {
384 let old_val = bincode::deserialize::<TOld>(bytes)
385 .map_err(|e| ContextError::DeserializationFailed(e.to_string()))?;
386 let current_val = migrate(old_val);
387 Ok(Box::new(current_val) as Box<dyn ContextValue>)
388 },
389 ),
390 );
391
392 Ok(())
393}
394
395pub struct RegistryBuilder {
415 map: RegistryMap,
416}
417
418impl RegistryBuilder {
419 pub fn new() -> Self {
421 Self {
422 map: HashMap::new(),
423 }
424 }
425
426 pub fn register<T>(&mut self, key: &'static str)
433 where
434 T: Clone + Default + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
435 {
436 self.try_register::<T>(key)
437 .unwrap_or_else(|e| panic!("RegistryBuilder::register failed for key '{key}': {e}"));
438 }
439
440 pub fn try_register<T>(&mut self, key: &'static str) -> Result<(), ContextError>
442 where
443 T: Clone + Default + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
444 {
445 do_register_with::<T>(&mut self.map, key, |opts| opts)
446 }
447
448 pub fn register_with<T>(
455 &mut self,
456 key: &'static str,
457 configure: impl FnOnce(RegistrationOptions<T>) -> RegistrationOptions<T>,
458 ) where
459 T: Clone + Default + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
460 {
461 self.try_register_with::<T>(key, configure)
462 .unwrap_or_else(|e| {
463 panic!("RegistryBuilder::register_with failed for key '{key}': {e}")
464 });
465 }
466
467 pub fn try_register_with<T>(
469 &mut self,
470 key: &'static str,
471 configure: impl FnOnce(RegistrationOptions<T>) -> RegistrationOptions<T>,
472 ) -> Result<(), ContextError>
473 where
474 T: Clone + Default + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
475 {
476 do_register_with(&mut self.map, key, configure)
477 }
478
479 pub fn register_migration<TOld, TCurrent>(
487 &mut self,
488 key: &'static str,
489 old_version: u32,
490 migrate: impl Fn(TOld) -> TCurrent + Send + Sync + 'static,
491 ) where
492 TOld: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
493 TCurrent: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
494 {
495 self.try_register_migration::<TOld, TCurrent>(key, old_version, migrate)
496 .unwrap_or_else(|e| {
497 panic!("RegistryBuilder::register_migration failed for key '{key}': {e}")
498 });
499 }
500
501 pub fn try_register_migration<TOld, TCurrent>(
503 &mut self,
504 key: &'static str,
505 old_version: u32,
506 migrate: impl Fn(TOld) -> TCurrent + Send + Sync + 'static,
507 ) -> Result<(), ContextError>
508 where
509 TOld: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
510 TCurrent: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
511 {
512 do_register_migration(&mut self.map, key, old_version, migrate)
513 }
514}
515
516impl Default for RegistryBuilder {
517 fn default() -> Self {
518 Self::new()
519 }
520}
521
522#[cfg(test)]
523impl RegistryBuilder {
524 pub(crate) fn into_map(self) -> RegistryMap {
525 self.map
526 }
527}
528
529pub fn initialize(builder: RegistryBuilder) {
549 try_initialize(builder).expect("dcontext::initialize called more than once");
550}
551
552pub fn try_initialize(builder: RegistryBuilder) -> Result<(), ContextError> {
554 FROZEN
555 .set(builder.map)
556 .map_err(|_| ContextError::RegistryFrozen)
557}
558
559#[cfg(test)]
566pub(crate) fn try_register<T>(key: &'static str) -> Result<(), ContextError>
567where
568 T: Clone + Default + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
569{
570 try_register_with::<T>(key, |opts| opts)
571}
572
573#[cfg(test)]
574pub(crate) fn try_register_with<T>(
575 key: &'static str,
576 configure: impl FnOnce(RegistrationOptions<T>) -> RegistrationOptions<T>,
577) -> Result<(), ContextError>
578where
579 T: Clone + Default + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
580{
581 let mut guard = lock_build();
582 let registry = guard.as_mut().ok_or(ContextError::RegistryFrozen)?;
583 do_register_with(registry, key, configure)
584}
585
586#[cfg(test)]
587pub(crate) fn register<T>(key: &'static str)
588where
589 T: Clone + Default + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
590{
591 try_register::<T>(key).expect("dcontext::register failed");
592}
593
594#[cfg(test)]
595pub(crate) fn register_with<T>(
596 key: &'static str,
597 configure: impl FnOnce(RegistrationOptions<T>) -> RegistrationOptions<T>,
598) where
599 T: Clone + Default + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
600{
601 try_register_with::<T>(key, configure).expect("dcontext::register_with failed");
602}
603
604#[cfg(test)]
605pub(crate) fn try_register_migration<TOld, TCurrent>(
606 key: &'static str,
607 old_version: u32,
608 migrate: impl Fn(TOld) -> TCurrent + Send + Sync + 'static,
609) -> Result<(), ContextError>
610where
611 TOld: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
612 TCurrent: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
613{
614 let mut guard = lock_build();
615 let registry = guard.as_mut().ok_or(ContextError::RegistryFrozen)?;
616 do_register_migration(registry, key, old_version, migrate)
617}
618
619#[cfg(test)]
620pub(crate) fn register_migration<TOld, TCurrent>(
621 key: &'static str,
622 old_version: u32,
623 migrate: impl Fn(TOld) -> TCurrent + Send + Sync + 'static,
624) where
625 TOld: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
626 TCurrent: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
627{
628 try_register_migration::<TOld, TCurrent>(key, old_version, migrate)
629 .expect("dcontext::register_migration failed");
630}
631
632#[allow(dead_code)]
639pub(crate) fn with_registration<R>(key: &str, f: impl FnOnce(&Registration) -> R) -> Option<R> {
640 with_global_registry(|registry| registry.with_registration(key, f))
641}
642
643pub(crate) struct SerializationInfo {
645 pub key_version: u32,
646 pub serialize_fn: Option<SerializeFn>,
647}
648
649#[allow(dead_code)]
651pub(crate) fn get_serialization_info(key: &str) -> Option<SerializationInfo> {
652 with_global_registry(|registry| registry.get_serialization_info(key))
653}
654
655#[allow(dead_code)]
659pub(crate) fn cached_keys() -> Vec<&'static str> {
660 with_global_registry(|registry| registry.cached_keys())
661}
662
663#[allow(dead_code)]
665pub(crate) fn is_local_key(key: &str) -> bool {
666 with_global_registry(|registry| registry.is_local_key(key))
667}
668
669#[allow(dead_code)]
671pub(crate) fn is_valid_value(key: &str, value: &dyn ContextValue) -> bool {
672 with_global_registry(|registry| registry.is_valid_value(key, value))
673}
674
675pub fn with_metadata<M: 'static, R>(key: &str, f: impl FnOnce(&M) -> R) -> Option<R> {
682 with_global_registry(|registry| registry.with_metadata(key, f))
683}
684
685pub fn keys_with_metadata<M: 'static, R>(f: impl Fn(&'static str, &M) -> R) -> Vec<R> {
690 with_global_registry(|registry| registry.keys_with_metadata(f))
691}
692
693#[cfg(test)]
694pub(crate) fn is_registered(key: &str) -> bool {
695 if let Some(frozen) = FROZEN.get() {
696 return frozen.contains_key(key);
697 }
698 let guard = lock_build();
699 guard.as_ref().map_or(false, |map| map.contains_key(key))
700}
701
702#[cfg(test)]
703mod tests {
704 use super::*;
705 use serde::{Deserialize, Serialize};
706
707 #[derive(Clone, Default, Debug, Serialize, Deserialize)]
708 struct TestVal(String);
709
710 #[derive(Clone, Default, Debug, Serialize, Deserialize)]
711 struct OtherVal(u64);
712
713 fn unique_reg_key(name: &str) -> &'static str {
714 let s = format!("reg_test_{}", name);
715 Box::leak(s.into_boxed_str())
716 }
717
718 #[test]
719 fn register_and_lookup() {
720 let key = unique_reg_key("lookup");
721 try_register::<TestVal>(key).unwrap();
722 assert!(is_registered(key));
723 assert!(!is_registered("reg_test_missing_xxx"));
724 }
725
726 #[test]
727 fn idempotent_registration() {
728 let key = unique_reg_key("idem");
729 try_register::<TestVal>(key).unwrap();
730 try_register::<TestVal>(key).unwrap();
731 }
732
733 #[test]
734 fn conflicting_registration() {
735 let key = unique_reg_key("conflict");
736 try_register::<TestVal>(key).unwrap();
737 let err = try_register::<OtherVal>(key).unwrap_err();
738 assert!(matches!(err, ContextError::AlreadyRegistered(_)));
739 }
740
741 #[test]
742 fn registry_supports_injected_builder_map() {
743 let key = unique_reg_key("injected");
744 let mut builder = RegistryBuilder::new();
745 builder.register_with::<TestVal>(key, |opts| opts.cached().with_metadata(7usize));
746
747 let map = builder.into_map();
748 let registry = Registry::new(&map);
749
750 assert!(registry.with_registration(key, |_| true).unwrap_or(false));
751 assert_eq!(registry.cached_keys(), vec![key]);
752 assert!(registry.is_valid_value(key, &TestVal::default()));
753 assert_eq!(registry.with_metadata::<usize, _>(key, |n| *n), Some(7));
754 }
755}