pliron/
uniqued_any.rs

1//! Store, in [Context], a single unique copy of any object.
2
3use std::{any::Any, cell::Ref, hash::Hash, marker::PhantomData};
4
5use crate::{
6    context::{ArenaIndex, Context},
7    storage_uniquer::TypeValueHash,
8};
9
10/// [Box]ed [Any], used for unique storage.
11pub(crate) struct UniquedAny(Box<dyn Any>);
12
13/// A handle to the stored unique copy of an object.
14#[derive(PartialEq, Eq, Debug)]
15pub struct UniquedKey<T> {
16    index: ArenaIndex,
17    _dummy: PhantomData<T>,
18}
19
20impl<T> Clone for UniquedKey<T> {
21    fn clone(&self) -> Self {
22        *self
23    }
24}
25impl<T> Copy for UniquedKey<T> {}
26
27impl<T: 'static> Hash for UniquedKey<T> {
28    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
29        self.index.hash(state);
30        std::any::TypeId::of::<T>().hash(state);
31    }
32}
33
34/// Save a unique copy of an object and get a handle to the saved copy.
35pub fn save<T: Any + Hash + Eq>(ctx: &mut Context, t: T) -> UniquedKey<T> {
36    let hash = TypeValueHash::new(&t);
37    let t = UniquedAny(Box::new(t));
38    let eq = |t1: &UniquedAny, t2: &UniquedAny| -> bool {
39        t1.0.downcast_ref::<T>() == t2.0.downcast_ref::<T>()
40    };
41    UniquedKey {
42        index: ctx.uniqued_any_store.get_or_create_unique(t, hash, &eq),
43        _dummy: PhantomData,
44    }
45}
46
47/// Given a handle to a stored unique copy of an object, get a reference to the object itself.
48pub fn get<T: Any + Hash + Eq>(ctx: &Context, key: UniquedKey<T>) -> Ref<'_, T> {
49    let r = ctx
50        .uniqued_any_store
51        .unique_store
52        .get(key.index)
53        .expect("Key not found in uniqued store")
54        .borrow();
55    Ref::map(r, |ua| {
56        ua.0.downcast_ref::<T>()
57            .expect("Type mismatch in UniquedAny retrieval")
58    })
59}
60
61#[cfg(test)]
62mod tests {
63    use crate::context::Context;
64
65    use super::{get, save};
66
67    #[test]
68    fn test_uniqued_any() {
69        let ctx = &mut Context::new();
70
71        let s1 = String::from("Hello");
72        let s1_handle = save(ctx, s1);
73        assert!(*get(ctx, s1_handle) == "Hello");
74
75        let s2 = String::from("Hello");
76        let s2_handle = save(ctx, s2);
77        assert!(s1_handle == s2_handle);
78
79        let s3 = String::from("World");
80        let s3_handle = save(ctx, s3);
81        assert!(s1_handle != s3_handle);
82
83        let i1 = 71i64;
84        let i1_handle = save(ctx, i1);
85        assert!(*get(ctx, i1_handle) == i1);
86    }
87}