assemble_core/plugins/
extensions.rs1use std::any;
4use std::any::Any;
5use std::collections::HashMap;
6use std::fmt::{Debug, Formatter};
7use std::ops::{Index, IndexMut};
8
9use crate::prelude::{ProjectError, ProjectResult};
10use thiserror::Error;
11
12pub trait Extension: 'static + Send + Sync {}
14
15impl<E: 'static + Send + Sync> Extension for E {}
16
17pub trait ExtensionAware {
19 fn extensions(&self) -> &ExtensionContainer;
21 fn extensions_mut(&mut self) -> &mut ExtensionContainer;
23
24 fn extension<E: Extension>(&self) -> ProjectResult<&E> {
27 self.extensions().get_by_type()
28 }
29
30 fn extension_mut<E: Extension>(&mut self) -> ProjectResult<&mut E> {
33 self.extensions_mut().get_by_type_mut()
34 }
35}
36
37type AnyExtension = Box<dyn Any + Send + Sync>;
38
39#[derive(Default)]
41pub struct ExtensionContainer {
42 ob_map: HashMap<String, AnyExtension>,
43}
44
45impl ExtensionContainer {
46 pub fn add<E: Extension, S: AsRef<str>>(
51 &mut self,
52 name: S,
53 value: E,
54 ) -> Result<(), ExtensionError> {
55 let name = name.as_ref();
56 if self.ob_map.contains_key(name) {
57 return Err(ExtensionError::AlreadyRegistered(name.to_string()));
58 }
59 let boxed = Box::new(value) as AnyExtension;
60 self.ob_map.insert(name.to_string(), boxed);
61 Ok(())
62 }
63
64 pub fn get<S: AsRef<str>>(&self, name: S) -> ProjectResult<&AnyExtension> {
66 self.ob_map
67 .get(name.as_ref())
68 .ok_or(ProjectError::ExtensionNotRegistered(name.as_ref().to_string()).into())
69 }
70
71 pub fn get_mut<S: AsRef<str>>(&mut self, name: S) -> ProjectResult<&mut AnyExtension> {
73 self.ob_map
74 .get_mut(name.as_ref())
75 .ok_or(ProjectError::ExtensionNotRegistered(name.as_ref().to_string()).into())
76 }
77
78 pub fn get_by_type<E: Extension>(&self) -> ProjectResult<&E> {
81 let mut output: Vec<&E> = vec![];
82 for value in self.ob_map.values() {
83 if let Some(ext) = value.downcast_ref::<E>() {
84 output.push(ext);
85 }
86 }
87 match output.len() {
88 1 => Ok(output.remove(0)),
89 _ => {
90 Err(ProjectError::ExtensionNotRegistered(any::type_name::<E>().to_string()).into())
91 }
92 }
93 }
94
95 pub fn get_by_type_mut<E: Extension>(&mut self) -> ProjectResult<&mut E> {
98 let mut output: Vec<String> = vec![];
99 for (name, ext) in &self.ob_map {
100 if ext.is::<E>() {
101 output.push(name.clone());
102 }
103 }
104 match output.len() {
105 1 => {
106 let index = output.remove(0);
107 self.ob_map
108 .get_mut(&index)
109 .and_then(|b| b.downcast_mut())
110 .ok_or_else(|| unreachable!())
111 }
112 _ => {
113 Err(ProjectError::ExtensionNotRegistered(any::type_name::<E>().to_string()).into())
114 }
115 }
116 }
117}
118
119impl Index<&str> for ExtensionContainer {
120 type Output = AnyExtension;
121
122 fn index(&self, index: &str) -> &Self::Output {
123 self.get(index).unwrap()
124 }
125}
126
127impl IndexMut<&str> for ExtensionContainer {
128 fn index_mut(&mut self, index: &str) -> &mut Self::Output {
129 self.get_mut(index).unwrap()
130 }
131}
132
133impl Index<String> for ExtensionContainer {
134 type Output = AnyExtension;
135
136 fn index(&self, index: String) -> &Self::Output {
137 self.get(index).unwrap()
138 }
139}
140
141impl IndexMut<String> for ExtensionContainer {
142 fn index_mut(&mut self, index: String) -> &mut Self::Output {
143 self.get_mut(index).unwrap()
144 }
145}
146
147impl Debug for ExtensionContainer {
148 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
149 f.debug_struct("ExtensionContainer").finish()
150 }
151}
152
153#[derive(Debug, Error)]
154pub enum ExtensionError {
155 #[error("Extension with name {0:?} already registered")]
156 AlreadyRegistered(String),
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn use_extensions() {
165 let mut ext = ExtensionContainer::default();
166 ext.add("test", String::from("Hello, World")).unwrap();
167
168 let value = ext.get("test").unwrap().downcast_ref::<String>().unwrap();
169 assert_eq!(value, "Hello, World")
170 }
171
172 #[test]
173 fn disallow_same_name_extensions() {
174 let mut ext = ExtensionContainer::default();
175 ext.add("test", String::from("Hello, World")).unwrap();
176 assert!(matches!(
177 ext.add("test", String::from("Hello, World")),
178 Err(ExtensionError::AlreadyRegistered(_))
179 ));
180 }
181}