1use crate::dialect::SqlDialect;
2use crate::executor::Executor;
3use crate::model::Model;
4use futures_util::StreamExt;
5use smallvec::{SmallVec, smallvec};
6use sqlx::{Database, IntoArguments};
7use std::time::{Duration, Instant};
8
9#[doc(hidden)]
10#[derive(Debug, Clone)]
11pub enum BindValue {
12 String(String),
13 I64(i64),
14 F64(f64),
15 Bool(bool),
16 Uuid(uuid::Uuid),
17 DateTime(chrono::DateTime<chrono::Utc>),
18 Null,
19}
20
21impl BindValue {
22 fn to_log_string(&self) -> String {
23 match self {
24 BindValue::String(v) => v.clone(),
25 BindValue::I64(v) => v.to_string(),
26 BindValue::F64(v) => v.to_string(),
27 BindValue::Bool(v) => v.to_string(),
28 BindValue::Uuid(v) => v.to_string(),
29 BindValue::DateTime(v) => v.to_rfc3339(),
30 BindValue::Null => "NULL".to_string(),
31 }
32 }
33}
34
35#[cfg(feature = "metrics")]
36fn record_query_metrics(operation: &str, table: &str, elapsed: Duration) {
37 let elapsed_ms = elapsed.as_secs_f64() * 1000.0;
38 let labels = [
39 ("operation", operation.to_string()),
40 ("table", table.to_string()),
41 ];
42 metrics::histogram!("premix.query.duration_ms", &labels).record(elapsed_ms);
43 metrics::counter!("premix.query.count", &labels).increment(1);
44}
45
46#[cfg(not(feature = "metrics"))]
47fn record_query_metrics(_operation: &str, _table: &str, _elapsed: Duration) {}
48
49#[inline(always)]
50fn bind_value_query<'q, DB>(
51 query: sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>,
52 value: BindValue,
53) -> sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>
54where
55 DB: Database,
56 String: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
57 i64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
58 f64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
59 bool: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
60 uuid::Uuid: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
61 chrono::DateTime<chrono::Utc>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
62 Option<String>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
63{
64 match value {
65 BindValue::String(v) => query.bind(v),
66 BindValue::I64(v) => query.bind(v),
67 BindValue::F64(v) => query.bind(v),
68 BindValue::Bool(v) => query.bind(v),
69 BindValue::Uuid(v) => query.bind(v),
70 BindValue::DateTime(v) => query.bind(v),
71 BindValue::Null => query.bind(Option::<String>::None),
72 }
73}
74
75#[inline(always)]
76fn bind_value_query_as<'q, DB, T>(
77 query: sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>,
78 value: BindValue,
79) -> sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>
80where
81 DB: Database,
82 String: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
83 i64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
84 f64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
85 bool: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
86 uuid::Uuid: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
87 chrono::DateTime<chrono::Utc>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
88 Option<String>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
89{
90 match value {
91 BindValue::String(v) => query.bind(v),
92 BindValue::I64(v) => query.bind(v),
93 BindValue::F64(v) => query.bind(v),
94 BindValue::Bool(v) => query.bind(v),
95 BindValue::Uuid(v) => query.bind(v),
96 BindValue::DateTime(v) => query.bind(v),
97 BindValue::Null => query.bind(Option::<String>::None),
98 }
99}
100
101#[inline]
102fn apply_persistent_query_as<'q, DB, T>(
103 query: sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>,
104 prepared: bool,
105) -> sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>
106where
107 DB: Database + sqlx::database::HasStatementCache,
108{
109 if prepared {
110 query.persistent(true)
111 } else {
112 query
113 }
114}
115
116#[inline]
117fn apply_persistent_query<'q, DB>(
118 query: sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>,
119 prepared: bool,
120) -> sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>
121where
122 DB: Database + sqlx::database::HasStatementCache,
123{
124 if prepared {
125 query.persistent(true)
126 } else {
127 query
128 }
129}
130
131impl From<String> for BindValue {
132 fn from(value: String) -> Self {
133 Self::String(value)
134 }
135}
136
137impl From<&str> for BindValue {
138 fn from(value: &str) -> Self {
139 Self::String(value.to_string())
140 }
141}
142
143impl From<i32> for BindValue {
144 fn from(value: i32) -> Self {
145 Self::I64(value as i64)
146 }
147}
148
149impl From<i64> for BindValue {
150 fn from(value: i64) -> Self {
151 Self::I64(value)
152 }
153}
154
155impl From<f64> for BindValue {
156 fn from(value: f64) -> Self {
157 Self::F64(value)
158 }
159}
160
161impl From<bool> for BindValue {
162 fn from(value: bool) -> Self {
163 Self::Bool(value)
164 }
165}
166
167impl From<Option<String>> for BindValue {
168 fn from(value: Option<String>) -> Self {
169 match value {
170 Some(v) => Self::String(v),
171 None => Self::Null,
172 }
173 }
174}
175
176impl From<uuid::Uuid> for BindValue {
177 fn from(value: uuid::Uuid) -> Self {
178 Self::Uuid(value)
179 }
180}
181
182impl From<chrono::DateTime<chrono::Utc>> for BindValue {
183 fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
184 Self::DateTime(value)
185 }
186}
187
188#[derive(Debug, Clone)]
189pub(crate) enum FilterExpr {
190 Raw(String),
191 Compare {
192 column: ColumnRef,
193 op: FilterOp,
194 values: SmallVec<[BindValue; 2]>,
195 },
196 NullCheck {
197 column: ColumnRef,
198 is_null: bool,
199 },
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq)]
203pub(crate) enum FilterOp {
204 Eq,
205 Ne,
206 Lt,
207 Lte,
208 Gt,
209 Gte,
210 Like,
211 In,
212}
213
214impl FilterOp {
215 fn as_str(self) -> &'static str {
216 match self {
217 FilterOp::Eq => "=",
218 FilterOp::Ne => "!=",
219 FilterOp::Lt => "<",
220 FilterOp::Lte => "<=",
221 FilterOp::Gt => ">",
222 FilterOp::Gte => ">=",
223 FilterOp::Like => "LIKE",
224 FilterOp::In => "IN",
225 }
226 }
227
228 fn is_in(self) -> bool {
229 matches!(self, FilterOp::In)
230 }
231}
232
233#[derive(Debug, Clone, PartialEq, Eq)]
235pub enum ColumnRef {
236 Static(&'static str),
238 Owned(String),
240}
241
242impl ColumnRef {
243 pub const fn static_str(value: &'static str) -> Self {
245 ColumnRef::Static(value)
246 }
247
248 fn as_str(&self) -> &str {
249 match self {
250 ColumnRef::Static(value) => value,
251 ColumnRef::Owned(value) => value,
252 }
253 }
254}
255
256impl From<&str> for ColumnRef {
257 fn from(value: &str) -> Self {
258 ColumnRef::Owned(value.to_string())
259 }
260}
261
262impl From<String> for ColumnRef {
263 fn from(value: String) -> Self {
264 ColumnRef::Owned(value)
265 }
266}
267
268impl From<&String> for ColumnRef {
269 fn from(value: &String) -> Self {
270 ColumnRef::Owned(value.clone())
271 }
272}
273
274pub struct QueryBuilder<'a, T, DB: Database> {
279 executor: Executor<'a, DB>,
280 filters: Vec<FilterExpr>,
281 limit: Option<i32>,
282 offset: Option<i32>,
283 includes: SmallVec<[String; 2]>,
284 include_deleted: bool,
285 allow_unsafe: bool,
286 has_raw_filter: bool,
287 fast_path: bool,
288 unsafe_fast: bool,
289 ultra_fast: bool,
290 prepared: bool,
291 _marker: std::marker::PhantomData<T>,
292}
293
294impl<'a, T, DB: Database> std::fmt::Debug for QueryBuilder<'a, T, DB> {
295 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296 f.debug_struct("QueryBuilder")
297 .field("filters", &self.filters)
298 .field("limit", &self.limit)
299 .field("offset", &self.offset)
300 .field("includes", &self.includes)
301 .field("include_deleted", &self.include_deleted)
302 .field("allow_unsafe", &self.allow_unsafe)
303 .field("fast_path", &self.fast_path)
304 .field("unsafe_fast", &self.unsafe_fast)
305 .field("ultra_fast", &self.ultra_fast)
306 .field("prepared", &self.prepared)
307 .finish()
308 }
309}
310
311impl<'a, T, DB> QueryBuilder<'a, T, DB>
312where
313 DB: SqlDialect + sqlx::database::HasStatementCache,
314 T: Model<DB>,
315{
316 pub fn new(executor: Executor<'a, DB>) -> Self {
318 Self {
319 executor,
320 filters: Vec::with_capacity(4), limit: None,
322 offset: None,
323 includes: SmallVec::with_capacity(2), include_deleted: false,
325 allow_unsafe: false,
326 has_raw_filter: false,
327 fast_path: false,
328 unsafe_fast: false,
329 ultra_fast: false,
330 prepared: true,
331 _marker: std::marker::PhantomData,
332 }
333 }
334
335 pub fn filter(mut self, condition: impl Into<String>) -> Self {
340 self.filters.push(FilterExpr::Raw(condition.into()));
341 self.has_raw_filter = true;
342 self
343 }
344
345 pub fn filter_raw(self, condition: impl Into<String>) -> Self {
347 self.filter(condition)
348 }
349
350 pub fn filter_eq(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
352 self.filters.push(FilterExpr::Compare {
353 column: column.into(),
354 op: FilterOp::Eq,
355 values: smallvec![value.into()],
356 });
357 self
358 }
359
360 pub fn filter_ne(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
362 self.filters.push(FilterExpr::Compare {
363 column: column.into(),
364 op: FilterOp::Ne,
365 values: smallvec![value.into()],
366 });
367 self
368 }
369
370 pub fn filter_lt(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
372 self.filters.push(FilterExpr::Compare {
373 column: column.into(),
374 op: FilterOp::Lt,
375 values: smallvec![value.into()],
376 });
377 self
378 }
379
380 pub fn filter_lte(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
382 self.filters.push(FilterExpr::Compare {
383 column: column.into(),
384 op: FilterOp::Lte,
385 values: smallvec![value.into()],
386 });
387 self
388 }
389
390 pub fn filter_gt(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
392 self.filters.push(FilterExpr::Compare {
393 column: column.into(),
394 op: FilterOp::Gt,
395 values: smallvec![value.into()],
396 });
397 self
398 }
399
400 pub fn filter_gte(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
402 self.filters.push(FilterExpr::Compare {
403 column: column.into(),
404 op: FilterOp::Gte,
405 values: smallvec![value.into()],
406 });
407 self
408 }
409
410 pub fn filter_like(
412 mut self,
413 column: impl Into<ColumnRef>,
414 value: impl Into<BindValue>,
415 ) -> Self {
416 self.filters.push(FilterExpr::Compare {
417 column: column.into(),
418 op: FilterOp::Like,
419 values: smallvec![value.into()],
420 });
421 self
422 }
423
424 pub fn filter_is_null(mut self, column: impl Into<ColumnRef>) -> Self {
426 self.filters.push(FilterExpr::NullCheck {
427 column: column.into(),
428 is_null: true,
429 });
430 self
431 }
432
433 pub fn filter_is_not_null(mut self, column: impl Into<ColumnRef>) -> Self {
435 self.filters.push(FilterExpr::NullCheck {
436 column: column.into(),
437 is_null: false,
438 });
439 self
440 }
441
442 pub fn filter_in<I, V>(mut self, column: impl Into<ColumnRef>, values: I) -> Self
444 where
445 I: IntoIterator<Item = V>,
446 V: Into<BindValue>,
447 {
448 let values: SmallVec<[BindValue; 2]> = values.into_iter().map(Into::into).collect();
449 self.filters.push(FilterExpr::Compare {
450 column: column.into(),
451 op: FilterOp::In,
452 values,
453 });
454 self
455 }
456
457 fn format_filters_for_log(&self) -> String {
458 let sensitive_fields = T::sensitive_fields();
459 let mut rendered = String::with_capacity(128);
460 let mut first_clause = true;
461 let mut append_and = |buf: &mut String| {
462 if first_clause {
463 first_clause = false;
464 } else {
465 buf.push_str(" AND ");
466 }
467 };
468 use std::fmt::Write;
469
470 for filter in &self.filters {
471 match filter {
472 FilterExpr::Raw(_) => {
473 append_and(&mut rendered);
474 rendered.push_str("RAW(<redacted>)");
475 }
476 FilterExpr::Compare { column, op, values } => {
477 let column_name = column.as_str();
478 let is_sensitive = sensitive_fields.contains(&column_name);
479 if op.is_in() {
480 if values.is_empty() {
481 append_and(&mut rendered);
482 rendered.push_str("1=0");
483 continue;
484 }
485 append_and(&mut rendered);
486 let _ = write!(rendered, "{} IN (", column_name);
487 for (idx, value) in values.iter().enumerate() {
488 if idx > 0 {
489 rendered.push_str(", ");
490 }
491 if is_sensitive {
492 rendered.push_str("***");
493 } else {
494 rendered.push_str(&value.to_log_string());
495 }
496 }
497 rendered.push(')');
498 } else {
499 append_and(&mut rendered);
500 let _ = write!(rendered, "{} {} ", column_name, op.as_str());
501 if is_sensitive {
502 rendered.push_str("***");
503 } else if let Some(value) = values.first() {
504 rendered.push_str(&value.to_log_string());
505 } else {
506 rendered.push_str("NULL");
507 }
508 }
509 }
510 FilterExpr::NullCheck { column, is_null } => {
511 append_and(&mut rendered);
512 if *is_null {
513 let _ = write!(rendered, "{} IS NULL", column.as_str());
514 } else {
515 let _ = write!(rendered, "{} IS NOT NULL", column.as_str());
516 }
517 }
518 }
519 }
520
521 if T::has_soft_delete() && !self.include_deleted {
522 append_and(&mut rendered);
523 rendered.push_str("deleted_at IS NULL");
524 }
525
526 rendered
527 }
528
529 fn estimate_bind_count(&self) -> usize {
530 let mut count = 0usize;
531 for filter in &self.filters {
532 if let FilterExpr::Compare { op, values, .. } = filter {
533 if op.is_in() {
534 count = count.saturating_add(values.len());
535 } else {
536 count = count.saturating_add(1);
537 }
538 }
539 }
540 count
541 }
542
543 pub fn limit(mut self, limit: i32) -> Self {
546 self.limit = Some(limit);
547 self
548 }
549
550 pub fn offset(mut self, offset: i32) -> Self {
553 self.offset = Some(offset);
554 self
555 }
556
557 pub fn include(mut self, relation: impl Into<String>) -> Self {
559 self.includes.push(relation.into());
560 self
561 }
562
563 pub fn with_deleted(mut self) -> Self {
565 self.include_deleted = true;
566 self
567 }
568
569 pub fn allow_unsafe(mut self) -> Self {
572 self.allow_unsafe = true;
573 self
574 }
575
576 pub fn fast(mut self) -> Self {
578 self.fast_path = true;
579 self
580 }
581
582 pub fn unsafe_fast(mut self) -> Self {
584 self.fast_path = true;
585 self.unsafe_fast = true;
586 self.allow_unsafe = true;
587 self
588 }
589
590 pub fn ultra_fast(mut self) -> Self {
593 self.fast_path = true;
594 self.unsafe_fast = true;
595 self.allow_unsafe = true;
596 self.ultra_fast = true;
597 self
598 }
599
600 pub fn prepared(mut self) -> Self {
602 self.prepared = true;
603 self
604 }
605
606 pub fn unprepared(mut self) -> Self {
608 self.prepared = false;
609 self
610 }
611
612 pub fn to_sql(&self) -> String {
614 let mut sql = String::with_capacity(128); use std::fmt::Write;
616
617 sql.push_str("SELECT * FROM ");
618 sql.push_str(T::table_name());
619
620 let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
621 self.render_where_clause_into(&mut sql, &mut dummy_binds, 1);
622
623 if let Some(limit) = self.limit {
624 let _ = write!(sql, " LIMIT {}", limit);
625 }
626
627 if let Some(offset) = self.offset {
628 let _ = write!(sql, " OFFSET {}", offset);
629 }
630
631 sql
632 }
633
634 pub fn to_update_sql(&self, values: &serde_json::Value) -> Result<String, sqlx::Error> {
636 let obj = values.as_object().ok_or_else(|| {
637 sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
638 })?;
639
640 let mut sql = String::with_capacity(256);
641 use std::fmt::Write;
642
643 let _ = write!(sql, "UPDATE {} SET ", T::table_name());
644
645 let mut i = 1;
646 let mut first = true;
647
648 for k in obj.keys() {
649 if !first {
650 sql.push_str(", ");
651 }
652 let p = DB::placeholder(i);
653 let _ = write!(sql, "{} = {}", DB::quote_identifier(k), p);
654 i += 1;
655 first = false;
656 }
657
658 let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
659 self.render_where_clause_into(&mut sql, &mut dummy_binds, obj.len() + 1);
660 Ok(sql)
661 }
662
663 pub fn to_delete_sql(&self) -> String {
665 let mut sql = String::with_capacity(128);
666 use std::fmt::Write;
667
668 if T::has_soft_delete() {
669 let _ = write!(
670 sql,
671 "UPDATE {} SET {} = {}",
672 T::table_name(),
673 DB::quote_identifier("deleted_at"),
674 DB::current_timestamp_fn()
675 );
676 } else {
677 let _ = write!(sql, "DELETE FROM {}", T::table_name());
678 }
679
680 let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
681 self.render_where_clause_into(&mut sql, &mut dummy_binds, 1);
682 sql
683 }
684
685 #[inline(always)]
687 fn render_where_clause_into(
688 &self,
689 sql: &mut String,
690 binds: &mut SmallVec<[BindValue; 8]>,
691 start_index: usize,
692 ) {
693 let mut idx = start_index;
694 let mut first_clause = true;
695 use std::fmt::Write;
696
697 let mut append_and = |sql: &mut String| {
699 if first_clause {
700 sql.push_str(" WHERE ");
701 first_clause = false;
702 } else {
703 sql.push_str(" AND ");
704 }
705 };
706
707 for filter in &self.filters {
708 match filter {
709 FilterExpr::Raw(condition) => {
710 append_and(sql);
711 sql.push_str(condition);
712 }
713 FilterExpr::Compare { column, op, values } => {
714 if op.is_in() {
715 if values.is_empty() {
716 append_and(sql);
717 sql.push_str("1=0");
718 continue;
719 }
720 append_and(sql);
721 let _ = write!(sql, "{} IN (", DB::quote_identifier(column.as_str()));
722 let placeholders = crate::cached_placeholders_from::<DB>(idx, values.len());
723 sql.push_str(placeholders);
724 sql.push(')');
725 idx = idx.saturating_add(values.len());
726 for v in values {
727 binds.push(v.clone());
728 }
729 } else {
730 append_and(sql);
731 let _ = write!(
732 sql,
733 "{} {} {}",
734 DB::quote_identifier(column.as_str()),
735 op.as_str(),
736 DB::placeholder(idx)
737 );
738 idx += 1;
739 if let Some(v) = values.first() {
740 binds.push(v.clone());
741 }
742 }
743 }
744 FilterExpr::NullCheck { column, is_null } => {
745 append_and(sql);
746 if *is_null {
747 let _ = write!(sql, "{} IS NULL", DB::quote_identifier(column.as_str()));
748 } else {
749 let _ =
750 write!(sql, "{} IS NOT NULL", DB::quote_identifier(column.as_str()));
751 }
752 }
753 }
754 }
755
756 if T::has_soft_delete() && !self.include_deleted {
757 append_and(sql);
758 sql.push_str("deleted_at IS NULL");
759 }
760 }
761}
762
763impl<'a, T, DB> QueryBuilder<'a, T, DB>
764where
765 DB: SqlDialect,
766 T: Model<DB>,
767 for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
768 for<'c> &'c mut <DB as Database>::Connection: sqlx::Executor<'c, Database = DB>,
769 for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
770 DB::Connection: Send,
771 T: Send,
772 String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
773 i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
774 f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
775 bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
776 Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
777 uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
778 chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
779{
780 fn ensure_safe_filters(&self) -> Result<(), sqlx::Error> {
781 if self.unsafe_fast {
782 return Ok(());
783 }
784 if self.has_raw_filter && !self.allow_unsafe {
785 return Err(sqlx::Error::Protocol(
786 "Refusing raw filter without allow_unsafe".to_string(),
787 ));
788 }
789 Ok(())
790 }
791
792 #[tracing::instrument(skip(self), fields(table = T::table_name()))]
797 pub async fn all(mut self) -> Result<Vec<T>, sqlx::Error> {
798 self.ensure_safe_filters()?;
799
800 let mut sql = String::with_capacity(128);
801 sql.push_str("SELECT * FROM ");
802 sql.push_str(T::table_name());
803
804 let mut where_binds: SmallVec<[BindValue; 8]> =
805 SmallVec::with_capacity(self.estimate_bind_count());
806 self.render_where_clause_into(&mut sql, &mut where_binds, 1);
807
808 if let Some(limit) = self.limit {
809 use std::fmt::Write;
810 let _ = write!(sql, " LIMIT {}", limit);
811 }
812
813 if let Some(offset) = self.offset {
814 use std::fmt::Write;
815 let _ = write!(sql, " OFFSET {}", offset);
816 }
817
818 #[cfg(debug_assertions)]
819 if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
820 let filters = self.format_filters_for_log();
821 tracing::debug!(
822 operation = "select",
823 sql = %sql,
824 filters = %filters,
825 "premix query"
826 );
827 }
828
829 let start = Instant::now();
830 let mut results: Vec<T> = match &mut self.executor {
831 Executor::Pool(pool) => {
832 let base = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
833 let query = where_binds.into_iter().fold(base, bind_value_query_as);
834 query.fetch_all(*pool).await?
835 }
836 Executor::Conn(conn) => {
837 let base = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
838 let query = where_binds.into_iter().fold(base, bind_value_query_as);
839 query.fetch_all(&mut **conn).await?
840 }
841 };
842 if !self.fast_path {
843 record_query_metrics("select", T::table_name(), start.elapsed());
844 }
845
846 if self.ultra_fast {
847 return Ok(results);
848 }
849
850 for relation in self.includes {
851 match &mut self.executor {
852 Executor::Pool(pool) => {
853 T::eager_load(&mut results, &relation, Executor::Pool(*pool)).await?;
854 }
855 Executor::Conn(conn) => {
856 T::eager_load(&mut results, &relation, Executor::Conn(&mut **conn)).await?;
857 }
858 }
859 }
860
861 Ok(results)
862 }
863
864 #[tracing::instrument(skip(self), fields(table = T::table_name()))]
868 pub fn stream(
869 self,
870 ) -> Result<futures_util::stream::BoxStream<'a, Result<T, sqlx::Error>>, sqlx::Error>
871 where
872 T: 'a,
873 {
874 self.ensure_safe_filters()?;
875
876 let mut sql = String::with_capacity(128);
877 sql.push_str("SELECT * FROM ");
878 sql.push_str(T::table_name());
879
880 let mut where_binds: SmallVec<[BindValue; 8]> =
881 SmallVec::with_capacity(self.estimate_bind_count());
882 self.render_where_clause_into(&mut sql, &mut where_binds, 1);
883
884 if let Some(limit) = self.limit {
885 use std::fmt::Write;
886 let _ = write!(sql, " LIMIT {}", limit);
887 }
888
889 if let Some(offset) = self.offset {
890 use std::fmt::Write;
891 let _ = write!(sql, " OFFSET {}", offset);
892 }
893
894 #[cfg(debug_assertions)]
895 if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
896 let filters = self.format_filters_for_log();
897 tracing::debug!(
898 operation = "stream",
899 sql = %sql,
900 filters = %filters,
901 "premix query"
902 );
903 }
904
905 let executor = self.executor;
906 Ok(Box::pin(async_stream::try_stream! {
907 let mut query = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
908 for bind in where_binds {
909 query = bind_value_query_as(query, bind);
910 }
911 let mut s = executor.fetch_stream(query);
912 while let Some(row) = s.next().await {
913 yield row?;
914 }
915 }))
916 }
917
918 #[inline(never)]
924 #[tracing::instrument(skip(self, values), fields(table = T::table_name()))]
925 pub async fn update(mut self, values: serde_json::Value) -> Result<u64, sqlx::Error>
926 where
927 String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
928 i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
929 f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
930 bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
931 Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
932 uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
933 chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
934 {
935 self.ensure_safe_filters()?;
936 if self.filters.is_empty() && !self.allow_unsafe && !self.unsafe_fast {
937 return Err(sqlx::Error::Protocol(
938 "Refusing bulk update without filters".to_string(),
939 ));
940 }
941 let obj = values.as_object().ok_or_else(|| {
942 sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
943 })?;
944
945 let mut sql = String::with_capacity(256);
946 use std::fmt::Write;
947 let _ = write!(sql, "UPDATE {} SET ", T::table_name());
948
949 let mut i = 1;
950 let mut first = true;
951 for k in obj.keys() {
952 if !first {
953 sql.push_str(", ");
954 }
955 let p = DB::placeholder(i);
956 let _ = write!(sql, "{} = {}", DB::quote_identifier(k), p);
957 i += 1;
958 first = false;
959 }
960
961 let mut where_binds: SmallVec<[BindValue; 8]> =
962 SmallVec::with_capacity(self.estimate_bind_count());
963 self.render_where_clause_into(&mut sql, &mut where_binds, obj.len() + 1);
964
965 if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
966 let filters = self.format_filters_for_log();
967 tracing::debug!(
968 operation = "bulk_update",
969 sql = %sql,
970 filters = %filters,
971 "premix query"
972 );
973 }
974 let mut query = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
975 for val in obj.values() {
976 match val {
977 serde_json::Value::String(s) => query = query.bind(s.clone()),
978 serde_json::Value::Number(n) => {
979 if let Some(v) = n.as_i64() {
980 query = query.bind(v);
981 } else if let Some(v) = n.as_f64() {
982 query = query.bind(v);
983 }
984 }
985 serde_json::Value::Bool(b) => query = query.bind(*b),
986 serde_json::Value::Null => query = query.bind(Option::<String>::None),
987 _ => {
988 return Err(sqlx::Error::Protocol(
989 "Unsupported type in bulk update".to_string(),
990 ));
991 }
992 }
993 }
994 for bind in where_binds {
995 query = bind_value_query(query, bind);
996 }
997
998 let start = Instant::now();
999 let result = match &mut self.executor {
1000 Executor::Pool(pool) => {
1001 let res = query.execute(*pool).await?;
1002 Ok(DB::rows_affected(&res))
1003 }
1004 Executor::Conn(conn) => {
1005 let res = query.execute(&mut **conn).await?;
1006 Ok(DB::rows_affected(&res))
1007 }
1008 };
1009 if !self.fast_path {
1010 record_query_metrics("bulk_update", T::table_name(), start.elapsed());
1011 }
1012 result
1013 }
1014
1015 pub async fn update_all(self, values: serde_json::Value) -> Result<u64, sqlx::Error>
1017 where
1018 String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1019 i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1020 f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1021 bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1022 Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1023 {
1024 self.update(values).await
1025 }
1026
1027 #[tracing::instrument(skip(self), fields(table = T::table_name()))]
1029 pub async fn delete(mut self) -> Result<u64, sqlx::Error> {
1030 self.ensure_safe_filters()?;
1031 if self.filters.is_empty() && !self.allow_unsafe && !self.unsafe_fast {
1032 return Err(sqlx::Error::Protocol(
1033 "Refusing bulk delete without filters".to_string(),
1034 ));
1035 }
1036
1037 let mut sql = String::with_capacity(128);
1038 use std::fmt::Write;
1039
1040 if T::has_soft_delete() {
1041 let _ = write!(
1042 sql,
1043 "UPDATE {} SET {} = {}",
1044 T::table_name(),
1045 DB::quote_identifier("deleted_at"),
1046 DB::current_timestamp_fn()
1047 );
1048 } else {
1049 let _ = write!(sql, "DELETE FROM {}", T::table_name());
1050 }
1051
1052 let mut where_binds: SmallVec<[BindValue; 8]> =
1053 SmallVec::with_capacity(self.estimate_bind_count());
1054 self.render_where_clause_into(&mut sql, &mut where_binds, 1);
1055
1056 if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
1057 let filters = self.format_filters_for_log();
1058 tracing::debug!(
1059 operation = "bulk_delete",
1060 sql = %sql,
1061 filters = %filters,
1062 "premix query"
1063 );
1064 }
1065 let start = Instant::now();
1066 let result = match &mut self.executor {
1067 Executor::Pool(pool) => {
1068 let base = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1069 let query = where_binds.into_iter().fold(base, bind_value_query);
1070 let res = query.execute(*pool).await?;
1071 Ok(DB::rows_affected(&res))
1072 }
1073 Executor::Conn(conn) => {
1074 let base = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1075 let query = where_binds.into_iter().fold(base, bind_value_query);
1076 let res = query.execute(&mut **conn).await?;
1077 Ok(DB::rows_affected(&res))
1078 }
1079 };
1080 if !self.fast_path {
1081 record_query_metrics("bulk_delete", T::table_name(), start.elapsed());
1082 }
1083 result
1084 }
1085
1086 pub async fn delete_all(self) -> Result<u64, sqlx::Error> {
1088 self.delete().await
1089 }
1090}
1091
1092#[cfg(test)]
1093mod tests {
1094 use super::*;
1095 use sqlx::Sqlite;
1096
1097 struct DummyModel;
1098
1099 impl Model<Sqlite> for DummyModel {
1100 fn table_name() -> &'static str {
1101 "users"
1102 }
1103 fn create_table_sql() -> String {
1104 String::new()
1105 }
1106 fn list_columns() -> Vec<String> {
1107 vec!["id".to_string()]
1108 }
1109 async fn save<'a, E>(&'a mut self, _e: E) -> Result<(), sqlx::Error>
1110 where
1111 E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1112 {
1113 Ok(())
1114 }
1115 async fn update<'a, E>(
1116 &'a mut self,
1117 _e: E,
1118 ) -> Result<crate::model::UpdateResult, sqlx::Error>
1119 where
1120 E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1121 {
1122 Ok(crate::model::UpdateResult::Success)
1123 }
1124 async fn delete<'a, E>(&'a mut self, _e: E) -> Result<(), sqlx::Error>
1125 where
1126 E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1127 {
1128 Ok(())
1129 }
1130 fn has_soft_delete() -> bool {
1131 false
1132 }
1133 async fn find_by_id<'a, E>(_e: E, _id: i32) -> Result<Option<Self>, sqlx::Error>
1134 where
1135 E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1136 {
1137 Ok(None)
1138 }
1139 }
1140
1141 impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for DummyModel {
1143 fn from_row(_row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
1144 Ok(DummyModel)
1145 }
1146 }
1147
1148 #[tokio::test]
1149 async fn test_sql_injection_mitigation() {
1150 let pool = sqlx::Pool::<Sqlite>::connect_lazy("sqlite::memory:").unwrap();
1151 let qb = DummyModel::find_in_pool(&pool);
1152
1153 let qb = qb.filter_eq("id; DROP TABLE users; --", 1);
1155 let sql = qb.to_sql();
1156 println!("SQL select: {}", sql);
1157
1158 assert!(sql.contains("`id; DROP TABLE users; --` = ?"));
1160 assert!(sql.contains("SELECT * FROM users WHERE"));
1161 }
1162
1163 #[tokio::test]
1164 async fn test_to_update_sql_quoting() {
1165 let pool = sqlx::Pool::<Sqlite>::connect_lazy("sqlite::memory:").unwrap();
1166 let qb = DummyModel::find_in_pool(&pool).filter_eq("id", 1);
1167
1168 let values = serde_json::json!({
1169 "name; DROP TABLE users; --": "admin"
1170 });
1171
1172 let sql = qb.to_update_sql(&values).unwrap();
1173 println!("SQL update: {}", sql);
1174 assert!(sql.contains("`name; DROP TABLE users; --` = ?"));
1175 }
1176
1177 #[tokio::test]
1178 async fn test_stream_api() {
1179 use sqlx::Connection;
1180 let mut conn = sqlx::SqliteConnection::connect("sqlite::memory:")
1181 .await
1182 .unwrap();
1183 sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
1184 .execute(&mut conn)
1185 .await
1186 .unwrap();
1187 sqlx::query("INSERT INTO users (id, name) VALUES (1, 'alice'), (2, 'bob')")
1188 .execute(&mut conn)
1189 .await
1190 .unwrap();
1191
1192 let qb = DummyModel::find_in_tx(&mut conn);
1194
1195 let mut stream = qb.stream().unwrap();
1196 let mut count = 0;
1197 while let Some(row) = stream.next().await {
1198 let _: DummyModel = row.unwrap();
1199 count += 1;
1200 }
1201 assert_eq!(count, 2);
1202 }
1203}