Skip to main content

diesel_clickhouse/
vectors.rs

1//! Helpers for ClickHouse vector-search expressions.
2
3use std::marker::PhantomData;
4
5use diesel::backend::Backend;
6use diesel::expression::{AppearsOnTable, Expression, SelectableExpression, ValidGrouping};
7use diesel::query_builder::{AstPass, QueryFragment, QueryId};
8use diesel::result::{Error, QueryResult};
9use diesel::sql_types::{Double, Float, SqlType};
10
11use crate::types::Array;
12
13/// Build a ClickHouse array literal typed as `Array(Float32)`.
14pub fn vector_f32<I>(values: I) -> VectorLiteral<Float>
15where
16    I: IntoIterator<Item = f32>,
17{
18    VectorLiteral::new(values.into_iter().map(f64::from).collect())
19}
20
21/// Build a ClickHouse array literal typed as `Array(Float64)`.
22pub fn vector_f64<I>(values: I) -> VectorLiteral<Double>
23where
24    I: IntoIterator<Item = f64>,
25{
26    VectorLiteral::new(values.into_iter().collect())
27}
28
29/// Reinterpret a binary string expression as `Array(Float32)`.
30///
31/// ClickHouse expects bytes in little-endian element order. For clients that
32/// cannot bind arbitrary bytes as a string literal, use [`vector_f32_hex`]
33/// with [`vector_f32_le_hex`] instead.
34pub fn vector_f32_binary<Expr>(bytes: Expr) -> VectorBytes<Expr, Float>
35where
36    Expr: Expression,
37{
38    VectorBytes::new(bytes, "Float32", VectorBytesEncoding::Raw)
39}
40
41/// Reinterpret a binary string expression as `Array(Float64)`.
42pub fn vector_f64_binary<Expr>(bytes: Expr) -> VectorBytes<Expr, Double>
43where
44    Expr: Expression,
45{
46    VectorBytes::new(bytes, "Float64", VectorBytesEncoding::Raw)
47}
48
49/// Decode a hex string expression with `unhex` and reinterpret it as `Array(Float32)`.
50pub fn vector_f32_hex<Expr>(hex: Expr) -> VectorBytes<Expr, Float>
51where
52    Expr: Expression,
53{
54    VectorBytes::new(hex, "Float32", VectorBytesEncoding::Hex)
55}
56
57/// Decode a hex string expression with `unhex` and reinterpret it as `Array(Float64)`.
58pub fn vector_f64_hex<Expr>(hex: Expr) -> VectorBytes<Expr, Double>
59where
60    Expr: Expression,
61{
62    VectorBytes::new(hex, "Float64", VectorBytesEncoding::Hex)
63}
64
65/// Convert `f32` vector values into ClickHouse-compatible little-endian bytes.
66pub fn vector_f32_le_bytes<I>(values: I) -> Vec<u8>
67where
68    I: IntoIterator<Item = f32>,
69{
70    values.into_iter().flat_map(f32::to_le_bytes).collect()
71}
72
73/// Convert `f64` vector values into ClickHouse-compatible little-endian bytes.
74pub fn vector_f64_le_bytes<I>(values: I) -> Vec<u8>
75where
76    I: IntoIterator<Item = f64>,
77{
78    values.into_iter().flat_map(f64::to_le_bytes).collect()
79}
80
81/// Convert `f32` vector values into a lower-case hex string of little-endian bytes.
82pub fn vector_f32_le_hex<I>(values: I) -> String
83where
84    I: IntoIterator<Item = f32>,
85{
86    bytes_to_hex(vector_f32_le_bytes(values))
87}
88
89/// Convert `f64` vector values into a lower-case hex string of little-endian bytes.
90pub fn vector_f64_le_hex<I>(values: I) -> String
91where
92    I: IntoIterator<Item = f64>,
93{
94    bytes_to_hex(vector_f64_le_bytes(values))
95}
96
97/// ClickHouse vector literal rendered as `[x, y, ...]`.
98#[derive(Debug, Clone)]
99pub struct VectorLiteral<ST> {
100    values: Vec<f64>,
101    _sql_type: PhantomData<ST>,
102}
103
104/// ClickHouse binary-vector reinterpret expression.
105#[derive(Debug, Clone)]
106pub struct VectorBytes<Expr, ST> {
107    expr: Expr,
108    element_type: &'static str,
109    encoding: VectorBytesEncoding,
110    _sql_type: PhantomData<ST>,
111}
112
113/// How the input expression should be decoded before reinterpretation.
114#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
115pub enum VectorBytesEncoding {
116    /// The expression already evaluates to a binary `String`/`FixedString` value.
117    Raw,
118    /// The expression evaluates to a hex string and should be wrapped in `unhex(...)`.
119    Hex,
120}
121
122impl<ST> VectorLiteral<ST> {
123    fn new(values: Vec<f64>) -> Self {
124        Self {
125            values,
126            _sql_type: PhantomData,
127        }
128    }
129}
130
131impl<Expr, ST> VectorBytes<Expr, ST> {
132    fn new(expr: Expr, element_type: &'static str, encoding: VectorBytesEncoding) -> Self {
133        Self {
134            expr,
135            element_type,
136            encoding,
137            _sql_type: PhantomData,
138        }
139    }
140}
141
142impl<ST> Expression for VectorLiteral<ST>
143where
144    ST: SqlType,
145{
146    type SqlType = Array<ST>;
147}
148
149impl<Expr, ST> Expression for VectorBytes<Expr, ST>
150where
151    Expr: Expression,
152    ST: SqlType,
153{
154    type SqlType = Array<ST>;
155}
156
157impl<ST, GB> ValidGrouping<GB> for VectorLiteral<ST> {
158    type IsAggregate = diesel::expression::is_aggregate::No;
159}
160
161impl<Expr, ST, GB> ValidGrouping<GB> for VectorBytes<Expr, ST>
162where
163    Expr: ValidGrouping<GB>,
164{
165    type IsAggregate = Expr::IsAggregate;
166}
167
168impl<ST, QS> AppearsOnTable<QS> for VectorLiteral<ST> where Self: Expression {}
169impl<Expr, ST, QS> AppearsOnTable<QS> for VectorBytes<Expr, ST>
170where
171    Expr: AppearsOnTable<QS>,
172    Self: Expression,
173{
174}
175impl<ST, QS> SelectableExpression<QS> for VectorLiteral<ST> where Self: AppearsOnTable<QS> {}
176impl<Expr, ST, QS> SelectableExpression<QS> for VectorBytes<Expr, ST> where Self: AppearsOnTable<QS> {}
177
178impl<ST> QueryId for VectorLiteral<ST> {
179    type QueryId = ();
180    const HAS_STATIC_QUERY_ID: bool = false;
181}
182
183impl<Expr, ST> QueryId for VectorBytes<Expr, ST> {
184    type QueryId = ();
185    const HAS_STATIC_QUERY_ID: bool = false;
186}
187
188impl<ST, DB> QueryFragment<DB> for VectorLiteral<ST>
189where
190    DB: Backend,
191{
192    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
193        if self.values.is_empty() {
194            return Err(Error::QueryBuilderError(
195                "ClickHouse vector literal requires at least one value".into(),
196            ));
197        }
198
199        out.push_sql("[");
200        for (idx, value) in self.values.iter().enumerate() {
201            if !value.is_finite() {
202                return Err(Error::QueryBuilderError(
203                    format!("ClickHouse vector literal value must be finite, got {value}").into(),
204                ));
205            }
206            if idx > 0 {
207                out.push_sql(", ");
208            }
209            out.push_sql(&value.to_string());
210        }
211        out.push_sql("]");
212        Ok(())
213    }
214}
215
216impl<Expr, ST, DB> QueryFragment<DB> for VectorBytes<Expr, ST>
217where
218    DB: Backend,
219    Expr: QueryFragment<DB>,
220{
221    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
222        out.push_sql("reinterpret(");
223        if self.encoding == VectorBytesEncoding::Hex {
224            out.push_sql("unhex(");
225        }
226        self.expr.walk_ast(out.reborrow())?;
227        if self.encoding == VectorBytesEncoding::Hex {
228            out.push_sql(")");
229        }
230        out.push_sql(", '");
231        out.push_sql("Array(");
232        out.push_sql(self.element_type);
233        out.push_sql(")");
234        out.push_sql("')");
235        Ok(())
236    }
237}
238
239fn bytes_to_hex(bytes: Vec<u8>) -> String {
240    const HEX: &[u8; 16] = b"0123456789abcdef";
241    let mut output = String::with_capacity(bytes.len() * 2);
242    for byte in bytes {
243        output.push(HEX[(byte >> 4) as usize] as char);
244        output.push(HEX[(byte & 0x0f) as usize] as char);
245    }
246    output
247}