layered_nlp/
type_bucket.rs

1#![allow(dead_code)]
2
3use std::collections::hash_map;
4use std::marker::PhantomData;
5use std::{
6    any::{Any, TypeId},
7    collections::HashMap,
8    fmt::{self, Debug},
9};
10
11/// Prepared key-value pair used to specify a custom attribute for input tokens
12/// and is used internally for collecting attribute assignments.
13// `Box<dyn Bucket>` is the empty bucket for this type.
14// It is required to add a type not present in `TypeBucket`.
15pub struct AnyAttribute(TypeId, Box<dyn Bucket>, Box<dyn Any>);
16
17impl std::fmt::Debug for AnyAttribute {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        f.debug_struct("AnyAttribute").finish()
20    }
21}
22
23impl AnyAttribute {
24    pub fn new<T: 'static + Debug>(value: T) -> Self {
25        AnyAttribute(
26            TypeId::of::<T>(),
27            Box::new(Vec::<T>::new()),
28            Box::new(value),
29        )
30    }
31
32    pub fn extract<T: 'static>(self) -> Result<T, Self> {
33        let AnyAttribute(key, empty_bucket, value) = self;
34        value
35            .downcast()
36            .map(|boxed| *boxed)
37            .map_err(|e| AnyAttribute(key, empty_bucket, e))
38    }
39
40    pub fn type_id(&self) -> TypeId {
41        self.0
42    }
43}
44
45/// A view into an occupied entry in a `TypeBucket`.
46#[derive(Debug)]
47pub struct OccupiedEntry<'a, T> {
48    data: hash_map::OccupiedEntry<'a, TypeId, Box<dyn Any>>,
49    marker: PhantomData<fn(T)>,
50}
51
52impl<'a, T: 'static> OccupiedEntry<'a, T> {
53    /// Gets a reference to the value in the entry.
54    pub fn get(&self) -> &T {
55        self.data.get().downcast_ref().unwrap()
56    }
57
58    ///Gets a mutable reference to the value in the entry.
59    pub fn get_mut(&mut self) -> &mut T {
60        self.data.get_mut().downcast_mut().unwrap()
61    }
62
63    /// Converts the `OccupiedEntry` into a mutable reference to the value in the entry
64    /// with a lifetime bound to the map itself.
65    pub fn into_mut(self) -> &'a mut T {
66        self.data.into_mut().downcast_mut().unwrap()
67    }
68
69    /// Sets the value of the entry, and returns the entry's old value.
70    pub fn insert(&mut self, value: T) -> T {
71        self.data
72            .insert(Box::new(value))
73            .downcast()
74            .map(|boxed| *boxed)
75            .unwrap()
76    }
77
78    /// Takes the value out of the entry, and returns it.    
79    pub fn remove(self) -> T {
80        self.data.remove().downcast().map(|boxed| *boxed).unwrap()
81    }
82}
83
84/// A view into a vacant entry in a `TypeBucket`.
85#[derive(Debug)]
86pub struct VacantEntry<'a, T> {
87    data: hash_map::VacantEntry<'a, TypeId, Box<dyn Any>>,
88    marker: PhantomData<fn(T)>,
89}
90
91impl<'a, T: 'static> VacantEntry<'a, T> {
92    /// Sets the value of the entry with the key of the `VacantEntry`, and returns a mutable reference to it.
93    pub fn insert(self, value: T) -> &'a mut T {
94        self.data.insert(Box::new(value)).downcast_mut().unwrap()
95    }
96}
97
98/// A view into a single entry in a map, which may either be vacant or occupied.
99#[derive(Debug)]
100pub enum Entry<'a, T> {
101    Occupied(OccupiedEntry<'a, T>),
102    Vacant(VacantEntry<'a, T>),
103}
104
105impl<'a, T: 'static> Entry<'a, T> {
106    /// Ensures a value is in the entry by inserting the default if empty, and returns
107    /// a mutable reference to the value in the entry.
108    pub fn or_insert(self, default: T) -> &'a mut T {
109        match self {
110            Entry::Occupied(inner) => inner.into_mut(),
111            Entry::Vacant(inner) => inner.insert(default),
112        }
113    }
114
115    /// Ensures a value is in the entry by inserting the result of the default function if empty, and returns
116    /// a mutable reference to the value in the entry.
117    pub fn or_insert_with<F: FnOnce() -> T>(self, default: F) -> &'a mut T {
118        match self {
119            Entry::Occupied(inner) => inner.into_mut(),
120            Entry::Vacant(inner) => inner.insert(default()),
121        }
122    }
123}
124
125#[derive(Debug, Default)]
126/// The TypeBucket container
127pub struct TypeBucket {
128    // dyn Bucket is always a Vec<T>
129    // Box<Vec<T>>
130    // Box<dyn Bucket>
131    map: HashMap<TypeId, Box<dyn Bucket>>,
132}
133
134impl fmt::Debug for dyn Bucket {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        Bucket::debug(self, f)
137    }
138}
139
140trait Bucket {
141    fn as_any(&self) -> &dyn Any
142    where
143        Self: 'static;
144    fn as_any_mut(&mut self) -> &mut dyn Any
145    where
146        Self: 'static;
147    fn insert_any(&mut self, val: Box<dyn Any>);
148    fn debug(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149        f.pad("Bucket")
150    }
151    fn default() -> Self
152    where
153        Self: Sized + Default,
154    {
155        Default::default()
156    }
157}
158
159impl<T: 'static + Debug> Bucket for Vec<T> {
160    #[inline]
161    fn as_any(&self) -> &dyn Any {
162        self
163    }
164    #[inline]
165    fn as_any_mut(&mut self) -> &mut dyn Any {
166        self
167    }
168    fn insert_any(&mut self, val: Box<dyn Any>) {
169        self.push(*val.downcast().expect("type doesn't match"));
170    }
171
172    fn debug(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173        Debug::fmt(self, f)
174    }
175}
176
177impl TypeBucket {
178    /// Create an empty `TypeBucket`.
179    #[inline]
180    pub fn new() -> Self {
181        Self {
182            map: Default::default(),
183        }
184    }
185
186    /// Insert a prepared `KvPair` into this `TypeBucket`.
187    ///
188    /// If a value of this type already exists, it will be returned.
189    pub fn insert_any_attribute(&mut self, AnyAttribute(key, empty_value, value): AnyAttribute) {
190        self.map.entry(key).or_insert(empty_value).insert_any(value)
191    }
192
193    /// Insert a value into this `TypeBucket`.
194    ///
195    /// If a value of this type already exists, it will be returned.
196    #[track_caller]
197    pub fn insert<T: 'static + Debug>(&mut self, val: T) {
198        self.map
199            .entry(TypeId::of::<T>())
200            .or_insert_with(|| Box::new(Vec::<T>::new()))
201            .as_any_mut()
202            .downcast_mut::<Vec<T>>()
203            .unwrap()
204            .push(val);
205    }
206
207    // /// Check if container contains value for type
208    // pub fn contains<T: 'static>(&self) -> bool {
209    //     self.map
210    //         .as_ref()
211    //         .and_then(|m| m.get(&TypeId::of::<T>()))
212    //         .is_some()
213    // }
214
215    /// Get a reference to a value previously inserted on this `TypeBucket`.
216    pub fn get<T: 'static>(&self) -> &[T] {
217        self.map
218            .get(&TypeId::of::<T>())
219            .map(|boxed_vec| boxed_vec.as_any().downcast_ref::<Vec<T>>().unwrap())
220            .map(|vec| vec.as_slice())
221            .unwrap_or_else(|| &[])
222    }
223
224    pub fn get_debug<T: 'static + Debug>(&self) -> Vec<String> {
225        self.map
226            .get(&TypeId::of::<T>())
227            .map(|vec| {
228                vec.as_any()
229                    .downcast_ref::<Vec<T>>()
230                    .unwrap()
231                    .iter()
232                    .map(|item| format!("{:?}", item))
233                    .collect()
234            })
235            .unwrap_or_default()
236    }
237
238    // /// Get a mutable reference to a value previously inserted on this `TypeBucket`.
239    // pub fn get_mut<T: 'static>(&mut self) -> Option<&mut T> {
240    //     self.map
241    //         .as_mut()
242    //         .and_then(|m| m.get_mut(&TypeId::of::<T>()))
243    //         .and_then(|boxed| boxed.downcast_mut())
244    // }
245
246    // /// Remove a value from this `TypeBucket`.
247    // ///
248    // /// If a value of this type exists, it will be returned.
249    // pub fn remove<T: 'static>(&mut self) -> Option<T> {
250    //     self.map
251    //         .as_mut()
252    //         .and_then(|m| m.remove(&TypeId::of::<T>()))
253    //         .and_then(|boxed| boxed.downcast().ok().map(|boxed| *boxed))
254    // }
255
256    /// Clear the `TypeBucket` of all inserted values.
257    #[inline]
258    pub fn clear(&mut self) {
259        self.map = Default::default();
260    }
261
262    // /// Get an entry in the `TypeBucket` for in-place manipulation.
263    // pub fn entry<T: 'static>(&mut self) -> Entry<T> {
264    //     match self
265    //         .map
266    //         .get_or_insert_with(|| HashMap::default())
267    //         .entry(TypeId::of::<T>())
268    //     {
269    //         hash_map::Entry::Occupied(e) => Entry::Occupied(OccupiedEntry {
270    //             data: e,
271    //             marker: PhantomData,
272    //         }),
273    //         hash_map::Entry::Vacant(e) => Entry::Vacant(VacantEntry {
274    //             data: e,
275    //             marker: PhantomData,
276    //         }),
277    //     }
278    // }
279}
280
281#[test]
282fn test_type_map() {
283    #[derive(Debug, PartialEq)]
284    struct MyType(i32);
285
286    #[derive(Debug, PartialEq, Default)]
287    struct MyType2(String);
288
289    let mut map = TypeBucket::new();
290
291    map.insert(5i32);
292    map.insert(MyType(10));
293
294    assert_eq!(map.get::<i32>(), &[5i32]);
295    // assert_eq!(map.get_mut(), Some(&mut 5i32));
296
297    // assert_eq!(map.remove::<i32>(), Some(5i32));
298    // assert!(map.get::<i32>().is_empty());
299
300    assert_eq!(map.get::<bool>(), &[] as &[bool]);
301    assert_eq!(map.get::<MyType>(), &[MyType(10)]);
302
303    map.insert(MyType(20));
304    assert_eq!(map.get::<MyType>(), &[MyType(10), MyType(20)]);
305    // let entry = map.entry::<MyType2>();
306
307    // let mut v = entry.or_insert_with(MyType2::default);
308
309    // v.0 = "Hello".into();
310
311    // assert_eq!(map.get(), Some(&MyType2("Hello".into())));
312}