markdown_that/parser/
extset.rs

1//! Extension sets
2//!
3//! These things allow you to put custom data inside internal markdown-it structures.
4//!
5use downcast_rs::{Downcast, impl_downcast};
6use std::fmt::Debug;
7
8/// Extension set member for the entire parser (only writable at init).
9pub trait MarkdownThatExt: Debug + Downcast + Send + Sync {}
10impl_downcast!(MarkdownThatExt);
11extension_set!(MarkdownThatExtSet, MarkdownThatExt);
12
13/// Extension set member for an arbitrary AST node.
14pub trait NodeExt: Debug + Downcast + Send + Sync {}
15impl_downcast!(NodeExt);
16extension_set!(NodeExtSet, NodeExt);
17
18/// Extension set member for an inline context.
19pub trait InlineRootExt: Debug + Downcast + Send + Sync {}
20impl_downcast!(InlineRootExt);
21extension_set!(InlineRootExtSet, InlineRootExt);
22
23/// Extension set member for a block context.
24pub trait RootExt: Debug + Downcast + Send + Sync {}
25impl_downcast!(RootExt);
26extension_set!(RootExtSet, RootExt);
27
28/// Extension set member for a renderer context.
29pub trait RenderExt: Debug + Downcast + Send + Sync {}
30impl_downcast!(RenderExt);
31extension_set!(RenderExtSet, RenderExt);
32
33// see https://github.com/malobre/erased_set for inspiration and API
34// see https://lucumr.pocoo.org/2022/1/7/as-any-hack/ for additional impl details
35macro_rules! extension_set {
36    ($name: ident, $trait: ident) => {
37        #[derive(Debug, Default)]
38        pub struct $name(::std::collections::HashMap<crate::common::TypeKey, Box<dyn $trait>>);
39
40        impl $name {
41            #[must_use]
42            pub fn new() -> Self {
43                Self::default()
44            }
45
46            #[must_use]
47            pub fn is_empty(&self) -> bool {
48                self.0.is_empty()
49            }
50
51            #[must_use]
52            pub fn len(&self) -> usize {
53                self.0.len()
54            }
55
56            pub fn clear(&mut self) {
57                self.0.clear();
58            }
59
60            #[must_use]
61            pub fn contains<T: 'static>(&self) -> bool {
62                let key = crate::common::TypeKey::of::<T>();
63                self.0.contains_key(&key)
64            }
65
66            #[must_use]
67            pub fn get<T: $trait>(&self) -> Option<&T> {
68                let key = crate::common::TypeKey::of::<T>();
69                let result = self.0.get(&key)?;
70                result.downcast_ref::<T>()
71            }
72
73            #[must_use]
74            pub fn get_mut<T: $trait>(&mut self) -> Option<&mut T> {
75                let key = crate::common::TypeKey::of::<T>();
76                let result = self.0.get_mut(&key)?;
77                result.downcast_mut::<T>()
78            }
79
80            pub fn get_or_insert<T: $trait>(&mut self, value: T) -> &mut T {
81                let key = crate::common::TypeKey::of::<T>();
82                let result = self.0.entry(key).or_insert_with(|| Box::new(value));
83                result.downcast_mut::<T>().unwrap()
84            }
85
86            pub fn get_or_insert_with<T: $trait>(&mut self, f: impl FnOnce() -> T) -> &mut T {
87                let key = crate::common::TypeKey::of::<T>();
88                let result = self.0.entry(key).or_insert_with(|| Box::new(f()));
89                result.downcast_mut::<T>().unwrap()
90            }
91
92            pub fn get_or_insert_default<T: $trait + Default>(&mut self) -> &mut T {
93                let key = crate::common::TypeKey::of::<T>();
94                let result = self.0.entry(key).or_insert_with(|| Box::<T>::default());
95                result.downcast_mut::<T>().unwrap()
96            }
97
98            pub fn insert<T: $trait>(&mut self, value: T) -> Option<T> {
99                let key = crate::common::TypeKey::of::<T>();
100                let result = self.0.insert(key, Box::new(value))?;
101                Some(*result.downcast::<T>().unwrap())
102            }
103
104            pub fn remove<T: $trait>(&mut self) -> Option<T> {
105                let key = crate::common::TypeKey::of::<T>();
106                let result = self.0.remove(&key)?;
107                Some(*result.downcast::<T>().unwrap())
108            }
109        }
110    };
111}
112
113pub(crate) use extension_set;
114
115#[cfg(test)]
116mod tests {
117    use super::extension_set;
118    use downcast_rs::{Downcast, impl_downcast};
119    use std::fmt::Debug;
120
121    pub trait TestExt: Debug + Downcast + Send + Sync {}
122    impl_downcast!(TestExt);
123
124    extension_set!(TestExtSet, TestExt);
125
126    impl<T: Debug + Downcast + Send + Sync> TestExt for T {}
127
128    #[test]
129    fn empty_set() {
130        let set = TestExtSet::new();
131        assert_eq!(set.len(), 0);
132        assert!(set.is_empty());
133    }
134
135    #[test]
136    fn insert_elements() {
137        let mut set = TestExtSet::new();
138        set.insert(42u8);
139        assert_eq!(set.len(), 1);
140        assert!(!set.is_empty());
141        set.insert(42u16);
142        assert_eq!(set.len(), 2);
143        assert!(!set.is_empty());
144    }
145
146    #[test]
147    fn contains() {
148        let mut set = TestExtSet::new();
149        set.insert(42u8);
150        assert!(!set.contains::<u16>());
151        set.insert(42u16);
152        assert!(set.contains::<u16>());
153        set.remove::<u16>();
154        assert!(!set.contains::<u16>());
155    }
156
157    #[test]
158    fn get() {
159        let mut set = TestExtSet::new();
160        set.insert(42u8);
161        assert_eq!(set.get::<u16>(), None);
162        set.insert(42u16);
163        set.insert(123u16);
164        assert_eq!(set.get::<u16>(), Some(&123u16));
165    }
166
167    #[test]
168    fn get_mut() {
169        let mut set = TestExtSet::new();
170        set.insert(42u16);
171        *set.get_mut::<u16>().unwrap() = 123u16;
172        assert_eq!(set.get::<u16>(), Some(&123u16));
173    }
174
175    #[test]
176    fn or_insert() {
177        let mut set = TestExtSet::new();
178        set.insert(123u8);
179        assert_eq!(set.get_or_insert(0u8), &mut 123u8);
180        assert_eq!(set.get_or_insert_default::<u8>(), &mut 123u8);
181        assert_eq!(set.get_or_insert_with(|| 0u8), &mut 123u8);
182        set.clear();
183        assert_eq!(set.get_or_insert(10u8), &mut 10u8);
184        set.clear();
185        assert_eq!(set.get_or_insert_with(|| 20u8), &mut 20u8);
186        set.clear();
187        assert_eq!(set.get_or_insert_default::<u8>(), &mut 0u8);
188    }
189
190    #[test]
191    fn different_types_stored_once() {
192        let mut set = TestExtSet::new();
193        set.insert("foo");
194        set.insert("bar");
195        set.insert("quux");
196        assert_eq!(set.len(), 1);
197    }
198
199    #[test]
200    fn zero_sized_types() {
201        #[derive(Debug, PartialEq, Eq)]
202        struct A;
203        #[derive(Debug, PartialEq, Eq)]
204        struct B;
205        let mut set = TestExtSet::new();
206        set.insert(A);
207        set.insert(B);
208        assert_eq!(set.len(), 2);
209        assert_eq!(set.get::<A>(), Some(&A));
210    }
211
212    #[test]
213    fn clear() {
214        let mut set = TestExtSet::new();
215        set.insert(42u8);
216        set.insert(42u16);
217        assert_eq!(set.len(), 2);
218        set.clear();
219        assert_eq!(set.len(), 0);
220    }
221
222    #[test]
223    fn debug() {
224        let mut set = TestExtSet::new();
225        set.insert(42);
226        set.insert("test");
227        let str = format!("{:?}", set);
228        // there are no guarantees about field order, so check both
229        assert!(
230            str == "TestExtSet({i32: 42, &str: \"test\"})"
231                || str == "TestExtSet({&str: \"test\", i32: 42})"
232        );
233    }
234}