ai_agent/utils/
side_query.rs1use crate::AgentError;
5
6pub struct SideQueryOptions {
8 pub base_url: String,
9 pub api_key: String,
10 pub model: String,
11 pub system_prompt: String,
12 pub message: String,
13 pub max_tokens: Option<u32>,
14 pub tools: Option<Vec<serde_json::Value>>,
15}
16
17impl SideQueryOptions {
18 pub fn new(base_url: String, api_key: String, model: String) -> Self {
19 Self {
20 base_url,
21 api_key,
22 model,
23 system_prompt: String::new(),
24 message: String::new(),
25 max_tokens: Some(4096),
26 tools: None,
27 }
28 }
29
30 pub fn system_prompt(mut self, prompt: String) -> Self {
31 self.system_prompt = prompt;
32 self
33 }
34
35 pub fn message(mut self, message: String) -> Self {
36 self.message = message;
37 self
38 }
39
40 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
41 self.max_tokens = Some(max_tokens);
42 self
43 }
44
45 pub fn tools(mut self, tools: Vec<serde_json::Value>) -> Self {
46 self.tools = Some(tools);
47 self
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct SideQueryMemorySelection {
54 pub filenames: Vec<String>,
55 pub reasoning: String,
56}
57
58impl SideQueryMemorySelection {
59 pub fn from_response(response: &str) -> Self {
60 if let Ok(val) = serde_json::from_str::<serde_json::Value>(response) {
61 let filenames = val
62 .get("filenames")
63 .and_then(|f| f.as_array())
64 .map(|arr| {
65 arr.iter()
66 .filter_map(|v| v.as_str().map(|s| s.to_string()))
67 .collect()
68 })
69 .unwrap_or_default();
70 let reasoning = val
71 .get("reasoning")
72 .and_then(|r| r.as_str())
73 .map(|s| s.to_string())
74 .unwrap_or_default();
75 return Self { filenames, reasoning };
76 }
77 let filenames = extract_filenames_from_text(response);
78 Self {
79 reasoning: response.to_string(),
80 filenames,
81 }
82 }
83}
84
85fn extract_filenames_from_text(text: &str) -> Vec<String> {
87 let mut filenames = Vec::new();
88 for line in text.lines() {
89 let clean = line.trim()
90 .trim_start_matches('-')
91 .trim_start_matches('*')
92 .trim_start_matches('`')
93 .trim_end_matches('`')
94 .trim()
95 .to_string();
96 if clean.is_empty() || filenames.contains(&clean) {
97 continue;
98 }
99 if clean.ends_with(".md")
100 || clean.ends_with(".txt")
101 || clean.ends_with(".json")
102 || clean.ends_with(".rs")
103 {
104 filenames.push(clean);
105 }
106 }
107 filenames
108}
109
110pub async fn side_query(opts: &SideQueryOptions) -> Result<String, AgentError> {
112 let client = reqwest::Client::new();
113 let mut body = serde_json::json!({
114 "model": opts.model,
115 "max_tokens": opts.max_tokens.unwrap_or(4096),
116 "messages": [{ "role": "user", "content": opts.message }]
117 });
118 if !opts.system_prompt.is_empty() {
119 body.as_object_mut()
120 .unwrap()
121 .insert("system".to_string(), serde_json::json!(opts.system_prompt));
122 }
123 let url = format!("{}/v1/messages", opts.base_url.trim_end_matches('/'));
124 let resp = client
125 .post(&url)
126 .header("x-api-key", &opts.api_key)
127 .header("anthropic-version", "2023-06-01")
128 .header("content-type", "application/json")
129 .json(&body)
130 .send()
131 .await
132 .map_err(|e| AgentError::Api(e.to_string()))?;
133 if !resp.status().is_success() {
134 let status = resp.status();
135 let body_text = resp.text().await.unwrap_or_else(|_| "No error body".to_string());
136 return Err(AgentError::Api(format!(
137 "Side query failed with status {}: {}",
138 status, body_text
139 )));
140 }
141 let json: serde_json::Value =
142 resp.json().await.map_err(|e| AgentError::Api(e.to_string()))?;
143 let content = json
144 .get("content")
145 .and_then(|c| c.as_array())
146 .and_then(|arr| arr.first())
147 .and_then(|c| c.get("text"))
148 .and_then(|t| t.as_str())
149 .unwrap_or("")
150 .to_string();
151 Ok(content)
152}
153
154pub async fn side_query_simple(opts: &SideQueryOptions) -> Result<String, AgentError> {
156 let client = reqwest::Client::new();
157 let body = serde_json::json!({
158 "model": opts.model,
159 "max_tokens": opts.max_tokens.unwrap_or(4096),
160 "messages": [
161 { "role": "system", "content": opts.system_prompt },
162 { "role": "user", "content": opts.message }
163 ]
164 });
165 let url = format!("{}/v1/chat/completions", opts.base_url.trim_end_matches('/'));
166 let resp = client
167 .post(&url)
168 .header("Authorization", format!("Bearer {}", opts.api_key))
169 .header("content-type", "application/json")
170 .json(&body)
171 .send()
172 .await
173 .map_err(|e| AgentError::Api(e.to_string()))?;
174 if !resp.status().is_success() {
175 let status = resp.status();
176 let body_text = resp.text().await.unwrap_or_else(|_| "No error body".to_string());
177 return Err(AgentError::Api(format!(
178 "Side query failed with status {}: {}",
179 status, body_text
180 )));
181 }
182 let json: serde_json::Value =
183 resp.json().await.map_err(|e| AgentError::Api(e.to_string()))?;
184 let content = json
185 .get("choices")
186 .and_then(|c| c.as_array())
187 .and_then(|arr| arr.first())
188 .and_then(|c| c.get("message"))
189 .and_then(|m| m.get("content"))
190 .and_then(|c| c.as_str())
191 .unwrap_or("")
192 .to_string();
193 Ok(content)
194}
195
196pub async fn side_query_with_tools(
198 opts: &SideQueryOptions,
199) -> Result<serde_json::Value, AgentError> {
200 let client = reqwest::Client::new();
201 let mut body = serde_json::json!({
202 "model": opts.model,
203 "max_tokens": opts.max_tokens.unwrap_or(4096),
204 "messages": [{ "role": "user", "content": opts.message }]
205 });
206 if !opts.system_prompt.is_empty() {
207 body.as_object_mut()
208 .unwrap()
209 .insert("system".to_string(), serde_json::json!(opts.system_prompt));
210 }
211 if let Some(ref tools) = opts.tools {
212 body.as_object_mut()
213 .unwrap()
214 .insert("tools".to_string(), serde_json::json!(tools));
215 }
216 let url = format!("{}/v1/messages", opts.base_url.trim_end_matches('/'));
217 let resp = client
218 .post(&url)
219 .header("x-api-key", &opts.api_key)
220 .header("anthropic-version", "2023-06-01")
221 .header("content-type", "application/json")
222 .json(&body)
223 .send()
224 .await
225 .map_err(|e| AgentError::Api(e.to_string()))?;
226 if !resp.status().is_success() {
227 let status = resp.status();
228 let body_text = resp.text().await.unwrap_or_else(|_| "No error body".to_string());
229 return Err(AgentError::Api(format!(
230 "Side query with tools failed with status {}: {}",
231 status, body_text
232 )));
233 }
234 let json: serde_json::Value =
235 resp.json().await.map_err(|e| AgentError::Api(e.to_string()))?;
236 Ok(json)
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn test_side_query_options_builder() {
245 let opts = SideQueryOptions::new(
246 "https://api.anthropic.com".to_string(),
247 "test-key".to_string(),
248 "claude-sonnet-4-6".to_string(),
249 )
250 .system_prompt("You are helpful.".to_string())
251 .message("Hello".to_string())
252 .max_tokens(2048);
253 assert_eq!(opts.base_url, "https://api.anthropic.com");
254 assert_eq!(opts.model, "claude-sonnet-4-6");
255 assert_eq!(opts.system_prompt, "You are helpful.");
256 assert_eq!(opts.message, "Hello");
257 assert_eq!(opts.max_tokens, Some(2048));
258 }
259
260 #[test]
261 fn test_memory_selection_from_json() {
262 let json_response = r#"{"filenames": ["notes.md", "ideas.txt"], "reasoning": "These files are relevant"}"#;
263 let selection = SideQueryMemorySelection::from_response(json_response);
264 assert_eq!(selection.filenames, vec!["notes.md", "ideas.txt"]);
265 assert_eq!(selection.reasoning, "These files are relevant");
266 }
267
268 #[test]
269 fn test_memory_selection_from_text() {
270 let text_response = "Based on the query, these files seem relevant:\n- notes.md\n- ideas.txt\n- project.rs\n";
271 let selection = SideQueryMemorySelection::from_response(text_response);
272 assert!(selection.filenames.contains(&"notes.md".to_string()));
273 assert!(selection.filenames.contains(&"ideas.txt".to_string()));
274 assert!(selection.filenames.contains(&"project.rs".to_string()));
275 }
276
277 #[test]
278 fn test_extract_filenames_from_text() {
279 let text = "Here are some files:\n- memory.md\n* scratch.txt\nconfig.json\nnot a file\nregular text\n";
280 let filenames = extract_filenames_from_text(text);
281 assert_eq!(filenames.len(), 3);
282 assert!(filenames.contains(&"memory.md".to_string()));
283 assert!(filenames.contains(&"scratch.txt".to_string()));
284 assert!(filenames.contains(&"config.json".to_string()));
285 }
286}