Skip to main content

dbkit_core/
func.rs

1use crate::expr::{Expr, ExprNode, IntoExpr, VectorBinaryOp};
2use crate::PgVector;
3
4pub fn upper(arg: impl IntoExpr<String>) -> Expr<String> {
5    let expr = arg.into_expr();
6    Expr::new(ExprNode::Func {
7        name: "UPPER",
8        args: vec![expr.node],
9    })
10}
11
12pub fn count<T>(arg: impl IntoExpr<T>) -> Expr<i64> {
13    let expr = arg.into_expr();
14    Expr::new(ExprNode::Func {
15        name: "COUNT",
16        args: vec![expr.node],
17    })
18}
19
20pub fn sum<T>(arg: impl IntoExpr<T>) -> Expr<T> {
21    let expr = arg.into_expr();
22    Expr::new(ExprNode::Func {
23        name: "SUM",
24        args: vec![expr.node],
25    })
26}
27
28pub fn coalesce<T>(a: impl IntoExpr<T>, b: impl IntoExpr<T>) -> Expr<T> {
29    let left = a.into_expr();
30    let right = b.into_expr();
31    Expr::new(ExprNode::Func {
32        name: "COALESCE",
33        args: vec![left.node, right.node],
34    })
35}
36
37pub fn date_trunc<T>(part: impl IntoExpr<String>, value: impl IntoExpr<T>) -> Expr<T> {
38    let part = part.into_expr();
39    let value = value.into_expr();
40    Expr::new(ExprNode::Func {
41        name: "DATE_TRUNC",
42        args: vec![part.node, value.node],
43    })
44}
45
46/// Marker trait for values that can participate in vector distance/similarity expressions.
47pub trait VectorExpr<const N: usize> {}
48
49impl<const N: usize> VectorExpr<N> for PgVector<N> {}
50impl<const N: usize> VectorExpr<N> for Option<PgVector<N>> {}
51
52fn vector_binary_fn<const N: usize, L, R>(name: &'static str, left: impl IntoExpr<L>, right: impl IntoExpr<R>) -> Expr<f32>
53where
54    L: VectorExpr<N>,
55    R: VectorExpr<N>,
56{
57    let left = left.into_expr();
58    let right = right.into_expr();
59    Expr::new(ExprNode::Func {
60        name,
61        args: vec![left.node, right.node],
62    })
63}
64
65fn vector_binary_operator<const N: usize, L, R>(op: VectorBinaryOp, left: impl IntoExpr<L>, right: impl IntoExpr<R>) -> Expr<f32>
66where
67    L: VectorExpr<N>,
68    R: VectorExpr<N>,
69{
70    let left = left.into_expr();
71    let right = right.into_expr();
72    Expr::new(ExprNode::VectorBinary {
73        left: Box::new(left.node),
74        op,
75        right: Box::new(right.node),
76    })
77}
78
79/// Euclidean (L2) distance using pgvector's `<->` operator.
80///
81/// Lower is more similar.
82///
83/// ANN note:
84/// - This form is operator-based and can use pgvector ivfflat/hnsw indexes for
85///   `ORDER BY ... LIMIT` nearest-neighbor queries.
86pub fn l2_distance<const N: usize, L, R>(left: impl IntoExpr<L>, right: impl IntoExpr<R>) -> Expr<f32>
87where
88    L: VectorExpr<N>,
89    R: VectorExpr<N>,
90{
91    vector_binary_operator::<N, L, R>(VectorBinaryOp::L2Distance, left, right)
92}
93
94/// Cosine distance using pgvector's `<=>` operator.
95///
96/// Lower is more similar.
97///
98/// ANN note:
99/// - This form is operator-based and can use pgvector ivfflat/hnsw indexes for
100///   `ORDER BY ... LIMIT` nearest-neighbor queries.
101pub fn cosine_distance<const N: usize, L, R>(left: impl IntoExpr<L>, right: impl IntoExpr<R>) -> Expr<f32>
102where
103    L: VectorExpr<N>,
104    R: VectorExpr<N>,
105{
106    vector_binary_operator::<N, L, R>(VectorBinaryOp::CosineDistance, left, right)
107}
108
109/// True inner product as a function expression (`INNER_PRODUCT(a, b)`).
110///
111/// Higher is more similar (for normalized embeddings, identical vectors are `1.0`).
112///
113/// ANN warning:
114/// - This is intentionally a function call to preserve true inner-product semantics,
115///   but function expressions are generally not pgvector ANN index-compatible for
116///   `ORDER BY ... LIMIT`.
117/// - For ANN-indexed retrieval, use [`inner_product_distance`] with `ORDER BY ASC`.
118pub fn inner_product<const N: usize, L, R>(left: impl IntoExpr<L>, right: impl IntoExpr<R>) -> Expr<f32>
119where
120    L: VectorExpr<N>,
121    R: VectorExpr<N>,
122{
123    vector_binary_fn::<N, L, R>("INNER_PRODUCT", left, right)
124}
125
126/// L1 (Manhattan) distance using pgvector's `<+>` operator.
127///
128/// Lower is more similar.
129///
130/// ANN note:
131/// - This form is operator-based and can use pgvector ivfflat/hnsw indexes for
132///   `ORDER BY ... LIMIT` nearest-neighbor queries.
133pub fn l1_distance<const N: usize, L, R>(left: impl IntoExpr<L>, right: impl IntoExpr<R>) -> Expr<f32>
134where
135    L: VectorExpr<N>,
136    R: VectorExpr<N>,
137{
138    vector_binary_operator::<N, L, R>(VectorBinaryOp::L1Distance, left, right)
139}
140
141/// Negative inner-product distance using pgvector's `<#>` operator.
142///
143/// Lower is more similar, so nearest-neighbor queries should use `ORDER BY ASC`.
144///
145/// ANN note:
146/// - This form is operator-based and can use pgvector ivfflat/hnsw indexes for
147///   `ORDER BY ... LIMIT` nearest-neighbor queries.
148/// - Thresholds are inverted relative to true inner product
149///   (for example `inner_product > 0.9` corresponds to
150///   `inner_product_distance < -0.9`).
151pub fn inner_product_distance<const N: usize, L, R>(left: impl IntoExpr<L>, right: impl IntoExpr<R>) -> Expr<f32>
152where
153    L: VectorExpr<N>,
154    R: VectorExpr<N>,
155{
156    vector_binary_operator::<N, L, R>(VectorBinaryOp::InnerProductDistance, left, right)
157}