alopex_sql/planner/
knn_optimizer.rs

1use crate::ast::expr::Literal;
2use crate::executor::evaluator::vector_ops::VectorMetric;
3use crate::planner::logical_plan::LogicalPlan;
4use crate::planner::typed_expr::{Projection, SortExpr, TypedExprKind};
5
6/// ORDER BY + LIMIT から抽出した KNN 最適化パターン。
7#[derive(Debug, Clone, PartialEq)]
8pub struct KnnPattern {
9    pub table: String,
10    pub column: String,
11    pub query_vector: Vec<f32>,
12    pub metric: VectorMetric,
13    pub k: u64,
14    pub sort_direction: SortDirection,
15}
16
17/// ソート方向(ASC / DESC)。
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum SortDirection {
20    Asc,
21    Desc,
22}
23
24/// ORDER BY vector_similarity/vector_distance + LIMIT K の形になっているか検出する。
25///
26/// - LIMIT が存在し、OFFSET が無い(または 0)の場合のみ最適化対象。
27/// - ORDER BY は単一のベクトル関数呼び出しであること。
28/// - メトリクスとソート方向の整合性を満たす場合のみ Some を返す。
29pub fn detect_knn_pattern(plan: &LogicalPlan) -> Option<KnnPattern> {
30    let (sort_plan, k) = extract_limit(plan)?;
31    let (order_expr, input_after_sort) = extract_sort(sort_plan)?;
32    let sort_direction = if order_expr.asc {
33        SortDirection::Asc
34    } else {
35        SortDirection::Desc
36    };
37
38    let (table, _projection, _filter) = extract_scan_context(input_after_sort)?;
39
40    let (func_name, args) = match &order_expr.expr.kind {
41        TypedExprKind::FunctionCall { name, args } => (name.to_ascii_lowercase(), args),
42        _ => return None,
43    };
44
45    if func_name != "vector_similarity" && func_name != "vector_distance" {
46        return None;
47    }
48
49    if args.len() != 3 {
50        return None;
51    }
52
53    let column_name = extract_column_name(&args[0], &table)?;
54    let query_vector = extract_query_vector(&args[1])?;
55    let metric = extract_metric(&args[2])?;
56
57    if !is_valid_knn_direction(metric, sort_direction) {
58        return None;
59    }
60
61    Some(KnnPattern {
62        table,
63        column: column_name,
64        query_vector,
65        metric,
66        k,
67        sort_direction,
68    })
69}
70
71fn extract_limit(plan: &LogicalPlan) -> Option<(&LogicalPlan, u64)> {
72    match plan {
73        LogicalPlan::Limit {
74            input,
75            limit: Some(k),
76            offset,
77        } if offset.unwrap_or(0) == 0 => Some((input.as_ref(), *k)),
78        _ => None,
79    }
80}
81
82fn extract_sort(plan: &LogicalPlan) -> Option<(&SortExpr, &LogicalPlan)> {
83    if let LogicalPlan::Sort { input, order_by } = plan
84        && order_by.len() == 1
85    {
86        return Some((&order_by[0], input.as_ref()));
87    }
88    None
89}
90
91fn extract_scan_context(
92    plan: &LogicalPlan,
93) -> Option<(
94    String,
95    Projection,
96    Option<crate::planner::typed_expr::TypedExpr>,
97)> {
98    match plan {
99        LogicalPlan::Filter { input, predicate } => {
100            if let LogicalPlan::Scan { table, projection } = input.as_ref() {
101                return Some((table.clone(), projection.clone(), Some(predicate.clone())));
102            }
103            None
104        }
105        LogicalPlan::Scan { table, projection } => Some((table.clone(), projection.clone(), None)),
106        _ => None,
107    }
108}
109
110fn extract_column_name(
111    expr: &crate::planner::typed_expr::TypedExpr,
112    table: &str,
113) -> Option<String> {
114    match &expr.kind {
115        TypedExprKind::ColumnRef {
116            table: tbl, column, ..
117        } if tbl == table => Some(column.clone()),
118        _ => None,
119    }
120}
121
122fn extract_query_vector(expr: &crate::planner::typed_expr::TypedExpr) -> Option<Vec<f32>> {
123    match &expr.kind {
124        TypedExprKind::VectorLiteral(values) if !values.is_empty() => {
125            Some(values.iter().map(|v| *v as f32).collect())
126        }
127        _ => None,
128    }
129}
130
131fn extract_metric(expr: &crate::planner::typed_expr::TypedExpr) -> Option<VectorMetric> {
132    match &expr.kind {
133        TypedExprKind::Literal(Literal::String(s)) => s.parse().ok(),
134        _ => None,
135    }
136}
137
138fn is_valid_knn_direction(metric: VectorMetric, dir: SortDirection) -> bool {
139    matches!(
140        (metric, dir),
141        (VectorMetric::Cosine, SortDirection::Desc)
142            | (VectorMetric::Inner, SortDirection::Desc)
143            | (VectorMetric::L2, SortDirection::Asc)
144    )
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use crate::ast::ddl::VectorMetric as AstVectorMetric;
151    use crate::ast::span::Span;
152    use crate::planner::logical_plan::LogicalPlan;
153    use crate::planner::typed_expr::{Projection, SortExpr, TypedExpr};
154    use crate::planner::types::ResolvedType;
155
156    fn base_vector_type() -> ResolvedType {
157        ResolvedType::Vector {
158            dimension: 2,
159            metric: AstVectorMetric::Cosine,
160        }
161    }
162
163    fn build_plan(order_asc: bool, metric_literal: &str, offset: Option<u64>) -> LogicalPlan {
164        let span = Span::empty();
165        let vector_expr = TypedExpr::function_call(
166            "vector_similarity".to_string(),
167            vec![
168                TypedExpr::column_ref(
169                    "items".to_string(),
170                    "embedding".to_string(),
171                    0,
172                    base_vector_type(),
173                    span,
174                ),
175                TypedExpr::vector_literal(vec![1.0, 0.0], 2, span),
176                TypedExpr::literal(
177                    Literal::String(metric_literal.to_string()),
178                    ResolvedType::Text,
179                    span,
180                ),
181            ],
182            ResolvedType::Double,
183            span,
184        );
185        let sort = LogicalPlan::Sort {
186            input: Box::new(LogicalPlan::Scan {
187                table: "items".to_string(),
188                projection: Projection::All(vec!["embedding".to_string()]),
189            }),
190            order_by: vec![SortExpr::new(vector_expr, order_asc, false)],
191        };
192
193        LogicalPlan::Limit {
194            input: Box::new(sort),
195            limit: Some(2),
196            offset,
197        }
198    }
199
200    #[test]
201    fn detect_knn_pattern_cosine_desc() {
202        let plan = build_plan(false, "cosine", None);
203        let pattern = detect_knn_pattern(&plan).expect("should detect pattern");
204        assert_eq!(pattern.table, "items");
205        assert_eq!(pattern.column, "embedding");
206        assert_eq!(pattern.k, 2);
207        assert_eq!(pattern.metric, VectorMetric::Cosine);
208        assert_eq!(pattern.sort_direction, SortDirection::Desc);
209        assert_eq!(pattern.query_vector, vec![1.0, 0.0]);
210    }
211
212    #[test]
213    fn reject_invalid_direction() {
214        let plan = build_plan(true, "cosine", None);
215        assert!(detect_knn_pattern(&plan).is_none());
216    }
217
218    #[test]
219    fn reject_missing_limit_or_offset() {
220        let plan_no_limit = LogicalPlan::Sort {
221            input: Box::new(LogicalPlan::Scan {
222                table: "items".to_string(),
223                projection: Projection::All(vec!["embedding".to_string()]),
224            }),
225            order_by: vec![],
226        };
227        assert!(detect_knn_pattern(&plan_no_limit).is_none());
228
229        let plan_with_offset = build_plan(false, "cosine", Some(1));
230        assert!(detect_knn_pattern(&plan_with_offset).is_none());
231    }
232
233    #[test]
234    fn reject_unknown_metric() {
235        let plan = build_plan(false, "unknown", None);
236        assert!(detect_knn_pattern(&plan).is_none());
237    }
238}