agentroot_core/llm/
query_parser.rs

1//! Natural language query parser using LLM
2//!
3//! Parses user queries like "files edited last hour" into structured search parameters
4
5use crate::error::{AgentRootError, Result};
6use chrono::{Duration, Utc};
7use llama_cpp_2::{
8    context::params::LlamaContextParams,
9    llama_backend::LlamaBackend,
10    llama_batch::LlamaBatch,
11    model::{params::LlamaModelParams, LlamaModel},
12};
13use serde::{Deserialize, Serialize};
14use std::path::PathBuf;
15
16/// Parsed query with extracted intent and filters
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ParsedQuery {
19    /// The cleaned search terms to use
20    pub search_terms: String,
21
22    /// Temporal constraints
23    pub temporal_filter: Option<TemporalFilter>,
24
25    /// Metadata filters extracted from query
26    pub metadata_filters: Vec<MetadataFilterHint>,
27
28    /// Suggested search type
29    pub search_type: SearchType,
30
31    /// Confidence in the parse (0.0 - 1.0)
32    pub confidence: f64,
33}
34
35/// Temporal filter for time-based queries
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct TemporalFilter {
38    /// Start datetime (ISO 8601)
39    pub start: Option<String>,
40
41    /// End datetime (ISO 8601)
42    pub end: Option<String>,
43
44    /// Human-readable description
45    pub description: String,
46}
47
48/// Metadata filter hint extracted from query
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct MetadataFilterHint {
51    /// Field name
52    pub field: String,
53
54    /// Expected value
55    pub value: String,
56
57    /// Operator (eq, contains, gt, lt)
58    pub operator: String,
59}
60
61/// Search type recommendation
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63#[serde(rename_all = "lowercase")]
64pub enum SearchType {
65    /// BM25 full-text search
66    Bm25,
67
68    /// Vector semantic search
69    Vector,
70
71    /// Hybrid (both + reranking)
72    Hybrid,
73}
74
75/// Query parser using local LLM
76pub struct QueryParser {
77    #[allow(dead_code)]
78    model_path: PathBuf,
79}
80
81impl QueryParser {
82    /// Create a new query parser with custom model
83    pub fn new(model_path: PathBuf) -> Result<Self> {
84        if !model_path.exists() {
85            return Err(AgentRootError::ModelNotFound(
86                model_path.to_string_lossy().to_string(),
87            ));
88        }
89        Ok(Self { model_path })
90    }
91
92    /// Create parser with default model
93    pub fn from_default() -> Result<Self> {
94        let model_dir = dirs::data_local_dir()
95            .ok_or_else(|| AgentRootError::Config("Cannot determine data directory".to_string()))?
96            .join("agentroot")
97            .join("models");
98
99        let model_path = model_dir.join("llama-3.1-8b-instruct.Q4_K_M.gguf");
100
101        if !model_path.exists() {
102            return Err(AgentRootError::ModelNotFound(format!(
103                "Model not found at {}. Run 'agentroot embed' first to download models.",
104                model_path.display()
105            )));
106        }
107
108        Ok(Self { model_path })
109    }
110
111    /// Parse natural language query into structured search
112    pub async fn parse(&self, query: &str) -> Result<ParsedQuery> {
113        self.llm_parse(query).await
114    }
115
116    /// Parse query using LLM
117    async fn llm_parse(&self, query: &str) -> Result<ParsedQuery> {
118        tracing::debug!("Using LLM to parse query: {}", query);
119
120        let mut backend = LlamaBackend::init()
121            .map_err(|e| AgentRootError::Llm(format!("Failed to init LLM backend: {}", e)))?;
122        backend.void_logs();
123
124        let model_params = LlamaModelParams::default();
125        let model = LlamaModel::load_from_file(&backend, &self.model_path, &model_params)
126            .map_err(|e| AgentRootError::Llm(format!("Failed to load LLM model: {}", e)))?;
127
128        let ctx_size = std::num::NonZeroU32::new(4096).unwrap();
129        let ctx_params = LlamaContextParams::default()
130            .with_n_ctx(Some(ctx_size))
131            .with_n_batch(512);
132
133        let mut ctx = model
134            .new_context(&backend, ctx_params)
135            .map_err(|e| AgentRootError::Llm(format!("Failed to create LLM context: {}", e)))?;
136
137        let prompt = self.build_parsing_prompt(query);
138
139        let tokens = model
140            .str_to_token(&prompt, llama_cpp_2::model::AddBos::Never)
141            .map_err(|e| AgentRootError::Llm(format!("Tokenization error: {}", e)))?;
142
143        let max_output_tokens = 256;
144        let mut output_tokens = Vec::new();
145        let mut current_pos = 0;
146
147        // Process prompt tokens - enable logits for the last token
148        let chunks: Vec<_> = tokens.chunks(512).collect();
149        for (chunk_idx, chunk) in chunks.iter().enumerate() {
150            let is_last_chunk = chunk_idx == chunks.len() - 1;
151            let mut batch = LlamaBatch::new(chunk.len(), 1);
152            for (i, token) in chunk.iter().enumerate() {
153                let is_last_token_overall = is_last_chunk && i == chunk.len() - 1;
154                batch
155                    .add(*token, current_pos + i as i32, &[0], is_last_token_overall)
156                    .map_err(|e| AgentRootError::Llm(format!("Batch error: {}", e)))?;
157            }
158            current_pos += chunk.len() as i32;
159
160            ctx.decode(&mut batch)
161                .map_err(|e| AgentRootError::Llm(format!("Decode error: {}", e)))?;
162        }
163
164        for (chunk_idx, chunk) in chunks.iter().enumerate() {
165            let is_last_chunk = chunk_idx == chunks.len() - 1;
166            let mut batch = LlamaBatch::new(chunk.len(), 1);
167            tracing::debug!(
168                "Processing chunk {}/{}, size: {}, is_last: {}",
169                chunk_idx + 1,
170                chunks.len(),
171                chunk.len(),
172                is_last_chunk
173            );
174
175            for (i, token) in chunk.iter().enumerate() {
176                let is_last_token_overall = is_last_chunk && i == chunk.len() - 1;
177                if is_last_token_overall {
178                    tracing::debug!(
179                        "Marking token at position {} (offset {} in batch) for logits",
180                        current_pos + i as i32,
181                        i
182                    );
183                }
184                batch
185                    .add(*token, current_pos + i as i32, &[0], is_last_token_overall)
186                    .map_err(|e| AgentRootError::Llm(format!("Batch error: {}", e)))?;
187            }
188            current_pos += chunk.len() as i32;
189
190            ctx.decode(&mut batch)
191                .map_err(|e| AgentRootError::Llm(format!("Decode error: {}", e)))?;
192        }
193
194        tracing::debug!(
195            "Prompt processed, {} tokens total, current_pos = {}, will sample from position {}",
196            tokens.len(),
197            current_pos,
198            current_pos - 1
199        );
200
201        let mut generated_text = String::new();
202        let mut brace_count = 0;
203        let mut json_started = false;
204
205        for i in 0..max_output_tokens {
206            let token_data_array = ctx.token_data_array();
207
208            let next_token = token_data_array
209                .data
210                .iter()
211                .max_by(|a, b| a.logit().partial_cmp(&b.logit()).unwrap())
212                .map(|td| td.id())
213                .ok_or_else(|| AgentRootError::Llm("No token found".to_string()))?;
214
215            if next_token == model.token_eos() {
216                tracing::debug!("Hit EOS token after {} tokens", i);
217                break;
218            }
219
220            let token_str = model
221                .token_to_str(next_token, llama_cpp_2::model::Special::Tokenize)
222                .map_err(|e| AgentRootError::Llm(format!("Token decode error: {}", e)))?;
223
224            generated_text.push_str(&token_str);
225            output_tokens.push(next_token);
226
227            if token_str.contains("{") {
228                json_started = true;
229                brace_count += token_str.matches("{").count() as i32;
230            }
231            if token_str.contains("}") {
232                brace_count -= token_str.matches("}").count() as i32;
233                if json_started && brace_count == 0 {
234                    tracing::debug!("JSON complete after {} tokens", i + 1);
235                    break;
236                }
237            }
238
239            if i % 50 == 0 && i > 0 {
240                tracing::debug!(
241                    "Generated {} tokens so far, text length: {}",
242                    i,
243                    generated_text.len()
244                );
245            }
246
247            let mut batch = LlamaBatch::new(1, 1);
248            batch
249                .add(next_token, current_pos, &[0], true)
250                .map_err(|e| AgentRootError::Llm(format!("Batch error: {}", e)))?;
251
252            ctx.decode(&mut batch)
253                .map_err(|e| AgentRootError::Llm(format!("Decode error: {}", e)))?;
254
255            current_pos += 1;
256        }
257
258        tracing::debug!("LLM raw output: {}", generated_text);
259
260        self.parse_llm_response(&generated_text, query)
261    }
262
263    fn build_parsing_prompt(&self, query: &str) -> String {
264        format!(
265            r#"<|begin_of_text|><|start_header_id|>system<|end_header_id|>
266
267You are a search query parser. Extract structured information from user queries.
268Output ONLY valid JSON with these fields:
269- search_terms: main keywords (string)
270- temporal_filter: {{"description": "...", "relative_hours": N}} or null
271- metadata_filters: [{{"field": "...", "value": "...", "operator": "contains"}}] or []
272- confidence: 0.0-1.0
273
274Examples:
275Query: "files that were edit recently"
276{{"search_terms": "files", "temporal_filter": {{"description": "recently", "relative_hours": 24}}, "metadata_filters": [], "confidence": 0.9}}
277
278Query: "rust code by Alice from last week"
279{{"search_terms": "rust code", "temporal_filter": {{"description": "last week", "relative_hours": 168}}, "metadata_filters": [{{"field": "author", "value": "Alice", "operator": "contains"}}], "confidence": 0.95}}
280
281Query: "python functions"
282{{"search_terms": "python functions", "temporal_filter": null, "metadata_filters": [], "confidence": 0.85}}
283
284<|eot_id|><|start_header_id|>user<|end_header_id|>
285
286Parse this query: "{}"<|eot_id|><|start_header_id|>assistant<|end_header_id|>
287
288"#,
289            query
290        )
291    }
292
293    fn parse_llm_response(&self, response: &str, original_query: &str) -> Result<ParsedQuery> {
294        let json_start = response.find('{');
295        let json_end = response.rfind('}');
296
297        let json_str = match (json_start, json_end) {
298            (Some(start), Some(end)) if end > start => &response[start..=end],
299            _ => {
300                tracing::warn!("Failed to extract JSON from LLM response, using fallback");
301                return Ok(ParsedQuery {
302                    search_terms: original_query.to_string(),
303                    temporal_filter: None,
304                    metadata_filters: vec![],
305                    search_type: SearchType::Hybrid,
306                    confidence: 0.5,
307                });
308            }
309        };
310
311        let parsed_json: serde_json::Value = serde_json::from_str(json_str).map_err(|e| {
312            tracing::warn!("Failed to parse LLM JSON output: {}", e);
313            AgentRootError::Llm(format!("JSON parse error: {}", e))
314        })?;
315
316        let search_terms = parsed_json["search_terms"]
317            .as_str()
318            .unwrap_or(original_query)
319            .to_string();
320
321        let temporal_filter = if let Some(tf) = parsed_json.get("temporal_filter") {
322            if !tf.is_null() {
323                let hours = tf["relative_hours"].as_i64().unwrap_or(24);
324                let description = tf["description"].as_str().unwrap_or("").to_string();
325                let now = Utc::now();
326                let start = now - Duration::hours(hours);
327                Some(TemporalFilter {
328                    start: Some(start.to_rfc3339()),
329                    end: Some(now.to_rfc3339()),
330                    description,
331                })
332            } else {
333                None
334            }
335        } else {
336            None
337        };
338
339        let metadata_filters = if let Some(filters) = parsed_json["metadata_filters"].as_array() {
340            filters
341                .iter()
342                .filter_map(|f| {
343                    Some(MetadataFilterHint {
344                        field: f["field"].as_str()?.to_string(),
345                        value: f["value"].as_str()?.to_string(),
346                        operator: f["operator"].as_str().unwrap_or("contains").to_string(),
347                    })
348                })
349                .collect()
350        } else {
351            vec![]
352        };
353
354        let confidence = parsed_json["confidence"].as_f64().unwrap_or(0.8);
355
356        Ok(ParsedQuery {
357            search_terms,
358            temporal_filter,
359            metadata_filters,
360            search_type: SearchType::Hybrid,
361            confidence,
362        })
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[tokio::test]
371    async fn test_parse_requires_model() {
372        let result = QueryParser::from_default();
373        if result.is_err() {
374            println!("Skipping test: LLM model not available");
375            return;
376        }
377
378        let parser = result.unwrap();
379        let parsed = parser.parse("test query").await;
380
381        assert!(parsed.is_ok() || parsed.is_err());
382    }
383
384    #[tokio::test]
385    async fn test_llm_parse_temporal_query() {
386        let result = QueryParser::from_default();
387        if result.is_err() {
388            println!("Skipping test: LLM model not available");
389            return;
390        }
391
392        let parser = result.unwrap();
393        let parsed = parser.parse("files that were edit recently").await;
394
395        if let Ok(parsed) = parsed {
396            println!("Parsed query: {:?}", parsed);
397            assert!(!parsed.search_terms.is_empty());
398        }
399    }
400
401    #[tokio::test]
402    async fn test_llm_parse_metadata_query() {
403        let result = QueryParser::from_default();
404        if result.is_err() {
405            println!("Skipping test: LLM model not available");
406            return;
407        }
408
409        let parser = result.unwrap();
410        let parsed = parser.parse("rust code by Alice").await;
411
412        if let Ok(parsed) = parsed {
413            println!("Parsed query: {:?}", parsed);
414            assert!(!parsed.search_terms.is_empty());
415        }
416    }
417
418    #[test]
419    fn test_parse_llm_response_valid_json() {
420        let parser = QueryParser {
421            model_path: PathBuf::from("dummy"),
422        };
423
424        let response = r#"{"search_terms": "files", "temporal_filter": {"description": "recently", "relative_hours": 24}, "metadata_filters": [], "confidence": 0.9}"#;
425        let result = parser.parse_llm_response(response, "files that were edit recently");
426
427        assert!(result.is_ok());
428        let parsed = result.unwrap();
429        assert_eq!(parsed.search_terms, "files");
430        assert!(parsed.temporal_filter.is_some());
431    }
432
433    #[test]
434    fn test_parse_llm_response_invalid_json_fallback() {
435        let parser = QueryParser {
436            model_path: PathBuf::from("dummy"),
437        };
438
439        let response = "not valid json";
440        let result = parser.parse_llm_response(response, "original query");
441
442        assert!(result.is_ok());
443        let parsed = result.unwrap();
444        assert_eq!(parsed.search_terms, "original query");
445        assert_eq!(parsed.confidence, 0.5);
446    }
447
448    #[test]
449    fn test_build_parsing_prompt() {
450        let parser = QueryParser {
451            model_path: PathBuf::from("dummy"),
452        };
453
454        let prompt = parser.build_parsing_prompt("test query");
455        assert!(prompt.contains("test query"));
456        assert!(prompt.contains("search_terms"));
457        assert!(prompt.contains("temporal_filter"));
458    }
459}