Skip to main content

nodedb_sql/ddl_ast/graph_parse/
fusion_params.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Shared parameter extraction for graph-vector fusion SQL surfaces.
4//!
5//! Two syntaxes reach the same `GraphOp::RagFusion` executor today:
6//!
7//! - `GRAPH RAG FUSION ON <col> QUERY ARRAY[...] ...` (DSL form)
8//! - `SEARCH <col> USING FUSION(ARRAY[...] ...)` (wrapped form)
9//!
10//! They use different keyword aliases for the same parameters
11//! (`EXPANSION_DEPTH` vs `DEPTH`, `EDGE_LABEL` vs `LABEL`, `FINAL_TOP_K`
12//! vs `TOP`). Both must extract the same typed bag so future fusion
13//! variants (hybrid text+vector, multi-vector, etc.) can share this
14//! code and cannot silently drop parameters the way substring-find
15//! parsing did.
16
17use super::super::statement::GraphDirection;
18use super::helpers::{
19    array_floats_after, float_pair_after, float_triple_after, quoted_after, usize_after, word_after,
20};
21use super::tokenizer::{Tok, tokenize};
22
23/// Keyword aliases for the shared fusion parameters.
24///
25/// Each fusion SQL surface picks one of the `*_KEYWORDS` constants below.
26/// New fusion variants add their own constant rather than editing the
27/// extractor.
28pub struct FusionKeywords {
29    pub vector_top_k: &'static str,
30    pub expansion_depth: &'static str,
31    pub edge_label: &'static str,
32    pub final_top_k: &'static str,
33    pub rrf_k: &'static str,
34    pub vector_field: &'static str,
35    pub direction: &'static str,
36    pub max_visited: &'static str,
37    /// Keyword that precedes `ARRAY[...]` in raw SQL (e.g. `QUERY` or
38    /// `ARRAY` itself when there is no leading keyword).
39    pub query_anchor: &'static str,
40    /// Keyword that precedes the BM25 query string for three-source fusion.
41    /// Empty string disables BM25 parsing for surfaces that do not support it.
42    pub bm25_query: &'static str,
43    /// Keyword that precedes the BM25 field name in three-source fusion.
44    pub bm25_field: &'static str,
45}
46
47/// Keywords used by `GRAPH RAG FUSION ON ...`.
48pub const RAG_FUSION_KEYWORDS: FusionKeywords = FusionKeywords {
49    vector_top_k: "VECTOR_TOP_K",
50    expansion_depth: "EXPANSION_DEPTH",
51    edge_label: "EDGE_LABEL",
52    final_top_k: "FINAL_TOP_K",
53    rrf_k: "RRF_K",
54    vector_field: "VECTOR_FIELD",
55    direction: "DIRECTION",
56    max_visited: "MAX_VISITED",
57    query_anchor: "QUERY",
58    bm25_query: "BM25",
59    bm25_field: "ON",
60};
61
62/// Keywords used by `SEARCH ... USING FUSION(...)`.
63pub const SEARCH_FUSION_KEYWORDS: FusionKeywords = FusionKeywords {
64    vector_top_k: "VECTOR_TOP_K",
65    expansion_depth: "DEPTH",
66    edge_label: "LABEL",
67    final_top_k: "TOP",
68    rrf_k: "RRF_K",
69    vector_field: "VECTOR_FIELD",
70    direction: "DIRECTION",
71    max_visited: "MAX_VISITED",
72    query_anchor: "ARRAY",
73    bm25_query: "BM25",
74    bm25_field: "ON",
75};
76
77/// Typed parameter bag for every graph-vector fusion SQL surface.
78///
79/// All fields are optional at parse time — bounds, caps, and
80/// "absent but required" errors are enforced at the pgwire boundary.
81///
82/// Three-source fusion (vector + text + graph) is enabled by populating
83/// `bm25_query` and `bm25_field` together with `rrf_k_triple`. When only
84/// `rrf_k` is set (two values), behaviour is unchanged from the two-source
85/// form. When `rrf_k_triple` is set it takes precedence and the BM25 leg
86/// participates in the fusion.
87#[derive(Debug, Clone, Default, PartialEq)]
88pub struct FusionParams {
89    pub query_vector: Option<Vec<f32>>,
90    pub vector_top_k: Option<usize>,
91    pub expansion_depth: Option<usize>,
92    pub edge_label: Option<String>,
93    pub final_top_k: Option<usize>,
94    /// Two-source RRF k constants: `(vector_k, graph_k)`. Used when no
95    /// `bm25_query` is present (backwards-compatible two-source form).
96    pub rrf_k: Option<(f64, f64)>,
97    /// Three-source RRF k constants: `(vector_k, text_k, graph_k)`. Set
98    /// when `RRF_K (kv, kt, kg)` is parsed and three values are found.
99    pub rrf_k_triple: Option<(f64, f64, f64)>,
100    pub vector_field: Option<String>,
101    pub direction: Option<GraphDirection>,
102    pub max_visited: Option<usize>,
103    /// BM25 query string for the text leg of three-source fusion. Parsed
104    /// from `BM25 'query string' ON 'field_name'` in the FUSION DSL.
105    pub bm25_query: Option<String>,
106    /// Document field on which BM25 scoring is applied in three-source fusion.
107    pub bm25_field: Option<String>,
108}
109
110impl FusionParams {
111    pub(super) fn extract(toks: &[Tok<'_>], sql: &str, kw: &FusionKeywords) -> Self {
112        let direction = match word_after(toks, kw.direction)
113            .as_deref()
114            .map(str::to_ascii_uppercase)
115            .as_deref()
116        {
117            Some("IN") => Some(GraphDirection::In),
118            Some("BOTH") => Some(GraphDirection::Both),
119            Some("OUT") => Some(GraphDirection::Out),
120            _ => None,
121        };
122
123        // Try to parse a three-value RRF_K triple first; fall back to the
124        // two-value pair. This way `RRF_K (60.0, 35.0, 50.0)` populates
125        // `rrf_k_triple` and leaves `rrf_k` as None, while the legacy
126        // `RRF_K (60.0, 35.0)` continues to populate only `rrf_k`.
127        let rrf_k_triple = float_triple_after(toks, kw.rrf_k);
128        let rrf_k = if rrf_k_triple.is_some() {
129            None
130        } else {
131            float_pair_after(toks, kw.rrf_k)
132        };
133
134        // BM25 text leg — only parsed when the keyword is non-empty.
135        let (bm25_query, bm25_field) = if !kw.bm25_query.is_empty() {
136            (
137                quoted_after(toks, kw.bm25_query),
138                quoted_after(toks, kw.bm25_field),
139            )
140        } else {
141            (None, None)
142        };
143
144        Self {
145            query_vector: array_floats_after(sql, kw.query_anchor),
146            vector_top_k: usize_after(toks, kw.vector_top_k),
147            expansion_depth: usize_after(toks, kw.expansion_depth),
148            edge_label: quoted_after(toks, kw.edge_label),
149            final_top_k: usize_after(toks, kw.final_top_k),
150            rrf_k,
151            rrf_k_triple,
152            vector_field: quoted_after(toks, kw.vector_field),
153            direction,
154            max_visited: usize_after(toks, kw.max_visited),
155            bm25_query,
156            bm25_field,
157        }
158    }
159}
160
161/// Parse `SEARCH <collection> USING FUSION(...)` into its collection name
162/// and a typed [`FusionParams`]. Returns `None` when the SQL does not
163/// match the expected shape.
164///
165/// Body extraction uses the same quote- and bracket-aware tokenizer as
166/// the `GRAPH RAG FUSION` path, so a keyword-shaped string literal (e.g.
167/// a label value `'TOP'`) cannot shadow a real parameter keyword.
168pub fn parse_search_using_fusion(sql: &str) -> Option<(String, FusionParams)> {
169    let toks = tokenize(sql);
170    let collection = match toks.as_slice() {
171        [Tok::Word(s), Tok::Word(c), Tok::Word(u), Tok::Word(f), ..]
172            if s.eq_ignore_ascii_case("SEARCH")
173                && u.eq_ignore_ascii_case("USING")
174                && f.eq_ignore_ascii_case("FUSION") =>
175        {
176            (*c).to_string()
177        }
178        _ => return None,
179    };
180    Some((
181        collection,
182        FusionParams::extract(&toks, sql, &SEARCH_FUSION_KEYWORDS),
183    ))
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn search_fusion_full_surface_parses() {
192        let (col, p) = parse_search_using_fusion(
193            "SEARCH mycol USING FUSION(ARRAY[0.1, 0.2] VECTOR_TOP_K 5 DEPTH 2 \
194             LABEL 'related' TOP 10 RRF_K (60.0, 35.0))",
195        )
196        .unwrap();
197        assert_eq!(col, "mycol");
198        assert_eq!(p.query_vector.as_deref().map(<[f32]>::len), Some(2));
199        assert_eq!(p.vector_top_k, Some(5));
200        assert_eq!(p.expansion_depth, Some(2));
201        assert_eq!(p.edge_label.as_deref(), Some("related"));
202        assert_eq!(p.final_top_k, Some(10));
203        assert_eq!(p.rrf_k, Some((60.0, 35.0)));
204        assert_eq!(p.rrf_k_triple, None);
205    }
206
207    #[test]
208    fn search_fusion_three_source_parses() {
209        let (col, p) = parse_search_using_fusion(
210            "SEARCH entities USING FUSION(ARRAY[0.1, 0.3] VECTOR_FIELD 'embedding' \
211             VECTOR_TOP_K 50 BM25 'transformer attention' ON 'body' \
212             DEPTH 2 LABEL 'related_to' TOP 10 RRF_K (60.0, 35.0, 50.0))",
213        )
214        .unwrap();
215        assert_eq!(col, "entities");
216        assert_eq!(p.rrf_k, None);
217        assert_eq!(p.rrf_k_triple, Some((60.0, 35.0, 50.0)));
218        assert_eq!(p.bm25_query.as_deref(), Some("transformer attention"));
219        assert_eq!(p.bm25_field.as_deref(), Some("body"));
220        assert_eq!(p.expansion_depth, Some(2));
221        assert_eq!(p.edge_label.as_deref(), Some("related_to"));
222        assert_eq!(p.final_top_k, Some(10));
223    }
224
225    #[test]
226    fn search_fusion_label_literal_that_shadows_top_keyword() {
227        // A quoted label value containing the `TOP` keyword must not be
228        // misread as the `TOP` numeric parameter — the tokenizer keeps
229        // quoted strings whole, so `TOP 10` is the real parameter.
230        let (_, p) =
231            parse_search_using_fusion("SEARCH c USING FUSION(ARRAY[0.5] LABEL 'TOP_SECRET' TOP 7)")
232                .unwrap();
233        assert_eq!(p.edge_label.as_deref(), Some("TOP_SECRET"));
234        assert_eq!(p.final_top_k, Some(7));
235    }
236
237    #[test]
238    fn search_fusion_rejects_wrong_prefix() {
239        assert!(parse_search_using_fusion("INSERT INTO x VALUES (1)").is_none());
240        assert!(parse_search_using_fusion("SEARCH x USING VECTOR(ARRAY[1.0])").is_none());
241    }
242}