nodedb_sql/ddl_ast/graph_parse/
fusion_params.rs1use super::super::statement::GraphDirection;
16use super::helpers::{array_floats_after, float_pair_after, quoted_after, usize_after, word_after};
17use super::tokenizer::{Tok, tokenize};
18
19pub struct FusionKeywords {
25 pub vector_top_k: &'static str,
26 pub expansion_depth: &'static str,
27 pub edge_label: &'static str,
28 pub final_top_k: &'static str,
29 pub rrf_k: &'static str,
30 pub vector_field: &'static str,
31 pub direction: &'static str,
32 pub max_visited: &'static str,
33 pub query_anchor: &'static str,
36}
37
38pub const RAG_FUSION_KEYWORDS: FusionKeywords = FusionKeywords {
40 vector_top_k: "VECTOR_TOP_K",
41 expansion_depth: "EXPANSION_DEPTH",
42 edge_label: "EDGE_LABEL",
43 final_top_k: "FINAL_TOP_K",
44 rrf_k: "RRF_K",
45 vector_field: "VECTOR_FIELD",
46 direction: "DIRECTION",
47 max_visited: "MAX_VISITED",
48 query_anchor: "QUERY",
49};
50
51pub const SEARCH_FUSION_KEYWORDS: FusionKeywords = FusionKeywords {
53 vector_top_k: "VECTOR_TOP_K",
54 expansion_depth: "DEPTH",
55 edge_label: "LABEL",
56 final_top_k: "TOP",
57 rrf_k: "RRF_K",
58 vector_field: "VECTOR_FIELD",
59 direction: "DIRECTION",
60 max_visited: "MAX_VISITED",
61 query_anchor: "ARRAY",
62};
63
64#[derive(Debug, Clone, Default, PartialEq)]
69pub struct FusionParams {
70 pub query_vector: Option<Vec<f32>>,
71 pub vector_top_k: Option<usize>,
72 pub expansion_depth: Option<usize>,
73 pub edge_label: Option<String>,
74 pub final_top_k: Option<usize>,
75 pub rrf_k: Option<(f64, f64)>,
76 pub vector_field: Option<String>,
77 pub direction: Option<GraphDirection>,
78 pub max_visited: Option<usize>,
79}
80
81impl FusionParams {
82 pub(super) fn extract(toks: &[Tok<'_>], sql: &str, kw: &FusionKeywords) -> Self {
83 let direction = match word_after(toks, kw.direction)
84 .as_deref()
85 .map(str::to_ascii_uppercase)
86 .as_deref()
87 {
88 Some("IN") => Some(GraphDirection::In),
89 Some("BOTH") => Some(GraphDirection::Both),
90 Some("OUT") => Some(GraphDirection::Out),
91 _ => None,
92 };
93 Self {
94 query_vector: array_floats_after(sql, kw.query_anchor),
95 vector_top_k: usize_after(toks, kw.vector_top_k),
96 expansion_depth: usize_after(toks, kw.expansion_depth),
97 edge_label: quoted_after(toks, kw.edge_label),
98 final_top_k: usize_after(toks, kw.final_top_k),
99 rrf_k: float_pair_after(toks, kw.rrf_k),
100 vector_field: quoted_after(toks, kw.vector_field),
101 direction,
102 max_visited: usize_after(toks, kw.max_visited),
103 }
104 }
105}
106
107pub fn parse_search_using_fusion(sql: &str) -> Option<(String, FusionParams)> {
115 let toks = tokenize(sql);
116 let collection = match toks.as_slice() {
117 [Tok::Word(s), Tok::Word(c), Tok::Word(u), Tok::Word(f), ..]
118 if s.eq_ignore_ascii_case("SEARCH")
119 && u.eq_ignore_ascii_case("USING")
120 && f.eq_ignore_ascii_case("FUSION") =>
121 {
122 (*c).to_string()
123 }
124 _ => return None,
125 };
126 Some((
127 collection,
128 FusionParams::extract(&toks, sql, &SEARCH_FUSION_KEYWORDS),
129 ))
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn search_fusion_full_surface_parses() {
138 let (col, p) = parse_search_using_fusion(
139 "SEARCH mycol USING FUSION(ARRAY[0.1, 0.2] VECTOR_TOP_K 5 DEPTH 2 \
140 LABEL 'related' TOP 10 RRF_K (60.0, 35.0))",
141 )
142 .unwrap();
143 assert_eq!(col, "mycol");
144 assert_eq!(p.query_vector.as_deref().map(<[f32]>::len), Some(2));
145 assert_eq!(p.vector_top_k, Some(5));
146 assert_eq!(p.expansion_depth, Some(2));
147 assert_eq!(p.edge_label.as_deref(), Some("related"));
148 assert_eq!(p.final_top_k, Some(10));
149 assert_eq!(p.rrf_k, Some((60.0, 35.0)));
150 }
151
152 #[test]
153 fn search_fusion_label_literal_that_shadows_top_keyword() {
154 let (_, p) =
158 parse_search_using_fusion("SEARCH c USING FUSION(ARRAY[0.5] LABEL 'TOP_SECRET' TOP 7)")
159 .unwrap();
160 assert_eq!(p.edge_label.as_deref(), Some("TOP_SECRET"));
161 assert_eq!(p.final_top_k, Some(7));
162 }
163
164 #[test]
165 fn search_fusion_rejects_wrong_prefix() {
166 assert!(parse_search_using_fusion("INSERT INTO x VALUES (1)").is_none());
167 assert!(parse_search_using_fusion("SEARCH x USING VECTOR(ARRAY[1.0])").is_none());
168 }
169}