mofa_foundation/agent/context/
ext.rs1use mofa_kernel::agent::context::AgentContext;
26use std::any::{Any, TypeId};
27use std::collections::HashMap;
28use std::sync::Arc;
29use tokio::sync::RwLock;
30
31pub trait ContextExt {
35 fn set_extension<T: Send + Sync + serde::Serialize + 'static>(
37 &self,
38 value: T,
39 ) -> impl std::future::Future<Output = ()> + Send;
40 fn get_extension<T: Send + Sync + serde::de::DeserializeOwned + 'static>(
42 &self,
43 ) -> impl std::future::Future<Output = Option<T>> + Send;
44 fn remove_extension<T: Send + Sync + serde::de::DeserializeOwned + 'static>(
46 &self,
47 ) -> impl std::future::Future<Output = Option<T>> + Send;
48 fn has_extension<T: Send + Sync + 'static>(
50 &self,
51 ) -> impl std::future::Future<Output = bool> + Send;
52}
53
54#[derive(Clone, Default)]
58pub struct ExtensionStorage {
59 inner: Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
60}
61
62impl ExtensionStorage {
63 pub fn new() -> Self {
65 Self::default()
66 }
67
68 pub async fn set<T: Send + Sync + 'static>(&self, value: T) {
70 let mut inner = self.inner.write().await;
71 inner.insert(TypeId::of::<T>(), Box::new(value));
72 }
73
74 pub async fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
76 let inner = self.inner.read().await;
77 inner
78 .get(&TypeId::of::<T>())
79 .and_then(|v| v.downcast_ref::<T>())
80 .cloned()
81 }
82
83 pub async fn remove<T: Send + Sync + 'static>(&self) -> Option<T> {
85 let mut inner = self.inner.write().await;
86 inner
87 .remove(&TypeId::of::<T>())
88 .and_then(|v| v.downcast::<T>().ok())
89 .map(|v| *v)
90 }
91
92 pub async fn has<T: Send + Sync + 'static>(&self) -> bool {
94 let inner = self.inner.read().await;
95 inner.contains_key(&TypeId::of::<T>())
96 }
97}
98
99impl ContextExt for AgentContext {
101 async fn set_extension<T: Send + Sync + serde::Serialize + 'static>(&self, value: T) {
102 let type_name = std::any::type_name::<T>();
104 let key = format!("__ext__:{}", type_name);
105 if let Ok(v) = serde_json::to_value(&value) {
106 self.set(&key, v).await;
107 }
108 }
109
110 async fn get_extension<T: Send + Sync + serde::de::DeserializeOwned + 'static>(
111 &self,
112 ) -> Option<T> {
113 let type_name = std::any::type_name::<T>();
114 let key = format!("__ext__:{}", type_name);
115 self.get(&key).await
116 }
117
118 async fn remove_extension<T: Send + Sync + serde::de::DeserializeOwned + 'static>(
119 &self,
120 ) -> Option<T> {
121 let type_name = std::any::type_name::<T>();
122 let key = format!("__ext__:{}", type_name);
123 self.remove(&key)
124 .await
125 .and_then(|v| serde_json::from_value(v).ok())
126 }
127
128 async fn has_extension<T: Send + Sync + 'static>(&self) -> bool {
129 let type_name = std::any::type_name::<T>();
130 let key = format!("__ext__:{}", type_name);
131 self.contains(&key).await
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
140 struct TestExtension {
141 value: String,
142 count: u32,
143 }
144
145 #[tokio::test]
146 async fn test_extension_storage() {
147 let storage = ExtensionStorage::new();
148
149 storage
150 .set(TestExtension {
151 value: "test".to_string(),
152 count: 42,
153 })
154 .await;
155
156 assert!(storage.has::<TestExtension>().await);
157
158 let retrieved = storage.get::<TestExtension>().await;
159 assert_eq!(
160 retrieved,
161 Some(TestExtension {
162 value: "test".to_string(),
163 count: 42,
164 })
165 );
166 }
167
168 #[tokio::test]
169 async fn test_context_ext() {
170 let ctx = AgentContext::new("test-exec");
171
172 ctx.set_extension(TestExtension {
173 value: "test".to_string(),
174 count: 42,
175 })
176 .await;
177
178 assert!(ctx.has_extension::<TestExtension>().await);
179
180 let retrieved = ctx.get_extension::<TestExtension>().await;
181 assert_eq!(
182 retrieved,
183 Some(TestExtension {
184 value: "test".to_string(),
185 count: 42,
186 })
187 );
188
189 let removed = ctx.remove_extension::<TestExtension>().await;
191 assert_eq!(
192 removed,
193 Some(TestExtension {
194 value: "test".to_string(),
195 count: 42,
196 })
197 );
198
199 assert!(!ctx.has_extension::<TestExtension>().await);
200 }
201}