Skip to main content

khive_query/compilers/
sql.rs

1//! Compile GQL AST to parameterized SQL.
2//!
3//! Two compilation paths:
4//! - Fixed-length patterns (all edges *1..1) → JOIN chain
5//! - Variable-length patterns (any edge *N..M where M>1) → recursive CTE
6//!
7//! Security invariants (MAJ-1/MAJ-2/MAJ-3 from critic review):
8//! - Namespace injection: WHERE clause always comes from CompileOptions.scopes, never the query.
9//! - Edge property whitelist: only `relation` and `weight` are queryable edge columns.
10//! - Depth cap: recursive CTE depth is min(requested, 10).
11
12use crate::ast::*;
13use crate::error::QueryError;
14use crate::validate::validate;
15use khive_storage::types::SqlValue;
16
17#[derive(Debug)]
18pub struct CompiledQuery {
19    pub sql: String,
20    pub params: Vec<SqlValue>,
21    pub return_vars: Vec<String>,
22}
23
24pub struct CompileOptions {
25    /// Namespace scope. Empty = cross-namespace (all). Non-empty = filter to these namespaces.
26    pub scopes: Vec<String>,
27    /// Hard limit cap (server-side safety). Query limit is min(requested, max_limit).
28    pub max_limit: usize,
29}
30
31impl Default for CompileOptions {
32    fn default() -> Self {
33        Self {
34            scopes: Vec::new(),
35            max_limit: 500,
36        }
37    }
38}
39
40pub fn compile(query: &GqlQuery, opts: &CompileOptions) -> Result<CompiledQuery, QueryError> {
41    if query.pattern.elements.is_empty() {
42        return Err(QueryError::Compile("empty pattern".into()));
43    }
44
45    // Validate + normalise into canonical taxonomies before emitting SQL.
46    let mut query = query.clone();
47    validate(&mut query)?;
48
49    if query.pattern.has_variable_length() {
50        compile_variable_length(&query, opts)
51    } else {
52        compile_fixed_length(&query, opts)
53    }
54}
55
56fn namespace_filter(alias: &str, opts: &CompileOptions, params: &mut Vec<SqlValue>) -> String {
57    if opts.scopes.is_empty() {
58        String::new()
59    } else if opts.scopes.len() == 1 {
60        params.push(SqlValue::Text(opts.scopes[0].clone()));
61        format!(" AND {alias}.namespace = ?{}", params.len())
62    } else {
63        let placeholders: Vec<String> = opts
64            .scopes
65            .iter()
66            .map(|s| {
67                params.push(SqlValue::Text(s.clone()));
68                format!("?{}", params.len())
69            })
70            .collect();
71        format!(" AND {alias}.namespace IN ({})", placeholders.join(", "))
72    }
73}
74
75/// Compile fixed-length patterns to a chain of JOINs.
76///
77/// MATCH (a:concept)-[e:introduced_by]->(b:paper) WHERE ... RETURN a, e, b LIMIT 10
78/// →
79/// SELECT a.*, e.*, b.*
80/// FROM entities a
81/// JOIN graph_edges e ON e.source_id = a.id
82/// JOIN entities b ON b.id = e.target_id
83/// WHERE a.kind = 'concept' AND e.relation = 'introduced_by' AND b.kind = 'paper'
84///   AND a.deleted_at IS NULL AND b.deleted_at IS NULL
85/// LIMIT 10
86fn compile_fixed_length(
87    query: &GqlQuery,
88    opts: &CompileOptions,
89) -> Result<CompiledQuery, QueryError> {
90    let mut params: Vec<SqlValue> = Vec::new();
91    let mut from_parts: Vec<String> = Vec::new();
92    let mut join_parts: Vec<String> = Vec::new();
93    let mut where_parts: Vec<String> = Vec::new();
94    let mut select_parts: Vec<String> = Vec::new();
95
96    let mut node_aliases: Vec<String> = Vec::new();
97    let mut edge_aliases: Vec<String> = Vec::new();
98    let mut var_to_alias: std::collections::HashMap<String, (String, VarKind)> =
99        std::collections::HashMap::new();
100
101    let mut node_idx = 0usize;
102    let mut edge_idx = 0usize;
103
104    for element in &query.pattern.elements {
105        match element {
106            PatternElement::Node(np) => {
107                let alias = format!("n{node_idx}");
108                node_aliases.push(alias.clone());
109
110                if node_idx == 0 {
111                    from_parts.push(format!("entities {alias}"));
112                }
113
114                where_parts.push(format!("{alias}.deleted_at IS NULL"));
115
116                let ns_filter = namespace_filter(&alias, opts, &mut params);
117                if !ns_filter.is_empty() {
118                    where_parts.push(ns_filter.trim_start_matches(" AND ").to_string());
119                }
120
121                if let Some(ref kind) = np.kind {
122                    params.push(SqlValue::Text(kind.clone()));
123                    where_parts.push(format!("{alias}.kind = ?{}", params.len()));
124                }
125
126                for (key, val) in &np.properties {
127                    params.push(SqlValue::Text(val.clone()));
128                    if key == "name" {
129                        where_parts
130                            .push(format!("{alias}.name = ?{} COLLATE NOCASE", params.len()));
131                    } else {
132                        where_parts.push(format!(
133                            "json_extract({alias}.properties, '$.{}') = ?{} COLLATE NOCASE",
134                            key.replace('\'', "''"),
135                            params.len()
136                        ));
137                    }
138                }
139
140                if let Some(ref var) = np.variable {
141                    var_to_alias.insert(var.clone(), (alias.clone(), VarKind::Node));
142                }
143
144                node_idx += 1;
145            }
146            PatternElement::Edge(ep) => {
147                let e_alias = format!("e{edge_idx}");
148                let prev_node = &node_aliases[node_aliases.len() - 1];
149
150                edge_aliases.push(e_alias.clone());
151
152                let (source_join, target_join) = match ep.direction {
153                    EdgeDirection::Out => (
154                        format!("{e_alias}.source_id = {prev_node}.id"),
155                        "target_id",
156                    ),
157                    EdgeDirection::In => (
158                        format!("{e_alias}.target_id = {prev_node}.id"),
159                        "source_id",
160                    ),
161                    EdgeDirection::Both => (
162                        format!(
163                            "({e_alias}.source_id = {prev_node}.id OR {e_alias}.target_id = {prev_node}.id)"
164                        ),
165                        "CASE_BOTH",
166                    ),
167                };
168
169                let next_alias = format!("n{}", node_idx);
170
171                let next_join_col = if target_join == "CASE_BOTH" {
172                    format!(
173                        "CASE WHEN {e_alias}.source_id = {prev_node}.id THEN {e_alias}.target_id ELSE {e_alias}.source_id END"
174                    )
175                } else {
176                    format!("{e_alias}.{target_join}")
177                };
178
179                join_parts.push(format!("JOIN graph_edges {e_alias} ON {source_join}"));
180
181                let ens_filter = namespace_filter(&e_alias, opts, &mut params);
182                if !ens_filter.is_empty() {
183                    where_parts.push(ens_filter.trim_start_matches(" AND ").to_string());
184                }
185
186                join_parts.push(format!(
187                    "JOIN entities {next_alias} ON {next_alias}.id = {next_join_col}"
188                ));
189
190                if !ep.relations.is_empty() {
191                    if ep.relations.len() == 1 {
192                        params.push(SqlValue::Text(ep.relations[0].clone()));
193                        where_parts.push(format!("{e_alias}.relation = ?{}", params.len()));
194                    } else {
195                        let placeholders: Vec<String> = ep
196                            .relations
197                            .iter()
198                            .map(|r| {
199                                params.push(SqlValue::Text(r.clone()));
200                                format!("?{}", params.len())
201                            })
202                            .collect();
203                        where_parts.push(format!(
204                            "{e_alias}.relation IN ({})",
205                            placeholders.join(", ")
206                        ));
207                    }
208                }
209
210                if let Some(ref var) = ep.variable {
211                    var_to_alias.insert(var.clone(), (e_alias.clone(), VarKind::Edge));
212                }
213
214                edge_idx += 1;
215            }
216        }
217    }
218
219    // WHERE clause conditions from GQL WHERE
220    for cond in &query.where_clause {
221        let (alias, kind) = var_to_alias.get(&cond.variable).ok_or_else(|| {
222            QueryError::Compile(format!(
223                "unknown variable '{}' in WHERE clause",
224                cond.variable
225            ))
226        })?;
227
228        let col_expr = match kind {
229            VarKind::Node => {
230                if cond.property == "name"
231                    || cond.property == "kind"
232                    || cond.property == "namespace"
233                {
234                    format!("{alias}.{}", cond.property)
235                } else {
236                    format!(
237                        "json_extract({alias}.properties, '$.{}')",
238                        cond.property.replace('\'', "''")
239                    )
240                }
241            }
242            VarKind::Edge => {
243                // MAJ-1: edge property whitelist — only relation and weight are queryable
244                match cond.property.as_str() {
245                    "relation" | "weight" => format!("{alias}.{}", cond.property),
246                    other => {
247                        return Err(QueryError::Validation(format!(
248                            "edge property '{other}' not queryable; use 'relation' or 'weight'"
249                        )))
250                    }
251                }
252            }
253        };
254
255        let op_str = match cond.op {
256            CompareOp::Eq => "=",
257            CompareOp::Neq => "!=",
258            CompareOp::Gt => ">",
259            CompareOp::Lt => "<",
260            CompareOp::Gte => ">=",
261            CompareOp::Lte => "<=",
262            CompareOp::Like => "LIKE",
263        };
264
265        match &cond.value {
266            ConditionValue::String(s) => {
267                params.push(SqlValue::Text(s.clone()));
268                let collate = if matches!(cond.op, CompareOp::Eq | CompareOp::Like) {
269                    " COLLATE NOCASE"
270                } else {
271                    ""
272                };
273                where_parts.push(format!("{col_expr} {op_str} ?{}{}", params.len(), collate));
274            }
275            ConditionValue::Number(n) => {
276                params.push(SqlValue::Float(*n));
277                where_parts.push(format!("{col_expr} {op_str} ?{}", params.len()));
278            }
279            ConditionValue::Bool(b) => {
280                params.push(SqlValue::Integer(if *b { 1 } else { 0 }));
281                where_parts.push(format!("{col_expr} {op_str} ?{}", params.len()));
282            }
283        }
284    }
285
286    // SELECT clause
287    for var in &query.return_items {
288        if let Some((alias, kind)) = var_to_alias.get(var) {
289            match kind {
290                VarKind::Node => {
291                    select_parts.push(format!(
292                        "{alias}.id AS {var}_id, {alias}.namespace AS {var}_namespace, \
293                         {alias}.kind AS {var}_kind, {alias}.name AS {var}_name, \
294                         {alias}.properties AS {var}_properties, \
295                         {alias}.created_at AS {var}_created_at, \
296                         {alias}.updated_at AS {var}_updated_at"
297                    ));
298                }
299                VarKind::Edge => {
300                    select_parts.push(format!(
301                        "{alias}.id AS {var}_id, {alias}.source_id AS {var}_source, \
302                         {alias}.target_id AS {var}_target, \
303                         {alias}.relation AS {var}_relation, \
304                         {alias}.weight AS {var}_weight"
305                    ));
306                }
307            }
308        } else {
309            return Err(QueryError::Compile(format!(
310                "unknown variable '{var}' in RETURN clause"
311            )));
312        }
313    }
314
315    let limit = query.limit.unwrap_or(opts.max_limit).min(opts.max_limit);
316    params.push(SqlValue::Integer(limit as i64));
317
318    let sql = format!(
319        "SELECT {} FROM {} {} WHERE {} LIMIT ?{}",
320        select_parts.join(", "),
321        from_parts.join(", "),
322        join_parts.join(" "),
323        where_parts.join(" AND "),
324        params.len(),
325    );
326
327    Ok(CompiledQuery {
328        sql,
329        params,
330        return_vars: query.return_items.clone(),
331    })
332}
333
334/// Compile variable-length patterns to a recursive CTE.
335///
336/// Depth is capped at min(requested, 10) — MAJ-2 (parameterized min_depth, not literal).
337fn compile_variable_length(
338    query: &GqlQuery,
339    opts: &CompileOptions,
340) -> Result<CompiledQuery, QueryError> {
341    let mut params: Vec<SqlValue> = Vec::new();
342    let mut var_to_alias: std::collections::HashMap<String, (String, VarKind)> =
343        std::collections::HashMap::new();
344
345    // For variable-length, we expect exactly: start_node -[*N..M]-> end_node.
346    // Mixed fixed+variable chains and additional trailing pattern elements are
347    // not yet supported — reject explicitly rather than silently dropping them.
348    let nodes: Vec<&NodePattern> = query.pattern.nodes().collect();
349    let edges: Vec<&EdgePattern> = query.pattern.edges().collect();
350
351    if nodes.len() != 2 || edges.len() != 1 || query.pattern.elements.len() != 3 {
352        return Err(QueryError::Unsupported(
353            "variable-length patterns must be a single start_node -[*N..M]-> end_node \
354             (mixed fixed/variable chains are not yet implemented)"
355                .into(),
356        ));
357    }
358
359    let start = &nodes[0];
360    let edge = &edges[0];
361    let end = &nodes[1];
362
363    // MAJ-2: depth cap — always parameterized, never injected as literal
364    let max_depth = edge.max_hops.min(10);
365    let min_depth = edge.min_hops;
366
367    // Build start-node conditions
368    let mut start_conditions: Vec<String> = vec!["s.deleted_at IS NULL".to_string()];
369    let ns_filter = namespace_filter("s", opts, &mut params);
370    if !ns_filter.is_empty() {
371        start_conditions.push(ns_filter.trim_start_matches(" AND ").to_string());
372    }
373
374    if let Some(ref kind) = start.kind {
375        params.push(SqlValue::Text(kind.clone()));
376        start_conditions.push(format!("s.kind = ?{}", params.len()));
377    }
378    for (key, val) in &start.properties {
379        params.push(SqlValue::Text(val.clone()));
380        if key == "name" {
381            start_conditions.push(format!("s.name = ?{} COLLATE NOCASE", params.len()));
382        } else {
383            start_conditions.push(format!(
384                "json_extract(s.properties, '$.{}') = ?{} COLLATE NOCASE",
385                key.replace('\'', "''"),
386                params.len()
387            ));
388        }
389    }
390
391    // Relation filter
392    let mut relation_condition = String::new();
393    if !edge.relations.is_empty() {
394        if edge.relations.len() == 1 {
395            params.push(SqlValue::Text(edge.relations[0].clone()));
396            relation_condition = format!(" AND e.relation = ?{}", params.len());
397        } else {
398            let placeholders: Vec<String> = edge
399                .relations
400                .iter()
401                .map(|r| {
402                    params.push(SqlValue::Text(r.clone()));
403                    format!("?{}", params.len())
404                })
405                .collect();
406            relation_condition = format!(" AND e.relation IN ({})", placeholders.join(", "));
407        }
408    }
409
410    // Edge namespace filter
411    let e_ns_filter = namespace_filter("e", opts, &mut params);
412
413    // Direction-dependent JOIN
414    let (seed_join, seed_next, recurse_join, recurse_next) = match edge.direction {
415        EdgeDirection::Out => (
416            "e.source_id = s.id",
417            "e.target_id",
418            "e.source_id = t.current_id",
419            "e.target_id",
420        ),
421        EdgeDirection::In => (
422            "e.target_id = s.id",
423            "e.source_id",
424            "e.target_id = t.current_id",
425            "e.source_id",
426        ),
427        EdgeDirection::Both => (
428            "(e.source_id = s.id OR e.target_id = s.id)",
429            "CASE WHEN e.source_id = s.id THEN e.target_id ELSE e.source_id END",
430            "(e.source_id = t.current_id OR e.target_id = t.current_id)",
431            "CASE WHEN e.source_id = t.current_id THEN e.target_id ELSE e.source_id END",
432        ),
433    };
434
435    params.push(SqlValue::Integer(max_depth as i64));
436    let depth_param = params.len();
437
438    // End-node conditions (applied in outer WHERE). `r` is always joined
439    // unconditionally below so these references resolve regardless of whether
440    // the end variable is projected.
441    let mut end_conditions: Vec<String> = vec!["r.deleted_at IS NULL".to_string()];
442    let r_ns_filter = namespace_filter("r", opts, &mut params);
443    if !r_ns_filter.is_empty() {
444        end_conditions.push(r_ns_filter.trim_start_matches(" AND ").to_string());
445    }
446    if let Some(ref kind) = end.kind {
447        params.push(SqlValue::Text(kind.clone()));
448        end_conditions.push(format!("r.kind = ?{}", params.len()));
449    }
450    for (key, val) in &end.properties {
451        params.push(SqlValue::Text(val.clone()));
452        if key == "name" {
453            end_conditions.push(format!("r.name = ?{} COLLATE NOCASE", params.len()));
454        } else {
455            end_conditions.push(format!(
456                "json_extract(r.properties, '$.{}') = ?{} COLLATE NOCASE",
457                key.replace('\'', "''"),
458                params.len()
459            ));
460        }
461    }
462
463    // WHERE clause conditions
464    for cond in &query.where_clause {
465        // Map variables to appropriate aliases
466        let col_alias = if start.variable.as_deref() == Some(&cond.variable) {
467            "s"
468        } else if end.variable.as_deref() == Some(&cond.variable) {
469            "r"
470        } else {
471            return Err(QueryError::Compile(format!(
472                "variable '{}' in WHERE not supported in variable-length pattern (only start/end node variables)",
473                cond.variable
474            )));
475        };
476
477        let col_expr = if cond.property == "name" || cond.property == "kind" {
478            format!("{col_alias}.{}", cond.property)
479        } else {
480            format!(
481                "json_extract({col_alias}.properties, '$.{}')",
482                cond.property.replace('\'', "''")
483            )
484        };
485
486        let op_str = match cond.op {
487            CompareOp::Eq => "=",
488            CompareOp::Neq => "!=",
489            CompareOp::Gt => ">",
490            CompareOp::Lt => "<",
491            CompareOp::Gte => ">=",
492            CompareOp::Lte => "<=",
493            CompareOp::Like => "LIKE",
494        };
495
496        match &cond.value {
497            ConditionValue::String(s) => {
498                params.push(SqlValue::Text(s.clone()));
499                let collate = if matches!(cond.op, CompareOp::Eq | CompareOp::Like) {
500                    " COLLATE NOCASE"
501                } else {
502                    ""
503                };
504                if col_alias == "s" {
505                    start_conditions
506                        .push(format!("{col_expr} {op_str} ?{}{collate}", params.len()));
507                } else {
508                    end_conditions.push(format!("{col_expr} {op_str} ?{}{collate}", params.len()));
509                }
510            }
511            ConditionValue::Number(n) => {
512                params.push(SqlValue::Float(*n));
513                if col_alias == "s" {
514                    start_conditions.push(format!("{col_expr} {op_str} ?{}", params.len()));
515                } else {
516                    end_conditions.push(format!("{col_expr} {op_str} ?{}", params.len()));
517                }
518            }
519            ConditionValue::Bool(b) => {
520                params.push(SqlValue::Integer(if *b { 1 } else { 0 }));
521                if col_alias == "s" {
522                    start_conditions.push(format!("{col_expr} {op_str} ?{}", params.len()));
523                } else {
524                    end_conditions.push(format!("{col_expr} {op_str} ?{}", params.len()));
525                }
526            }
527        }
528    }
529
530    // MAJ-2: min_depth is always a bound parameter, never a literal
531    if min_depth > 0 {
532        params.push(SqlValue::Integer(min_depth as i64));
533        end_conditions.push(format!("t.depth >= ?{}", params.len()));
534    }
535
536    let limit = query.limit.unwrap_or(opts.max_limit).min(opts.max_limit);
537    params.push(SqlValue::Integer(limit as i64));
538    let limit_param = params.len();
539
540    // Register variables
541    if let Some(ref var) = start.variable {
542        var_to_alias.insert(var.clone(), ("s".to_string(), VarKind::Node));
543    }
544    if let Some(ref var) = end.variable {
545        var_to_alias.insert(var.clone(), ("r".to_string(), VarKind::Node));
546    }
547    if let Some(ref var) = edge.variable {
548        var_to_alias.insert(var.clone(), ("e".to_string(), VarKind::Edge));
549    }
550
551    // Build SELECT based on RETURN items
552    let mut select_parts: Vec<String> = Vec::new();
553    let mut has_start = false;
554
555    for var in &query.return_items {
556        if let Some((_, kind)) = var_to_alias.get(var) {
557            match kind {
558                VarKind::Node => {
559                    if start.variable.as_deref() == Some(var.as_str()) {
560                        has_start = true;
561                        select_parts.push(format!(
562                            "s.id AS {var}_id, s.namespace AS {var}_namespace, \
563                             s.kind AS {var}_kind, s.name AS {var}_name, \
564                             s.properties AS {var}_properties, \
565                             s.created_at AS {var}_created_at, \
566                             s.updated_at AS {var}_updated_at"
567                        ));
568                    } else {
569                        select_parts.push(format!(
570                            "r.id AS {var}_id, r.namespace AS {var}_namespace, \
571                             r.kind AS {var}_kind, r.name AS {var}_name, \
572                             r.properties AS {var}_properties, \
573                             r.created_at AS {var}_created_at, \
574                             r.updated_at AS {var}_updated_at"
575                        ));
576                    }
577                }
578                VarKind::Edge => {
579                    select_parts.push(format!(
580                        "t.via_edge AS {var}_id, t.via_relation AS {var}_relation, \
581                         t.via_weight AS {var}_weight"
582                    ));
583                }
584            }
585        } else {
586            return Err(QueryError::Compile(format!(
587                "unknown variable '{var}' in RETURN clause"
588            )));
589        }
590    }
591
592    // Always include traversal metadata
593    select_parts.push("t.depth AS _depth".to_string());
594    select_parts.push("t.total_weight AS _total_weight".to_string());
595
596    // `s` is optional (only joined if the start variable is projected); `r` is
597    // always joined because the outer WHERE always references `r.deleted_at`,
598    // `r.namespace` (and possibly r.kind / r.properties) regardless of whether
599    // it appears in RETURN.
600    let join_start = if has_start {
601        "JOIN entities s ON s.id = t.start_id"
602    } else {
603        ""
604    };
605    let join_end = "JOIN entities r ON r.id = t.current_id";
606
607    let sql = format!(
608        "WITH RECURSIVE traverse(start_id, current_id, depth, path, total_weight, via_edge, via_relation, via_weight) AS (\
609             SELECT s.id, {seed_next}, 1, s.id || ',' || {seed_next}, e.weight, \
610                    e.id, e.relation, e.weight \
611             FROM entities s \
612             JOIN graph_edges e ON {seed_join}{e_ns_filter}{relation_condition} \
613             WHERE {start_where} \
614             UNION ALL \
615             SELECT t.start_id, {recurse_next}, t.depth + 1, \
616                    t.path || ',' || {recurse_next}, \
617                    t.total_weight + e.weight, \
618                    e.id, e.relation, e.weight \
619             FROM traverse t \
620             JOIN graph_edges e ON {recurse_join}{e_ns_filter}{relation_condition} \
621             WHERE t.depth < ?{depth_param} \
622               AND (',' || t.path || ',') NOT LIKE '%,' || {recurse_next} || ',%' \
623         ) \
624         SELECT DISTINCT {select_cols} \
625         FROM traverse t \
626         {join_start} {join_end} \
627         WHERE {end_where} \
628         ORDER BY t.depth, t.total_weight DESC \
629         LIMIT ?{limit_param}",
630        seed_next = seed_next,
631        seed_join = seed_join,
632        e_ns_filter = e_ns_filter,
633        relation_condition = relation_condition,
634        start_where = start_conditions.join(" AND "),
635        recurse_next = recurse_next,
636        recurse_join = recurse_join,
637        depth_param = depth_param,
638        select_cols = select_parts.join(", "),
639        join_start = join_start,
640        join_end = join_end,
641        end_where = end_conditions.join(" AND "),
642        limit_param = limit_param,
643    );
644
645    Ok(CompiledQuery {
646        sql,
647        params,
648        return_vars: query.return_items.clone(),
649    })
650}
651
652#[derive(Clone, Copy)]
653enum VarKind {
654    Node,
655    Edge,
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661    use crate::parsers::gql;
662
663    fn opts() -> CompileOptions {
664        CompileOptions::default()
665    }
666
667    fn scoped(namespace: &str) -> CompileOptions {
668        CompileOptions {
669            scopes: vec![namespace.to_string()],
670            max_limit: 500,
671        }
672    }
673
674    #[test]
675    fn fixed_length_basic() {
676        let q =
677            gql::parse("MATCH (a:concept)-[e:introduced_by]->(b:paper) RETURN a, e, b LIMIT 10")
678                .unwrap();
679        let compiled = compile(&q, &opts()).unwrap();
680        assert!(compiled.sql.contains("JOIN graph_edges"));
681        assert!(compiled.sql.contains("LIMIT"));
682        assert_eq!(compiled.return_vars, vec!["a", "e", "b"]);
683        // No recursive CTE for fixed-length
684        assert!(!compiled.sql.contains("WITH RECURSIVE"));
685    }
686
687    #[test]
688    fn namespace_scoping_injected() {
689        // Namespace must come from opts, never from the query
690        let q =
691            gql::parse("MATCH (a:concept)-[e:introduced_by]->(b:paper) RETURN a LIMIT 5").unwrap();
692        let compiled = compile(&q, &scoped("research")).unwrap();
693        assert!(compiled.sql.contains("namespace"));
694        // The namespace value must appear as a parameter, not a literal in SQL
695        let has_ns_param = compiled
696            .params
697            .iter()
698            .any(|p| matches!(p, SqlValue::Text(s) if s == "research"));
699        assert!(has_ns_param, "namespace must be a bound parameter");
700    }
701
702    #[test]
703    fn edge_property_whitelist_rejects_unknown() {
704        // MAJ-1: only 'relation' and 'weight' are queryable edge properties
705        let q = gql::parse("MATCH (a)-[e:introduced_by]->(b) WHERE e.source_id = 'x' RETURN a")
706            .unwrap();
707        let result = compile(&q, &opts());
708        assert!(result.is_err());
709        let err = result.unwrap_err().to_string();
710        assert!(
711            err.contains("source_id") || err.contains("not queryable"),
712            "error: {err}"
713        );
714    }
715
716    #[test]
717    fn edge_property_relation_allowed() {
718        let q = gql::parse("MATCH (a)-[e]->(b) WHERE e.relation = 'extends' RETURN a").unwrap();
719        let result = compile(&q, &opts());
720        assert!(
721            result.is_ok(),
722            "relation should be allowed: {:?}",
723            result.err()
724        );
725    }
726
727    #[test]
728    fn edge_property_weight_allowed() {
729        let q = gql::parse("MATCH (a)-[e]->(b) WHERE e.weight > 0.5 RETURN a").unwrap();
730        let result = compile(&q, &opts());
731        assert!(
732            result.is_ok(),
733            "weight should be allowed: {:?}",
734            result.err()
735        );
736    }
737
738    #[test]
739    fn variable_length_uses_cte() {
740        let q =
741            gql::parse("MATCH (a {name: 'LoRA'})-[:extends*1..3]->(b) RETURN b LIMIT 20").unwrap();
742        let compiled = compile(&q, &opts()).unwrap();
743        assert!(compiled.sql.contains("WITH RECURSIVE"));
744        assert!(compiled.sql.contains("traverse"));
745    }
746
747    #[test]
748    fn depth_cap_at_ten() {
749        // MAJ-2: depth capped at 10 regardless of query request
750        let q = gql::parse("MATCH (a)-[:extends*1..50]->(b) RETURN b").unwrap();
751        let compiled = compile(&q, &opts()).unwrap();
752        // The depth parameter must be <= 10
753        let depth_val = compiled.params.iter().find_map(|p| {
754            if let SqlValue::Integer(n) = p {
755                Some(*n)
756            } else {
757                None
758            }
759        });
760        assert!(depth_val.unwrap() <= 10, "depth must be capped at 10");
761    }
762
763    #[test]
764    fn limit_capped_by_max_limit() {
765        // Query requests 1000, max_limit is 500 — result should be 500
766        let q = gql::parse("MATCH (a:concept)-[e]->(b) RETURN a LIMIT 1000").unwrap();
767        let compiled = compile(&q, &opts()).unwrap();
768        let limit_param = compiled.params.last().unwrap();
769        assert!(
770            matches!(limit_param, SqlValue::Integer(500)),
771            "expected Integer(500), got {limit_param:?}"
772        );
773    }
774
775    #[test]
776    fn compile_rejects_unknown_relation() {
777        let q = gql::parse("MATCH (a)-[:not_a_relation]->(b) RETURN a").unwrap();
778        let err = compile(&q, &opts()).unwrap_err();
779        let msg = err.to_string();
780        assert!(msg.contains("not_a_relation"), "msg: {msg}");
781    }
782
783    #[test]
784    fn compile_rejects_unknown_kind() {
785        let q = gql::parse("MATCH (a:gizmo)-[:extends]->(b) RETURN a").unwrap();
786        let err = compile(&q, &opts()).unwrap_err();
787        let msg = err.to_string();
788        assert!(msg.contains("gizmo"), "msg: {msg}");
789    }
790
791    #[test]
792    fn compile_normalises_kind_aliases_to_canonical() {
793        // `paper` is an alias for the canonical `document` kind; the SQL must
794        // bind `document` so it matches what `khive-db` stores.
795        let q =
796            gql::parse("MATCH (a:paper)-[:introduced_by]->(b:concept) RETURN a LIMIT 1").unwrap();
797        let compiled = compile(&q, &opts()).unwrap();
798        let has_document = compiled
799            .params
800            .iter()
801            .any(|p| matches!(p, SqlValue::Text(s) if s == "document"));
802        let has_paper = compiled
803            .params
804            .iter()
805            .any(|p| matches!(p, SqlValue::Text(s) if s == "paper"));
806        assert!(has_document, "expected canonical 'document' in params");
807        assert!(
808            !has_paper,
809            "raw alias 'paper' must not leak into SQL params"
810        );
811    }
812
813    #[test]
814    fn compile_rejects_namespace_in_where() {
815        let q =
816            gql::parse("MATCH (a:concept)-[:extends]->(b) WHERE a.namespace = 'other' RETURN a")
817                .unwrap();
818        let err = compile(&q, &opts()).unwrap_err();
819        assert!(err.to_string().contains("namespace"), "msg: {err}");
820    }
821
822    #[test]
823    fn compile_rejects_unknown_relation_in_where() {
824        let q = gql::parse("MATCH (a)-[e:extends]->(b) WHERE e.relation = 'related_to' RETURN a")
825            .unwrap();
826        let err = compile(&q, &opts()).unwrap_err();
827        assert!(err.to_string().contains("related_to"), "msg: {err}");
828    }
829
830    #[test]
831    fn compile_normalises_kind_alias_in_where_param() {
832        let q = gql::parse("MATCH (a)-[:extends]->(b) WHERE a.kind = 'paper' RETURN a").unwrap();
833        let compiled = compile(&q, &opts()).unwrap();
834        let has_document = compiled
835            .params
836            .iter()
837            .any(|p| matches!(p, SqlValue::Text(s) if s == "document"));
838        let has_paper = compiled
839            .params
840            .iter()
841            .any(|p| matches!(p, SqlValue::Text(s) if s == "paper"));
842        assert!(
843            has_document,
844            "WHERE a.kind = 'paper' must normalise to 'document'"
845        );
846        assert!(!has_paper, "raw 'paper' must not leak into SQL params");
847    }
848
849    #[test]
850    fn variable_length_return_start_only_joins_end_entity() {
851        // Even when only the start variable is projected, the outer query
852        // references `r.deleted_at` / `r.namespace`, so entities r must be
853        // joined unconditionally.
854        let q = gql::parse("MATCH (a:concept)-[:extends*1..3]->(b) RETURN a LIMIT 10").unwrap();
855        let compiled = compile(&q, &opts()).unwrap();
856        assert!(
857            compiled.sql.contains("JOIN entities r"),
858            "entities r must always be joined when r.* conditions are emitted; sql: {}",
859            compiled.sql
860        );
861    }
862
863    #[test]
864    fn variable_length_trailing_pattern_unsupported() {
865        let q = gql::parse("MATCH (a)-[:extends*1..3]->(b)-[:implements]->(c) RETURN b").unwrap();
866        let err = compile(&q, &opts()).unwrap_err();
867        assert!(
868            matches!(err, QueryError::Unsupported(_)),
869            "expected Unsupported, got {err:?}"
870        );
871    }
872
873    #[test]
874    fn variable_length_mixed_chain_unsupported() {
875        // Mixed fixed + variable in one chain — has_variable_length() triggers
876        // the variable-length path, which must reject because edges.len() > 1.
877        let q = gql::parse("MATCH (a)-[:extends]->(b)-[:implements*1..2]->(c) RETURN c").unwrap();
878        let err = compile(&q, &opts()).unwrap_err();
879        assert!(matches!(err, QueryError::Unsupported(_)), "got {err:?}");
880    }
881
882    #[test]
883    fn sparql_star_rejected_as_unsupported() {
884        use crate::parsers::sparql;
885        let err = sparql::parse("SELECT ?a ?b WHERE { ?a :extends* ?b . }").unwrap_err();
886        assert!(matches!(err, QueryError::Unsupported(_)), "got {err:?}");
887    }
888}