Skip to main content

khive_query/compilers/
sql.rs

1//! Compile GQL AST to parameterized SQL.
2//!
3//! Two compilation paths:
4//! - Fixed-length patterns (all edges *1..1) → JOIN chain
5//! - Variable-length patterns (any edge *N..M where M>1) → recursive CTE
6//!
7//! Security invariants (MAJ-1/MAJ-2/MAJ-3 from critic review):
8//! - Namespace injection: WHERE clause always comes from CompileOptions.scopes, never the query.
9//! - Edge property whitelist: only `relation` and `weight` are queryable edge columns.
10//! - Depth cap: recursive CTE depth is min(requested, 10).
11
12use crate::ast::*;
13use crate::error::QueryError;
14use crate::validate::validate;
15use khive_storage::types::SqlValue;
16
17#[derive(Debug)]
18pub struct CompiledQuery {
19    pub sql: String,
20    pub params: Vec<SqlValue>,
21    pub return_vars: Vec<ReturnItem>,
22}
23
24pub struct CompileOptions {
25    /// Namespace scope. Empty = cross-namespace (all). Non-empty = filter to these namespaces.
26    pub scopes: Vec<String>,
27    /// Hard limit cap (server-side safety). Query limit is min(requested, max_limit).
28    pub max_limit: usize,
29}
30
31impl Default for CompileOptions {
32    fn default() -> Self {
33        Self {
34            scopes: Vec::new(),
35            max_limit: 500,
36        }
37    }
38}
39
40pub fn compile(query: &GqlQuery, opts: &CompileOptions) -> Result<CompiledQuery, QueryError> {
41    if query.pattern.elements.is_empty() {
42        return Err(QueryError::Compile("empty pattern".into()));
43    }
44
45    // Validate edge relations + structural rules before emitting SQL.
46    let mut query = query.clone();
47    validate(&mut query)?;
48
49    if query.pattern.has_variable_length() {
50        compile_variable_length(&query, opts)
51    } else {
52        compile_fixed_length(&query, opts)
53    }
54}
55
56fn namespace_filter(alias: &str, opts: &CompileOptions, params: &mut Vec<SqlValue>) -> String {
57    if opts.scopes.is_empty() {
58        String::new()
59    } else if opts.scopes.len() == 1 {
60        params.push(SqlValue::Text(opts.scopes[0].clone()));
61        format!(" AND {alias}.namespace = ?{}", params.len())
62    } else {
63        let placeholders: Vec<String> = opts
64            .scopes
65            .iter()
66            .map(|s| {
67                params.push(SqlValue::Text(s.clone()));
68                format!("?{}", params.len())
69            })
70            .collect();
71        format!(" AND {alias}.namespace IN ({})", placeholders.join(", "))
72    }
73}
74
75/// Compile fixed-length patterns to a chain of JOINs.
76///
77/// MATCH (a:concept)-[e:introduced_by]->(b:paper) WHERE ... RETURN a, e, b LIMIT 10
78/// →
79/// SELECT a.*, e.*, b.*
80/// FROM entities a
81/// JOIN graph_edges e ON e.source_id = a.id
82/// JOIN entities b ON b.id = e.target_id
83/// WHERE a.kind = 'concept' AND e.relation = 'introduced_by' AND b.kind = 'paper'
84///   AND a.deleted_at IS NULL AND b.deleted_at IS NULL
85/// LIMIT 10
86fn compile_fixed_length(
87    query: &GqlQuery,
88    opts: &CompileOptions,
89) -> Result<CompiledQuery, QueryError> {
90    let mut params: Vec<SqlValue> = Vec::new();
91    let mut from_parts: Vec<String> = Vec::new();
92    let mut join_parts: Vec<String> = Vec::new();
93    let mut where_parts: Vec<String> = Vec::new();
94    let mut select_parts: Vec<String> = Vec::new();
95
96    let mut node_aliases: Vec<String> = Vec::new();
97    let mut edge_aliases: Vec<String> = Vec::new();
98    let mut var_to_alias: std::collections::HashMap<String, (String, VarKind)> =
99        std::collections::HashMap::new();
100
101    let mut node_idx = 0usize;
102    let mut edge_idx = 0usize;
103
104    for element in &query.pattern.elements {
105        match element {
106            PatternElement::Node(np) => {
107                let alias = format!("n{node_idx}");
108                node_aliases.push(alias.clone());
109
110                if node_idx == 0 {
111                    from_parts.push(format!("entities {alias}"));
112                }
113
114                where_parts.push(format!("{alias}.deleted_at IS NULL"));
115
116                let ns_filter = namespace_filter(&alias, opts, &mut params);
117                if !ns_filter.is_empty() {
118                    where_parts.push(ns_filter.trim_start_matches(" AND ").to_string());
119                }
120
121                if let Some(ref kind) = np.kind {
122                    params.push(SqlValue::Text(kind.clone()));
123                    where_parts.push(format!("{alias}.kind = ?{}", params.len()));
124                }
125
126                for (key, val) in &np.properties {
127                    params.push(SqlValue::Text(val.clone()));
128                    if key == "name" {
129                        where_parts
130                            .push(format!("{alias}.name = ?{} COLLATE NOCASE", params.len()));
131                    } else {
132                        where_parts.push(format!(
133                            "json_extract({alias}.properties, '$.{}') = ?{} COLLATE NOCASE",
134                            key.replace('\'', "''"),
135                            params.len()
136                        ));
137                    }
138                }
139
140                if let Some(ref var) = np.variable {
141                    var_to_alias.insert(var.clone(), (alias.clone(), VarKind::Node));
142                }
143
144                node_idx += 1;
145            }
146            PatternElement::Edge(ep) => {
147                let e_alias = format!("e{edge_idx}");
148                let prev_node = &node_aliases[node_aliases.len() - 1];
149
150                edge_aliases.push(e_alias.clone());
151
152                let (source_join, target_join) = match ep.direction {
153                    EdgeDirection::Out => (
154                        format!("{e_alias}.source_id = {prev_node}.id"),
155                        "target_id",
156                    ),
157                    EdgeDirection::In => (
158                        format!("{e_alias}.target_id = {prev_node}.id"),
159                        "source_id",
160                    ),
161                    EdgeDirection::Both => (
162                        format!(
163                            "({e_alias}.source_id = {prev_node}.id OR {e_alias}.target_id = {prev_node}.id)"
164                        ),
165                        "CASE_BOTH",
166                    ),
167                };
168
169                let next_alias = format!("n{}", node_idx);
170
171                let next_join_col = if target_join == "CASE_BOTH" {
172                    format!(
173                        "CASE WHEN {e_alias}.source_id = {prev_node}.id THEN {e_alias}.target_id ELSE {e_alias}.source_id END"
174                    )
175                } else {
176                    format!("{e_alias}.{target_join}")
177                };
178
179                join_parts.push(format!("JOIN graph_edges {e_alias} ON {source_join}"));
180
181                let ens_filter = namespace_filter(&e_alias, opts, &mut params);
182                if !ens_filter.is_empty() {
183                    where_parts.push(ens_filter.trim_start_matches(" AND ").to_string());
184                }
185
186                join_parts.push(format!(
187                    "JOIN entities {next_alias} ON {next_alias}.id = {next_join_col}"
188                ));
189
190                if !ep.relations.is_empty() {
191                    if ep.relations.len() == 1 {
192                        params.push(SqlValue::Text(ep.relations[0].clone()));
193                        where_parts.push(format!("{e_alias}.relation = ?{}", params.len()));
194                    } else {
195                        let placeholders: Vec<String> = ep
196                            .relations
197                            .iter()
198                            .map(|r| {
199                                params.push(SqlValue::Text(r.clone()));
200                                format!("?{}", params.len())
201                            })
202                            .collect();
203                        where_parts.push(format!(
204                            "{e_alias}.relation IN ({})",
205                            placeholders.join(", ")
206                        ));
207                    }
208                }
209
210                if let Some(ref var) = ep.variable {
211                    var_to_alias.insert(var.clone(), (e_alias.clone(), VarKind::Edge));
212                }
213
214                edge_idx += 1;
215            }
216        }
217    }
218
219    // WHERE clause conditions from GQL WHERE
220    for cond in &query.where_clause {
221        let (alias, kind) = var_to_alias.get(&cond.variable).ok_or_else(|| {
222            QueryError::Compile(format!(
223                "unknown variable '{}' in WHERE clause",
224                cond.variable
225            ))
226        })?;
227
228        let col_expr = match kind {
229            VarKind::Node => {
230                if cond.property == "name"
231                    || cond.property == "kind"
232                    || cond.property == "namespace"
233                {
234                    format!("{alias}.{}", cond.property)
235                } else {
236                    format!(
237                        "json_extract({alias}.properties, '$.{}')",
238                        cond.property.replace('\'', "''")
239                    )
240                }
241            }
242            VarKind::Edge => {
243                // MAJ-1: edge property whitelist — only relation and weight are queryable
244                match cond.property.as_str() {
245                    "relation" | "weight" => format!("{alias}.{}", cond.property),
246                    other => {
247                        return Err(QueryError::Validation(format!(
248                            "edge property '{other}' not queryable; use 'relation' or 'weight'"
249                        )))
250                    }
251                }
252            }
253        };
254
255        let op_str = match cond.op {
256            CompareOp::Eq => "=",
257            CompareOp::Neq => "!=",
258            CompareOp::Gt => ">",
259            CompareOp::Lt => "<",
260            CompareOp::Gte => ">=",
261            CompareOp::Lte => "<=",
262            CompareOp::Like => "LIKE",
263        };
264
265        match &cond.value {
266            ConditionValue::String(s) => {
267                params.push(SqlValue::Text(s.clone()));
268                let collate = if matches!(cond.op, CompareOp::Eq | CompareOp::Like) {
269                    " COLLATE NOCASE"
270                } else {
271                    ""
272                };
273                where_parts.push(format!("{col_expr} {op_str} ?{}{}", params.len(), collate));
274            }
275            ConditionValue::Number(n) => {
276                params.push(SqlValue::Float(*n));
277                where_parts.push(format!("{col_expr} {op_str} ?{}", params.len()));
278            }
279            ConditionValue::Bool(b) => {
280                params.push(SqlValue::Integer(if *b { 1 } else { 0 }));
281                where_parts.push(format!("{col_expr} {op_str} ?{}", params.len()));
282            }
283        }
284    }
285
286    // SELECT clause
287    for item in &query.return_items {
288        let var = item.variable();
289        if let Some((alias, kind)) = var_to_alias.get(var) {
290            match item {
291                ReturnItem::Property(_, prop) => {
292                    let col = property_to_column(prop, kind)?;
293                    select_parts.push(format!("{alias}.{col} AS {var}_{prop}"));
294                }
295                ReturnItem::Variable(_) => match kind {
296                    VarKind::Node => {
297                        select_parts.push(format!(
298                            "{alias}.id AS {var}_id, {alias}.namespace AS {var}_namespace, \
299                             {alias}.kind AS {var}_kind, {alias}.name AS {var}_name, \
300                             {alias}.properties AS {var}_properties, \
301                             {alias}.created_at AS {var}_created_at, \
302                             {alias}.updated_at AS {var}_updated_at"
303                        ));
304                    }
305                    VarKind::Edge => {
306                        select_parts.push(format!(
307                            "{alias}.id AS {var}_id, {alias}.source_id AS {var}_source, \
308                             {alias}.target_id AS {var}_target, \
309                             {alias}.relation AS {var}_relation, \
310                             {alias}.weight AS {var}_weight"
311                        ));
312                    }
313                },
314            }
315        } else {
316            return Err(QueryError::Compile(format!(
317                "unknown variable '{var}' in RETURN clause"
318            )));
319        }
320    }
321
322    let limit = query.limit.unwrap_or(opts.max_limit).min(opts.max_limit);
323    params.push(SqlValue::Integer(limit as i64));
324
325    let sql = format!(
326        "SELECT {} FROM {} {} WHERE {} LIMIT ?{}",
327        select_parts.join(", "),
328        from_parts.join(", "),
329        join_parts.join(" "),
330        where_parts.join(" AND "),
331        params.len(),
332    );
333
334    Ok(CompiledQuery {
335        sql,
336        params,
337        return_vars: query.return_items.clone(),
338    })
339}
340
341/// Compile variable-length patterns to a recursive CTE.
342///
343/// Depth is capped at min(requested, 10) — MAJ-2 (parameterized min_depth, not literal).
344fn compile_variable_length(
345    query: &GqlQuery,
346    opts: &CompileOptions,
347) -> Result<CompiledQuery, QueryError> {
348    let mut params: Vec<SqlValue> = Vec::new();
349    let mut var_to_alias: std::collections::HashMap<String, (String, VarKind)> =
350        std::collections::HashMap::new();
351
352    // For variable-length, we expect exactly: start_node -[*N..M]-> end_node.
353    // Mixed fixed+variable chains and additional trailing pattern elements are
354    // not yet supported — reject explicitly rather than silently dropping them.
355    let nodes: Vec<&NodePattern> = query.pattern.nodes().collect();
356    let edges: Vec<&EdgePattern> = query.pattern.edges().collect();
357
358    if nodes.len() != 2 || edges.len() != 1 || query.pattern.elements.len() != 3 {
359        return Err(QueryError::Unsupported(
360            "variable-length patterns must be a single start_node -[*N..M]-> end_node \
361             (mixed fixed/variable chains are not yet implemented)"
362                .into(),
363        ));
364    }
365
366    let start = &nodes[0];
367    let edge = &edges[0];
368    let end = &nodes[1];
369
370    // MAJ-2: depth cap — always parameterized, never injected as literal
371    let max_depth = edge.max_hops.min(10);
372    let min_depth = edge.min_hops;
373
374    // Build start-node conditions
375    let mut start_conditions: Vec<String> = vec!["s.deleted_at IS NULL".to_string()];
376    let ns_filter = namespace_filter("s", opts, &mut params);
377    if !ns_filter.is_empty() {
378        start_conditions.push(ns_filter.trim_start_matches(" AND ").to_string());
379    }
380
381    if let Some(ref kind) = start.kind {
382        params.push(SqlValue::Text(kind.clone()));
383        start_conditions.push(format!("s.kind = ?{}", params.len()));
384    }
385    for (key, val) in &start.properties {
386        params.push(SqlValue::Text(val.clone()));
387        if key == "name" {
388            start_conditions.push(format!("s.name = ?{} COLLATE NOCASE", params.len()));
389        } else {
390            start_conditions.push(format!(
391                "json_extract(s.properties, '$.{}') = ?{} COLLATE NOCASE",
392                key.replace('\'', "''"),
393                params.len()
394            ));
395        }
396    }
397
398    // Relation filter
399    let mut relation_condition = String::new();
400    if !edge.relations.is_empty() {
401        if edge.relations.len() == 1 {
402            params.push(SqlValue::Text(edge.relations[0].clone()));
403            relation_condition = format!(" AND e.relation = ?{}", params.len());
404        } else {
405            let placeholders: Vec<String> = edge
406                .relations
407                .iter()
408                .map(|r| {
409                    params.push(SqlValue::Text(r.clone()));
410                    format!("?{}", params.len())
411                })
412                .collect();
413            relation_condition = format!(" AND e.relation IN ({})", placeholders.join(", "));
414        }
415    }
416
417    // Edge namespace filter
418    let e_ns_filter = namespace_filter("e", opts, &mut params);
419
420    // Direction-dependent JOIN
421    let (seed_join, seed_next, recurse_join, recurse_next) = match edge.direction {
422        EdgeDirection::Out => (
423            "e.source_id = s.id",
424            "e.target_id",
425            "e.source_id = t.current_id",
426            "e.target_id",
427        ),
428        EdgeDirection::In => (
429            "e.target_id = s.id",
430            "e.source_id",
431            "e.target_id = t.current_id",
432            "e.source_id",
433        ),
434        EdgeDirection::Both => (
435            "(e.source_id = s.id OR e.target_id = s.id)",
436            "CASE WHEN e.source_id = s.id THEN e.target_id ELSE e.source_id END",
437            "(e.source_id = t.current_id OR e.target_id = t.current_id)",
438            "CASE WHEN e.source_id = t.current_id THEN e.target_id ELSE e.source_id END",
439        ),
440    };
441
442    params.push(SqlValue::Integer(max_depth as i64));
443    let depth_param = params.len();
444
445    // End-node conditions (applied in outer WHERE). `r` is always joined
446    // unconditionally below so these references resolve regardless of whether
447    // the end variable is projected.
448    let mut end_conditions: Vec<String> = vec!["r.deleted_at IS NULL".to_string()];
449    let r_ns_filter = namespace_filter("r", opts, &mut params);
450    if !r_ns_filter.is_empty() {
451        end_conditions.push(r_ns_filter.trim_start_matches(" AND ").to_string());
452    }
453    if let Some(ref kind) = end.kind {
454        params.push(SqlValue::Text(kind.clone()));
455        end_conditions.push(format!("r.kind = ?{}", params.len()));
456    }
457    for (key, val) in &end.properties {
458        params.push(SqlValue::Text(val.clone()));
459        if key == "name" {
460            end_conditions.push(format!("r.name = ?{} COLLATE NOCASE", params.len()));
461        } else {
462            end_conditions.push(format!(
463                "json_extract(r.properties, '$.{}') = ?{} COLLATE NOCASE",
464                key.replace('\'', "''"),
465                params.len()
466            ));
467        }
468    }
469
470    // WHERE clause conditions
471    for cond in &query.where_clause {
472        // Map variables to appropriate aliases
473        let col_alias = if start.variable.as_deref() == Some(&cond.variable) {
474            "s"
475        } else if end.variable.as_deref() == Some(&cond.variable) {
476            "r"
477        } else {
478            return Err(QueryError::Compile(format!(
479                "variable '{}' in WHERE not supported in variable-length pattern (only start/end node variables)",
480                cond.variable
481            )));
482        };
483
484        let col_expr = if cond.property == "name" || cond.property == "kind" {
485            format!("{col_alias}.{}", cond.property)
486        } else {
487            format!(
488                "json_extract({col_alias}.properties, '$.{}')",
489                cond.property.replace('\'', "''")
490            )
491        };
492
493        let op_str = match cond.op {
494            CompareOp::Eq => "=",
495            CompareOp::Neq => "!=",
496            CompareOp::Gt => ">",
497            CompareOp::Lt => "<",
498            CompareOp::Gte => ">=",
499            CompareOp::Lte => "<=",
500            CompareOp::Like => "LIKE",
501        };
502
503        match &cond.value {
504            ConditionValue::String(s) => {
505                params.push(SqlValue::Text(s.clone()));
506                let collate = if matches!(cond.op, CompareOp::Eq | CompareOp::Like) {
507                    " COLLATE NOCASE"
508                } else {
509                    ""
510                };
511                if col_alias == "s" {
512                    start_conditions
513                        .push(format!("{col_expr} {op_str} ?{}{collate}", params.len()));
514                } else {
515                    end_conditions.push(format!("{col_expr} {op_str} ?{}{collate}", params.len()));
516                }
517            }
518            ConditionValue::Number(n) => {
519                params.push(SqlValue::Float(*n));
520                if col_alias == "s" {
521                    start_conditions.push(format!("{col_expr} {op_str} ?{}", params.len()));
522                } else {
523                    end_conditions.push(format!("{col_expr} {op_str} ?{}", params.len()));
524                }
525            }
526            ConditionValue::Bool(b) => {
527                params.push(SqlValue::Integer(if *b { 1 } else { 0 }));
528                if col_alias == "s" {
529                    start_conditions.push(format!("{col_expr} {op_str} ?{}", params.len()));
530                } else {
531                    end_conditions.push(format!("{col_expr} {op_str} ?{}", params.len()));
532                }
533            }
534        }
535    }
536
537    // MAJ-2: min_depth is always a bound parameter, never a literal
538    if min_depth > 0 {
539        params.push(SqlValue::Integer(min_depth as i64));
540        end_conditions.push(format!("t.depth >= ?{}", params.len()));
541    }
542
543    let limit = query.limit.unwrap_or(opts.max_limit).min(opts.max_limit);
544    params.push(SqlValue::Integer(limit as i64));
545    let limit_param = params.len();
546
547    // Register variables
548    if let Some(ref var) = start.variable {
549        var_to_alias.insert(var.clone(), ("s".to_string(), VarKind::Node));
550    }
551    if let Some(ref var) = end.variable {
552        var_to_alias.insert(var.clone(), ("r".to_string(), VarKind::Node));
553    }
554    if let Some(ref var) = edge.variable {
555        var_to_alias.insert(var.clone(), ("e".to_string(), VarKind::Edge));
556    }
557
558    // Build SELECT based on RETURN items
559    let mut select_parts: Vec<String> = Vec::new();
560    let mut has_start = false;
561
562    for item in &query.return_items {
563        let var = item.variable();
564        if let Some((_, kind)) = var_to_alias.get(var) {
565            match item {
566                ReturnItem::Property(_, prop) => {
567                    let is_start = start.variable.as_deref() == Some(var);
568                    if *kind == VarKind::Node {
569                        let tbl = if is_start { "s" } else { "r" };
570                        if is_start {
571                            has_start = true;
572                        }
573                        let col = property_to_column(prop, kind)?;
574                        select_parts.push(format!("{tbl}.{col} AS {var}_{prop}"));
575                    } else {
576                        let col = match prop.as_str() {
577                            "id" => "via_edge",
578                            "relation" => "via_relation",
579                            "weight" => "via_weight",
580                            _ => {
581                                return Err(QueryError::Compile(format!(
582                                    "unknown edge property '{prop}' in RETURN projection. \
583                                     Valid: id, source_id, target_id, relation, weight"
584                                )));
585                            }
586                        };
587                        select_parts.push(format!("t.{col} AS {var}_{prop}"));
588                    }
589                }
590                ReturnItem::Variable(_) => match kind {
591                    VarKind::Node => {
592                        if start.variable.as_deref() == Some(var) {
593                            has_start = true;
594                            select_parts.push(format!(
595                                "s.id AS {var}_id, s.namespace AS {var}_namespace, \
596                                 s.kind AS {var}_kind, s.name AS {var}_name, \
597                                 s.properties AS {var}_properties, \
598                                 s.created_at AS {var}_created_at, \
599                                 s.updated_at AS {var}_updated_at"
600                            ));
601                        } else {
602                            select_parts.push(format!(
603                                "r.id AS {var}_id, r.namespace AS {var}_namespace, \
604                                 r.kind AS {var}_kind, r.name AS {var}_name, \
605                                 r.properties AS {var}_properties, \
606                                 r.created_at AS {var}_created_at, \
607                                 r.updated_at AS {var}_updated_at"
608                            ));
609                        }
610                    }
611                    VarKind::Edge => {
612                        select_parts.push(format!(
613                            "t.via_edge AS {var}_id, t.via_relation AS {var}_relation, \
614                             t.via_weight AS {var}_weight"
615                        ));
616                    }
617                },
618            }
619        } else {
620            return Err(QueryError::Compile(format!(
621                "unknown variable '{var}' in RETURN clause"
622            )));
623        }
624    }
625
626    // Always include traversal metadata
627    select_parts.push("t.depth AS _depth".to_string());
628    select_parts.push("t.total_weight AS _total_weight".to_string());
629
630    // `s` is optional (only joined if the start variable is projected); `r` is
631    // always joined because the outer WHERE always references `r.deleted_at`,
632    // `r.namespace` (and possibly r.kind / r.properties) regardless of whether
633    // it appears in RETURN.
634    let join_start = if has_start {
635        "JOIN entities s ON s.id = t.start_id"
636    } else {
637        ""
638    };
639    let join_end = "JOIN entities r ON r.id = t.current_id";
640
641    let sql = format!(
642        "WITH RECURSIVE traverse(start_id, current_id, depth, path, total_weight, via_edge, via_relation, via_weight) AS (\
643             SELECT s.id, {seed_next}, 1, s.id || ',' || {seed_next}, e.weight, \
644                    e.id, e.relation, e.weight \
645             FROM entities s \
646             JOIN graph_edges e ON {seed_join}{e_ns_filter}{relation_condition} \
647             WHERE {start_where} \
648             UNION ALL \
649             SELECT t.start_id, {recurse_next}, t.depth + 1, \
650                    t.path || ',' || {recurse_next}, \
651                    t.total_weight + e.weight, \
652                    e.id, e.relation, e.weight \
653             FROM traverse t \
654             JOIN graph_edges e ON {recurse_join}{e_ns_filter}{relation_condition} \
655             WHERE t.depth < ?{depth_param} \
656               AND (',' || t.path || ',') NOT LIKE '%,' || {recurse_next} || ',%' \
657         ) \
658         SELECT DISTINCT {select_cols} \
659         FROM traverse t \
660         {join_start} {join_end} \
661         WHERE {end_where} \
662         ORDER BY t.depth, t.total_weight DESC \
663         LIMIT ?{limit_param}",
664        seed_next = seed_next,
665        seed_join = seed_join,
666        e_ns_filter = e_ns_filter,
667        relation_condition = relation_condition,
668        start_where = start_conditions.join(" AND "),
669        recurse_next = recurse_next,
670        recurse_join = recurse_join,
671        depth_param = depth_param,
672        select_cols = select_parts.join(", "),
673        join_start = join_start,
674        join_end = join_end,
675        end_where = end_conditions.join(" AND "),
676        limit_param = limit_param,
677    );
678
679    Ok(CompiledQuery {
680        sql,
681        params,
682        return_vars: query.return_items.clone(),
683    })
684}
685
686#[derive(Clone, Copy, PartialEq, Eq)]
687enum VarKind {
688    Node,
689    Edge,
690}
691
692const NODE_COLUMNS: &[&str] = &[
693    "id",
694    "name",
695    "kind",
696    "namespace",
697    "description",
698    "properties",
699    "created_at",
700    "updated_at",
701];
702const EDGE_COLUMNS: &[&str] = &["id", "source_id", "target_id", "relation", "weight"];
703
704fn property_to_column<'a>(prop: &'a str, kind: &VarKind) -> Result<&'a str, QueryError> {
705    let valid = match kind {
706        VarKind::Node => NODE_COLUMNS,
707        VarKind::Edge => EDGE_COLUMNS,
708    };
709    if valid.contains(&prop) {
710        Ok(prop)
711    } else {
712        let kind_name = match kind {
713            VarKind::Node => "node",
714            VarKind::Edge => "edge",
715        };
716        Err(QueryError::Compile(format!(
717            "unknown {kind_name} property '{prop}' in RETURN projection. \
718             Valid: {}",
719            valid.join(", ")
720        )))
721    }
722}
723
724#[cfg(test)]
725mod tests {
726    use super::*;
727    use crate::parsers::gql;
728
729    fn opts() -> CompileOptions {
730        CompileOptions::default()
731    }
732
733    fn scoped(namespace: &str) -> CompileOptions {
734        CompileOptions {
735            scopes: vec![namespace.to_string()],
736            max_limit: 500,
737        }
738    }
739
740    #[test]
741    fn fixed_length_basic() {
742        let q =
743            gql::parse("MATCH (a:concept)-[e:introduced_by]->(b:paper) RETURN a, e, b LIMIT 10")
744                .unwrap();
745        let compiled = compile(&q, &opts()).unwrap();
746        assert!(compiled.sql.contains("JOIN graph_edges"));
747        assert!(compiled.sql.contains("LIMIT"));
748        assert_eq!(
749            compiled.return_vars,
750            vec![
751                ReturnItem::Variable("a".into()),
752                ReturnItem::Variable("e".into()),
753                ReturnItem::Variable("b".into()),
754            ]
755        );
756        // No recursive CTE for fixed-length
757        assert!(!compiled.sql.contains("WITH RECURSIVE"));
758    }
759
760    #[test]
761    fn namespace_scoping_injected() {
762        // Namespace must come from opts, never from the query
763        let q =
764            gql::parse("MATCH (a:concept)-[e:introduced_by]->(b:paper) RETURN a LIMIT 5").unwrap();
765        let compiled = compile(&q, &scoped("research")).unwrap();
766        assert!(compiled.sql.contains("namespace"));
767        // The namespace value must appear as a parameter, not a literal in SQL
768        let has_ns_param = compiled
769            .params
770            .iter()
771            .any(|p| matches!(p, SqlValue::Text(s) if s == "research"));
772        assert!(has_ns_param, "namespace must be a bound parameter");
773    }
774
775    #[test]
776    fn edge_property_whitelist_rejects_unknown() {
777        // MAJ-1: only 'relation' and 'weight' are queryable edge properties
778        let q = gql::parse("MATCH (a)-[e:introduced_by]->(b) WHERE e.source_id = 'x' RETURN a")
779            .unwrap();
780        let result = compile(&q, &opts());
781        assert!(result.is_err());
782        let err = result.unwrap_err().to_string();
783        assert!(
784            err.contains("source_id") || err.contains("not queryable"),
785            "error: {err}"
786        );
787    }
788
789    #[test]
790    fn edge_property_relation_allowed() {
791        let q = gql::parse("MATCH (a)-[e]->(b) WHERE e.relation = 'extends' RETURN a").unwrap();
792        let result = compile(&q, &opts());
793        assert!(
794            result.is_ok(),
795            "relation should be allowed: {:?}",
796            result.err()
797        );
798    }
799
800    #[test]
801    fn edge_property_weight_allowed() {
802        let q = gql::parse("MATCH (a)-[e]->(b) WHERE e.weight > 0.5 RETURN a").unwrap();
803        let result = compile(&q, &opts());
804        assert!(
805            result.is_ok(),
806            "weight should be allowed: {:?}",
807            result.err()
808        );
809    }
810
811    #[test]
812    fn variable_length_uses_cte() {
813        let q =
814            gql::parse("MATCH (a {name: 'LoRA'})-[:extends*1..3]->(b) RETURN b LIMIT 20").unwrap();
815        let compiled = compile(&q, &opts()).unwrap();
816        assert!(compiled.sql.contains("WITH RECURSIVE"));
817        assert!(compiled.sql.contains("traverse"));
818    }
819
820    #[test]
821    fn depth_cap_at_ten() {
822        // MAJ-2: depth capped at 10 regardless of query request
823        let q = gql::parse("MATCH (a)-[:extends*1..50]->(b) RETURN b").unwrap();
824        let compiled = compile(&q, &opts()).unwrap();
825        // The depth parameter must be <= 10
826        let depth_val = compiled.params.iter().find_map(|p| {
827            if let SqlValue::Integer(n) = p {
828                Some(*n)
829            } else {
830                None
831            }
832        });
833        assert!(depth_val.unwrap() <= 10, "depth must be capped at 10");
834    }
835
836    #[test]
837    fn limit_capped_by_max_limit() {
838        // Query requests 1000, max_limit is 500 — result should be 500
839        let q = gql::parse("MATCH (a:concept)-[e]->(b) RETURN a LIMIT 1000").unwrap();
840        let compiled = compile(&q, &opts()).unwrap();
841        let limit_param = compiled.params.last().unwrap();
842        assert!(
843            matches!(limit_param, SqlValue::Integer(500)),
844            "expected Integer(500), got {limit_param:?}"
845        );
846    }
847
848    #[test]
849    fn compile_rejects_unknown_relation() {
850        let q = gql::parse("MATCH (a)-[:not_a_relation]->(b) RETURN a").unwrap();
851        let err = compile(&q, &opts()).unwrap_err();
852        let msg = err.to_string();
853        assert!(msg.contains("not_a_relation"), "msg: {msg}");
854    }
855
856    #[test]
857    fn compile_unknown_kind_passes_through() {
858        // Pack-agnostic: any string is accepted as an entity kind at the query layer.
859        // Validation is a pack-handler concern.
860        let q = gql::parse("MATCH (a:gizmo)-[:extends]->(b) RETURN a").unwrap();
861        let compiled = compile(&q, &opts()).unwrap();
862        let has_gizmo = compiled
863            .params
864            .iter()
865            .any(|p| matches!(p, SqlValue::Text(s) if s == "gizmo"));
866        assert!(
867            has_gizmo,
868            "pack-agnostic: unknown kind must pass through into SQL params"
869        );
870    }
871
872    #[test]
873    fn compile_kind_passes_through_unchanged() {
874        // Pack-agnostic: 'paper' is no longer normalized to 'document' at the query layer.
875        // The string passes through as-is.
876        let q =
877            gql::parse("MATCH (a:paper)-[:introduced_by]->(b:concept) RETURN a LIMIT 1").unwrap();
878        let compiled = compile(&q, &opts()).unwrap();
879        let has_paper = compiled
880            .params
881            .iter()
882            .any(|p| matches!(p, SqlValue::Text(s) if s == "paper"));
883        assert!(
884            has_paper,
885            "kind 'paper' must pass through unchanged into SQL params"
886        );
887    }
888
889    #[test]
890    fn compile_rejects_namespace_in_where() {
891        let q =
892            gql::parse("MATCH (a:concept)-[:extends]->(b) WHERE a.namespace = 'other' RETURN a")
893                .unwrap();
894        let err = compile(&q, &opts()).unwrap_err();
895        assert!(err.to_string().contains("namespace"), "msg: {err}");
896    }
897
898    #[test]
899    fn compile_rejects_unknown_relation_in_where() {
900        let q = gql::parse("MATCH (a)-[e:extends]->(b) WHERE e.relation = 'related_to' RETURN a")
901            .unwrap();
902        let err = compile(&q, &opts()).unwrap_err();
903        assert!(err.to_string().contains("related_to"), "msg: {err}");
904    }
905
906    #[test]
907    fn compile_kind_in_where_passes_through_unchanged() {
908        // Pack-agnostic: kind strings in WHERE conditions pass through as-is.
909        let q = gql::parse("MATCH (a)-[:extends]->(b) WHERE a.kind = 'paper' RETURN a").unwrap();
910        let compiled = compile(&q, &opts()).unwrap();
911        let has_paper = compiled
912            .params
913            .iter()
914            .any(|p| matches!(p, SqlValue::Text(s) if s == "paper"));
915        assert!(
916            has_paper,
917            "kind 'paper' must pass through unchanged into SQL params"
918        );
919    }
920
921    #[test]
922    fn variable_length_return_start_only_joins_end_entity() {
923        // Even when only the start variable is projected, the outer query
924        // references `r.deleted_at` / `r.namespace`, so entities r must be
925        // joined unconditionally.
926        let q = gql::parse("MATCH (a:concept)-[:extends*1..3]->(b) RETURN a LIMIT 10").unwrap();
927        let compiled = compile(&q, &opts()).unwrap();
928        assert!(
929            compiled.sql.contains("JOIN entities r"),
930            "entities r must always be joined when r.* conditions are emitted; sql: {}",
931            compiled.sql
932        );
933    }
934
935    #[test]
936    fn variable_length_trailing_pattern_unsupported() {
937        let q = gql::parse("MATCH (a)-[:extends*1..3]->(b)-[:implements]->(c) RETURN b").unwrap();
938        let err = compile(&q, &opts()).unwrap_err();
939        assert!(
940            matches!(err, QueryError::Unsupported(_)),
941            "expected Unsupported, got {err:?}"
942        );
943    }
944
945    #[test]
946    fn variable_length_mixed_chain_unsupported() {
947        // Mixed fixed + variable in one chain — has_variable_length() triggers
948        // the variable-length path, which must reject because edges.len() > 1.
949        let q = gql::parse("MATCH (a)-[:extends]->(b)-[:implements*1..2]->(c) RETURN c").unwrap();
950        let err = compile(&q, &opts()).unwrap_err();
951        assert!(matches!(err, QueryError::Unsupported(_)), "got {err:?}");
952    }
953
954    #[test]
955    fn sparql_star_rejected_as_unsupported() {
956        use crate::parsers::sparql;
957        let err = sparql::parse("SELECT ?a ?b WHERE { ?a :extends* ?b . }").unwrap_err();
958        assert!(matches!(err, QueryError::Unsupported(_)), "got {err:?}");
959    }
960
961    #[test]
962    fn return_property_projection_compiles() {
963        let q =
964            gql::parse("MATCH (a:concept)-[e:extends]->(b:concept) RETURN a.name, b.name LIMIT 5")
965                .unwrap();
966        let compiled = compile(&q, &opts()).unwrap();
967        // Node aliases are n0, n1; the SQL uses `alias.col AS var_prop`
968        assert!(
969            compiled.sql.contains(".name AS a_name"),
970            "sql: {}",
971            compiled.sql
972        );
973        assert!(
974            compiled.sql.contains(".name AS b_name"),
975            "sql: {}",
976            compiled.sql
977        );
978        assert!(
979            !compiled.sql.contains("a_kind"),
980            "should not emit full node columns"
981        );
982    }
983
984    #[test]
985    fn return_unknown_node_property_rejected() {
986        let q = gql::parse("MATCH (a:concept)-[:extends]->(b) RETURN a.domain LIMIT 5").unwrap();
987        let err = compile(&q, &opts()).unwrap_err();
988        assert!(
989            matches!(err, QueryError::Compile(ref msg) if msg.contains("unknown node property 'domain'")),
990            "got {err:?}"
991        );
992    }
993
994    #[test]
995    fn return_unknown_edge_property_rejected() {
996        let q = gql::parse("MATCH (a)-[e:extends]->(b) RETURN e.label LIMIT 5").unwrap();
997        let err = compile(&q, &opts()).unwrap_err();
998        assert!(
999            matches!(err, QueryError::Compile(ref msg) if msg.contains("unknown edge property 'label'")),
1000            "got {err:?}"
1001        );
1002    }
1003
1004    #[test]
1005    fn return_valid_edge_property_compiles() {
1006        let q =
1007            gql::parse("MATCH (a)-[e:extends]->(b) RETURN e.relation, e.weight LIMIT 5").unwrap();
1008        let compiled = compile(&q, &opts()).unwrap();
1009        // Edge alias is e0; SQL: `e0.relation AS e_relation`
1010        assert!(
1011            compiled.sql.contains(".relation AS e_relation"),
1012            "sql: {}",
1013            compiled.sql
1014        );
1015        assert!(
1016            compiled.sql.contains(".weight AS e_weight"),
1017            "sql: {}",
1018            compiled.sql
1019        );
1020    }
1021}