1use std::collections::HashMap;
2use std::sync::Arc;
3
4use parking_lot::RwLock;
5use serde_json::{Value, json};
6
7use ai_agents_core::{AgentError, Result};
8
9use super::builtin::get_builtin_value;
10use super::provider::ContextProvider;
11use super::render::TemplateRenderer;
12use super::source::{ContextSource, RefreshPolicy};
13
14pub struct ContextManager {
15 schema: HashMap<String, ContextSource>,
16 values: RwLock<HashMap<String, Value>>,
17 providers: RwLock<HashMap<String, Arc<dyn ContextProvider>>>,
18 agent_name: String,
19 agent_version: String,
20 renderer: TemplateRenderer,
21}
22
23impl ContextManager {
24 pub fn new(
25 schema: HashMap<String, ContextSource>,
26 agent_name: String,
27 agent_version: String,
28 ) -> Self {
29 Self {
30 schema,
31 values: RwLock::new(HashMap::new()),
32 providers: RwLock::new(HashMap::new()),
33 agent_name,
34 agent_version,
35 renderer: TemplateRenderer::new(),
36 }
37 }
38
39 pub fn set(&self, key: &str, value: Value) -> Result<()> {
40 self.values.write().insert(key.to_string(), value);
41 Ok(())
42 }
43
44 pub fn update(&self, path: &str, value: Value) -> Result<()> {
45 let parts: Vec<&str> = path.split('.').collect();
46 if parts.is_empty() {
47 return Err(AgentError::InvalidSpec("Empty path".into()));
48 }
49
50 let mut values = self.values.write();
51 let root_key = parts[0];
52
53 if parts.len() == 1 {
54 values.insert(root_key.to_string(), value);
55 return Ok(());
56 }
57
58 let root = values
59 .entry(root_key.to_string())
60 .or_insert_with(|| json!({}));
61
62 let mut current = root;
63 for part in &parts[1..parts.len() - 1] {
64 current = current
65 .as_object_mut()
66 .ok_or_else(|| AgentError::InvalidSpec(format!("Path {} is not an object", path)))?
67 .entry(*part)
68 .or_insert_with(|| json!({}));
69 }
70
71 if let Some(obj) = current.as_object_mut() {
72 obj.insert(parts[parts.len() - 1].to_string(), value);
73 }
74
75 Ok(())
76 }
77
78 pub fn remove(&self, key: &str) -> Option<Value> {
79 self.values.write().remove(key)
80 }
81
82 pub fn get(&self, key: &str) -> Option<Value> {
83 self.values.read().get(key).cloned()
84 }
85
86 pub fn get_path(&self, path: &str) -> Option<Value> {
87 let values = self.values.read();
88 ai_agents_core::get_dot_path_from_map(&values, path)
89 }
90
91 pub fn get_all(&self) -> HashMap<String, Value> {
92 self.values.read().clone()
93 }
94
95 pub async fn refresh(&self, key: &str) -> Result<()> {
96 let source = self
97 .schema
98 .get(key)
99 .ok_or_else(|| AgentError::InvalidSpec(format!("Unknown context key: {}", key)))?;
100
101 let value = self.resolve_source(key, source).await?;
102 if let Some(v) = value {
103 self.values.write().insert(key.to_string(), v);
104 }
105 Ok(())
106 }
107
108 pub async fn refresh_per_turn(&self) -> Result<()> {
109 for (key, source) in &self.schema {
110 if source.refresh_policy() == RefreshPolicy::PerTurn {
111 if let Some(value) = self.resolve_source(key, source).await? {
112 self.values.write().insert(key.clone(), value);
113 }
114 }
115 }
116 Ok(())
117 }
118
119 pub async fn refresh_per_session(&self) -> Result<()> {
120 for (key, source) in &self.schema {
121 match source.refresh_policy() {
122 RefreshPolicy::PerSession | RefreshPolicy::Once => {
123 if let Some(value) = self.resolve_source(key, source).await? {
124 self.values.write().insert(key.clone(), value);
125 }
126 }
127 _ => {}
128 }
129 }
130 Ok(())
131 }
132
133 pub async fn initialize(&self) -> Result<()> {
134 for (key, source) in &self.schema {
135 if let ContextSource::Runtime { default, .. } = source {
136 if let Some(default_value) = default {
137 if !self.values.read().contains_key(key) {
138 self.values
139 .write()
140 .insert(key.clone(), default_value.clone());
141 }
142 }
143 } else if let Some(value) = self.resolve_source(key, source).await? {
144 self.values.write().insert(key.clone(), value);
145 }
146 }
147 Ok(())
148 }
149
150 pub fn register_provider(&self, name: &str, provider: Arc<dyn ContextProvider>) {
151 self.providers.write().insert(name.to_string(), provider);
152 }
153
154 pub fn validate(&self) -> Result<()> {
155 for (key, source) in &self.schema {
156 if source.is_required() && !self.values.read().contains_key(key) {
157 return Err(AgentError::InvalidSpec(format!(
158 "Required context '{}' not provided",
159 key
160 )));
161 }
162 }
163 Ok(())
164 }
165
166 pub fn snapshot(&self) -> HashMap<String, Value> {
167 self.values.read().clone()
168 }
169
170 pub fn restore(&self, snapshot: HashMap<String, Value>) {
171 *self.values.write() = snapshot;
172 }
173
174 async fn resolve_source(&self, key: &str, source: &ContextSource) -> Result<Option<Value>> {
175 match source {
176 ContextSource::Runtime { default, .. } => Ok(default.clone()),
177
178 ContextSource::Builtin { source: src, .. } => Ok(Some(get_builtin_value(
179 src,
180 &self.agent_name,
181 &self.agent_version,
182 ))),
183
184 ContextSource::File { path, fallback, .. } => {
185 let current_context = self.get_all();
186 let resolved_path = self.renderer.render_path(path, ¤t_context)?;
187
188 match tokio::fs::read_to_string(&resolved_path).await {
189 Ok(content) => Ok(Some(Value::String(content))),
190 Err(_) => {
191 if let Some(fb) = fallback {
192 let fallback_path = self.renderer.render_path(fb, ¤t_context)?;
193 match tokio::fs::read_to_string(&fallback_path).await {
194 Ok(content) => Ok(Some(Value::String(content))),
195 Err(_) => Ok(None),
196 }
197 } else {
198 Ok(None)
199 }
200 }
201 }
202 }
203
204 #[cfg(feature = "http-context")]
205 ContextSource::Http {
206 url,
207 method,
208 headers,
209 timeout_ms,
210 fallback,
211 ..
212 } => {
213 let current_context = self.get_all();
214 let resolved_url = self.renderer.render(url, ¤t_context)?;
215
216 let client = reqwest::Client::new();
217 let mut request = match method.to_uppercase().as_str() {
218 "POST" => client.post(&resolved_url),
219 "PUT" => client.put(&resolved_url),
220 "DELETE" => client.delete(&resolved_url),
221 _ => client.get(&resolved_url),
222 };
223
224 for (k, v) in headers {
225 let resolved_value = self.renderer.render(v, ¤t_context)?;
226 request = request.header(k, resolved_value);
227 }
228
229 if let Some(timeout) = timeout_ms {
230 request = request.timeout(std::time::Duration::from_millis(*timeout));
231 }
232
233 match request.send().await {
234 Ok(response) => {
235 if response.status().is_success() {
236 match response.json::<Value>().await {
237 Ok(json) => Ok(Some(json)),
238 Err(_) => Ok(fallback.clone()),
239 }
240 } else {
241 Ok(fallback.clone())
242 }
243 }
244 Err(_) => Ok(fallback.clone()),
245 }
246 }
247
248 #[cfg(not(feature = "http-context"))]
249 ContextSource::Http { fallback, .. } => {
250 tracing::warn!(
252 "HTTP context source requested but 'http-context' feature is not enabled"
253 );
254 Ok(fallback.clone())
255 }
256
257 ContextSource::Env { name } => Ok(std::env::var(name).ok().map(Value::String)),
258
259 ContextSource::Callback { name, .. } => {
260 let provider = {
262 let providers = self.providers.read();
263 providers.get(name).cloned()
264 };
265 if let Some(provider) = provider {
266 let current_context = json!(self.get_all());
267 Ok(Some(provider.get(key, ¤t_context).await?))
268 } else {
269 Ok(None)
270 }
271 }
272 }
273 }
274
275 pub fn schema(&self) -> &HashMap<String, ContextSource> {
276 &self.schema
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::super::source::BuiltinSource;
283 use super::*;
284
285 #[test]
286 fn test_set_and_get() {
287 let manager = ContextManager::new(HashMap::new(), "Test".into(), "1.0".into());
288 manager.set("user", json!({"name": "Alice"})).unwrap();
289 let user = manager.get("user").unwrap();
290 assert_eq!(user.get("name").unwrap(), "Alice");
291 }
292
293 #[test]
294 fn test_update_nested() {
295 let manager = ContextManager::new(HashMap::new(), "Test".into(), "1.0".into());
296 manager.set("user", json!({"name": "Alice"})).unwrap();
297 manager.update("user.tier", json!("premium")).unwrap();
298 let user = manager.get("user").unwrap();
299 assert_eq!(user.get("tier").unwrap(), "premium");
300 assert_eq!(user.get("name").unwrap(), "Alice");
301 }
302
303 #[test]
304 fn test_get_path() {
305 let manager = ContextManager::new(HashMap::new(), "Test".into(), "1.0".into());
306 manager
307 .set("user", json!({"preferences": {"theme": "dark"}}))
308 .unwrap();
309 let theme = manager.get_path("user.preferences.theme").unwrap();
310 assert_eq!(theme, "dark");
311 }
312
313 #[test]
314 fn test_snapshot_restore() {
315 let manager = ContextManager::new(HashMap::new(), "Test".into(), "1.0".into());
316 manager.set("key1", json!("value1")).unwrap();
317 manager.set("key2", json!(42)).unwrap();
318
319 let snapshot = manager.snapshot();
320 assert_eq!(snapshot.len(), 2);
321
322 let manager2 = ContextManager::new(HashMap::new(), "Test".into(), "1.0".into());
323 manager2.restore(snapshot);
324 assert_eq!(manager2.get("key1").unwrap(), "value1");
325 assert_eq!(manager2.get("key2").unwrap(), 42);
326 }
327
328 #[test]
329 fn test_validate_required() {
330 let mut schema = HashMap::new();
331 schema.insert(
332 "user".into(),
333 ContextSource::Runtime {
334 required: true,
335 schema: None,
336 default: None,
337 },
338 );
339
340 let manager = ContextManager::new(schema, "Test".into(), "1.0".into());
341 assert!(manager.validate().is_err());
342
343 manager.set("user", json!({"name": "Alice"})).unwrap();
344 assert!(manager.validate().is_ok());
345 }
346
347 #[tokio::test]
348 async fn test_builtin_datetime() {
349 let mut schema = HashMap::new();
350 schema.insert(
351 "time".into(),
352 ContextSource::Builtin {
353 source: BuiltinSource::Datetime,
354 refresh: RefreshPolicy::PerTurn,
355 },
356 );
357
358 let manager = ContextManager::new(schema, "Test".into(), "1.0".into());
359 manager.initialize().await.unwrap();
360
361 let time = manager.get("time").unwrap();
362 assert!(time.get("date").is_some());
363 assert!(time.get("time").is_some());
364 }
365
366 #[tokio::test]
367 async fn test_env_source() {
368 unsafe {
370 std::env::set_var("TEST_CONTEXT_VAR", "test_value");
371 }
372
373 let mut schema = HashMap::new();
374 schema.insert(
375 "test_env".into(),
376 ContextSource::Env {
377 name: "TEST_CONTEXT_VAR".into(),
378 },
379 );
380
381 let manager = ContextManager::new(schema, "Test".into(), "1.0".into());
382 manager.initialize().await.unwrap();
383
384 let value = manager.get("test_env").unwrap();
385 assert_eq!(value, "test_value");
386 }
387
388 #[tokio::test]
389 async fn test_runtime_default() {
390 let mut schema = HashMap::new();
391 schema.insert(
392 "settings".into(),
393 ContextSource::Runtime {
394 required: false,
395 schema: None,
396 default: Some(json!({"theme": "light"})),
397 },
398 );
399
400 let manager = ContextManager::new(schema, "Test".into(), "1.0".into());
401 manager.initialize().await.unwrap();
402
403 let settings = manager.get("settings").unwrap();
404 assert_eq!(settings.get("theme").unwrap(), "light");
405 }
406}