assemble_core/plugins/
extensions.rs

1//! Extensions that plugins can add
2
3use std::any;
4use std::any::Any;
5use std::collections::HashMap;
6use std::fmt::{Debug, Formatter};
7use std::ops::{Index, IndexMut};
8
9use crate::prelude::{ProjectError, ProjectResult};
10use thiserror::Error;
11
12/// A a helper trait that extends the needed traits to add a value as an extension
13pub trait Extension: 'static + Send + Sync {}
14
15impl<E: 'static + Send + Sync> Extension for E {}
16
17/// A type that contains extensions
18pub trait ExtensionAware {
19    /// Gets the extension container
20    fn extensions(&self) -> &ExtensionContainer;
21    /// Gets a mutable reference to the extension container
22    fn extensions_mut(&mut self) -> &mut ExtensionContainer;
23
24    /// If a single extension is registered with a given type, a reference to that value is returned
25    /// as `Some(_)`
26    fn extension<E: Extension>(&self) -> ProjectResult<&E> {
27        self.extensions().get_by_type()
28    }
29
30    /// If a single extension is registered with a given type, a mutable reference to that value is returned
31    /// as `Some(_)`
32    fn extension_mut<E: Extension>(&mut self) -> ProjectResult<&mut E> {
33        self.extensions_mut().get_by_type_mut()
34    }
35}
36
37type AnyExtension = Box<dyn Any + Send + Sync>;
38
39/// Contains extensions
40#[derive(Default)]
41pub struct ExtensionContainer {
42    ob_map: HashMap<String, AnyExtension>,
43}
44
45impl ExtensionContainer {
46    /// Adds a new extension to this container
47    ///
48    /// # Error
49    /// Will return an error if `name` is already registered to this container
50    pub fn add<E: Extension, S: AsRef<str>>(
51        &mut self,
52        name: S,
53        value: E,
54    ) -> Result<(), ExtensionError> {
55        let name = name.as_ref();
56        if self.ob_map.contains_key(name) {
57            return Err(ExtensionError::AlreadyRegistered(name.to_string()));
58        }
59        let boxed = Box::new(value) as AnyExtension;
60        self.ob_map.insert(name.to_string(), boxed);
61        Ok(())
62    }
63
64    /// Gets a reference to an extension, if it exists
65    pub fn get<S: AsRef<str>>(&self, name: S) -> ProjectResult<&AnyExtension> {
66        self.ob_map
67            .get(name.as_ref())
68            .ok_or(ProjectError::ExtensionNotRegistered(name.as_ref().to_string()).into())
69    }
70
71    /// Gets a mutable reference to an extension, if it exists
72    pub fn get_mut<S: AsRef<str>>(&mut self, name: S) -> ProjectResult<&mut AnyExtension> {
73        self.ob_map
74            .get_mut(name.as_ref())
75            .ok_or(ProjectError::ExtensionNotRegistered(name.as_ref().to_string()).into())
76    }
77
78    /// If a single extension is registered with a given type, a reference to that value is returned
79    /// as `Ok(_)`
80    pub fn get_by_type<E: Extension>(&self) -> ProjectResult<&E> {
81        let mut output: Vec<&E> = vec![];
82        for value in self.ob_map.values() {
83            if let Some(ext) = value.downcast_ref::<E>() {
84                output.push(ext);
85            }
86        }
87        match output.len() {
88            1 => Ok(output.remove(0)),
89            _ => {
90                Err(ProjectError::ExtensionNotRegistered(any::type_name::<E>().to_string()).into())
91            }
92        }
93    }
94
95    /// If a single extension is registered with a given type, a mutable reference to that value is returned
96    /// as `Some(_)`
97    pub fn get_by_type_mut<E: Extension>(&mut self) -> ProjectResult<&mut E> {
98        let mut output: Vec<String> = vec![];
99        for (name, ext) in &self.ob_map {
100            if ext.is::<E>() {
101                output.push(name.clone());
102            }
103        }
104        match output.len() {
105            1 => {
106                let index = output.remove(0);
107                self.ob_map
108                    .get_mut(&index)
109                    .and_then(|b| b.downcast_mut())
110                    .ok_or_else(|| unreachable!())
111            }
112            _ => {
113                Err(ProjectError::ExtensionNotRegistered(any::type_name::<E>().to_string()).into())
114            }
115        }
116    }
117}
118
119impl Index<&str> for ExtensionContainer {
120    type Output = AnyExtension;
121
122    fn index(&self, index: &str) -> &Self::Output {
123        self.get(index).unwrap()
124    }
125}
126
127impl IndexMut<&str> for ExtensionContainer {
128    fn index_mut(&mut self, index: &str) -> &mut Self::Output {
129        self.get_mut(index).unwrap()
130    }
131}
132
133impl Index<String> for ExtensionContainer {
134    type Output = AnyExtension;
135
136    fn index(&self, index: String) -> &Self::Output {
137        self.get(index).unwrap()
138    }
139}
140
141impl IndexMut<String> for ExtensionContainer {
142    fn index_mut(&mut self, index: String) -> &mut Self::Output {
143        self.get_mut(index).unwrap()
144    }
145}
146
147impl Debug for ExtensionContainer {
148    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
149        f.debug_struct("ExtensionContainer").finish()
150    }
151}
152
153#[derive(Debug, Error)]
154pub enum ExtensionError {
155    #[error("Extension with name {0:?} already registered")]
156    AlreadyRegistered(String),
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn use_extensions() {
165        let mut ext = ExtensionContainer::default();
166        ext.add("test", String::from("Hello, World")).unwrap();
167
168        let value = ext.get("test").unwrap().downcast_ref::<String>().unwrap();
169        assert_eq!(value, "Hello, World")
170    }
171
172    #[test]
173    fn disallow_same_name_extensions() {
174        let mut ext = ExtensionContainer::default();
175        ext.add("test", String::from("Hello, World")).unwrap();
176        assert!(matches!(
177            ext.add("test", String::from("Hello, World")),
178            Err(ExtensionError::AlreadyRegistered(_))
179        ));
180    }
181}