ai_agents_context/
source.rs1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
5#[serde(rename_all = "snake_case")]
6pub enum RefreshPolicy {
7 Once,
8 #[default]
9 PerSession,
10 PerTurn,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14#[serde(rename_all = "lowercase")]
15pub enum BuiltinSource {
16 Datetime,
17 Session,
18 Agent,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(tag = "type", rename_all = "lowercase")]
23pub enum ContextSource {
24 Runtime {
25 #[serde(default)]
26 required: bool,
27 #[serde(default)]
28 schema: Option<serde_json::Value>,
29 #[serde(default)]
30 default: Option<serde_json::Value>,
31 },
32 Builtin {
33 source: BuiltinSource,
34 #[serde(default)]
35 refresh: RefreshPolicy,
36 },
37 File {
38 path: String,
39 #[serde(default)]
40 refresh: RefreshPolicy,
41 #[serde(default)]
42 fallback: Option<String>,
43 },
44 Http {
45 url: String,
46 #[serde(default = "default_method")]
47 method: String,
48 #[serde(default)]
49 headers: HashMap<String, String>,
50 #[serde(default)]
51 refresh: RefreshPolicy,
52 #[serde(default)]
53 cache_ttl: Option<u64>,
54 #[serde(default)]
55 timeout_ms: Option<u64>,
56 #[serde(default)]
57 fallback: Option<serde_json::Value>,
58 },
59 Env {
60 name: String,
61 },
62 Callback {
63 name: String,
64 #[serde(default)]
65 refresh: RefreshPolicy,
66 },
67}
68
69fn default_method() -> String {
70 "GET".to_string()
71}
72
73impl ContextSource {
74 pub fn refresh_policy(&self) -> RefreshPolicy {
75 match self {
76 ContextSource::Runtime { .. } => RefreshPolicy::Once,
77 ContextSource::Builtin { refresh, .. } => refresh.clone(),
78 ContextSource::File { refresh, .. } => refresh.clone(),
79 ContextSource::Http { refresh, .. } => refresh.clone(),
80 ContextSource::Env { .. } => RefreshPolicy::Once,
81 ContextSource::Callback { refresh, .. } => refresh.clone(),
82 }
83 }
84
85 pub fn is_required(&self) -> bool {
86 matches!(self, ContextSource::Runtime { required: true, .. })
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn test_runtime_source() {
96 let yaml = r#"
97type: runtime
98required: true
99default:
100 name: "Guest"
101"#;
102 let source: ContextSource = serde_yaml::from_str(yaml).unwrap();
103 assert!(source.is_required());
104 assert_eq!(source.refresh_policy(), RefreshPolicy::Once);
105 }
106
107 #[test]
108 fn test_builtin_source() {
109 let yaml = r#"
110type: builtin
111source: datetime
112refresh: per_turn
113"#;
114 let source: ContextSource = serde_yaml::from_str(yaml).unwrap();
115 assert_eq!(source.refresh_policy(), RefreshPolicy::PerTurn);
116 }
117
118 #[test]
119 fn test_file_source() {
120 let yaml = r#"
121type: file
122path: "./rules/{{ context.user.language }}/support.txt"
123refresh: per_session
124fallback: "./rules/en/support.txt"
125"#;
126 let source: ContextSource = serde_yaml::from_str(yaml).unwrap();
127 if let ContextSource::File { path, fallback, .. } = source {
128 assert!(path.contains("{{ context.user.language }}"));
129 assert_eq!(fallback, Some("./rules/en/support.txt".into()));
130 } else {
131 panic!("Expected File source");
132 }
133 }
134
135 #[test]
136 fn test_http_source() {
137 let yaml = r#"
138type: http
139url: "https://api.example.com/users/{{ context.user.id }}"
140method: GET
141headers:
142 Authorization: "Bearer {{ env.API_TOKEN }}"
143refresh: per_session
144cache_ttl: 300
145timeout_ms: 5000
146fallback:
147 theme: "default"
148"#;
149 let source: ContextSource = serde_yaml::from_str(yaml).unwrap();
150 if let ContextSource::Http {
151 url,
152 method,
153 headers,
154 cache_ttl,
155 timeout_ms,
156 ..
157 } = source
158 {
159 assert!(url.contains("{{ context.user.id }}"));
160 assert_eq!(method, "GET");
161 assert!(headers.contains_key("Authorization"));
162 assert_eq!(cache_ttl, Some(300));
163 assert_eq!(timeout_ms, Some(5000));
164 } else {
165 panic!("Expected Http source");
166 }
167 }
168
169 #[test]
170 fn test_env_source() {
171 let yaml = r#"
172type: env
173name: API_TOKEN
174"#;
175 let source: ContextSource = serde_yaml::from_str(yaml).unwrap();
176 if let ContextSource::Env { name } = source {
177 assert_eq!(name, "API_TOKEN");
178 } else {
179 panic!("Expected Env source");
180 }
181 }
182
183 #[test]
184 fn test_callback_source() {
185 let yaml = r#"
186type: callback
187name: get_user_analytics
188refresh: per_session
189"#;
190 let source: ContextSource = serde_yaml::from_str(yaml).unwrap();
191 if let ContextSource::Callback { name, refresh } = source {
192 assert_eq!(name, "get_user_analytics");
193 assert_eq!(refresh, RefreshPolicy::PerSession);
194 } else {
195 panic!("Expected Callback source");
196 }
197 }
198}