1use crate::ast::expr::Literal;
2use crate::executor::evaluator::vector_ops::VectorMetric;
3use crate::planner::logical_plan::LogicalPlan;
4use crate::planner::typed_expr::{Projection, SortExpr, TypedExprKind};
5
6#[derive(Debug, Clone, PartialEq)]
8pub struct KnnPattern {
9 pub table: String,
10 pub column: String,
11 pub query_vector: Vec<f32>,
12 pub metric: VectorMetric,
13 pub k: u64,
14 pub sort_direction: SortDirection,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum SortDirection {
20 Asc,
21 Desc,
22}
23
24pub fn detect_knn_pattern(plan: &LogicalPlan) -> Option<KnnPattern> {
30 let (sort_plan, k) = extract_limit(plan)?;
31 let (order_expr, input_after_sort) = extract_sort(sort_plan)?;
32 let sort_direction = if order_expr.asc {
33 SortDirection::Asc
34 } else {
35 SortDirection::Desc
36 };
37
38 let (table, _projection, _filter) = extract_scan_context(input_after_sort)?;
39
40 let (func_name, args) = match &order_expr.expr.kind {
41 TypedExprKind::FunctionCall { name, args } => (name.to_ascii_lowercase(), args),
42 _ => return None,
43 };
44
45 if func_name != "vector_similarity" && func_name != "vector_distance" {
46 return None;
47 }
48
49 if args.len() != 3 {
50 return None;
51 }
52
53 let column_name = extract_column_name(&args[0], &table)?;
54 let query_vector = extract_query_vector(&args[1])?;
55 let metric = extract_metric(&args[2])?;
56
57 if !is_valid_knn_direction(metric, sort_direction) {
58 return None;
59 }
60
61 Some(KnnPattern {
62 table,
63 column: column_name,
64 query_vector,
65 metric,
66 k,
67 sort_direction,
68 })
69}
70
71fn extract_limit(plan: &LogicalPlan) -> Option<(&LogicalPlan, u64)> {
72 match plan {
73 LogicalPlan::Limit {
74 input,
75 limit: Some(k),
76 offset,
77 } if offset.unwrap_or(0) == 0 => Some((input.as_ref(), *k)),
78 _ => None,
79 }
80}
81
82fn extract_sort(plan: &LogicalPlan) -> Option<(&SortExpr, &LogicalPlan)> {
83 if let LogicalPlan::Sort { input, order_by } = plan
84 && order_by.len() == 1
85 {
86 return Some((&order_by[0], input.as_ref()));
87 }
88 None
89}
90
91fn extract_scan_context(
92 plan: &LogicalPlan,
93) -> Option<(
94 String,
95 Projection,
96 Option<crate::planner::typed_expr::TypedExpr>,
97)> {
98 match plan {
99 LogicalPlan::Filter { input, predicate } => {
100 if let LogicalPlan::Scan { table, projection } = input.as_ref() {
101 return Some((table.clone(), projection.clone(), Some(predicate.clone())));
102 }
103 None
104 }
105 LogicalPlan::Scan { table, projection } => Some((table.clone(), projection.clone(), None)),
106 _ => None,
107 }
108}
109
110fn extract_column_name(
111 expr: &crate::planner::typed_expr::TypedExpr,
112 table: &str,
113) -> Option<String> {
114 match &expr.kind {
115 TypedExprKind::ColumnRef {
116 table: tbl, column, ..
117 } if tbl == table => Some(column.clone()),
118 _ => None,
119 }
120}
121
122fn extract_query_vector(expr: &crate::planner::typed_expr::TypedExpr) -> Option<Vec<f32>> {
123 match &expr.kind {
124 TypedExprKind::VectorLiteral(values) if !values.is_empty() => {
125 Some(values.iter().map(|v| *v as f32).collect())
126 }
127 _ => None,
128 }
129}
130
131fn extract_metric(expr: &crate::planner::typed_expr::TypedExpr) -> Option<VectorMetric> {
132 match &expr.kind {
133 TypedExprKind::Literal(Literal::String(s)) => s.parse().ok(),
134 _ => None,
135 }
136}
137
138fn is_valid_knn_direction(metric: VectorMetric, dir: SortDirection) -> bool {
139 matches!(
140 (metric, dir),
141 (VectorMetric::Cosine, SortDirection::Desc)
142 | (VectorMetric::Inner, SortDirection::Desc)
143 | (VectorMetric::L2, SortDirection::Asc)
144 )
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use crate::ast::ddl::VectorMetric as AstVectorMetric;
151 use crate::ast::span::Span;
152 use crate::planner::logical_plan::LogicalPlan;
153 use crate::planner::typed_expr::{Projection, SortExpr, TypedExpr};
154 use crate::planner::types::ResolvedType;
155
156 fn base_vector_type() -> ResolvedType {
157 ResolvedType::Vector {
158 dimension: 2,
159 metric: AstVectorMetric::Cosine,
160 }
161 }
162
163 fn build_plan(order_asc: bool, metric_literal: &str, offset: Option<u64>) -> LogicalPlan {
164 let span = Span::empty();
165 let vector_expr = TypedExpr::function_call(
166 "vector_similarity".to_string(),
167 vec![
168 TypedExpr::column_ref(
169 "items".to_string(),
170 "embedding".to_string(),
171 0,
172 base_vector_type(),
173 span,
174 ),
175 TypedExpr::vector_literal(vec![1.0, 0.0], 2, span),
176 TypedExpr::literal(
177 Literal::String(metric_literal.to_string()),
178 ResolvedType::Text,
179 span,
180 ),
181 ],
182 ResolvedType::Double,
183 span,
184 );
185 let sort = LogicalPlan::Sort {
186 input: Box::new(LogicalPlan::Scan {
187 table: "items".to_string(),
188 projection: Projection::All(vec!["embedding".to_string()]),
189 }),
190 order_by: vec![SortExpr::new(vector_expr, order_asc, false)],
191 };
192
193 LogicalPlan::Limit {
194 input: Box::new(sort),
195 limit: Some(2),
196 offset,
197 }
198 }
199
200 #[test]
201 fn detect_knn_pattern_cosine_desc() {
202 let plan = build_plan(false, "cosine", None);
203 let pattern = detect_knn_pattern(&plan).expect("should detect pattern");
204 assert_eq!(pattern.table, "items");
205 assert_eq!(pattern.column, "embedding");
206 assert_eq!(pattern.k, 2);
207 assert_eq!(pattern.metric, VectorMetric::Cosine);
208 assert_eq!(pattern.sort_direction, SortDirection::Desc);
209 assert_eq!(pattern.query_vector, vec![1.0, 0.0]);
210 }
211
212 #[test]
213 fn reject_invalid_direction() {
214 let plan = build_plan(true, "cosine", None);
215 assert!(detect_knn_pattern(&plan).is_none());
216 }
217
218 #[test]
219 fn reject_missing_limit_or_offset() {
220 let plan_no_limit = LogicalPlan::Sort {
221 input: Box::new(LogicalPlan::Scan {
222 table: "items".to_string(),
223 projection: Projection::All(vec!["embedding".to_string()]),
224 }),
225 order_by: vec![],
226 };
227 assert!(detect_knn_pattern(&plan_no_limit).is_none());
228
229 let plan_with_offset = build_plan(false, "cosine", Some(1));
230 assert!(detect_knn_pattern(&plan_with_offset).is_none());
231 }
232
233 #[test]
234 fn reject_unknown_metric() {
235 let plan = build_plan(false, "unknown", None);
236 assert!(detect_knn_pattern(&plan).is_none());
237 }
238}