Skip to main content

hirn_engine/ql/
analyzer.rs

1//! Semantic analysis — validates a parsed AST before planning.
2//!
3//! Checks field names, value types, temporal format validity, and other
4//! semantic constraints that go beyond what the PEG grammar can enforce.
5
6use std::collections::HashSet;
7
8use hirn_query::ast::*;
9
10/// A semantic error discovered during analysis.
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct AnalysisError {
13    pub message: String,
14    pub kind: AnalysisErrorKind,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum AnalysisErrorKind {
19    /// Unknown field name in WHERE clause.
20    UnknownField,
21    /// Type mismatch (e.g., comparing importance with a string).
22    TypeMismatch,
23    /// Invalid temporal format.
24    InvalidTemporal,
25    /// Value out of range (e.g., importance > 1.0).
26    ValueOutOfRange,
27    /// Missing required clause.
28    MissingRequired,
29    /// Unknown relation type for CONNECT.
30    UnknownRelation,
31    /// Invalid layer for operation.
32    InvalidLayer,
33}
34
35impl std::fmt::Display for AnalysisError {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        write!(f, "analysis error: {}", self.message)
38    }
39}
40
41impl std::error::Error for AnalysisError {}
42
43/// Known WHERE-clause fields and their expected value types.
44const NUMERIC_FIELDS: &[&str] = &[
45    "importance",
46    "confidence",
47    "surprise",
48    "access_count",
49    "evidence_count",
50    "relevance_score",
51    "success_rate",
52    "invocation_count",
53    "trust",
54    "episodic.access_count",
55];
56
57/// Analyze a parsed statement for semantic correctness.
58///
59/// Returns a list of errors (empty = valid).
60pub fn analyze(stmt: &Statement) -> Vec<AnalysisError> {
61    match stmt {
62        Statement::Recall(r) => analyze_recall(r),
63        Statement::Think(t) => analyze_think(t),
64        Statement::Correct(c) => analyze_correct(c),
65        Statement::Supersede(s) => analyze_supersede(s),
66        Statement::MergeMemory(m) => analyze_merge_memory(m),
67        Statement::Retract(r) => analyze_retract(r),
68        Statement::Inspect(_) | Statement::History(_) | Statement::Trace(_) => vec![],
69        Statement::Traverse(t) => analyze_traverse(t),
70        Statement::Explain(e) => analyze(&e.inner),
71        Statement::CreateRealm(_)
72        | Statement::DropRealm(_)
73        | Statement::Grant(_)
74        | Statement::Revoke(_)
75        | Statement::ShowPolicies(_)
76        | Statement::ExplainPolicy(_)
77        | Statement::RecallEvents(_)
78        | Statement::ShowCluster
79        | Statement::SetTierPolicy(_)
80        | Statement::ExplainCauses(_)
81        | Statement::WhatIf(_)
82        | Statement::Counterfactual(_) => vec![],
83    }
84}
85
86fn semantic_target_is_empty(target: &SemanticTargetRef) -> bool {
87    target.raw_value().trim().is_empty()
88}
89
90fn analyze_recall(r: &RecallStmt) -> Vec<AnalysisError> {
91    let mut errors = Vec::new();
92
93    if r.about.trim().is_empty() {
94        errors.push(AnalysisError {
95            message: "ABOUT clause cannot be empty".into(),
96            kind: AnalysisErrorKind::MissingRequired,
97        });
98    }
99
100    if r.layers.is_empty() {
101        errors.push(AnalysisError {
102            message: "RECALL requires at least one layer".into(),
103            kind: AnalysisErrorKind::MissingRequired,
104        });
105    }
106
107    errors.extend(analyze_where_clauses(&r.where_clauses));
108    errors.extend(analyze_temporal(r.temporal.as_ref()));
109    errors.extend(analyze_expand(r.expand.as_ref()));
110    errors.extend(analyze_budget(r.budget));
111    errors
112}
113
114fn analyze_think(t: &ThinkStmt) -> Vec<AnalysisError> {
115    let mut errors = Vec::new();
116
117    if t.about.trim().is_empty() {
118        errors.push(AnalysisError {
119            message: "THINK ABOUT clause cannot be empty".into(),
120            kind: AnalysisErrorKind::MissingRequired,
121        });
122    }
123
124    errors.extend(analyze_where_clauses(&t.where_clauses));
125    errors.extend(analyze_temporal(t.temporal.as_ref()));
126    errors.extend(analyze_expand(t.expand.as_ref()));
127    errors.extend(analyze_budget(t.budget));
128    errors
129}
130
131fn analyze_correct(c: &CorrectStmt) -> Vec<AnalysisError> {
132    let mut errors = Vec::new();
133
134    if semantic_target_is_empty(&c.target) {
135        errors.push(AnalysisError {
136            message: "CORRECT target cannot be empty".into(),
137            kind: AnalysisErrorKind::MissingRequired,
138        });
139    }
140
141    errors.extend(analyze_semantic_updates(&c.updates, "CORRECT", true));
142    errors.extend(analyze_semantic_observed_at(
143        c.observed_at.as_ref(),
144        "CORRECT",
145    ));
146
147    errors
148}
149
150fn analyze_semantic_updates(
151    updates: &[SetAssignment],
152    verb: &str,
153    require_updates: bool,
154) -> Vec<AnalysisError> {
155    let mut errors = Vec::new();
156
157    if require_updates && updates.is_empty() {
158        errors.push(AnalysisError {
159            message: format!("{verb} requires at least one field assignment"),
160            kind: AnalysisErrorKind::MissingRequired,
161        });
162    }
163
164    for update in updates {
165        match update.field.as_str() {
166            "description" => {
167                if !matches!(update.value, SetValue::String(_)) {
168                    errors.push(AnalysisError {
169                        message: format!("{verb} description requires a string value"),
170                        kind: AnalysisErrorKind::TypeMismatch,
171                    });
172                }
173            }
174            "confidence" => {
175                let value = match update.value {
176                    SetValue::Float(v) => Some(v),
177                    SetValue::Int(v) => Some(v as f64),
178                    _ => None,
179                };
180
181                if let Some(value) = value {
182                    if !(0.0..=1.0).contains(&value) {
183                        errors.push(AnalysisError {
184                            message: format!(
185                                "{verb} confidence must be between 0.0 and 1.0, got {value}"
186                            ),
187                            kind: AnalysisErrorKind::ValueOutOfRange,
188                        });
189                    }
190                } else {
191                    errors.push(AnalysisError {
192                        message: format!("{verb} confidence requires a numeric value"),
193                        kind: AnalysisErrorKind::TypeMismatch,
194                    });
195                }
196            }
197            "evidence_count" => match update.value {
198                SetValue::Int(v) if v >= 0 => {}
199                SetValue::Int(v) => errors.push(AnalysisError {
200                    message: format!("{verb} evidence_count must be non-negative, got {v}"),
201                    kind: AnalysisErrorKind::ValueOutOfRange,
202                }),
203                _ => errors.push(AnalysisError {
204                    message: format!("{verb} evidence_count requires a non-negative integer"),
205                    kind: AnalysisErrorKind::TypeMismatch,
206                }),
207            },
208            other => errors.push(AnalysisError {
209                message: format!(
210                    "unknown {verb} field '{other}' (allowed: description, confidence, evidence_count)"
211                ),
212                kind: AnalysisErrorKind::UnknownField,
213            }),
214        }
215    }
216
217    errors
218}
219
220fn analyze_semantic_observed_at(observed_at: Option<&String>, verb: &str) -> Vec<AnalysisError> {
221    let mut errors = Vec::new();
222
223    if let Some(observed_at) = observed_at
224        && !is_valid_temporal(observed_at)
225    {
226        errors.push(AnalysisError {
227            message: format!("invalid {verb} OBSERVED AT temporal format: '{observed_at}'"),
228            kind: AnalysisErrorKind::InvalidTemporal,
229        });
230    }
231
232    errors
233}
234
235fn analyze_supersede(s: &SupersedeStmt) -> Vec<AnalysisError> {
236    let mut errors = Vec::new();
237
238    if semantic_target_is_empty(&s.target) {
239        errors.push(AnalysisError {
240            message: "SUPERSEDE target cannot be empty".into(),
241            kind: AnalysisErrorKind::MissingRequired,
242        });
243    }
244
245    errors.extend(analyze_semantic_updates(&s.updates, "SUPERSEDE", true));
246    errors.extend(analyze_semantic_observed_at(
247        s.observed_at.as_ref(),
248        "SUPERSEDE",
249    ));
250
251    errors
252}
253
254fn analyze_merge_memory(m: &MergeMemoryStmt) -> Vec<AnalysisError> {
255    let mut errors = Vec::new();
256
257    if m.sources.is_empty() {
258        errors.push(AnalysisError {
259            message: "MERGE MEMORY requires at least one source memory".into(),
260            kind: AnalysisErrorKind::MissingRequired,
261        });
262    }
263
264    if semantic_target_is_empty(&m.target) {
265        errors.push(AnalysisError {
266            message: "MERGE MEMORY target cannot be empty".into(),
267            kind: AnalysisErrorKind::MissingRequired,
268        });
269    }
270
271    let mut seen_sources = HashSet::new();
272    for source in &m.sources {
273        let normalized = source.raw_value().trim();
274        if normalized.is_empty() {
275            errors.push(AnalysisError {
276                message: "MERGE MEMORY source cannot be empty".into(),
277                kind: AnalysisErrorKind::MissingRequired,
278            });
279            continue;
280        }
281
282        let canonical = source.to_string();
283        if !seen_sources.insert(canonical.clone()) {
284            errors.push(AnalysisError {
285                message: format!("MERGE MEMORY source '{}' is duplicated", source.raw_value()),
286                kind: AnalysisErrorKind::ValueOutOfRange,
287            });
288        }
289
290        if canonical == m.target.to_string() {
291            errors.push(AnalysisError {
292                message: format!(
293                    "MERGE MEMORY source '{}' cannot also be the target",
294                    source.raw_value()
295                ),
296                kind: AnalysisErrorKind::ValueOutOfRange,
297            });
298        }
299    }
300
301    errors.extend(analyze_semantic_updates(&m.updates, "MERGE MEMORY", false));
302    errors.extend(analyze_semantic_observed_at(
303        m.observed_at.as_ref(),
304        "MERGE MEMORY",
305    ));
306
307    errors
308}
309
310fn analyze_retract(r: &RetractStmt) -> Vec<AnalysisError> {
311    let mut errors = Vec::new();
312
313    if semantic_target_is_empty(&r.target) {
314        errors.push(AnalysisError {
315            message: "RETRACT target cannot be empty".into(),
316            kind: AnalysisErrorKind::MissingRequired,
317        });
318    }
319
320    if let Some(ref observed_at) = r.observed_at {
321        if !is_valid_temporal(observed_at) {
322            errors.push(AnalysisError {
323                message: format!("invalid OBSERVED AT temporal format: '{observed_at}'"),
324                kind: AnalysisErrorKind::InvalidTemporal,
325            });
326        }
327    }
328
329    errors
330}
331
332fn analyze_traverse(t: &TraverseStmt) -> Vec<AnalysisError> {
333    let mut errors = Vec::new();
334
335    if t.from.trim().is_empty() {
336        errors.push(AnalysisError {
337            message: "TRAVERSE FROM cannot be empty".into(),
338            kind: AnalysisErrorKind::MissingRequired,
339        });
340    }
341
342    if t.depth == 0 {
343        errors.push(AnalysisError {
344            message: "TRAVERSE DEPTH must be at least 1".into(),
345            kind: AnalysisErrorKind::ValueOutOfRange,
346        });
347    }
348
349    errors.extend(analyze_where_clauses(&t.where_clauses));
350    errors
351}
352
353fn analyze_where_clauses(clauses: &[WhereCondition]) -> Vec<AnalysisError> {
354    let mut errors = Vec::new();
355
356    for wc in clauses {
357        // Check that numeric fields are compared with numeric values.
358        if NUMERIC_FIELDS.contains(&wc.field.as_str()) {
359            if matches!(wc.value, ConditionValue::String(_)) {
360                errors.push(AnalysisError {
361                    message: format!("field '{}' expects a numeric value, got string", wc.field),
362                    kind: AnalysisErrorKind::TypeMismatch,
363                });
364            }
365        }
366
367        // Check numeric range for known bounded fields.
368        match wc.field.as_str() {
369            "importance" | "confidence" | "trust" | "relevance_score" | "success_rate" => {
370                let v = match &wc.value {
371                    ConditionValue::Float(v) => Some(*v),
372                    ConditionValue::Int(v) => Some(*v as f64),
373                    _ => None,
374                };
375                if let Some(v) = v {
376                    if !(0.0..=1.0).contains(&v) {
377                        errors.push(AnalysisError {
378                            message: format!(
379                                "field '{}' threshold should be between 0.0 and 1.0, got {}",
380                                wc.field, v
381                            ),
382                            kind: AnalysisErrorKind::ValueOutOfRange,
383                        });
384                    }
385                }
386            }
387            _ => {}
388        }
389    }
390
391    errors
392}
393
394fn analyze_temporal(temporal: Option<&TemporalClause>) -> Vec<AnalysisError> {
395    let Some(tc) = temporal else { return vec![] };
396    let mut errors = Vec::new();
397
398    let timestamps = match tc {
399        TemporalClause::After(s) => vec![s.as_str()],
400        TemporalClause::Before(s) => vec![s.as_str()],
401        TemporalClause::Between { start, end } => vec![start.as_str(), end.as_str()],
402    };
403
404    for ts in timestamps {
405        if !is_valid_temporal(ts) {
406            errors.push(AnalysisError {
407                message: format!(
408                    "invalid temporal value: '{ts}' (expected YYYY-MM-DD or RFC 3339)"
409                ),
410                kind: AnalysisErrorKind::InvalidTemporal,
411            });
412        }
413    }
414
415    errors
416}
417
418fn analyze_expand(expand: Option<&ExpandClause>) -> Vec<AnalysisError> {
419    let Some(ex) = expand else { return vec![] };
420    let mut errors = Vec::new();
421
422    if ex.depth == 0 {
423        errors.push(AnalysisError {
424            message: "EXPAND GRAPH DEPTH must be at least 1".into(),
425            kind: AnalysisErrorKind::ValueOutOfRange,
426        });
427    }
428
429    if let Some(mw) = ex.min_weight {
430        if !(0.0..=1.0).contains(&mw) {
431            errors.push(AnalysisError {
432                message: format!("MIN_WEIGHT must be between 0.0 and 1.0, got {mw}"),
433                kind: AnalysisErrorKind::ValueOutOfRange,
434            });
435        }
436    }
437
438    errors
439}
440
441fn analyze_budget(budget: Option<usize>) -> Vec<AnalysisError> {
442    if let Some(b) = budget {
443        if b == 0 {
444            return vec![AnalysisError {
445                message: "BUDGET must be greater than 0".into(),
446                kind: AnalysisErrorKind::ValueOutOfRange,
447            }];
448        }
449    }
450    vec![]
451}
452
453fn is_valid_temporal(s: &str) -> bool {
454    use chrono::NaiveDate;
455    // Accept YYYY-MM-DD.
456    if NaiveDate::parse_from_str(s, "%Y-%m-%d").is_ok() {
457        return true;
458    }
459    // Accept RFC 3339 / ISO 8601.
460    if chrono::DateTime::parse_from_rfc3339(s).is_ok() {
461        return true;
462    }
463    false
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn valid_recall_passes() {
472        let stmt = hirn_query::parse(r#"RECALL episodic ABOUT "test""#).unwrap();
473        assert!(analyze(&stmt).is_empty());
474    }
475
476    #[test]
477    fn recall_with_valid_where() {
478        let stmt =
479            hirn_query::parse(r#"RECALL episodic ABOUT "x" WHERE importance > 0.5"#).unwrap();
480        assert!(analyze(&stmt).is_empty());
481    }
482
483    #[test]
484    fn recall_with_out_of_range_importance() {
485        let stmt =
486            hirn_query::parse(r#"RECALL episodic ABOUT "x" WHERE importance > 2.0"#).unwrap();
487        let errors = analyze(&stmt);
488        assert_eq!(errors.len(), 1);
489        assert_eq!(errors[0].kind, AnalysisErrorKind::ValueOutOfRange);
490    }
491
492    #[test]
493    fn recall_with_invalid_temporal() {
494        let stmt = hirn_query::parse(r#"RECALL episodic ABOUT "x" AFTER "not-a-date""#).unwrap();
495        let errors = analyze(&stmt);
496        assert_eq!(errors.len(), 1);
497        assert_eq!(errors[0].kind, AnalysisErrorKind::InvalidTemporal);
498    }
499
500    #[test]
501    fn remember_is_rejected_before_analysis() {
502        let error =
503            hirn_query::parse(r#"REMEMBER episode CONTENT "x" IMPORTANCE 1.5"#).unwrap_err();
504        assert!(
505            error
506                .to_string()
507                .contains("REMEMBER is not supported via embedded HirnQL anymore")
508        );
509    }
510
511    #[test]
512    fn correct_unknown_field_is_rejected() {
513        let stmt = hirn_query::parse(r#"CORRECT "x" SET unsupported = 1"#).unwrap();
514        let errors = analyze(&stmt);
515        assert_eq!(errors.len(), 1);
516        assert_eq!(errors[0].kind, AnalysisErrorKind::UnknownField);
517    }
518
519    #[test]
520    fn supersede_unknown_field_is_rejected() {
521        let stmt = hirn_query::parse(r#"SUPERSEDE "x" SET unsupported = 1"#).unwrap();
522        let errors = analyze(&stmt);
523        assert_eq!(errors.len(), 1);
524        assert_eq!(errors[0].kind, AnalysisErrorKind::UnknownField);
525    }
526
527    #[test]
528    fn retract_invalid_observed_at_is_rejected() {
529        let stmt = hirn_query::parse(r#"RETRACT "x" OBSERVED AT "not-a-date""#).unwrap();
530        let errors = analyze(&stmt);
531        assert_eq!(errors.len(), 1);
532        assert_eq!(errors[0].kind, AnalysisErrorKind::InvalidTemporal);
533    }
534
535    #[test]
536    fn connect_unknown_relation() {
537        let error = hirn_query::parse(r#"CONNECT "a" TO "b" AS unknown_rel"#).unwrap_err();
538        assert!(
539            error
540                .to_string()
541                .contains("CONNECT is not supported via embedded HirnQL anymore")
542        );
543    }
544
545    #[test]
546    fn connect_valid_relation() {
547        let error =
548            hirn_query::parse(r#"CONNECT "a" TO "b" AS related_to WEIGHT 0.5"#).unwrap_err();
549        assert!(
550            error
551                .to_string()
552                .contains("CONNECT is not supported via embedded HirnQL anymore")
553        );
554    }
555
556    #[test]
557    fn connect_weight_out_of_range() {
558        let error = hirn_query::parse(r#"CONNECT "a" TO "b" AS causes WEIGHT 1.5"#).unwrap_err();
559        assert!(
560            error
561                .to_string()
562                .contains("CONNECT is not supported via embedded HirnQL anymore")
563        );
564    }
565
566    #[test]
567    fn think_valid_passes() {
568        let stmt = hirn_query::parse(r#"THINK ABOUT "test" BUDGET 4096"#).unwrap();
569        assert!(analyze(&stmt).is_empty());
570    }
571
572    #[test]
573    fn think_global_valid() {
574        let stmt = hirn_query::parse(r#"THINK GLOBAL ABOUT "test""#).unwrap();
575        assert!(analyze(&stmt).is_empty());
576    }
577
578    #[test]
579    fn consolidate_valid() {
580        let error = hirn_query::parse("CONSOLIDATE").unwrap_err();
581        assert!(
582            error
583                .to_string()
584                .contains("CONSOLIDATE is not supported via HirnQL anymore")
585        );
586    }
587
588    #[test]
589    fn watch_valid() {
590        let error = hirn_query::parse(r#"WATCH ALL"#).unwrap_err();
591        assert!(
592            error
593                .to_string()
594                .contains("WATCH is not supported via embedded HirnQL anymore")
595        );
596    }
597
598    #[test]
599    fn budget_zero_rejected() {
600        let stmt = hirn_query::parse(r#"RECALL episodic ABOUT "x" BUDGET 0"#).unwrap();
601        let errors = analyze(&stmt);
602        assert_eq!(errors.len(), 1);
603        assert_eq!(errors[0].kind, AnalysisErrorKind::ValueOutOfRange);
604    }
605
606    #[test]
607    fn valid_temporal_formats() {
608        assert!(is_valid_temporal("2026-03-01"));
609        assert!(is_valid_temporal("2026-03-01T12:00:00Z"));
610        assert!(is_valid_temporal("2026-03-01T12:00:00+01:00"));
611        assert!(!is_valid_temporal("not-a-date"));
612        assert!(!is_valid_temporal("March 1st"));
613    }
614
615    #[test]
616    fn between_with_valid_dates() {
617        let stmt =
618            hirn_query::parse(r#"RECALL episodic ABOUT "x" BETWEEN "2026-01-01" AND "2026-03-01""#)
619                .unwrap();
620        assert!(analyze(&stmt).is_empty());
621    }
622
623    // ── TRAVERSE, Batch FORGET ──
624
625    #[test]
626    fn traverse_valid() {
627        let stmt = hirn_query::parse(r#"TRAVERSE FROM "node1" DEPTH 3"#).unwrap();
628        assert!(analyze(&stmt).is_empty());
629    }
630
631    #[test]
632    fn traverse_with_via_and_where() {
633        let stmt =
634            hirn_query::parse(r#"TRAVERSE FROM "node1" VIA causes DEPTH 2 WHERE weight > 0.5"#)
635                .unwrap();
636        assert!(analyze(&stmt).is_empty());
637    }
638
639    #[test]
640    fn batch_forget_valid() {
641        let error =
642            hirn_query::parse(r#"FORGET episodic WHERE importance < 0.1 ARCHIVE"#).unwrap_err();
643        assert!(
644            error
645                .to_string()
646                .contains("FORGET is not supported via embedded HirnQL anymore")
647        );
648    }
649
650    #[test]
651    fn forget_hard_mode_valid() {
652        let error = hirn_query::parse(r#"FORGET "id123" HARD"#).unwrap_err();
653        assert!(
654            error
655                .to_string()
656                .contains("FORGET is not supported via embedded HirnQL anymore")
657        );
658    }
659
660    // ── EXPLAIN ──
661
662    #[test]
663    fn explain_valid_recall_no_warnings() {
664        let stmt = hirn_query::parse(r#"EXPLAIN RECALL episodic ABOUT "test""#).unwrap();
665        assert!(analyze(&stmt).is_empty());
666    }
667
668    #[test]
669    fn explain_analyze_delegates_to_inner() {
670        // EXPLAIN ANALYZE on a query with an invalid range should still report the inner warning
671        let stmt = hirn_query::parse(
672            r#"EXPLAIN ANALYZE RECALL episodic ABOUT "test" WHERE importance > 2.0"#,
673        )
674        .unwrap();
675        let warnings = analyze(&stmt);
676        assert!(
677            warnings
678                .iter()
679                .any(|w| matches!(w.kind, AnalysisErrorKind::ValueOutOfRange)),
680            "should propagate inner analysis warnings: {warnings:?}"
681        );
682    }
683}