nodedb_sql/ddl_ast/graph_parse/
fusion_params.rs1use super::super::statement::GraphDirection;
18use super::helpers::{
19 array_floats_after, float_pair_after, float_triple_after, quoted_after, usize_after, word_after,
20};
21use super::tokenizer::{Tok, tokenize};
22
23pub struct FusionKeywords {
29 pub vector_top_k: &'static str,
30 pub expansion_depth: &'static str,
31 pub edge_label: &'static str,
32 pub final_top_k: &'static str,
33 pub rrf_k: &'static str,
34 pub vector_field: &'static str,
35 pub direction: &'static str,
36 pub max_visited: &'static str,
37 pub query_anchor: &'static str,
40 pub bm25_query: &'static str,
43 pub bm25_field: &'static str,
45}
46
47pub const RAG_FUSION_KEYWORDS: FusionKeywords = FusionKeywords {
49 vector_top_k: "VECTOR_TOP_K",
50 expansion_depth: "EXPANSION_DEPTH",
51 edge_label: "EDGE_LABEL",
52 final_top_k: "FINAL_TOP_K",
53 rrf_k: "RRF_K",
54 vector_field: "VECTOR_FIELD",
55 direction: "DIRECTION",
56 max_visited: "MAX_VISITED",
57 query_anchor: "QUERY",
58 bm25_query: "BM25",
59 bm25_field: "ON",
60};
61
62pub const SEARCH_FUSION_KEYWORDS: FusionKeywords = FusionKeywords {
64 vector_top_k: "VECTOR_TOP_K",
65 expansion_depth: "DEPTH",
66 edge_label: "LABEL",
67 final_top_k: "TOP",
68 rrf_k: "RRF_K",
69 vector_field: "VECTOR_FIELD",
70 direction: "DIRECTION",
71 max_visited: "MAX_VISITED",
72 query_anchor: "ARRAY",
73 bm25_query: "BM25",
74 bm25_field: "ON",
75};
76
77#[derive(Debug, Clone, Default, PartialEq)]
88pub struct FusionParams {
89 pub query_vector: Option<Vec<f32>>,
90 pub vector_top_k: Option<usize>,
91 pub expansion_depth: Option<usize>,
92 pub edge_label: Option<String>,
93 pub final_top_k: Option<usize>,
94 pub rrf_k: Option<(f64, f64)>,
97 pub rrf_k_triple: Option<(f64, f64, f64)>,
100 pub vector_field: Option<String>,
101 pub direction: Option<GraphDirection>,
102 pub max_visited: Option<usize>,
103 pub bm25_query: Option<String>,
106 pub bm25_field: Option<String>,
108}
109
110impl FusionParams {
111 pub(super) fn extract(toks: &[Tok<'_>], sql: &str, kw: &FusionKeywords) -> Self {
112 let direction = match word_after(toks, kw.direction)
113 .as_deref()
114 .map(str::to_ascii_uppercase)
115 .as_deref()
116 {
117 Some("IN") => Some(GraphDirection::In),
118 Some("BOTH") => Some(GraphDirection::Both),
119 Some("OUT") => Some(GraphDirection::Out),
120 _ => None,
121 };
122
123 let rrf_k_triple = float_triple_after(toks, kw.rrf_k);
128 let rrf_k = if rrf_k_triple.is_some() {
129 None
130 } else {
131 float_pair_after(toks, kw.rrf_k)
132 };
133
134 let (bm25_query, bm25_field) = if !kw.bm25_query.is_empty() {
136 (
137 quoted_after(toks, kw.bm25_query),
138 quoted_after(toks, kw.bm25_field),
139 )
140 } else {
141 (None, None)
142 };
143
144 Self {
145 query_vector: array_floats_after(sql, kw.query_anchor),
146 vector_top_k: usize_after(toks, kw.vector_top_k),
147 expansion_depth: usize_after(toks, kw.expansion_depth),
148 edge_label: quoted_after(toks, kw.edge_label),
149 final_top_k: usize_after(toks, kw.final_top_k),
150 rrf_k,
151 rrf_k_triple,
152 vector_field: quoted_after(toks, kw.vector_field),
153 direction,
154 max_visited: usize_after(toks, kw.max_visited),
155 bm25_query,
156 bm25_field,
157 }
158 }
159}
160
161pub fn parse_search_using_fusion(sql: &str) -> Option<(String, FusionParams)> {
169 let toks = tokenize(sql);
170 let collection = match toks.as_slice() {
171 [Tok::Word(s), Tok::Word(c), Tok::Word(u), Tok::Word(f), ..]
172 if s.eq_ignore_ascii_case("SEARCH")
173 && u.eq_ignore_ascii_case("USING")
174 && f.eq_ignore_ascii_case("FUSION") =>
175 {
176 (*c).to_string()
177 }
178 _ => return None,
179 };
180 Some((
181 collection,
182 FusionParams::extract(&toks, sql, &SEARCH_FUSION_KEYWORDS),
183 ))
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn search_fusion_full_surface_parses() {
192 let (col, p) = parse_search_using_fusion(
193 "SEARCH mycol USING FUSION(ARRAY[0.1, 0.2] VECTOR_TOP_K 5 DEPTH 2 \
194 LABEL 'related' TOP 10 RRF_K (60.0, 35.0))",
195 )
196 .unwrap();
197 assert_eq!(col, "mycol");
198 assert_eq!(p.query_vector.as_deref().map(<[f32]>::len), Some(2));
199 assert_eq!(p.vector_top_k, Some(5));
200 assert_eq!(p.expansion_depth, Some(2));
201 assert_eq!(p.edge_label.as_deref(), Some("related"));
202 assert_eq!(p.final_top_k, Some(10));
203 assert_eq!(p.rrf_k, Some((60.0, 35.0)));
204 assert_eq!(p.rrf_k_triple, None);
205 }
206
207 #[test]
208 fn search_fusion_three_source_parses() {
209 let (col, p) = parse_search_using_fusion(
210 "SEARCH entities USING FUSION(ARRAY[0.1, 0.3] VECTOR_FIELD 'embedding' \
211 VECTOR_TOP_K 50 BM25 'transformer attention' ON 'body' \
212 DEPTH 2 LABEL 'related_to' TOP 10 RRF_K (60.0, 35.0, 50.0))",
213 )
214 .unwrap();
215 assert_eq!(col, "entities");
216 assert_eq!(p.rrf_k, None);
217 assert_eq!(p.rrf_k_triple, Some((60.0, 35.0, 50.0)));
218 assert_eq!(p.bm25_query.as_deref(), Some("transformer attention"));
219 assert_eq!(p.bm25_field.as_deref(), Some("body"));
220 assert_eq!(p.expansion_depth, Some(2));
221 assert_eq!(p.edge_label.as_deref(), Some("related_to"));
222 assert_eq!(p.final_top_k, Some(10));
223 }
224
225 #[test]
226 fn search_fusion_label_literal_that_shadows_top_keyword() {
227 let (_, p) =
231 parse_search_using_fusion("SEARCH c USING FUSION(ARRAY[0.5] LABEL 'TOP_SECRET' TOP 7)")
232 .unwrap();
233 assert_eq!(p.edge_label.as_deref(), Some("TOP_SECRET"));
234 assert_eq!(p.final_top_k, Some(7));
235 }
236
237 #[test]
238 fn search_fusion_rejects_wrong_prefix() {
239 assert!(parse_search_using_fusion("INSERT INTO x VALUES (1)").is_none());
240 assert!(parse_search_using_fusion("SEARCH x USING VECTOR(ARRAY[1.0])").is_none());
241 }
242}