1use super::{Field, Value, WhereOperator};
20use crate::Result;
21use std::collections::HashMap;
22
23fn infer_type_cast(value: &Value) -> &'static str {
27 match value {
28 Value::String(_) => "::text",
29 Value::Number(_) => "::numeric", Value::Bool(_) => "::boolean",
31 Value::Null => "", Value::Array(_) => "", Value::FloatArray(_) => "", Value::RawSql(_) => "", }
36}
37
38pub fn generate_where_operator_sql(
61 operator: &WhereOperator,
62 param_index: &mut usize,
63 params: &mut HashMap<usize, Value>,
64) -> Result<String> {
65 operator.validate().map_err(crate::Error::InvalidSchema)?;
66
67 match operator {
68 WhereOperator::Eq(field, value) => {
72 let field_sql = field.to_sql();
73 if value.is_null() {
74 Ok(format!("{} IS NULL", field_sql))
75 } else {
76 let param_num = *param_index + 1;
77 *param_index += 1;
78 params.insert(param_num, value.clone());
79 let cast = match field {
81 Field::JsonbField(_) | Field::JsonbPath(_) => infer_type_cast(value),
82 Field::DirectColumn(_) => "", };
84 Ok(format!("{}{} = ${}", field_sql, cast, param_num))
85 }
86 }
87
88 WhereOperator::Neq(field, value) => {
89 let field_sql = field.to_sql();
90 if value.is_null() {
91 Ok(format!("{} IS NOT NULL", field_sql))
92 } else {
93 let param_num = *param_index + 1;
94 *param_index += 1;
95 params.insert(param_num, value.clone());
96 let cast = match field {
97 Field::JsonbField(_) | Field::JsonbPath(_) => infer_type_cast(value),
98 Field::DirectColumn(_) => "",
99 };
100 Ok(format!("{}{} != ${}", field_sql, cast, param_num))
101 }
102 }
103
104 WhereOperator::Gt(field, value) => {
105 let field_sql = field.to_sql();
106 let param_num = *param_index + 1;
107 *param_index += 1;
108 params.insert(param_num, value.clone());
109 let cast = match field {
110 Field::JsonbField(_) | Field::JsonbPath(_) => infer_type_cast(value),
111 Field::DirectColumn(_) => "",
112 };
113 Ok(format!("{}{} > ${}", field_sql, cast, param_num))
114 }
115
116 WhereOperator::Gte(field, value) => {
117 let field_sql = field.to_sql();
118 let param_num = *param_index + 1;
119 *param_index += 1;
120 params.insert(param_num, value.clone());
121 let cast = match field {
122 Field::JsonbField(_) | Field::JsonbPath(_) => infer_type_cast(value),
123 Field::DirectColumn(_) => "",
124 };
125 Ok(format!("{}{} >= ${}", field_sql, cast, param_num))
126 }
127
128 WhereOperator::Lt(field, value) => {
129 let field_sql = field.to_sql();
130 let param_num = *param_index + 1;
131 *param_index += 1;
132 params.insert(param_num, value.clone());
133 let cast = match field {
134 Field::JsonbField(_) | Field::JsonbPath(_) => infer_type_cast(value),
135 Field::DirectColumn(_) => "",
136 };
137 Ok(format!("{}{} < ${}", field_sql, cast, param_num))
138 }
139
140 WhereOperator::Lte(field, value) => {
141 let field_sql = field.to_sql();
142 let param_num = *param_index + 1;
143 *param_index += 1;
144 params.insert(param_num, value.clone());
145 let cast = match field {
146 Field::JsonbField(_) | Field::JsonbPath(_) => infer_type_cast(value),
147 Field::DirectColumn(_) => "",
148 };
149 Ok(format!("{}{} <= ${}", field_sql, cast, param_num))
150 }
151
152 WhereOperator::In(field, values) => {
154 let field_sql = field.to_sql();
155 let placeholders: Vec<String> = values
156 .iter()
157 .map(|v| {
158 let param_num = *param_index + 1;
159 *param_index += 1;
160 params.insert(param_num, v.clone());
161 format!("${}", param_num)
162 })
163 .collect();
164 Ok(format!("{} IN ({})", field_sql, placeholders.join(", ")))
165 }
166
167 WhereOperator::Nin(field, values) => {
168 let field_sql = field.to_sql();
169 let placeholders: Vec<String> = values
170 .iter()
171 .map(|v| {
172 let param_num = *param_index + 1;
173 *param_index += 1;
174 params.insert(param_num, v.clone());
175 format!("${}", param_num)
176 })
177 .collect();
178 Ok(format!(
179 "{} NOT IN ({})",
180 field_sql,
181 placeholders.join(", ")
182 ))
183 }
184
185 WhereOperator::Contains(field, substring) => {
186 let field_sql = field.to_sql();
187 let param_num = *param_index + 1;
188 *param_index += 1;
189 params.insert(param_num, Value::String(substring.clone()));
190 Ok(format!(
191 "{} LIKE '%' || ${}::text || '%'",
192 field_sql, param_num
193 ))
194 }
195
196 WhereOperator::ArrayContains(field, value) => {
197 let field_sql = field.to_sql();
198 let param_num = *param_index + 1;
199 *param_index += 1;
200 params.insert(param_num, value.clone());
201 Ok(format!("{} @> ARRAY[${}]", field_sql, param_num))
202 }
203
204 WhereOperator::ArrayContainedBy(field, value) => {
205 let field_sql = field.to_sql();
206 let param_num = *param_index + 1;
207 *param_index += 1;
208 params.insert(param_num, value.clone());
209 Ok(format!("{} <@ ARRAY[${}]", field_sql, param_num))
210 }
211
212 WhereOperator::ArrayOverlaps(field, values) => {
213 let field_sql = field.to_sql();
214 let placeholders: Vec<String> = values
215 .iter()
216 .map(|v| {
217 let param_num = *param_index + 1;
218 *param_index += 1;
219 params.insert(param_num, v.clone());
220 format!("${}", param_num)
221 })
222 .collect();
223 Ok(format!(
224 "{} && ARRAY[{}]",
225 field_sql,
226 placeholders.join(", ")
227 ))
228 }
229
230 WhereOperator::LenEq(field, len) => {
232 let field_sql = field.to_sql();
233 Ok(format!("array_length({}, 1) = {}", field_sql, len))
234 }
235
236 WhereOperator::LenGt(field, len) => {
237 let field_sql = field.to_sql();
238 Ok(format!("array_length({}, 1) > {}", field_sql, len))
239 }
240
241 WhereOperator::LenGte(field, len) => {
242 let field_sql = field.to_sql();
243 Ok(format!("array_length({}, 1) >= {}", field_sql, len))
244 }
245
246 WhereOperator::LenLt(field, len) => {
247 let field_sql = field.to_sql();
248 Ok(format!("array_length({}, 1) < {}", field_sql, len))
249 }
250
251 WhereOperator::LenLte(field, len) => {
252 let field_sql = field.to_sql();
253 Ok(format!("array_length({}, 1) <= {}", field_sql, len))
254 }
255
256 WhereOperator::Icontains(field, substring) => {
258 let field_sql = field.to_sql();
259 let param_num = *param_index + 1;
260 *param_index += 1;
261 params.insert(param_num, Value::String(substring.clone()));
262 Ok(format!(
263 "{} ILIKE '%' || ${}::text || '%'",
264 field_sql, param_num
265 ))
266 }
267
268 WhereOperator::Startswith(field, prefix) => {
269 let field_sql = field.to_sql();
270 let param_num = *param_index + 1;
271 *param_index += 1;
272 params.insert(param_num, Value::String(format!("{}%", prefix)));
273 Ok(format!("{} LIKE ${}", field_sql, param_num))
274 }
275
276 WhereOperator::Endswith(field, suffix) => {
277 let field_sql = field.to_sql();
278 let param_num = *param_index + 1;
279 *param_index += 1;
280 params.insert(param_num, Value::String(format!("%{}", suffix)));
281 Ok(format!("{} LIKE ${}", field_sql, param_num))
282 }
283
284 WhereOperator::Like(field, pattern) => {
285 let field_sql = field.to_sql();
286 let param_num = *param_index + 1;
287 *param_index += 1;
288 params.insert(param_num, Value::String(pattern.clone()));
289 Ok(format!("{} LIKE ${}", field_sql, param_num))
290 }
291
292 WhereOperator::Ilike(field, pattern) => {
293 let field_sql = field.to_sql();
294 let param_num = *param_index + 1;
295 *param_index += 1;
296 params.insert(param_num, Value::String(pattern.clone()));
297 Ok(format!("{} ILIKE ${}", field_sql, param_num))
298 }
299
300 WhereOperator::IsNull(field, is_null) => {
302 let field_sql = field.to_sql();
303 if *is_null {
304 Ok(format!("{} IS NULL", field_sql))
305 } else {
306 Ok(format!("{} IS NOT NULL", field_sql))
307 }
308 }
309
310 WhereOperator::L2Distance {
312 field,
313 vector,
314 threshold,
315 } => {
316 let field_sql = field.to_sql();
317 let param_num = *param_index + 1;
318 *param_index += 1;
319 params.insert(param_num, Value::FloatArray(vector.clone()));
320 Ok(format!(
321 "l2_distance({}::vector, ${}::vector) < {}",
322 field_sql, param_num, threshold
323 ))
324 }
325
326 WhereOperator::CosineDistance {
327 field,
328 vector,
329 threshold,
330 } => {
331 let field_sql = field.to_sql();
332 let param_num = *param_index + 1;
333 *param_index += 1;
334 params.insert(param_num, Value::FloatArray(vector.clone()));
335 Ok(format!(
336 "cosine_distance({}::vector, ${}::vector) < {}",
337 field_sql, param_num, threshold
338 ))
339 }
340
341 WhereOperator::InnerProduct {
342 field,
343 vector,
344 threshold,
345 } => {
346 let field_sql = field.to_sql();
347 let param_num = *param_index + 1;
348 *param_index += 1;
349 params.insert(param_num, Value::FloatArray(vector.clone()));
350 Ok(format!(
351 "inner_product({}::vector, ${}::vector) > {}",
352 field_sql, param_num, threshold
353 ))
354 }
355
356 WhereOperator::JaccardDistance {
357 field,
358 set,
359 threshold,
360 } => {
361 let field_sql = field.to_sql();
362 let param_num = *param_index + 1;
363 *param_index += 1;
364 let value_array: Vec<Value> = set.iter().map(|s| Value::String(s.clone())).collect();
365 params.insert(param_num, Value::Array(value_array));
366 Ok(format!(
367 "jaccard_distance({}::text[], ${}::text[]) < {}",
368 field_sql, param_num, threshold
369 ))
370 }
371
372 WhereOperator::Matches {
374 field,
375 query,
376 language,
377 } => {
378 let field_sql = field.to_sql();
379 let param_num = *param_index + 1;
380 *param_index += 1;
381 params.insert(param_num, Value::String(query.clone()));
382 let lang = language.as_deref().unwrap_or("english");
383 Ok(format!(
384 "{} @@ plainto_tsquery('{}', ${})",
385 field_sql, lang, param_num
386 ))
387 }
388
389 WhereOperator::PlainQuery { field, query } => {
390 let field_sql = field.to_sql();
391 let param_num = *param_index + 1;
392 *param_index += 1;
393 params.insert(param_num, Value::String(query.clone()));
394 Ok(format!(
395 "{} @@ plainto_tsquery(${})::tsvector",
396 field_sql, param_num
397 ))
398 }
399
400 WhereOperator::PhraseQuery {
401 field,
402 query,
403 language,
404 } => {
405 let field_sql = field.to_sql();
406 let param_num = *param_index + 1;
407 *param_index += 1;
408 params.insert(param_num, Value::String(query.clone()));
409 let lang = language.as_deref().unwrap_or("english");
410 Ok(format!(
411 "{} @@ phraseto_tsquery('{}', ${})",
412 field_sql, lang, param_num
413 ))
414 }
415
416 WhereOperator::WebsearchQuery {
417 field,
418 query,
419 language,
420 } => {
421 let field_sql = field.to_sql();
422 let param_num = *param_index + 1;
423 *param_index += 1;
424 params.insert(param_num, Value::String(query.clone()));
425 let lang = language.as_deref().unwrap_or("english");
426 Ok(format!(
427 "{} @@ websearch_to_tsquery('{}', ${})",
428 field_sql, lang, param_num
429 ))
430 }
431
432 WhereOperator::IsIPv4(field) => {
434 let field_sql = field.to_sql();
435 Ok(format!("family({}::inet) = 4", field_sql))
436 }
437
438 WhereOperator::IsIPv6(field) => {
439 let field_sql = field.to_sql();
440 Ok(format!("family({}::inet) = 6", field_sql))
441 }
442
443 WhereOperator::IsPrivate(field) => {
444 let field_sql = field.to_sql();
445 Ok(format!(
447 "({}::inet << '10.0.0.0/8'::inet OR {}::inet << '172.16.0.0/12'::inet OR {}::inet << '192.168.0.0/16'::inet OR {}::inet << '169.254.0.0/16'::inet)",
448 field_sql, field_sql, field_sql, field_sql
449 ))
450 }
451
452 WhereOperator::IsLoopback(field) => {
453 let field_sql = field.to_sql();
454 Ok(format!(
455 "(family({}::inet) = 4 AND {}::inet << '127.0.0.0/8'::inet) OR (family({}::inet) = 6 AND {}::inet << '::1/128'::inet)",
456 field_sql, field_sql, field_sql, field_sql
457 ))
458 }
459
460 WhereOperator::InSubnet { field, subnet } => {
461 let field_sql = field.to_sql();
462 let param_num = *param_index + 1;
463 *param_index += 1;
464 params.insert(param_num, Value::String(subnet.clone()));
465 Ok(format!("{}::inet << ${}::inet", field_sql, param_num))
466 }
467
468 WhereOperator::ContainsSubnet { field, subnet } => {
469 let field_sql = field.to_sql();
470 let param_num = *param_index + 1;
471 *param_index += 1;
472 params.insert(param_num, Value::String(subnet.clone()));
473 Ok(format!("{}::inet >> ${}::inet", field_sql, param_num))
474 }
475
476 WhereOperator::ContainsIP { field, ip } => {
477 let field_sql = field.to_sql();
478 let param_num = *param_index + 1;
479 *param_index += 1;
480 params.insert(param_num, Value::String(ip.clone()));
481 Ok(format!("{}::inet >> ${}::inet", field_sql, param_num))
482 }
483
484 WhereOperator::IPRangeOverlap { field, range } => {
485 let field_sql = field.to_sql();
486 let param_num = *param_index + 1;
487 *param_index += 1;
488 params.insert(param_num, Value::String(range.clone()));
489 Ok(format!("{}::inet && ${}::inet", field_sql, param_num))
490 }
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[test]
499 fn test_eq_operator_jsonb_string() {
500 let mut param_index = 0;
501 let mut params = HashMap::new();
502 let op = WhereOperator::Eq(
503 Field::JsonbField("name".to_string()),
504 Value::String("John".to_string()),
505 );
506 let sql = generate_where_operator_sql(&op, &mut param_index, &mut params).unwrap();
507 assert_eq!(sql, "(data->'name')::text = $1");
509 assert_eq!(param_index, 1);
510 }
511
512 #[test]
513 fn test_eq_operator_direct_column() {
514 let mut param_index = 0;
515 let mut params = HashMap::new();
516 let op = WhereOperator::Eq(
517 Field::DirectColumn("status".to_string()),
518 Value::String("active".to_string()),
519 );
520 let sql = generate_where_operator_sql(&op, &mut param_index, &mut params).unwrap();
521 assert_eq!(sql, "status = $1");
523 assert_eq!(param_index, 1);
524 }
525
526 #[test]
527 fn test_len_eq_operator() {
528 let mut param_index = 0;
529 let mut params = HashMap::new();
530 let op = WhereOperator::LenEq(Field::JsonbField("tags".to_string()), 5);
531 let sql = generate_where_operator_sql(&op, &mut param_index, &mut params).unwrap();
532 assert_eq!(sql, "array_length((data->'tags'), 1) = 5");
533 assert_eq!(param_index, 0); }
535
536 #[test]
537 fn test_is_ipv4_operator() {
538 let mut param_index = 0;
539 let mut params = HashMap::new();
540 let op = WhereOperator::IsIPv4(Field::JsonbField("ip".to_string()));
541 let sql = generate_where_operator_sql(&op, &mut param_index, &mut params).unwrap();
542 assert_eq!(sql, "family((data->'ip')::inet) = 4");
543 }
544
545 #[test]
546 fn test_l2_distance_operator() {
547 let mut param_index = 0;
548 let mut params = HashMap::new();
549 let op = WhereOperator::L2Distance {
550 field: Field::JsonbField("embedding".to_string()),
551 vector: vec![0.1, 0.2, 0.3],
552 threshold: 0.5,
553 };
554 let sql = generate_where_operator_sql(&op, &mut param_index, &mut params).unwrap();
555 assert_eq!(
556 sql,
557 "l2_distance((data->'embedding')::vector, $1::vector) < 0.5"
558 );
559 assert_eq!(param_index, 1);
560 }
561
562 #[test]
563 fn test_in_operator() {
564 let mut param_index = 0;
565 let mut params = HashMap::new();
566 let op = WhereOperator::In(
567 Field::JsonbField("status".to_string()),
568 vec![
569 Value::String("active".to_string()),
570 Value::String("pending".to_string()),
571 ],
572 );
573 let sql = generate_where_operator_sql(&op, &mut param_index, &mut params).unwrap();
574 assert_eq!(sql, "(data->'status') IN ($1, $2)");
575 assert_eq!(param_index, 2);
576 }
577}