1use std::marker::PhantomData;
4
5use diesel::backend::Backend;
6use diesel::expression::{AppearsOnTable, Expression, SelectableExpression, ValidGrouping};
7use diesel::query_builder::{AstPass, QueryFragment, QueryId};
8use diesel::result::{Error, QueryResult};
9use diesel::sql_types::{Double, Float, SqlType};
10
11use crate::types::Array;
12
13pub fn vector_f32<I>(values: I) -> VectorLiteral<Float>
15where
16 I: IntoIterator<Item = f32>,
17{
18 VectorLiteral::new(values.into_iter().map(f64::from).collect())
19}
20
21pub fn vector_f64<I>(values: I) -> VectorLiteral<Double>
23where
24 I: IntoIterator<Item = f64>,
25{
26 VectorLiteral::new(values.into_iter().collect())
27}
28
29pub fn vector_f32_binary<Expr>(bytes: Expr) -> VectorBytes<Expr, Float>
35where
36 Expr: Expression,
37{
38 VectorBytes::new(bytes, "Float32", VectorBytesEncoding::Raw)
39}
40
41pub fn vector_f64_binary<Expr>(bytes: Expr) -> VectorBytes<Expr, Double>
43where
44 Expr: Expression,
45{
46 VectorBytes::new(bytes, "Float64", VectorBytesEncoding::Raw)
47}
48
49pub fn vector_f32_hex<Expr>(hex: Expr) -> VectorBytes<Expr, Float>
51where
52 Expr: Expression,
53{
54 VectorBytes::new(hex, "Float32", VectorBytesEncoding::Hex)
55}
56
57pub fn vector_f64_hex<Expr>(hex: Expr) -> VectorBytes<Expr, Double>
59where
60 Expr: Expression,
61{
62 VectorBytes::new(hex, "Float64", VectorBytesEncoding::Hex)
63}
64
65pub fn vector_f32_le_bytes<I>(values: I) -> Vec<u8>
67where
68 I: IntoIterator<Item = f32>,
69{
70 values.into_iter().flat_map(f32::to_le_bytes).collect()
71}
72
73pub fn vector_f64_le_bytes<I>(values: I) -> Vec<u8>
75where
76 I: IntoIterator<Item = f64>,
77{
78 values.into_iter().flat_map(f64::to_le_bytes).collect()
79}
80
81pub fn vector_f32_le_hex<I>(values: I) -> String
83where
84 I: IntoIterator<Item = f32>,
85{
86 bytes_to_hex(vector_f32_le_bytes(values))
87}
88
89pub fn vector_f64_le_hex<I>(values: I) -> String
91where
92 I: IntoIterator<Item = f64>,
93{
94 bytes_to_hex(vector_f64_le_bytes(values))
95}
96
97#[derive(Debug, Clone)]
99pub struct VectorLiteral<ST> {
100 values: Vec<f64>,
101 _sql_type: PhantomData<ST>,
102}
103
104#[derive(Debug, Clone)]
106pub struct VectorBytes<Expr, ST> {
107 expr: Expr,
108 element_type: &'static str,
109 encoding: VectorBytesEncoding,
110 _sql_type: PhantomData<ST>,
111}
112
113#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
115pub enum VectorBytesEncoding {
116 Raw,
118 Hex,
120}
121
122impl<ST> VectorLiteral<ST> {
123 fn new(values: Vec<f64>) -> Self {
124 Self {
125 values,
126 _sql_type: PhantomData,
127 }
128 }
129}
130
131impl<Expr, ST> VectorBytes<Expr, ST> {
132 fn new(expr: Expr, element_type: &'static str, encoding: VectorBytesEncoding) -> Self {
133 Self {
134 expr,
135 element_type,
136 encoding,
137 _sql_type: PhantomData,
138 }
139 }
140}
141
142impl<ST> Expression for VectorLiteral<ST>
143where
144 ST: SqlType,
145{
146 type SqlType = Array<ST>;
147}
148
149impl<Expr, ST> Expression for VectorBytes<Expr, ST>
150where
151 Expr: Expression,
152 ST: SqlType,
153{
154 type SqlType = Array<ST>;
155}
156
157impl<ST, GB> ValidGrouping<GB> for VectorLiteral<ST> {
158 type IsAggregate = diesel::expression::is_aggregate::No;
159}
160
161impl<Expr, ST, GB> ValidGrouping<GB> for VectorBytes<Expr, ST>
162where
163 Expr: ValidGrouping<GB>,
164{
165 type IsAggregate = Expr::IsAggregate;
166}
167
168impl<ST, QS> AppearsOnTable<QS> for VectorLiteral<ST> where Self: Expression {}
169impl<Expr, ST, QS> AppearsOnTable<QS> for VectorBytes<Expr, ST>
170where
171 Expr: AppearsOnTable<QS>,
172 Self: Expression,
173{
174}
175impl<ST, QS> SelectableExpression<QS> for VectorLiteral<ST> where Self: AppearsOnTable<QS> {}
176impl<Expr, ST, QS> SelectableExpression<QS> for VectorBytes<Expr, ST> where Self: AppearsOnTable<QS> {}
177
178impl<ST> QueryId for VectorLiteral<ST> {
179 type QueryId = ();
180 const HAS_STATIC_QUERY_ID: bool = false;
181}
182
183impl<Expr, ST> QueryId for VectorBytes<Expr, ST> {
184 type QueryId = ();
185 const HAS_STATIC_QUERY_ID: bool = false;
186}
187
188impl<ST, DB> QueryFragment<DB> for VectorLiteral<ST>
189where
190 DB: Backend,
191{
192 fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
193 if self.values.is_empty() {
194 return Err(Error::QueryBuilderError(
195 "ClickHouse vector literal requires at least one value".into(),
196 ));
197 }
198
199 out.push_sql("[");
200 for (idx, value) in self.values.iter().enumerate() {
201 if !value.is_finite() {
202 return Err(Error::QueryBuilderError(
203 format!("ClickHouse vector literal value must be finite, got {value}").into(),
204 ));
205 }
206 if idx > 0 {
207 out.push_sql(", ");
208 }
209 out.push_sql(&value.to_string());
210 }
211 out.push_sql("]");
212 Ok(())
213 }
214}
215
216impl<Expr, ST, DB> QueryFragment<DB> for VectorBytes<Expr, ST>
217where
218 DB: Backend,
219 Expr: QueryFragment<DB>,
220{
221 fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
222 out.push_sql("reinterpret(");
223 if self.encoding == VectorBytesEncoding::Hex {
224 out.push_sql("unhex(");
225 }
226 self.expr.walk_ast(out.reborrow())?;
227 if self.encoding == VectorBytesEncoding::Hex {
228 out.push_sql(")");
229 }
230 out.push_sql(", '");
231 out.push_sql("Array(");
232 out.push_sql(self.element_type);
233 out.push_sql(")");
234 out.push_sql("')");
235 Ok(())
236 }
237}
238
239fn bytes_to_hex(bytes: Vec<u8>) -> String {
240 const HEX: &[u8; 16] = b"0123456789abcdef";
241 let mut output = String::with_capacity(bytes.len() * 2);
242 for byte in bytes {
243 output.push(HEX[(byte >> 4) as usize] as char);
244 output.push(HEX[(byte & 0x0f) as usize] as char);
245 }
246 output
247}