1use crate::error::{AgentRootError, Result};
6use chrono::{Duration, Utc};
7use llama_cpp_2::{
8 context::params::LlamaContextParams,
9 llama_backend::LlamaBackend,
10 llama_batch::LlamaBatch,
11 model::{params::LlamaModelParams, LlamaModel},
12};
13use serde::{Deserialize, Serialize};
14use std::path::PathBuf;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ParsedQuery {
19 pub search_terms: String,
21
22 pub temporal_filter: Option<TemporalFilter>,
24
25 pub metadata_filters: Vec<MetadataFilterHint>,
27
28 pub search_type: SearchType,
30
31 pub confidence: f64,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct TemporalFilter {
38 pub start: Option<String>,
40
41 pub end: Option<String>,
43
44 pub description: String,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct MetadataFilterHint {
51 pub field: String,
53
54 pub value: String,
56
57 pub operator: String,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63#[serde(rename_all = "lowercase")]
64pub enum SearchType {
65 Bm25,
67
68 Vector,
70
71 Hybrid,
73}
74
75pub struct QueryParser {
77 #[allow(dead_code)]
78 model_path: PathBuf,
79}
80
81impl QueryParser {
82 pub fn new(model_path: PathBuf) -> Result<Self> {
84 if !model_path.exists() {
85 return Err(AgentRootError::ModelNotFound(
86 model_path.to_string_lossy().to_string(),
87 ));
88 }
89 Ok(Self { model_path })
90 }
91
92 pub fn from_default() -> Result<Self> {
94 let model_dir = dirs::data_local_dir()
95 .ok_or_else(|| AgentRootError::Config("Cannot determine data directory".to_string()))?
96 .join("agentroot")
97 .join("models");
98
99 let model_path = model_dir.join("llama-3.1-8b-instruct.Q4_K_M.gguf");
100
101 if !model_path.exists() {
102 return Err(AgentRootError::ModelNotFound(format!(
103 "Model not found at {}. Run 'agentroot embed' first to download models.",
104 model_path.display()
105 )));
106 }
107
108 Ok(Self { model_path })
109 }
110
111 pub async fn parse(&self, query: &str) -> Result<ParsedQuery> {
113 self.llm_parse(query).await
114 }
115
116 async fn llm_parse(&self, query: &str) -> Result<ParsedQuery> {
118 tracing::debug!("Using LLM to parse query: {}", query);
119
120 let mut backend = LlamaBackend::init()
121 .map_err(|e| AgentRootError::Llm(format!("Failed to init LLM backend: {}", e)))?;
122 backend.void_logs();
123
124 let model_params = LlamaModelParams::default();
125 let model = LlamaModel::load_from_file(&backend, &self.model_path, &model_params)
126 .map_err(|e| AgentRootError::Llm(format!("Failed to load LLM model: {}", e)))?;
127
128 let ctx_size = std::num::NonZeroU32::new(4096).unwrap();
129 let ctx_params = LlamaContextParams::default()
130 .with_n_ctx(Some(ctx_size))
131 .with_n_batch(512);
132
133 let mut ctx = model
134 .new_context(&backend, ctx_params)
135 .map_err(|e| AgentRootError::Llm(format!("Failed to create LLM context: {}", e)))?;
136
137 let prompt = self.build_parsing_prompt(query);
138
139 let tokens = model
140 .str_to_token(&prompt, llama_cpp_2::model::AddBos::Never)
141 .map_err(|e| AgentRootError::Llm(format!("Tokenization error: {}", e)))?;
142
143 let max_output_tokens = 256;
144 let mut output_tokens = Vec::new();
145 let mut current_pos = 0;
146
147 let chunks: Vec<_> = tokens.chunks(512).collect();
149 for (chunk_idx, chunk) in chunks.iter().enumerate() {
150 let is_last_chunk = chunk_idx == chunks.len() - 1;
151 let mut batch = LlamaBatch::new(chunk.len(), 1);
152 for (i, token) in chunk.iter().enumerate() {
153 let is_last_token_overall = is_last_chunk && i == chunk.len() - 1;
154 batch
155 .add(*token, current_pos + i as i32, &[0], is_last_token_overall)
156 .map_err(|e| AgentRootError::Llm(format!("Batch error: {}", e)))?;
157 }
158 current_pos += chunk.len() as i32;
159
160 ctx.decode(&mut batch)
161 .map_err(|e| AgentRootError::Llm(format!("Decode error: {}", e)))?;
162 }
163
164 for (chunk_idx, chunk) in chunks.iter().enumerate() {
165 let is_last_chunk = chunk_idx == chunks.len() - 1;
166 let mut batch = LlamaBatch::new(chunk.len(), 1);
167 tracing::debug!(
168 "Processing chunk {}/{}, size: {}, is_last: {}",
169 chunk_idx + 1,
170 chunks.len(),
171 chunk.len(),
172 is_last_chunk
173 );
174
175 for (i, token) in chunk.iter().enumerate() {
176 let is_last_token_overall = is_last_chunk && i == chunk.len() - 1;
177 if is_last_token_overall {
178 tracing::debug!(
179 "Marking token at position {} (offset {} in batch) for logits",
180 current_pos + i as i32,
181 i
182 );
183 }
184 batch
185 .add(*token, current_pos + i as i32, &[0], is_last_token_overall)
186 .map_err(|e| AgentRootError::Llm(format!("Batch error: {}", e)))?;
187 }
188 current_pos += chunk.len() as i32;
189
190 ctx.decode(&mut batch)
191 .map_err(|e| AgentRootError::Llm(format!("Decode error: {}", e)))?;
192 }
193
194 tracing::debug!(
195 "Prompt processed, {} tokens total, current_pos = {}, will sample from position {}",
196 tokens.len(),
197 current_pos,
198 current_pos - 1
199 );
200
201 let mut generated_text = String::new();
202 let mut brace_count = 0;
203 let mut json_started = false;
204
205 for i in 0..max_output_tokens {
206 let token_data_array = ctx.token_data_array();
207
208 let next_token = token_data_array
209 .data
210 .iter()
211 .max_by(|a, b| a.logit().partial_cmp(&b.logit()).unwrap())
212 .map(|td| td.id())
213 .ok_or_else(|| AgentRootError::Llm("No token found".to_string()))?;
214
215 if next_token == model.token_eos() {
216 tracing::debug!("Hit EOS token after {} tokens", i);
217 break;
218 }
219
220 let token_str = model
221 .token_to_str(next_token, llama_cpp_2::model::Special::Tokenize)
222 .map_err(|e| AgentRootError::Llm(format!("Token decode error: {}", e)))?;
223
224 generated_text.push_str(&token_str);
225 output_tokens.push(next_token);
226
227 if token_str.contains("{") {
228 json_started = true;
229 brace_count += token_str.matches("{").count() as i32;
230 }
231 if token_str.contains("}") {
232 brace_count -= token_str.matches("}").count() as i32;
233 if json_started && brace_count == 0 {
234 tracing::debug!("JSON complete after {} tokens", i + 1);
235 break;
236 }
237 }
238
239 if i % 50 == 0 && i > 0 {
240 tracing::debug!(
241 "Generated {} tokens so far, text length: {}",
242 i,
243 generated_text.len()
244 );
245 }
246
247 let mut batch = LlamaBatch::new(1, 1);
248 batch
249 .add(next_token, current_pos, &[0], true)
250 .map_err(|e| AgentRootError::Llm(format!("Batch error: {}", e)))?;
251
252 ctx.decode(&mut batch)
253 .map_err(|e| AgentRootError::Llm(format!("Decode error: {}", e)))?;
254
255 current_pos += 1;
256 }
257
258 tracing::debug!("LLM raw output: {}", generated_text);
259
260 self.parse_llm_response(&generated_text, query)
261 }
262
263 fn build_parsing_prompt(&self, query: &str) -> String {
264 format!(
265 r#"<|begin_of_text|><|start_header_id|>system<|end_header_id|>
266
267You are a search query parser. Extract structured information from user queries.
268Output ONLY valid JSON with these fields:
269- search_terms: main keywords (string)
270- temporal_filter: {{"description": "...", "relative_hours": N}} or null
271- metadata_filters: [{{"field": "...", "value": "...", "operator": "contains"}}] or []
272- confidence: 0.0-1.0
273
274Examples:
275Query: "files that were edit recently"
276{{"search_terms": "files", "temporal_filter": {{"description": "recently", "relative_hours": 24}}, "metadata_filters": [], "confidence": 0.9}}
277
278Query: "rust code by Alice from last week"
279{{"search_terms": "rust code", "temporal_filter": {{"description": "last week", "relative_hours": 168}}, "metadata_filters": [{{"field": "author", "value": "Alice", "operator": "contains"}}], "confidence": 0.95}}
280
281Query: "python functions"
282{{"search_terms": "python functions", "temporal_filter": null, "metadata_filters": [], "confidence": 0.85}}
283
284<|eot_id|><|start_header_id|>user<|end_header_id|>
285
286Parse this query: "{}"<|eot_id|><|start_header_id|>assistant<|end_header_id|>
287
288"#,
289 query
290 )
291 }
292
293 fn parse_llm_response(&self, response: &str, original_query: &str) -> Result<ParsedQuery> {
294 let json_start = response.find('{');
295 let json_end = response.rfind('}');
296
297 let json_str = match (json_start, json_end) {
298 (Some(start), Some(end)) if end > start => &response[start..=end],
299 _ => {
300 tracing::warn!("Failed to extract JSON from LLM response, using fallback");
301 return Ok(ParsedQuery {
302 search_terms: original_query.to_string(),
303 temporal_filter: None,
304 metadata_filters: vec![],
305 search_type: SearchType::Hybrid,
306 confidence: 0.5,
307 });
308 }
309 };
310
311 let parsed_json: serde_json::Value = serde_json::from_str(json_str).map_err(|e| {
312 tracing::warn!("Failed to parse LLM JSON output: {}", e);
313 AgentRootError::Llm(format!("JSON parse error: {}", e))
314 })?;
315
316 let search_terms = parsed_json["search_terms"]
317 .as_str()
318 .unwrap_or(original_query)
319 .to_string();
320
321 let temporal_filter = if let Some(tf) = parsed_json.get("temporal_filter") {
322 if !tf.is_null() {
323 let hours = tf["relative_hours"].as_i64().unwrap_or(24);
324 let description = tf["description"].as_str().unwrap_or("").to_string();
325 let now = Utc::now();
326 let start = now - Duration::hours(hours);
327 Some(TemporalFilter {
328 start: Some(start.to_rfc3339()),
329 end: Some(now.to_rfc3339()),
330 description,
331 })
332 } else {
333 None
334 }
335 } else {
336 None
337 };
338
339 let metadata_filters = if let Some(filters) = parsed_json["metadata_filters"].as_array() {
340 filters
341 .iter()
342 .filter_map(|f| {
343 Some(MetadataFilterHint {
344 field: f["field"].as_str()?.to_string(),
345 value: f["value"].as_str()?.to_string(),
346 operator: f["operator"].as_str().unwrap_or("contains").to_string(),
347 })
348 })
349 .collect()
350 } else {
351 vec![]
352 };
353
354 let confidence = parsed_json["confidence"].as_f64().unwrap_or(0.8);
355
356 Ok(ParsedQuery {
357 search_terms,
358 temporal_filter,
359 metadata_filters,
360 search_type: SearchType::Hybrid,
361 confidence,
362 })
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[tokio::test]
371 async fn test_parse_requires_model() {
372 let result = QueryParser::from_default();
373 if result.is_err() {
374 println!("Skipping test: LLM model not available");
375 return;
376 }
377
378 let parser = result.unwrap();
379 let parsed = parser.parse("test query").await;
380
381 assert!(parsed.is_ok() || parsed.is_err());
382 }
383
384 #[tokio::test]
385 async fn test_llm_parse_temporal_query() {
386 let result = QueryParser::from_default();
387 if result.is_err() {
388 println!("Skipping test: LLM model not available");
389 return;
390 }
391
392 let parser = result.unwrap();
393 let parsed = parser.parse("files that were edit recently").await;
394
395 if let Ok(parsed) = parsed {
396 println!("Parsed query: {:?}", parsed);
397 assert!(!parsed.search_terms.is_empty());
398 }
399 }
400
401 #[tokio::test]
402 async fn test_llm_parse_metadata_query() {
403 let result = QueryParser::from_default();
404 if result.is_err() {
405 println!("Skipping test: LLM model not available");
406 return;
407 }
408
409 let parser = result.unwrap();
410 let parsed = parser.parse("rust code by Alice").await;
411
412 if let Ok(parsed) = parsed {
413 println!("Parsed query: {:?}", parsed);
414 assert!(!parsed.search_terms.is_empty());
415 }
416 }
417
418 #[test]
419 fn test_parse_llm_response_valid_json() {
420 let parser = QueryParser {
421 model_path: PathBuf::from("dummy"),
422 };
423
424 let response = r#"{"search_terms": "files", "temporal_filter": {"description": "recently", "relative_hours": 24}, "metadata_filters": [], "confidence": 0.9}"#;
425 let result = parser.parse_llm_response(response, "files that were edit recently");
426
427 assert!(result.is_ok());
428 let parsed = result.unwrap();
429 assert_eq!(parsed.search_terms, "files");
430 assert!(parsed.temporal_filter.is_some());
431 }
432
433 #[test]
434 fn test_parse_llm_response_invalid_json_fallback() {
435 let parser = QueryParser {
436 model_path: PathBuf::from("dummy"),
437 };
438
439 let response = "not valid json";
440 let result = parser.parse_llm_response(response, "original query");
441
442 assert!(result.is_ok());
443 let parsed = result.unwrap();
444 assert_eq!(parsed.search_terms, "original query");
445 assert_eq!(parsed.confidence, 0.5);
446 }
447
448 #[test]
449 fn test_build_parsing_prompt() {
450 let parser = QueryParser {
451 model_path: PathBuf::from("dummy"),
452 };
453
454 let prompt = parser.build_parsing_prompt("test query");
455 assert!(prompt.contains("test query"));
456 assert!(prompt.contains("search_terms"));
457 assert!(prompt.contains("temporal_filter"));
458 }
459}