Skip to main content

khive_query/compilers/
sql.rs

1//! Compile GQL AST to parameterized SQL.
2//!
3//! Two compilation paths:
4//! - Fixed-length patterns (all edges *1..1) → JOIN chain
5//! - Variable-length patterns (any edge *N..M where M>1) → recursive CTE
6//!
7//! Security invariants (MAJ-1/MAJ-2/MAJ-3 from critic review):
8//! - Namespace injection: WHERE clause always comes from CompileOptions.scopes, never the query.
9//! - Edge property whitelist: only `relation` and `weight` are queryable edge columns.
10//! - Depth cap: recursive CTE depth is min(requested, 10).
11
12use crate::ast::*;
13use crate::error::QueryError;
14use crate::validate::{validate_with_warnings, MAX_DEPTH};
15use khive_storage::types::SqlValue;
16
17#[derive(Debug)]
18pub struct CompiledQuery {
19    pub sql: String,
20    pub params: Vec<SqlValue>,
21    pub return_vars: Vec<ReturnItem>,
22    pub warnings: Vec<String>,
23}
24
25pub struct CompileOptions {
26    /// Namespace scope. Empty = cross-namespace (all). Non-empty = filter to these namespaces.
27    pub scopes: Vec<String>,
28    /// Hard limit cap (server-side safety). Query limit is min(requested, max_limit).
29    pub max_limit: usize,
30}
31
32impl Default for CompileOptions {
33    fn default() -> Self {
34        Self {
35            scopes: Vec::new(),
36            max_limit: 500,
37        }
38    }
39}
40
41pub fn compile(query: &GqlQuery, opts: &CompileOptions) -> Result<CompiledQuery, QueryError> {
42    if query.pattern.elements.is_empty() {
43        return Err(QueryError::Compile("empty pattern".into()));
44    }
45
46    // Validate edge relations + structural rules before emitting SQL.
47    let mut query = query.clone();
48    let warnings = validate_with_warnings(&mut query)?;
49
50    let mut compiled = if query.pattern.has_variable_length() {
51        compile_variable_length(&query, opts)?
52    } else {
53        compile_fixed_length(&query, opts)?
54    };
55    compiled.warnings = warnings;
56    Ok(compiled)
57}
58
59fn namespace_filter(alias: &str, opts: &CompileOptions, params: &mut Vec<SqlValue>) -> String {
60    if opts.scopes.is_empty() {
61        String::new()
62    } else if opts.scopes.len() == 1 {
63        params.push(SqlValue::Text(opts.scopes[0].clone()));
64        format!(" AND {alias}.namespace = ?{}", params.len())
65    } else {
66        let placeholders: Vec<String> = opts
67            .scopes
68            .iter()
69            .map(|s| {
70                params.push(SqlValue::Text(s.clone()));
71                format!("?{}", params.len())
72            })
73            .collect();
74        format!(" AND {alias}.namespace IN ({})", placeholders.join(", "))
75    }
76}
77
78/// Compile fixed-length patterns to a chain of JOINs.
79///
80/// MATCH (a:concept)-[e:introduced_by]->(b:paper) WHERE ... RETURN a, e, b LIMIT 10
81/// →
82/// SELECT a.*, e.*, b.*
83/// FROM entities a
84/// JOIN graph_edges e ON e.source_id = a.id
85/// JOIN entities b ON b.id = e.target_id
86/// WHERE a.kind = 'concept' AND e.relation = 'introduced_by' AND b.kind = 'paper'
87///   AND a.deleted_at IS NULL AND b.deleted_at IS NULL
88/// LIMIT 10
89fn compile_fixed_length(
90    query: &GqlQuery,
91    opts: &CompileOptions,
92) -> Result<CompiledQuery, QueryError> {
93    let mut params: Vec<SqlValue> = Vec::new();
94    let mut from_parts: Vec<String> = Vec::new();
95    let mut join_parts: Vec<String> = Vec::new();
96    let mut where_parts: Vec<String> = Vec::new();
97    let mut select_parts: Vec<String> = Vec::new();
98
99    let mut node_aliases: Vec<String> = Vec::new();
100    let mut edge_aliases: Vec<String> = Vec::new();
101    let mut var_to_alias: std::collections::HashMap<String, (String, VarKind)> =
102        std::collections::HashMap::new();
103
104    let mut node_idx = 0usize;
105    let mut edge_idx = 0usize;
106
107    for element in &query.pattern.elements {
108        match element {
109            PatternElement::Node(np) => {
110                let alias = format!("n{node_idx}");
111                node_aliases.push(alias.clone());
112
113                if node_idx == 0 {
114                    from_parts.push(format!("entities {alias}"));
115                }
116
117                where_parts.push(format!("{alias}.deleted_at IS NULL"));
118
119                let ns_filter = namespace_filter(&alias, opts, &mut params);
120                if !ns_filter.is_empty() {
121                    where_parts.push(ns_filter.trim_start_matches(" AND ").to_string());
122                }
123
124                if let Some(ref kind) = np.kind {
125                    params.push(SqlValue::Text(kind.clone()));
126                    where_parts.push(format!("{alias}.kind = ?{}", params.len()));
127                }
128
129                for (key, val) in &np.properties {
130                    params.push(SqlValue::Text(val.clone()));
131                    if key == "name" {
132                        where_parts
133                            .push(format!("{alias}.name = ?{} COLLATE NOCASE", params.len()));
134                    } else {
135                        where_parts.push(format!(
136                            "json_extract({alias}.properties, '$.{}') = ?{} COLLATE NOCASE",
137                            key.replace('\'', "''"),
138                            params.len()
139                        ));
140                    }
141                }
142
143                if let Some(ref var) = np.variable {
144                    var_to_alias.insert(var.clone(), (alias.clone(), VarKind::Node));
145                }
146
147                node_idx += 1;
148            }
149            PatternElement::Edge(ep) => {
150                let e_alias = format!("e{edge_idx}");
151                let prev_node = &node_aliases[node_aliases.len() - 1];
152
153                edge_aliases.push(e_alias.clone());
154
155                let (source_join, target_join) = match ep.direction {
156                    EdgeDirection::Out => (
157                        format!("{e_alias}.source_id = {prev_node}.id"),
158                        "target_id",
159                    ),
160                    EdgeDirection::In => (
161                        format!("{e_alias}.target_id = {prev_node}.id"),
162                        "source_id",
163                    ),
164                    EdgeDirection::Both => (
165                        format!(
166                            "({e_alias}.source_id = {prev_node}.id OR {e_alias}.target_id = {prev_node}.id)"
167                        ),
168                        "CASE_BOTH",
169                    ),
170                };
171
172                let next_alias = format!("n{}", node_idx);
173
174                let next_join_col = if target_join == "CASE_BOTH" {
175                    format!(
176                        "CASE WHEN {e_alias}.source_id = {prev_node}.id THEN {e_alias}.target_id ELSE {e_alias}.source_id END"
177                    )
178                } else {
179                    format!("{e_alias}.{target_join}")
180                };
181
182                join_parts.push(format!("JOIN graph_edges {e_alias} ON {source_join}"));
183
184                let ens_filter = namespace_filter(&e_alias, opts, &mut params);
185                if !ens_filter.is_empty() {
186                    where_parts.push(ens_filter.trim_start_matches(" AND ").to_string());
187                }
188
189                join_parts.push(format!(
190                    "JOIN entities {next_alias} ON {next_alias}.id = {next_join_col}"
191                ));
192
193                if !ep.relations.is_empty() {
194                    if ep.relations.len() == 1 {
195                        params.push(SqlValue::Text(ep.relations[0].clone()));
196                        where_parts.push(format!("{e_alias}.relation = ?{}", params.len()));
197                    } else {
198                        let placeholders: Vec<String> = ep
199                            .relations
200                            .iter()
201                            .map(|r| {
202                                params.push(SqlValue::Text(r.clone()));
203                                format!("?{}", params.len())
204                            })
205                            .collect();
206                        where_parts.push(format!(
207                            "{e_alias}.relation IN ({})",
208                            placeholders.join(", ")
209                        ));
210                    }
211                }
212
213                if let Some(ref var) = ep.variable {
214                    var_to_alias.insert(var.clone(), (e_alias.clone(), VarKind::Edge));
215                }
216
217                edge_idx += 1;
218            }
219        }
220    }
221
222    // WHERE clause conditions from GQL WHERE
223    for cond in &query.where_clause {
224        let (alias, kind) = var_to_alias.get(&cond.variable).ok_or_else(|| {
225            QueryError::Compile(format!(
226                "unknown variable '{}' in WHERE clause",
227                cond.variable
228            ))
229        })?;
230
231        let col_expr = match kind {
232            VarKind::Node => {
233                if cond.property == "name"
234                    || cond.property == "kind"
235                    || cond.property == "namespace"
236                {
237                    format!("{alias}.{}", cond.property)
238                } else {
239                    format!(
240                        "json_extract({alias}.properties, '$.{}')",
241                        cond.property.replace('\'', "''")
242                    )
243                }
244            }
245            VarKind::Edge => {
246                // MAJ-1: edge property whitelist — only relation and weight are queryable
247                match cond.property.as_str() {
248                    "relation" | "weight" => format!("{alias}.{}", cond.property),
249                    other => {
250                        return Err(QueryError::Validation(format!(
251                            "edge property '{other}' not queryable; use 'relation' or 'weight'"
252                        )))
253                    }
254                }
255            }
256        };
257
258        let op_str = match cond.op {
259            CompareOp::Eq => "=",
260            CompareOp::Neq => "!=",
261            CompareOp::Gt => ">",
262            CompareOp::Lt => "<",
263            CompareOp::Gte => ">=",
264            CompareOp::Lte => "<=",
265            CompareOp::Like => "LIKE",
266        };
267
268        match &cond.value {
269            ConditionValue::String(s) => {
270                params.push(SqlValue::Text(s.clone()));
271                let collate = if matches!(cond.op, CompareOp::Eq | CompareOp::Like) {
272                    " COLLATE NOCASE"
273                } else {
274                    ""
275                };
276                where_parts.push(format!("{col_expr} {op_str} ?{}{}", params.len(), collate));
277            }
278            ConditionValue::Number(n) => {
279                params.push(SqlValue::Float(*n));
280                where_parts.push(format!("{col_expr} {op_str} ?{}", params.len()));
281            }
282            ConditionValue::Bool(b) => {
283                params.push(SqlValue::Integer(if *b { 1 } else { 0 }));
284                where_parts.push(format!("{col_expr} {op_str} ?{}", params.len()));
285            }
286        }
287    }
288
289    // SELECT clause
290    for item in &query.return_items {
291        let var = item.variable();
292        if let Some((alias, kind)) = var_to_alias.get(var) {
293            match item {
294                ReturnItem::Property(_, prop) => {
295                    let col = property_to_column(prop, kind)?;
296                    select_parts.push(format!("{alias}.{col} AS {var}_{prop}"));
297                }
298                ReturnItem::Variable(_) => match kind {
299                    VarKind::Node => {
300                        select_parts.push(format!(
301                            "{alias}.id AS {var}_id, {alias}.namespace AS {var}_namespace, \
302                             {alias}.kind AS {var}_kind, {alias}.name AS {var}_name, \
303                             {alias}.properties AS {var}_properties, \
304                             {alias}.created_at AS {var}_created_at, \
305                             {alias}.updated_at AS {var}_updated_at"
306                        ));
307                    }
308                    VarKind::Edge => {
309                        select_parts.push(format!(
310                            "{alias}.id AS {var}_id, {alias}.source_id AS {var}_source, \
311                             {alias}.target_id AS {var}_target, \
312                             {alias}.relation AS {var}_relation, \
313                             {alias}.weight AS {var}_weight"
314                        ));
315                    }
316                },
317            }
318        } else {
319            return Err(QueryError::Compile(format!(
320                "unknown variable '{var}' in RETURN clause"
321            )));
322        }
323    }
324
325    let limit = query.limit.unwrap_or(opts.max_limit).min(opts.max_limit);
326    params.push(SqlValue::Integer(limit as i64));
327
328    let sql = format!(
329        "SELECT {} FROM {} {} WHERE {} LIMIT ?{}",
330        select_parts.join(", "),
331        from_parts.join(", "),
332        join_parts.join(" "),
333        where_parts.join(" AND "),
334        params.len(),
335    );
336
337    Ok(CompiledQuery {
338        sql,
339        params,
340        return_vars: query.return_items.clone(),
341        warnings: Vec::new(),
342    })
343}
344
345/// Compile variable-length patterns to a recursive CTE.
346///
347/// Depth is capped at min(requested, 10) — MAJ-2 (parameterized min_depth, not literal).
348fn compile_variable_length(
349    query: &GqlQuery,
350    opts: &CompileOptions,
351) -> Result<CompiledQuery, QueryError> {
352    let mut params: Vec<SqlValue> = Vec::new();
353    let mut var_to_alias: std::collections::HashMap<String, (String, VarKind)> =
354        std::collections::HashMap::new();
355
356    // For variable-length, we expect exactly: start_node -[*N..M]-> end_node.
357    // Mixed fixed+variable chains and additional trailing pattern elements are
358    // not yet supported — reject explicitly rather than silently dropping them.
359    let nodes: Vec<&NodePattern> = query.pattern.nodes().collect();
360    let edges: Vec<&EdgePattern> = query.pattern.edges().collect();
361
362    if nodes.len() != 2 || edges.len() != 1 || query.pattern.elements.len() != 3 {
363        return Err(QueryError::Unsupported(
364            "variable-length patterns must be a single start_node -[*N..M]-> end_node \
365             (mixed fixed/variable chains are not yet implemented)"
366                .into(),
367        ));
368    }
369
370    let start = &nodes[0];
371    let edge = &edges[0];
372    let end = &nodes[1];
373
374    // MAJ-2: depth cap — always parameterized, never injected as literal
375    let max_depth = edge.max_hops.min(MAX_DEPTH);
376    let min_depth = edge.min_hops;
377
378    // Build start-node conditions
379    let mut start_conditions: Vec<String> = vec!["s.deleted_at IS NULL".to_string()];
380    let ns_filter = namespace_filter("s", opts, &mut params);
381    if !ns_filter.is_empty() {
382        start_conditions.push(ns_filter.trim_start_matches(" AND ").to_string());
383    }
384
385    if let Some(ref kind) = start.kind {
386        params.push(SqlValue::Text(kind.clone()));
387        start_conditions.push(format!("s.kind = ?{}", params.len()));
388    }
389    for (key, val) in &start.properties {
390        params.push(SqlValue::Text(val.clone()));
391        if key == "name" {
392            start_conditions.push(format!("s.name = ?{} COLLATE NOCASE", params.len()));
393        } else {
394            start_conditions.push(format!(
395                "json_extract(s.properties, '$.{}') = ?{} COLLATE NOCASE",
396                key.replace('\'', "''"),
397                params.len()
398            ));
399        }
400    }
401
402    // Relation filter
403    let mut relation_condition = String::new();
404    if !edge.relations.is_empty() {
405        if edge.relations.len() == 1 {
406            params.push(SqlValue::Text(edge.relations[0].clone()));
407            relation_condition = format!(" AND e.relation = ?{}", params.len());
408        } else {
409            let placeholders: Vec<String> = edge
410                .relations
411                .iter()
412                .map(|r| {
413                    params.push(SqlValue::Text(r.clone()));
414                    format!("?{}", params.len())
415                })
416                .collect();
417            relation_condition = format!(" AND e.relation IN ({})", placeholders.join(", "));
418        }
419    }
420
421    // Edge namespace filter
422    let e_ns_filter = namespace_filter("e", opts, &mut params);
423
424    // Direction-dependent JOIN
425    let (seed_join, seed_next, recurse_join, recurse_next) = match edge.direction {
426        EdgeDirection::Out => (
427            "e.source_id = s.id",
428            "e.target_id",
429            "e.source_id = t.current_id",
430            "e.target_id",
431        ),
432        EdgeDirection::In => (
433            "e.target_id = s.id",
434            "e.source_id",
435            "e.target_id = t.current_id",
436            "e.source_id",
437        ),
438        EdgeDirection::Both => (
439            "(e.source_id = s.id OR e.target_id = s.id)",
440            "CASE WHEN e.source_id = s.id THEN e.target_id ELSE e.source_id END",
441            "(e.source_id = t.current_id OR e.target_id = t.current_id)",
442            "CASE WHEN e.source_id = t.current_id THEN e.target_id ELSE e.source_id END",
443        ),
444    };
445
446    params.push(SqlValue::Integer(max_depth as i64));
447    let depth_param = params.len();
448
449    // End-node conditions (applied in outer WHERE). `r` is always joined
450    // unconditionally below so these references resolve regardless of whether
451    // the end variable is projected.
452    let mut end_conditions: Vec<String> = vec!["r.deleted_at IS NULL".to_string()];
453    let r_ns_filter = namespace_filter("r", opts, &mut params);
454    if !r_ns_filter.is_empty() {
455        end_conditions.push(r_ns_filter.trim_start_matches(" AND ").to_string());
456    }
457    if let Some(ref kind) = end.kind {
458        params.push(SqlValue::Text(kind.clone()));
459        end_conditions.push(format!("r.kind = ?{}", params.len()));
460    }
461    for (key, val) in &end.properties {
462        params.push(SqlValue::Text(val.clone()));
463        if key == "name" {
464            end_conditions.push(format!("r.name = ?{} COLLATE NOCASE", params.len()));
465        } else {
466            end_conditions.push(format!(
467                "json_extract(r.properties, '$.{}') = ?{} COLLATE NOCASE",
468                key.replace('\'', "''"),
469                params.len()
470            ));
471        }
472    }
473
474    // WHERE clause conditions
475    for cond in &query.where_clause {
476        // Map variables to appropriate aliases
477        let col_alias = if start.variable.as_deref() == Some(&cond.variable) {
478            "s"
479        } else if end.variable.as_deref() == Some(&cond.variable) {
480            "r"
481        } else {
482            return Err(QueryError::Compile(format!(
483                "variable '{}' in WHERE not supported in variable-length pattern (only start/end node variables)",
484                cond.variable
485            )));
486        };
487
488        let col_expr = if cond.property == "name" || cond.property == "kind" {
489            format!("{col_alias}.{}", cond.property)
490        } else {
491            format!(
492                "json_extract({col_alias}.properties, '$.{}')",
493                cond.property.replace('\'', "''")
494            )
495        };
496
497        let op_str = match cond.op {
498            CompareOp::Eq => "=",
499            CompareOp::Neq => "!=",
500            CompareOp::Gt => ">",
501            CompareOp::Lt => "<",
502            CompareOp::Gte => ">=",
503            CompareOp::Lte => "<=",
504            CompareOp::Like => "LIKE",
505        };
506
507        match &cond.value {
508            ConditionValue::String(s) => {
509                params.push(SqlValue::Text(s.clone()));
510                let collate = if matches!(cond.op, CompareOp::Eq | CompareOp::Like) {
511                    " COLLATE NOCASE"
512                } else {
513                    ""
514                };
515                if col_alias == "s" {
516                    start_conditions
517                        .push(format!("{col_expr} {op_str} ?{}{collate}", params.len()));
518                } else {
519                    end_conditions.push(format!("{col_expr} {op_str} ?{}{collate}", params.len()));
520                }
521            }
522            ConditionValue::Number(n) => {
523                params.push(SqlValue::Float(*n));
524                if col_alias == "s" {
525                    start_conditions.push(format!("{col_expr} {op_str} ?{}", params.len()));
526                } else {
527                    end_conditions.push(format!("{col_expr} {op_str} ?{}", params.len()));
528                }
529            }
530            ConditionValue::Bool(b) => {
531                params.push(SqlValue::Integer(if *b { 1 } else { 0 }));
532                if col_alias == "s" {
533                    start_conditions.push(format!("{col_expr} {op_str} ?{}", params.len()));
534                } else {
535                    end_conditions.push(format!("{col_expr} {op_str} ?{}", params.len()));
536                }
537            }
538        }
539    }
540
541    // MAJ-2: min_depth is always a bound parameter, never a literal
542    if min_depth > 0 {
543        params.push(SqlValue::Integer(min_depth as i64));
544        end_conditions.push(format!("t.depth >= ?{}", params.len()));
545    }
546
547    let limit = query.limit.unwrap_or(opts.max_limit).min(opts.max_limit);
548    params.push(SqlValue::Integer(limit as i64));
549    let limit_param = params.len();
550
551    // Register variables
552    if let Some(ref var) = start.variable {
553        var_to_alias.insert(var.clone(), ("s".to_string(), VarKind::Node));
554    }
555    if let Some(ref var) = end.variable {
556        var_to_alias.insert(var.clone(), ("r".to_string(), VarKind::Node));
557    }
558    if let Some(ref var) = edge.variable {
559        var_to_alias.insert(var.clone(), ("e".to_string(), VarKind::Edge));
560    }
561
562    // Build SELECT based on RETURN items
563    let mut select_parts: Vec<String> = Vec::new();
564    let mut has_start = false;
565
566    for item in &query.return_items {
567        let var = item.variable();
568        if let Some((_, kind)) = var_to_alias.get(var) {
569            match item {
570                ReturnItem::Property(_, prop) => {
571                    let is_start = start.variable.as_deref() == Some(var);
572                    if *kind == VarKind::Node {
573                        let tbl = if is_start { "s" } else { "r" };
574                        if is_start {
575                            has_start = true;
576                        }
577                        let col = property_to_column(prop, kind)?;
578                        select_parts.push(format!("{tbl}.{col} AS {var}_{prop}"));
579                    } else {
580                        let col = match prop.as_str() {
581                            "id" => "via_edge",
582                            "relation" => "via_relation",
583                            "weight" => "via_weight",
584                            _ => {
585                                return Err(QueryError::Compile(format!(
586                                    "unknown edge property '{prop}' in RETURN projection. \
587                                     Valid: id, source_id, target_id, relation, weight"
588                                )));
589                            }
590                        };
591                        select_parts.push(format!("t.{col} AS {var}_{prop}"));
592                    }
593                }
594                ReturnItem::Variable(_) => match kind {
595                    VarKind::Node => {
596                        if start.variable.as_deref() == Some(var) {
597                            has_start = true;
598                            select_parts.push(format!(
599                                "s.id AS {var}_id, s.namespace AS {var}_namespace, \
600                                 s.kind AS {var}_kind, s.name AS {var}_name, \
601                                 s.properties AS {var}_properties, \
602                                 s.created_at AS {var}_created_at, \
603                                 s.updated_at AS {var}_updated_at"
604                            ));
605                        } else {
606                            select_parts.push(format!(
607                                "r.id AS {var}_id, r.namespace AS {var}_namespace, \
608                                 r.kind AS {var}_kind, r.name AS {var}_name, \
609                                 r.properties AS {var}_properties, \
610                                 r.created_at AS {var}_created_at, \
611                                 r.updated_at AS {var}_updated_at"
612                            ));
613                        }
614                    }
615                    VarKind::Edge => {
616                        select_parts.push(format!(
617                            "t.via_edge AS {var}_id, t.via_relation AS {var}_relation, \
618                             t.via_weight AS {var}_weight"
619                        ));
620                    }
621                },
622            }
623        } else {
624            return Err(QueryError::Compile(format!(
625                "unknown variable '{var}' in RETURN clause"
626            )));
627        }
628    }
629
630    // Always include traversal metadata
631    select_parts.push("t.depth AS _depth".to_string());
632    select_parts.push("t.total_weight AS _total_weight".to_string());
633
634    // `s` is optional (only joined if the start variable is projected); `r` is
635    // always joined because the outer WHERE always references `r.deleted_at`,
636    // `r.namespace` (and possibly r.kind / r.properties) regardless of whether
637    // it appears in RETURN.
638    let join_start = if has_start {
639        "JOIN entities s ON s.id = t.start_id"
640    } else {
641        ""
642    };
643    let join_end = "JOIN entities r ON r.id = t.current_id";
644
645    let sql = format!(
646        "WITH RECURSIVE traverse(start_id, current_id, depth, path, total_weight, via_edge, via_relation, via_weight) AS (\
647             SELECT s.id, {seed_next}, 1, s.id || ',' || {seed_next}, e.weight, \
648                    e.id, e.relation, e.weight \
649             FROM entities s \
650             JOIN graph_edges e ON {seed_join}{e_ns_filter}{relation_condition} \
651             WHERE {start_where} \
652             UNION ALL \
653             SELECT t.start_id, {recurse_next}, t.depth + 1, \
654                    t.path || ',' || {recurse_next}, \
655                    t.total_weight + e.weight, \
656                    e.id, e.relation, e.weight \
657             FROM traverse t \
658             JOIN graph_edges e ON {recurse_join}{e_ns_filter}{relation_condition} \
659             WHERE t.depth < ?{depth_param} \
660               AND (',' || t.path || ',') NOT LIKE '%,' || {recurse_next} || ',%' \
661         ) \
662         SELECT DISTINCT {select_cols} \
663         FROM traverse t \
664         {join_start} {join_end} \
665         WHERE {end_where} \
666         ORDER BY t.depth, t.total_weight DESC \
667         LIMIT ?{limit_param}",
668        seed_next = seed_next,
669        seed_join = seed_join,
670        e_ns_filter = e_ns_filter,
671        relation_condition = relation_condition,
672        start_where = start_conditions.join(" AND "),
673        recurse_next = recurse_next,
674        recurse_join = recurse_join,
675        depth_param = depth_param,
676        select_cols = select_parts.join(", "),
677        join_start = join_start,
678        join_end = join_end,
679        end_where = end_conditions.join(" AND "),
680        limit_param = limit_param,
681    );
682
683    Ok(CompiledQuery {
684        sql,
685        params,
686        return_vars: query.return_items.clone(),
687        warnings: Vec::new(),
688    })
689}
690
691#[derive(Clone, Copy, PartialEq, Eq)]
692enum VarKind {
693    Node,
694    Edge,
695}
696
697const NODE_COLUMNS: &[&str] = &[
698    "id",
699    "name",
700    "kind",
701    "namespace",
702    "description",
703    "properties",
704    "created_at",
705    "updated_at",
706];
707const EDGE_COLUMNS: &[&str] = &["id", "source_id", "target_id", "relation", "weight"];
708
709fn property_to_column<'a>(prop: &'a str, kind: &VarKind) -> Result<&'a str, QueryError> {
710    let valid = match kind {
711        VarKind::Node => NODE_COLUMNS,
712        VarKind::Edge => EDGE_COLUMNS,
713    };
714    if valid.contains(&prop) {
715        Ok(prop)
716    } else {
717        let kind_name = match kind {
718            VarKind::Node => "node",
719            VarKind::Edge => "edge",
720        };
721        Err(QueryError::Compile(format!(
722            "unknown {kind_name} property '{prop}' in RETURN projection. \
723             Valid: {}",
724            valid.join(", ")
725        )))
726    }
727}
728
729#[cfg(test)]
730mod tests {
731    use super::*;
732    use crate::parsers::gql;
733
734    fn opts() -> CompileOptions {
735        CompileOptions::default()
736    }
737
738    fn scoped(namespace: &str) -> CompileOptions {
739        CompileOptions {
740            scopes: vec![namespace.to_string()],
741            max_limit: 500,
742        }
743    }
744
745    #[test]
746    fn fixed_length_basic() {
747        let q =
748            gql::parse("MATCH (a:concept)-[e:introduced_by]->(b:paper) RETURN a, e, b LIMIT 10")
749                .unwrap();
750        let compiled = compile(&q, &opts()).unwrap();
751        assert!(compiled.sql.contains("JOIN graph_edges"));
752        assert!(compiled.sql.contains("LIMIT"));
753        assert_eq!(
754            compiled.return_vars,
755            vec![
756                ReturnItem::Variable("a".into()),
757                ReturnItem::Variable("e".into()),
758                ReturnItem::Variable("b".into()),
759            ]
760        );
761        // No recursive CTE for fixed-length
762        assert!(!compiled.sql.contains("WITH RECURSIVE"));
763    }
764
765    #[test]
766    fn namespace_scoping_injected() {
767        // Namespace must come from opts, never from the query
768        let q =
769            gql::parse("MATCH (a:concept)-[e:introduced_by]->(b:paper) RETURN a LIMIT 5").unwrap();
770        let compiled = compile(&q, &scoped("research")).unwrap();
771        assert!(compiled.sql.contains("namespace"));
772        // The namespace value must appear as a parameter, not a literal in SQL
773        let has_ns_param = compiled
774            .params
775            .iter()
776            .any(|p| matches!(p, SqlValue::Text(s) if s == "research"));
777        assert!(has_ns_param, "namespace must be a bound parameter");
778    }
779
780    #[test]
781    fn edge_property_whitelist_rejects_unknown() {
782        // MAJ-1: only 'relation' and 'weight' are queryable edge properties
783        let q = gql::parse("MATCH (a)-[e:introduced_by]->(b) WHERE e.source_id = 'x' RETURN a")
784            .unwrap();
785        let result = compile(&q, &opts());
786        assert!(result.is_err());
787        let err = result.unwrap_err().to_string();
788        assert!(
789            err.contains("source_id") || err.contains("not queryable"),
790            "error: {err}"
791        );
792    }
793
794    #[test]
795    fn edge_property_relation_allowed() {
796        let q = gql::parse("MATCH (a)-[e]->(b) WHERE e.relation = 'extends' RETURN a").unwrap();
797        let result = compile(&q, &opts());
798        assert!(
799            result.is_ok(),
800            "relation should be allowed: {:?}",
801            result.err()
802        );
803    }
804
805    #[test]
806    fn edge_property_weight_allowed() {
807        let q = gql::parse("MATCH (a)-[e]->(b) WHERE e.weight > 0.5 RETURN a").unwrap();
808        let result = compile(&q, &opts());
809        assert!(
810            result.is_ok(),
811            "weight should be allowed: {:?}",
812            result.err()
813        );
814    }
815
816    #[test]
817    fn variable_length_uses_cte() {
818        let q =
819            gql::parse("MATCH (a {name: 'LoRA'})-[:extends*1..3]->(b) RETURN b LIMIT 20").unwrap();
820        let compiled = compile(&q, &opts()).unwrap();
821        assert!(compiled.sql.contains("WITH RECURSIVE"));
822        assert!(compiled.sql.contains("traverse"));
823    }
824
825    #[test]
826    fn depth_cap_at_ten() {
827        // MAJ-2: depth capped at 10 regardless of query request
828        let q = gql::parse("MATCH (a)-[:extends*1..50]->(b) RETURN b").unwrap();
829        let compiled = compile(&q, &opts()).unwrap();
830        // The depth parameter must be <= 10
831        let depth_val = compiled.params.iter().find_map(|p| {
832            if let SqlValue::Integer(n) = p {
833                Some(*n)
834            } else {
835                None
836            }
837        });
838        assert!(depth_val.unwrap() <= 10, "depth must be capped at 10");
839    }
840
841    #[test]
842    fn limit_capped_by_max_limit() {
843        // Query requests 1000, max_limit is 500 — result should be 500
844        let q = gql::parse("MATCH (a:concept)-[e]->(b) RETURN a LIMIT 1000").unwrap();
845        let compiled = compile(&q, &opts()).unwrap();
846        let limit_param = compiled.params.last().unwrap();
847        assert!(
848            matches!(limit_param, SqlValue::Integer(500)),
849            "expected Integer(500), got {limit_param:?}"
850        );
851    }
852
853    #[test]
854    fn compile_rejects_unknown_relation() {
855        let q = gql::parse("MATCH (a)-[:not_a_relation]->(b) RETURN a").unwrap();
856        let err = compile(&q, &opts()).unwrap_err();
857        let msg = err.to_string();
858        assert!(msg.contains("not_a_relation"), "msg: {msg}");
859    }
860
861    #[test]
862    fn compile_unknown_kind_passes_through() {
863        // Pack-agnostic: any string is accepted as an entity kind at the query layer.
864        // Validation is a pack-handler concern.
865        let q = gql::parse("MATCH (a:gizmo)-[:extends]->(b) RETURN a").unwrap();
866        let compiled = compile(&q, &opts()).unwrap();
867        let has_gizmo = compiled
868            .params
869            .iter()
870            .any(|p| matches!(p, SqlValue::Text(s) if s == "gizmo"));
871        assert!(
872            has_gizmo,
873            "pack-agnostic: unknown kind must pass through into SQL params"
874        );
875    }
876
877    #[test]
878    fn compile_kind_passes_through_unchanged() {
879        // Pack-agnostic: 'paper' is no longer normalized to 'document' at the query layer.
880        // The string passes through as-is.
881        let q =
882            gql::parse("MATCH (a:paper)-[:introduced_by]->(b:concept) RETURN a LIMIT 1").unwrap();
883        let compiled = compile(&q, &opts()).unwrap();
884        let has_paper = compiled
885            .params
886            .iter()
887            .any(|p| matches!(p, SqlValue::Text(s) if s == "paper"));
888        assert!(
889            has_paper,
890            "kind 'paper' must pass through unchanged into SQL params"
891        );
892    }
893
894    #[test]
895    fn compile_rejects_namespace_in_where() {
896        let q =
897            gql::parse("MATCH (a:concept)-[:extends]->(b) WHERE a.namespace = 'other' RETURN a")
898                .unwrap();
899        let err = compile(&q, &opts()).unwrap_err();
900        assert!(err.to_string().contains("namespace"), "msg: {err}");
901    }
902
903    #[test]
904    fn compile_rejects_unknown_relation_in_where() {
905        let q = gql::parse("MATCH (a)-[e:extends]->(b) WHERE e.relation = 'related_to' RETURN a")
906            .unwrap();
907        let err = compile(&q, &opts()).unwrap_err();
908        assert!(err.to_string().contains("related_to"), "msg: {err}");
909    }
910
911    #[test]
912    fn compile_kind_in_where_passes_through_unchanged() {
913        // Pack-agnostic: kind strings in WHERE conditions pass through as-is.
914        let q = gql::parse("MATCH (a)-[:extends]->(b) WHERE a.kind = 'paper' RETURN a").unwrap();
915        let compiled = compile(&q, &opts()).unwrap();
916        let has_paper = compiled
917            .params
918            .iter()
919            .any(|p| matches!(p, SqlValue::Text(s) if s == "paper"));
920        assert!(
921            has_paper,
922            "kind 'paper' must pass through unchanged into SQL params"
923        );
924    }
925
926    #[test]
927    fn variable_length_return_start_only_joins_end_entity() {
928        // Even when only the start variable is projected, the outer query
929        // references `r.deleted_at` / `r.namespace`, so entities r must be
930        // joined unconditionally.
931        let q = gql::parse("MATCH (a:concept)-[:extends*1..3]->(b) RETURN a LIMIT 10").unwrap();
932        let compiled = compile(&q, &opts()).unwrap();
933        assert!(
934            compiled.sql.contains("JOIN entities r"),
935            "entities r must always be joined when r.* conditions are emitted; sql: {}",
936            compiled.sql
937        );
938    }
939
940    #[test]
941    fn variable_length_trailing_pattern_unsupported() {
942        let q = gql::parse("MATCH (a)-[:extends*1..3]->(b)-[:implements]->(c) RETURN b").unwrap();
943        let err = compile(&q, &opts()).unwrap_err();
944        assert!(
945            matches!(err, QueryError::Unsupported(_)),
946            "expected Unsupported, got {err:?}"
947        );
948    }
949
950    #[test]
951    fn variable_length_mixed_chain_unsupported() {
952        // Mixed fixed + variable in one chain — has_variable_length() triggers
953        // the variable-length path, which must reject because edges.len() > 1.
954        let q = gql::parse("MATCH (a)-[:extends]->(b)-[:implements*1..2]->(c) RETURN c").unwrap();
955        let err = compile(&q, &opts()).unwrap_err();
956        assert!(matches!(err, QueryError::Unsupported(_)), "got {err:?}");
957    }
958
959    #[test]
960    fn sparql_star_rejected_as_unsupported() {
961        use crate::parsers::sparql;
962        let err = sparql::parse("SELECT ?a ?b WHERE { ?a :extends* ?b . }").unwrap_err();
963        assert!(matches!(err, QueryError::Unsupported(_)), "got {err:?}");
964    }
965
966    /// Regression guard for ISSUE #231.
967    ///
968    /// Verifies the full SPARQL subject→predicate→object direction contract:
969    ///   ?a :extends ?b  must compile so that ?a binds `source_id` and ?b binds `target_id`.
970    ///
971    /// A swap (subject→target_id, object→source_id) would cause a query for
972    /// A–extends→B to return rows where B–extends→A, silently returning wrong results.
973    #[test]
974    fn sparql_subject_object_direction_compiles_outbound() {
975        use crate::parsers::sparql;
976
977        let q = sparql::parse("SELECT ?a ?b WHERE { ?a :extends ?b . }").unwrap();
978        let compiled = compile(&q, &opts()).unwrap();
979
980        assert!(
981            compiled
982                .sql
983                .contains("JOIN graph_edges e0 ON e0.source_id = n0.id"),
984            "SPARQL subject must bind graph_edges.source_id; sql: {}",
985            compiled.sql
986        );
987        assert!(
988            compiled
989                .sql
990                .contains("JOIN entities n1 ON n1.id = e0.target_id"),
991            "SPARQL object must bind graph_edges.target_id; sql: {}",
992            compiled.sql
993        );
994        assert!(
995            compiled.sql.contains("e0.relation = ?1"),
996            "SPARQL predicate must bind graph_edges.relation; sql: {}",
997            compiled.sql
998        );
999    }
1000
1001    #[test]
1002    fn return_property_projection_compiles() {
1003        let q =
1004            gql::parse("MATCH (a:concept)-[e:extends]->(b:concept) RETURN a.name, b.name LIMIT 5")
1005                .unwrap();
1006        let compiled = compile(&q, &opts()).unwrap();
1007        // Node aliases are n0, n1; the SQL uses `alias.col AS var_prop`
1008        assert!(
1009            compiled.sql.contains(".name AS a_name"),
1010            "sql: {}",
1011            compiled.sql
1012        );
1013        assert!(
1014            compiled.sql.contains(".name AS b_name"),
1015            "sql: {}",
1016            compiled.sql
1017        );
1018        assert!(
1019            !compiled.sql.contains("a_kind"),
1020            "should not emit full node columns"
1021        );
1022    }
1023
1024    #[test]
1025    fn return_unknown_node_property_rejected() {
1026        let q = gql::parse("MATCH (a:concept)-[:extends]->(b) RETURN a.domain LIMIT 5").unwrap();
1027        let err = compile(&q, &opts()).unwrap_err();
1028        assert!(
1029            matches!(err, QueryError::Compile(ref msg) if msg.contains("unknown node property 'domain'")),
1030            "got {err:?}"
1031        );
1032    }
1033
1034    #[test]
1035    fn return_unknown_edge_property_rejected() {
1036        let q = gql::parse("MATCH (a)-[e:extends]->(b) RETURN e.label LIMIT 5").unwrap();
1037        let err = compile(&q, &opts()).unwrap_err();
1038        assert!(
1039            matches!(err, QueryError::Compile(ref msg) if msg.contains("unknown edge property 'label'")),
1040            "got {err:?}"
1041        );
1042    }
1043
1044    #[test]
1045    fn return_valid_edge_property_compiles() {
1046        let q =
1047            gql::parse("MATCH (a)-[e:extends]->(b) RETURN e.relation, e.weight LIMIT 5").unwrap();
1048        let compiled = compile(&q, &opts()).unwrap();
1049        // Edge alias is e0; SQL: `e0.relation AS e_relation`
1050        assert!(
1051            compiled.sql.contains(".relation AS e_relation"),
1052            "sql: {}",
1053            compiled.sql
1054        );
1055        assert!(
1056            compiled.sql.contains(".weight AS e_weight"),
1057            "sql: {}",
1058            compiled.sql
1059        );
1060    }
1061}