Skip to main content

ai_agents_context/
source.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
5#[serde(rename_all = "snake_case")]
6pub enum RefreshPolicy {
7    Once,
8    #[default]
9    PerSession,
10    PerTurn,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14#[serde(rename_all = "lowercase")]
15pub enum BuiltinSource {
16    Datetime,
17    Session,
18    Agent,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(tag = "type", rename_all = "lowercase")]
23pub enum ContextSource {
24    Runtime {
25        #[serde(default)]
26        required: bool,
27        #[serde(default)]
28        schema: Option<serde_json::Value>,
29        #[serde(default)]
30        default: Option<serde_json::Value>,
31    },
32    Builtin {
33        source: BuiltinSource,
34        #[serde(default)]
35        refresh: RefreshPolicy,
36    },
37    File {
38        path: String,
39        #[serde(default)]
40        refresh: RefreshPolicy,
41        #[serde(default)]
42        fallback: Option<String>,
43    },
44    Http {
45        url: String,
46        #[serde(default = "default_method")]
47        method: String,
48        #[serde(default)]
49        headers: HashMap<String, String>,
50        #[serde(default)]
51        refresh: RefreshPolicy,
52        #[serde(default)]
53        cache_ttl: Option<u64>,
54        #[serde(default)]
55        timeout_ms: Option<u64>,
56        #[serde(default)]
57        fallback: Option<serde_json::Value>,
58    },
59    Env {
60        name: String,
61    },
62    Callback {
63        name: String,
64        #[serde(default)]
65        refresh: RefreshPolicy,
66    },
67}
68
69fn default_method() -> String {
70    "GET".to_string()
71}
72
73impl ContextSource {
74    pub fn refresh_policy(&self) -> RefreshPolicy {
75        match self {
76            ContextSource::Runtime { .. } => RefreshPolicy::Once,
77            ContextSource::Builtin { refresh, .. } => refresh.clone(),
78            ContextSource::File { refresh, .. } => refresh.clone(),
79            ContextSource::Http { refresh, .. } => refresh.clone(),
80            ContextSource::Env { .. } => RefreshPolicy::Once,
81            ContextSource::Callback { refresh, .. } => refresh.clone(),
82        }
83    }
84
85    pub fn is_required(&self) -> bool {
86        matches!(self, ContextSource::Runtime { required: true, .. })
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn test_runtime_source() {
96        let yaml = r#"
97type: runtime
98required: true
99default:
100  name: "Guest"
101"#;
102        let source: ContextSource = serde_yaml::from_str(yaml).unwrap();
103        assert!(source.is_required());
104        assert_eq!(source.refresh_policy(), RefreshPolicy::Once);
105    }
106
107    #[test]
108    fn test_builtin_source() {
109        let yaml = r#"
110type: builtin
111source: datetime
112refresh: per_turn
113"#;
114        let source: ContextSource = serde_yaml::from_str(yaml).unwrap();
115        assert_eq!(source.refresh_policy(), RefreshPolicy::PerTurn);
116    }
117
118    #[test]
119    fn test_file_source() {
120        let yaml = r#"
121type: file
122path: "./rules/{{ context.user.language }}/support.txt"
123refresh: per_session
124fallback: "./rules/en/support.txt"
125"#;
126        let source: ContextSource = serde_yaml::from_str(yaml).unwrap();
127        if let ContextSource::File { path, fallback, .. } = source {
128            assert!(path.contains("{{ context.user.language }}"));
129            assert_eq!(fallback, Some("./rules/en/support.txt".into()));
130        } else {
131            panic!("Expected File source");
132        }
133    }
134
135    #[test]
136    fn test_http_source() {
137        let yaml = r#"
138type: http
139url: "https://api.example.com/users/{{ context.user.id }}"
140method: GET
141headers:
142  Authorization: "Bearer {{ env.API_TOKEN }}"
143refresh: per_session
144cache_ttl: 300
145timeout_ms: 5000
146fallback:
147  theme: "default"
148"#;
149        let source: ContextSource = serde_yaml::from_str(yaml).unwrap();
150        if let ContextSource::Http {
151            url,
152            method,
153            headers,
154            cache_ttl,
155            timeout_ms,
156            ..
157        } = source
158        {
159            assert!(url.contains("{{ context.user.id }}"));
160            assert_eq!(method, "GET");
161            assert!(headers.contains_key("Authorization"));
162            assert_eq!(cache_ttl, Some(300));
163            assert_eq!(timeout_ms, Some(5000));
164        } else {
165            panic!("Expected Http source");
166        }
167    }
168
169    #[test]
170    fn test_env_source() {
171        let yaml = r#"
172type: env
173name: API_TOKEN
174"#;
175        let source: ContextSource = serde_yaml::from_str(yaml).unwrap();
176        if let ContextSource::Env { name } = source {
177            assert_eq!(name, "API_TOKEN");
178        } else {
179            panic!("Expected Env source");
180        }
181    }
182
183    #[test]
184    fn test_callback_source() {
185        let yaml = r#"
186type: callback
187name: get_user_analytics
188refresh: per_session
189"#;
190        let source: ContextSource = serde_yaml::from_str(yaml).unwrap();
191        if let ContextSource::Callback { name, refresh } = source {
192            assert_eq!(name, "get_user_analytics");
193            assert_eq!(refresh, RefreshPolicy::PerSession);
194        } else {
195            panic!("Expected Callback source");
196        }
197    }
198}