Skip to main content

mofa_foundation/agent/context/
ext.rs

1//! Context Extension Traits
2//!
3//! Provides a generic extension mechanism for CoreAgentContext
4//!
5//! # Design Philosophy
6//!
7//! The context system should be extensible for different use cases:
8//! - RichAgentContext: metrics and output tracking
9//! - PromptContext: prompt building capabilities
10//! - Custom contexts: user-specific extensions
11//!
12//! # Example
13//!
14//! ```rust,ignore
15//! use mofa_foundation::agent::context::{ContextExt, RichAgentContext};
16//! use mofa_kernel::agent::context::CoreAgentContext;
17//!
18//! let core = CoreAgentContext::new("exec-123");
19//! let rich = RichAgentContext::new(core);
20//!
21//! // Use extension methods
22//! rich.record_output("llm", json!("response")).await;
23//! ```
24
25use mofa_kernel::agent::context::AgentContext;
26use std::any::{Any, TypeId};
27use std::collections::HashMap;
28use std::sync::Arc;
29use tokio::sync::RwLock;
30
31/// Generic context extension trait
32///
33/// Allows adding custom data to any context implementation
34pub trait ContextExt {
35    /// Set extension data
36    fn set_extension<T: Send + Sync + serde::Serialize + 'static>(
37        &self,
38        value: T,
39    ) -> impl std::future::Future<Output = ()> + Send;
40    /// Get extension data
41    fn get_extension<T: Send + Sync + serde::de::DeserializeOwned + 'static>(
42        &self,
43    ) -> impl std::future::Future<Output = Option<T>> + Send;
44    /// Remove extension data
45    fn remove_extension<T: Send + Sync + serde::de::DeserializeOwned + 'static>(
46        &self,
47    ) -> impl std::future::Future<Output = Option<T>> + Send;
48    /// Check if extension exists
49    fn has_extension<T: Send + Sync + 'static>(
50        &self,
51    ) -> impl std::future::Future<Output = bool> + Send;
52}
53
54/// Extension storage for context
55///
56/// Stores type-safe extension data
57#[derive(Clone, Default)]
58pub struct ExtensionStorage {
59    inner: Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
60}
61
62impl ExtensionStorage {
63    /// Create new storage
64    pub fn new() -> Self {
65        Self::default()
66    }
67
68    /// Set extension value
69    pub async fn set<T: Send + Sync + 'static>(&self, value: T) {
70        let mut inner = self.inner.write().await;
71        inner.insert(TypeId::of::<T>(), Box::new(value));
72    }
73
74    /// Get extension value
75    pub async fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
76        let inner = self.inner.read().await;
77        inner
78            .get(&TypeId::of::<T>())
79            .and_then(|v| v.downcast_ref::<T>())
80            .cloned()
81    }
82
83    /// Remove extension value
84    pub async fn remove<T: Send + Sync + 'static>(&self) -> Option<T> {
85        let mut inner = self.inner.write().await;
86        inner
87            .remove(&TypeId::of::<T>())
88            .and_then(|v| v.downcast::<T>().ok())
89            .map(|v| *v)
90    }
91
92    /// Check if extension exists
93    pub async fn has<T: Send + Sync + 'static>(&self) -> bool {
94        let inner = self.inner.read().await;
95        inner.contains_key(&TypeId::of::<T>())
96    }
97}
98
99/// Implement ContextExt for CoreAgentContext using extension storage
100impl ContextExt for AgentContext {
101    async fn set_extension<T: Send + Sync + serde::Serialize + 'static>(&self, value: T) {
102        // Store in the generic K/V store
103        let type_name = std::any::type_name::<T>();
104        let key = format!("__ext__:{}", type_name);
105        if let Ok(v) = serde_json::to_value(&value) {
106            self.set(&key, v).await;
107        }
108    }
109
110    async fn get_extension<T: Send + Sync + serde::de::DeserializeOwned + 'static>(
111        &self,
112    ) -> Option<T> {
113        let type_name = std::any::type_name::<T>();
114        let key = format!("__ext__:{}", type_name);
115        self.get(&key).await
116    }
117
118    async fn remove_extension<T: Send + Sync + serde::de::DeserializeOwned + 'static>(
119        &self,
120    ) -> Option<T> {
121        let type_name = std::any::type_name::<T>();
122        let key = format!("__ext__:{}", type_name);
123        self.remove(&key)
124            .await
125            .and_then(|v| serde_json::from_value(v).ok())
126    }
127
128    async fn has_extension<T: Send + Sync + 'static>(&self) -> bool {
129        let type_name = std::any::type_name::<T>();
130        let key = format!("__ext__:{}", type_name);
131        self.contains(&key).await
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
140    struct TestExtension {
141        value: String,
142        count: u32,
143    }
144
145    #[tokio::test]
146    async fn test_extension_storage() {
147        let storage = ExtensionStorage::new();
148
149        storage
150            .set(TestExtension {
151                value: "test".to_string(),
152                count: 42,
153            })
154            .await;
155
156        assert!(storage.has::<TestExtension>().await);
157
158        let retrieved = storage.get::<TestExtension>().await;
159        assert_eq!(
160            retrieved,
161            Some(TestExtension {
162                value: "test".to_string(),
163                count: 42,
164            })
165        );
166    }
167
168    #[tokio::test]
169    async fn test_context_ext() {
170        let ctx = AgentContext::new("test-exec");
171
172        ctx.set_extension(TestExtension {
173            value: "test".to_string(),
174            count: 42,
175        })
176        .await;
177
178        assert!(ctx.has_extension::<TestExtension>().await);
179
180        let retrieved = ctx.get_extension::<TestExtension>().await;
181        assert_eq!(
182            retrieved,
183            Some(TestExtension {
184                value: "test".to_string(),
185                count: 42,
186            })
187        );
188
189        // Remove and verify
190        let removed = ctx.remove_extension::<TestExtension>().await;
191        assert_eq!(
192            removed,
193            Some(TestExtension {
194                value: "test".to_string(),
195                count: 42,
196            })
197        );
198
199        assert!(!ctx.has_extension::<TestExtension>().await);
200    }
201}