Skip to main content

ai_agents_context/
manager.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use parking_lot::RwLock;
5use serde_json::{Value, json};
6
7use ai_agents_core::{AgentError, Result};
8
9use super::builtin::get_builtin_value;
10use super::provider::ContextProvider;
11use super::render::TemplateRenderer;
12use super::source::{ContextSource, RefreshPolicy};
13
14pub struct ContextManager {
15    schema: HashMap<String, ContextSource>,
16    values: RwLock<HashMap<String, Value>>,
17    providers: RwLock<HashMap<String, Arc<dyn ContextProvider>>>,
18    agent_name: String,
19    agent_version: String,
20    renderer: TemplateRenderer,
21}
22
23impl ContextManager {
24    pub fn new(
25        schema: HashMap<String, ContextSource>,
26        agent_name: String,
27        agent_version: String,
28    ) -> Self {
29        Self {
30            schema,
31            values: RwLock::new(HashMap::new()),
32            providers: RwLock::new(HashMap::new()),
33            agent_name,
34            agent_version,
35            renderer: TemplateRenderer::new(),
36        }
37    }
38
39    pub fn set(&self, key: &str, value: Value) -> Result<()> {
40        self.values.write().insert(key.to_string(), value);
41        Ok(())
42    }
43
44    pub fn update(&self, path: &str, value: Value) -> Result<()> {
45        let parts: Vec<&str> = path.split('.').collect();
46        if parts.is_empty() {
47            return Err(AgentError::InvalidSpec("Empty path".into()));
48        }
49
50        let mut values = self.values.write();
51        let root_key = parts[0];
52
53        if parts.len() == 1 {
54            values.insert(root_key.to_string(), value);
55            return Ok(());
56        }
57
58        let root = values
59            .entry(root_key.to_string())
60            .or_insert_with(|| json!({}));
61
62        let mut current = root;
63        for part in &parts[1..parts.len() - 1] {
64            current = current
65                .as_object_mut()
66                .ok_or_else(|| AgentError::InvalidSpec(format!("Path {} is not an object", path)))?
67                .entry(*part)
68                .or_insert_with(|| json!({}));
69        }
70
71        if let Some(obj) = current.as_object_mut() {
72            obj.insert(parts[parts.len() - 1].to_string(), value);
73        }
74
75        Ok(())
76    }
77
78    pub fn remove(&self, key: &str) -> Option<Value> {
79        self.values.write().remove(key)
80    }
81
82    pub fn get(&self, key: &str) -> Option<Value> {
83        self.values.read().get(key).cloned()
84    }
85
86    pub fn get_path(&self, path: &str) -> Option<Value> {
87        let values = self.values.read();
88        ai_agents_core::get_dot_path_from_map(&values, path)
89    }
90
91    pub fn get_all(&self) -> HashMap<String, Value> {
92        self.values.read().clone()
93    }
94
95    pub async fn refresh(&self, key: &str) -> Result<()> {
96        let source = self
97            .schema
98            .get(key)
99            .ok_or_else(|| AgentError::InvalidSpec(format!("Unknown context key: {}", key)))?;
100
101        let value = self.resolve_source(key, source).await?;
102        if let Some(v) = value {
103            self.values.write().insert(key.to_string(), v);
104        }
105        Ok(())
106    }
107
108    pub async fn refresh_per_turn(&self) -> Result<()> {
109        for (key, source) in &self.schema {
110            if source.refresh_policy() == RefreshPolicy::PerTurn {
111                if let Some(value) = self.resolve_source(key, source).await? {
112                    self.values.write().insert(key.clone(), value);
113                }
114            }
115        }
116        Ok(())
117    }
118
119    pub async fn refresh_per_session(&self) -> Result<()> {
120        for (key, source) in &self.schema {
121            match source.refresh_policy() {
122                RefreshPolicy::PerSession | RefreshPolicy::Once => {
123                    if let Some(value) = self.resolve_source(key, source).await? {
124                        self.values.write().insert(key.clone(), value);
125                    }
126                }
127                _ => {}
128            }
129        }
130        Ok(())
131    }
132
133    pub async fn initialize(&self) -> Result<()> {
134        for (key, source) in &self.schema {
135            if let ContextSource::Runtime { default, .. } = source {
136                if let Some(default_value) = default {
137                    if !self.values.read().contains_key(key) {
138                        self.values
139                            .write()
140                            .insert(key.clone(), default_value.clone());
141                    }
142                }
143            } else if let Some(value) = self.resolve_source(key, source).await? {
144                self.values.write().insert(key.clone(), value);
145            }
146        }
147        Ok(())
148    }
149
150    pub fn register_provider(&self, name: &str, provider: Arc<dyn ContextProvider>) {
151        self.providers.write().insert(name.to_string(), provider);
152    }
153
154    pub fn validate(&self) -> Result<()> {
155        for (key, source) in &self.schema {
156            if source.is_required() && !self.values.read().contains_key(key) {
157                return Err(AgentError::InvalidSpec(format!(
158                    "Required context '{}' not provided",
159                    key
160                )));
161            }
162        }
163        Ok(())
164    }
165
166    pub fn snapshot(&self) -> HashMap<String, Value> {
167        self.values.read().clone()
168    }
169
170    pub fn restore(&self, snapshot: HashMap<String, Value>) {
171        *self.values.write() = snapshot;
172    }
173
174    async fn resolve_source(&self, key: &str, source: &ContextSource) -> Result<Option<Value>> {
175        match source {
176            ContextSource::Runtime { default, .. } => Ok(default.clone()),
177
178            ContextSource::Builtin { source: src, .. } => Ok(Some(get_builtin_value(
179                src,
180                &self.agent_name,
181                &self.agent_version,
182            ))),
183
184            ContextSource::File { path, fallback, .. } => {
185                let current_context = self.get_all();
186                let resolved_path = self.renderer.render_path(path, &current_context)?;
187
188                match tokio::fs::read_to_string(&resolved_path).await {
189                    Ok(content) => Ok(Some(Value::String(content))),
190                    Err(_) => {
191                        if let Some(fb) = fallback {
192                            let fallback_path = self.renderer.render_path(fb, &current_context)?;
193                            match tokio::fs::read_to_string(&fallback_path).await {
194                                Ok(content) => Ok(Some(Value::String(content))),
195                                Err(_) => Ok(None),
196                            }
197                        } else {
198                            Ok(None)
199                        }
200                    }
201                }
202            }
203
204            #[cfg(feature = "http-context")]
205            ContextSource::Http {
206                url,
207                method,
208                headers,
209                timeout_ms,
210                fallback,
211                ..
212            } => {
213                let current_context = self.get_all();
214                let resolved_url = self.renderer.render(url, &current_context)?;
215
216                let client = reqwest::Client::new();
217                let mut request = match method.to_uppercase().as_str() {
218                    "POST" => client.post(&resolved_url),
219                    "PUT" => client.put(&resolved_url),
220                    "DELETE" => client.delete(&resolved_url),
221                    _ => client.get(&resolved_url),
222                };
223
224                for (k, v) in headers {
225                    let resolved_value = self.renderer.render(v, &current_context)?;
226                    request = request.header(k, resolved_value);
227                }
228
229                if let Some(timeout) = timeout_ms {
230                    request = request.timeout(std::time::Duration::from_millis(*timeout));
231                }
232
233                match request.send().await {
234                    Ok(response) => {
235                        if response.status().is_success() {
236                            match response.json::<Value>().await {
237                                Ok(json) => Ok(Some(json)),
238                                Err(_) => Ok(fallback.clone()),
239                            }
240                        } else {
241                            Ok(fallback.clone())
242                        }
243                    }
244                    Err(_) => Ok(fallback.clone()),
245                }
246            }
247
248            #[cfg(not(feature = "http-context"))]
249            ContextSource::Http { fallback, .. } => {
250                // HTTP context sources require the "http-context" feature
251                tracing::warn!(
252                    "HTTP context source requested but 'http-context' feature is not enabled"
253                );
254                Ok(fallback.clone())
255            }
256
257            ContextSource::Env { name } => Ok(std::env::var(name).ok().map(Value::String)),
258
259            ContextSource::Callback { name, .. } => {
260                // Clone the provider to avoid holding the lock across await
261                let provider = {
262                    let providers = self.providers.read();
263                    providers.get(name).cloned()
264                };
265                if let Some(provider) = provider {
266                    let current_context = json!(self.get_all());
267                    Ok(Some(provider.get(key, &current_context).await?))
268                } else {
269                    Ok(None)
270                }
271            }
272        }
273    }
274
275    pub fn schema(&self) -> &HashMap<String, ContextSource> {
276        &self.schema
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::super::source::BuiltinSource;
283    use super::*;
284
285    #[test]
286    fn test_set_and_get() {
287        let manager = ContextManager::new(HashMap::new(), "Test".into(), "1.0".into());
288        manager.set("user", json!({"name": "Alice"})).unwrap();
289        let user = manager.get("user").unwrap();
290        assert_eq!(user.get("name").unwrap(), "Alice");
291    }
292
293    #[test]
294    fn test_update_nested() {
295        let manager = ContextManager::new(HashMap::new(), "Test".into(), "1.0".into());
296        manager.set("user", json!({"name": "Alice"})).unwrap();
297        manager.update("user.tier", json!("premium")).unwrap();
298        let user = manager.get("user").unwrap();
299        assert_eq!(user.get("tier").unwrap(), "premium");
300        assert_eq!(user.get("name").unwrap(), "Alice");
301    }
302
303    #[test]
304    fn test_get_path() {
305        let manager = ContextManager::new(HashMap::new(), "Test".into(), "1.0".into());
306        manager
307            .set("user", json!({"preferences": {"theme": "dark"}}))
308            .unwrap();
309        let theme = manager.get_path("user.preferences.theme").unwrap();
310        assert_eq!(theme, "dark");
311    }
312
313    #[test]
314    fn test_snapshot_restore() {
315        let manager = ContextManager::new(HashMap::new(), "Test".into(), "1.0".into());
316        manager.set("key1", json!("value1")).unwrap();
317        manager.set("key2", json!(42)).unwrap();
318
319        let snapshot = manager.snapshot();
320        assert_eq!(snapshot.len(), 2);
321
322        let manager2 = ContextManager::new(HashMap::new(), "Test".into(), "1.0".into());
323        manager2.restore(snapshot);
324        assert_eq!(manager2.get("key1").unwrap(), "value1");
325        assert_eq!(manager2.get("key2").unwrap(), 42);
326    }
327
328    #[test]
329    fn test_validate_required() {
330        let mut schema = HashMap::new();
331        schema.insert(
332            "user".into(),
333            ContextSource::Runtime {
334                required: true,
335                schema: None,
336                default: None,
337            },
338        );
339
340        let manager = ContextManager::new(schema, "Test".into(), "1.0".into());
341        assert!(manager.validate().is_err());
342
343        manager.set("user", json!({"name": "Alice"})).unwrap();
344        assert!(manager.validate().is_ok());
345    }
346
347    #[tokio::test]
348    async fn test_builtin_datetime() {
349        let mut schema = HashMap::new();
350        schema.insert(
351            "time".into(),
352            ContextSource::Builtin {
353                source: BuiltinSource::Datetime,
354                refresh: RefreshPolicy::PerTurn,
355            },
356        );
357
358        let manager = ContextManager::new(schema, "Test".into(), "1.0".into());
359        manager.initialize().await.unwrap();
360
361        let time = manager.get("time").unwrap();
362        assert!(time.get("date").is_some());
363        assert!(time.get("time").is_some());
364    }
365
366    #[tokio::test]
367    async fn test_env_source() {
368        // SAFETY: This test runs single-threaded and no other code accesses this env var
369        unsafe {
370            std::env::set_var("TEST_CONTEXT_VAR", "test_value");
371        }
372
373        let mut schema = HashMap::new();
374        schema.insert(
375            "test_env".into(),
376            ContextSource::Env {
377                name: "TEST_CONTEXT_VAR".into(),
378            },
379        );
380
381        let manager = ContextManager::new(schema, "Test".into(), "1.0".into());
382        manager.initialize().await.unwrap();
383
384        let value = manager.get("test_env").unwrap();
385        assert_eq!(value, "test_value");
386    }
387
388    #[tokio::test]
389    async fn test_runtime_default() {
390        let mut schema = HashMap::new();
391        schema.insert(
392            "settings".into(),
393            ContextSource::Runtime {
394                required: false,
395                schema: None,
396                default: Some(json!({"theme": "light"})),
397            },
398        );
399
400        let manager = ContextManager::new(schema, "Test".into(), "1.0".into());
401        manager.initialize().await.unwrap();
402
403        let settings = manager.get("settings").unwrap();
404        assert_eq!(settings.get("theme").unwrap(), "light");
405    }
406}