1use std::collections::HashMap;
8use std::hash::{DefaultHasher, Hash, Hasher};
9
10use hirn_core::HirnError;
11use parking_lot::RwLock;
12
13use super::analyzer::{self, AnalysisError, AnalysisErrorKind};
14use super::planner::{self, QueryPlan};
15use crate::db::DbStats;
16use hirn_query::ast::Statement;
17use hirn_query::parser::{self, ParseError};
18
19#[derive(Debug, Clone)]
21pub struct CompiledQuery {
22 pub source: String,
24 pub ast: Statement,
26 pub plan: QueryPlan,
28}
29
30#[derive(Debug, Clone)]
32pub enum CompileError {
33 Parse(ParseError),
35 Analysis(Vec<AnalysisError>),
37}
38
39impl std::fmt::Display for CompileError {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 Self::Parse(e) => write!(f, "{e}"),
43 Self::Analysis(errors) => {
44 for (i, e) in errors.iter().enumerate() {
45 if i > 0 {
46 write!(f, "; ")?;
47 }
48 write!(f, "{e}")?;
49 }
50 Ok(())
51 }
52 }
53 }
54}
55
56impl std::error::Error for CompileError {}
57
58impl From<CompileError> for HirnError {
59 fn from(e: CompileError) -> Self {
60 match e {
61 CompileError::Parse(pe) => HirnError::InvalidInput(format!("parse error: {pe}")),
62 CompileError::Analysis(errors) => {
63 let msg = errors
64 .iter()
65 .map(|e| e.message.clone())
66 .collect::<Vec<_>>()
67 .join("; ");
68 HirnError::InvalidInput(msg)
69 }
70 }
71 }
72}
73
74pub fn compile(query: &str, stats: Option<&DbStats>) -> Result<CompiledQuery, CompileError> {
79 let ast = parser::parse(query).map_err(CompileError::Parse)?;
81
82 let errors = analyzer::analyze(&ast);
84 if !errors.is_empty() {
85 return Err(CompileError::Analysis(errors));
86 }
87
88 let plan = planner::plan(&ast, stats);
90
91 Ok(CompiledQuery {
92 source: query.to_string(),
93 ast,
94 plan,
95 })
96}
97
98#[derive(Debug, Clone)]
106pub struct PreparedStatement {
107 pub source: String,
109 pub params: Vec<String>,
111 pub plan: QueryPlan,
113}
114
115pub fn prepare(query: &str, stats: Option<&DbStats>) -> Result<PreparedStatement, CompileError> {
123 let ast = parser::parse(query).map_err(CompileError::Parse)?;
125
126 let params = hirn_query::ast::collect_parameters(&ast);
128
129 if params.is_empty() {
133 let errors = analyzer::analyze(&ast);
134 if !errors.is_empty() {
135 return Err(CompileError::Analysis(errors));
136 }
137 }
138
139 let plan = planner::plan(&ast, stats);
141
142 Ok(PreparedStatement {
143 source: query.to_string(),
144 params,
145 plan,
146 })
147}
148
149pub fn bind(
156 prepared: &PreparedStatement,
157 values: &HashMap<String, String>,
158) -> Result<CompiledQuery, CompileError> {
159 for param in &prepared.params {
161 if !values.contains_key(param) {
162 return Err(CompileError::Analysis(vec![AnalysisError {
163 message: format!("missing value for parameter {param}"),
164 kind: AnalysisErrorKind::UnknownField,
165 }]));
166 }
167 }
168
169 let mut bound_query = prepared.source.clone();
171 for (name, value) in values {
172 let replacement = if value.parse::<f64>().is_ok() || value.parse::<i64>().is_ok() {
175 value.clone()
176 } else {
177 format!("\"{}\"", value.replace('\\', "\\\\").replace('"', "\\\""))
179 };
180 bound_query = bound_query.replace(name.as_str(), &replacement);
181 }
182
183 let ast = parser::parse(&bound_query).map_err(CompileError::Parse)?;
185
186 let errors = analyzer::analyze(&ast);
188 if !errors.is_empty() {
189 return Err(CompileError::Analysis(errors));
190 }
191
192 Ok(CompiledQuery {
193 source: bound_query,
194 ast,
195 plan: prepared.plan.clone(),
196 })
197}
198
199#[derive(Debug)]
204pub struct PlanCache {
205 cache: RwLock<PlanCacheInner>,
206 capacity: usize,
207}
208
209#[derive(Debug)]
210struct PlanCacheInner {
211 entries: HashMap<u64, CacheEntry>,
212 stats_fingerprint: u64,
214}
215
216#[derive(Debug, Clone)]
217struct CacheEntry {
218 compiled: CompiledQuery,
219 hits: u64,
220}
221
222impl PlanCache {
223 pub fn new(capacity: usize) -> Self {
225 Self {
226 cache: RwLock::new(PlanCacheInner {
227 entries: HashMap::with_capacity(capacity),
228 stats_fingerprint: 0,
229 }),
230 capacity,
231 }
232 }
233
234 pub fn compile(
239 &self,
240 query: &str,
241 stats: Option<&DbStats>,
242 ) -> Result<CompiledQuery, CompileError> {
243 let key = hash_query(query);
244 let fingerprint = stats_fingerprint(stats);
245
246 {
248 let cache = self.cache.read();
249 if cache.stats_fingerprint == fingerprint {
250 if let Some(entry) = cache.entries.get(&key) {
251 return Ok(entry.compiled.clone());
252 }
253 }
254 }
255
256 let compiled = compile(query, stats)?;
258
259 {
261 let mut cache = self.cache.write();
262
263 if cache.stats_fingerprint != fingerprint {
265 cache.entries.clear();
266 cache.stats_fingerprint = fingerprint;
267 }
268
269 if cache.entries.len() >= self.capacity {
271 if let Some((&evict_key, _)) = cache.entries.iter().min_by_key(|(_, e)| e.hits) {
272 cache.entries.remove(&evict_key);
273 }
274 }
275
276 cache.entries.insert(
277 key,
278 CacheEntry {
279 compiled: compiled.clone(),
280 hits: 1,
281 },
282 );
283 }
284
285 Ok(compiled)
286 }
287
288 pub fn len(&self) -> usize {
290 self.cache.read().entries.len()
291 }
292
293 pub fn is_empty(&self) -> bool {
295 self.len() == 0
296 }
297
298 pub fn clear(&self) {
300 let mut cache = self.cache.write();
301 cache.entries.clear();
302 }
303}
304
305fn hash_query(query: &str) -> u64 {
306 let mut hasher = DefaultHasher::new();
307 let normalized: String = query.split_whitespace().collect::<Vec<_>>().join(" ");
309 normalized.hash(&mut hasher);
310 hasher.finish()
311}
312
313fn stats_fingerprint(stats: Option<&DbStats>) -> u64 {
314 let Some(s) = stats else { return 0 };
315 let mut hasher = DefaultHasher::new();
316 s.total_count.hash(&mut hasher);
317 s.episodic_count.hash(&mut hasher);
318 s.semantic_count.hash(&mut hasher);
319 hasher.finish()
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn compile_valid_recall() {
328 let result = compile(r#"RECALL episodic ABOUT "test""#, None);
329 assert!(result.is_ok());
330 let compiled = result.unwrap();
331 assert!(matches!(compiled.ast, Statement::Recall(_)));
332 assert!(!compiled.plan.steps.is_empty());
333 }
334
335 #[test]
336 fn compile_invalid_syntax() {
337 let result = compile("NOT_A_QUERY", None);
338 assert!(matches!(result, Err(CompileError::Parse(_))));
339 }
340
341 #[test]
342 fn compile_semantic_error() {
343 let result = compile(r#"RECALL episodic ABOUT "x" WHERE importance > 2.0"#, None);
344 assert!(matches!(result, Err(CompileError::Analysis(_))));
345 if let Err(CompileError::Analysis(errors)) = result {
346 assert_eq!(errors[0].kind, analyzer::AnalysisErrorKind::ValueOutOfRange);
347 }
348 }
349
350 #[test]
351 fn compile_error_display() {
352 let result = compile("INVALID", None);
353 let err = result.unwrap_err();
354 let msg = err.to_string();
355 assert!(!msg.is_empty());
356 }
357
358 #[test]
359 fn compile_same_query_deterministic() {
360 let q = r#"RECALL episodic ABOUT "test" LIMIT 5"#;
361 let c1 = compile(q, None).unwrap();
362 let c2 = compile(q, None).unwrap();
363 assert_eq!(c1.plan, c2.plan);
364 }
365
366 #[test]
367 fn compile_think_with_budget() {
368 let result = compile(r#"THINK ABOUT "optimize" BUDGET 4096"#, None);
369 assert!(result.is_ok());
370 let compiled = result.unwrap();
371 assert!(matches!(compiled.ast, Statement::Think(_)));
372 }
373
374 #[test]
375 fn compile_remember() {
376 let result = compile(r#"REMEMBER episode CONTENT "data""#, None);
377 assert!(
378 result
379 .unwrap_err()
380 .to_string()
381 .contains("REMEMBER is not supported via embedded HirnQL anymore")
382 );
383 }
384
385 #[test]
386 fn compile_complex_recall() {
387 let q = r#"
388 RECALL semantic, episodic
389 ABOUT "vector database"
390 INVOLVING "HNSW"
391 AFTER "2026-03-01"
392 EXPAND GRAPH DEPTH 2 MIN_WEIGHT 0.3 ACTIVATION spreading
393 WHERE importance > 0.4
394 WHERE confidence > 0.8
395 AS NARRATIVE
396 BUDGET 4096
397 NAMESPACE shared
398 LIMIT 20
399 "#;
400 let result = compile(q, None);
401 assert!(result.is_ok());
402 let compiled = result.unwrap();
403 assert!(compiled.plan.steps.len() > 5);
404 }
405
406 #[test]
407 fn compile_error_is_hirn_error() {
408 let result = compile("INVALID", None);
409 let err = result.unwrap_err();
410 let hirn_err: HirnError = err.into();
411 assert!(matches!(hirn_err, HirnError::InvalidInput(_)));
412 }
413
414 #[test]
417 fn cache_hit_returns_same_plan() {
418 let cache = PlanCache::new(100);
419 let q = r#"RECALL episodic ABOUT "test""#;
420 let c1 = cache.compile(q, None).unwrap();
421 let c2 = cache.compile(q, None).unwrap();
422 assert_eq!(c1.plan, c2.plan);
423 assert_eq!(cache.len(), 1);
424 }
425
426 #[test]
427 fn cache_different_queries_stored_separately() {
428 let cache = PlanCache::new(100);
429 cache.compile(r#"RECALL episodic ABOUT "a""#, None).unwrap();
430 cache.compile(r#"RECALL episodic ABOUT "b""#, None).unwrap();
431 assert_eq!(cache.len(), 2);
432 }
433
434 #[test]
435 fn cache_invalidated_on_stats_change() {
436 let cache = PlanCache::new(100);
437 let stats1 = DbStats {
438 working_count: 0,
439 episodic_count: 100,
440 semantic_count: 50,
441 edge_count: 0,
442 procedural_count: 0,
443 total_count: 150,
444 file_size_bytes: 0,
445 };
446 let stats2 = DbStats {
447 working_count: 0,
448 episodic_count: 5000,
449 semantic_count: 2000,
450 edge_count: 0,
451 procedural_count: 0,
452 total_count: 7000,
453 file_size_bytes: 0,
454 };
455
456 cache
457 .compile(r#"RECALL episodic ABOUT "test""#, Some(&stats1))
458 .unwrap();
459 assert_eq!(cache.len(), 1);
460
461 cache
463 .compile(r#"RECALL episodic ABOUT "test""#, Some(&stats2))
464 .unwrap();
465 assert_eq!(cache.len(), 1); }
467
468 #[test]
469 fn cache_eviction_at_capacity() {
470 let cache = PlanCache::new(2);
471 cache.compile(r#"RECALL episodic ABOUT "a""#, None).unwrap();
472 cache.compile(r#"RECALL episodic ABOUT "b""#, None).unwrap();
473 assert_eq!(cache.len(), 2);
474
475 cache.compile(r#"RECALL episodic ABOUT "c""#, None).unwrap();
477 assert_eq!(cache.len(), 2);
478 }
479
480 #[test]
481 fn cache_clear_empties_all() {
482 let cache = PlanCache::new(100);
483 cache.compile(r#"RECALL episodic ABOUT "a""#, None).unwrap();
484 cache.compile(r#"RECALL episodic ABOUT "b""#, None).unwrap();
485 assert_eq!(cache.len(), 2);
486 cache.clear();
487 assert!(cache.is_empty());
488 }
489
490 #[test]
491 fn cache_whitespace_normalized() {
492 let cache = PlanCache::new(100);
493 cache
494 .compile(r#"RECALL episodic ABOUT "test""#, None)
495 .unwrap();
496 cache
497 .compile(r#"RECALL episodic ABOUT "test""#, None)
498 .unwrap();
499 assert_eq!(cache.len(), 1);
501 }
502
503 #[test]
506 fn parse_10k_queries_under_1_second() {
507 let q = r#"RECALL episodic ABOUT "test query" INVOLVING "auth" AFTER "2026-01-01" WHERE importance > 0.5 LIMIT 10"#;
508 let max_elapsed = if cfg!(debug_assertions) {
509 std::time::Duration::from_millis(2500)
510 } else {
511 std::time::Duration::from_millis(1500)
512 };
513 let start = std::time::Instant::now();
514 for _ in 0..10_000 {
515 let _ = parser::parse(q).unwrap();
516 }
517 let elapsed = start.elapsed();
518 assert!(
519 elapsed <= max_elapsed,
520 "10K parses took {:.2}s (>{:.2}s limit)",
521 elapsed.as_secs_f64(),
522 max_elapsed.as_secs_f64()
523 );
524 }
525
526 #[test]
529 fn prepare_extracts_positional_params() {
530 let stmt = prepare(r#"RECALL episodic ABOUT $1 LIMIT 10"#, None).unwrap();
531 assert_eq!(stmt.params, vec!["$1"]);
532 }
533
534 #[test]
535 fn prepare_extracts_named_params() {
536 let stmt = prepare(
537 r#"RECALL episodic ABOUT $query WHERE importance > $threshold"#,
538 None,
539 )
540 .unwrap();
541 assert!(stmt.params.contains(&"$query".to_string()));
542 assert!(stmt.params.contains(&"$threshold".to_string()));
543 assert_eq!(stmt.params.len(), 2);
544 }
545
546 #[test]
547 fn prepare_no_params_runs_analysis() {
548 let result = prepare(r#"RECALL episodic ABOUT "x" WHERE importance > 2.0"#, None);
550 assert!(matches!(result, Err(CompileError::Analysis(_))));
551 }
552
553 #[test]
554 fn prepare_with_params_skips_analysis() {
555 let result = prepare(r#"RECALL episodic ABOUT $1 WHERE importance > $2"#, None);
557 assert!(result.is_ok());
558 }
559
560 #[test]
561 fn bind_substitutes_string_param() {
562 let stmt = prepare(r#"RECALL episodic ABOUT $1 LIMIT 10"#, None).unwrap();
563 let mut values = HashMap::new();
564 values.insert("$1".to_string(), "authentication".to_string());
565
566 let compiled = bind(&stmt, &values).unwrap();
567 match &compiled.ast {
568 Statement::Recall(r) => assert_eq!(r.about, "authentication"),
569 _ => panic!("expected Recall"),
570 }
571 }
572
573 #[test]
574 fn bind_substitutes_numeric_param() {
575 let stmt = prepare(r#"RECALL episodic ABOUT $query LIMIT $limit"#, None).unwrap();
576 let mut values = HashMap::new();
577 values.insert("$query".to_string(), "test".to_string());
578 values.insert("$limit".to_string(), "20".to_string());
579
580 let compiled = bind(&stmt, &values).unwrap();
581 match &compiled.ast {
582 Statement::Recall(r) => {
583 assert_eq!(r.about, "test");
584 assert_eq!(r.limit, Some(20));
585 }
586 _ => panic!("expected Recall"),
587 }
588 }
589
590 #[test]
591 fn bind_missing_param_returns_error() {
592 let stmt = prepare(r#"RECALL episodic ABOUT $1 LIMIT 10"#, None).unwrap();
593 let values = HashMap::new(); let result = bind(&stmt, &values);
595 assert!(result.is_err());
596 }
597
598 #[test]
599 fn bind_reuses_plan() {
600 let stmt = prepare(r#"RECALL episodic ABOUT $1 LIMIT 10"#, None).unwrap();
601 let plan_before = stmt.plan.clone();
602
603 let mut values = HashMap::new();
604 values.insert("$1".to_string(), "auth".to_string());
605 let compiled = bind(&stmt, &values).unwrap();
606
607 assert_eq!(
608 compiled.plan, plan_before,
609 "plan should be reused from prepare"
610 );
611 }
612
613 #[test]
614 fn bind_different_values_produce_different_asts() {
615 let stmt = prepare(r#"RECALL episodic ABOUT $1 LIMIT 10"#, None).unwrap();
616
617 let mut v1 = HashMap::new();
618 v1.insert("$1".to_string(), "auth".to_string());
619 let c1 = bind(&stmt, &v1).unwrap();
620
621 let mut v2 = HashMap::new();
622 v2.insert("$1".to_string(), "deployment".to_string());
623 let c2 = bind(&stmt, &v2).unwrap();
624
625 match (&c1.ast, &c2.ast) {
626 (Statement::Recall(r1), Statement::Recall(r2)) => {
627 assert_eq!(r1.about, "auth");
628 assert_eq!(r2.about, "deployment");
629 }
630 _ => panic!("expected Recall"),
631 }
632 }
633
634 #[test]
635 fn bind_validates_bound_values() {
636 let stmt = prepare(r#"RECALL episodic ABOUT $1 WHERE importance > $2"#, None).unwrap();
637 let mut values = HashMap::new();
638 values.insert("$1".to_string(), "test".to_string());
639 values.insert("$2".to_string(), "5.0".to_string()); let result = bind(&stmt, &values);
642 assert!(matches!(result, Err(CompileError::Analysis(_))));
643 }
644
645 #[test]
646 fn prepared_stmt_faster_than_cold_compile() {
647 let q = r#"RECALL episodic ABOUT $1 INVOLVING "auth" AFTER "2026-01-01" WHERE importance > 0.5 LIMIT 10"#;
648 let stmt = prepare(q, None).unwrap();
649
650 let mut values = HashMap::new();
652 values.insert("$1".to_string(), "test".to_string());
653 let start = std::time::Instant::now();
654 for _ in 0..1_000 {
655 let _ = bind(&stmt, &values).unwrap();
656 }
657 let bind_elapsed = start.elapsed();
658
659 let q_concrete = r#"RECALL episodic ABOUT "test" INVOLVING "auth" AFTER "2026-01-01" WHERE importance > 0.5 LIMIT 10"#;
661 let start = std::time::Instant::now();
662 for _ in 0..1_000 {
663 let _ = compile(q_concrete, None).unwrap();
664 }
665 let compile_elapsed = start.elapsed();
666
667 assert!(
671 bind_elapsed.as_secs_f64() < 2.0,
672 "1K binds took {:.2}s",
673 bind_elapsed.as_secs_f64()
674 );
675 let _ = compile_elapsed; }
677
678 #[test]
681 fn compile_explain_succeeds() {
682 let cq = compile(r#"EXPLAIN RECALL episodic ABOUT "hello""#, None).unwrap();
683 assert!(matches!(cq.ast, Statement::Explain(_)));
684 }
685
686 #[test]
687 fn compile_explain_analyze_succeeds() {
688 let cq = compile(
689 r#"EXPLAIN ANALYZE RECALL episodic ABOUT "hello" LIMIT 5"#,
690 None,
691 )
692 .unwrap();
693 match &cq.ast {
694 Statement::Explain(e) => {
695 assert!(e.analyze);
696 assert!(matches!(*e.inner, Statement::Recall(_)));
697 }
698 _ => panic!("expected Explain"),
699 }
700 }
701
702 #[test]
703 fn compile_explain_invalid_inner_fails() {
704 let result = compile(r#"EXPLAIN"#, None);
706 assert!(result.is_err());
707 }
708}