1use crate::dialect::SqlDialect;
2use crate::executor::Executor;
3use crate::model::Model;
4use futures_util::StreamExt;
5use smallvec::{SmallVec, smallvec};
6use sqlx::{Database, IntoArguments};
7use std::time::{Duration, Instant};
8
9#[doc(hidden)]
10#[derive(Debug, Clone)]
11pub enum BindValue {
12 String(String),
13 I64(i64),
14 F64(f64),
15 Bool(bool),
16 Uuid(uuid::Uuid),
17 DateTime(chrono::DateTime<chrono::Utc>),
18 NaiveDateTime(chrono::NaiveDateTime),
19 NaiveDate(chrono::NaiveDate),
20 Json(serde_json::Value),
21 Null,
22}
23
24impl BindValue {
25 fn to_log_string(&self) -> String {
26 match self {
27 BindValue::String(v) => v.clone(),
28 BindValue::I64(v) => v.to_string(),
29 BindValue::F64(v) => v.to_string(),
30 BindValue::Bool(v) => v.to_string(),
31 BindValue::Uuid(v) => v.to_string(),
32 BindValue::DateTime(v) => v.to_rfc3339(),
33 BindValue::NaiveDateTime(v) => v.to_string(),
34 BindValue::NaiveDate(v) => v.to_string(),
35 BindValue::Json(v) => v.to_string(),
36 BindValue::Null => "NULL".to_string(),
37 }
38 }
39}
40
41#[cfg(feature = "metrics")]
42fn record_query_metrics(operation: &str, table: &str, elapsed: Duration) {
43 let elapsed_ms = elapsed.as_secs_f64() * 1000.0;
44 let labels = [
45 ("operation", operation.to_string()),
46 ("table", table.to_string()),
47 ];
48 metrics::histogram!("premix.query.duration_ms", &labels).record(elapsed_ms);
49 metrics::counter!("premix.query.count", &labels).increment(1);
50}
51
52#[cfg(not(feature = "metrics"))]
53fn record_query_metrics(_operation: &str, _table: &str, _elapsed: Duration) {}
54
55#[inline(always)]
56fn bind_value_query<'q, DB>(
57 query: sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>,
58 value: BindValue,
59) -> sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>
60where
61 DB: Database,
62 String: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
63 i64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
64 f64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
65 bool: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
66 uuid::Uuid: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
67 chrono::DateTime<chrono::Utc>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
68 chrono::NaiveDateTime: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
69 chrono::NaiveDate: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
70 sqlx::types::Json<serde_json::Value>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
71 Option<String>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
72{
73 match value {
74 BindValue::String(v) => query.bind(v),
75 BindValue::I64(v) => query.bind(v),
76 BindValue::F64(v) => query.bind(v),
77 BindValue::Bool(v) => query.bind(v),
78 BindValue::Uuid(v) => query.bind(v),
79 BindValue::DateTime(v) => query.bind(v),
80 BindValue::NaiveDateTime(v) => query.bind(v),
81 BindValue::NaiveDate(v) => query.bind(v),
82 BindValue::Json(v) => query.bind(sqlx::types::Json(v)),
83 BindValue::Null => query.bind(Option::<String>::None),
84 }
85}
86
87#[inline(always)]
88fn bind_value_query_as<'q, DB, T>(
89 query: sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>,
90 value: BindValue,
91) -> sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>
92where
93 DB: Database,
94 String: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
95 i64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
96 f64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
97 bool: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
98 uuid::Uuid: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
99 chrono::DateTime<chrono::Utc>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
100 chrono::NaiveDateTime: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
101 chrono::NaiveDate: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
102 sqlx::types::Json<serde_json::Value>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
103 Option<String>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
104{
105 match value {
106 BindValue::String(v) => query.bind(v),
107 BindValue::I64(v) => query.bind(v),
108 BindValue::F64(v) => query.bind(v),
109 BindValue::Bool(v) => query.bind(v),
110 BindValue::Uuid(v) => query.bind(v),
111 BindValue::DateTime(v) => query.bind(v),
112 BindValue::NaiveDateTime(v) => query.bind(v),
113 BindValue::NaiveDate(v) => query.bind(v),
114 BindValue::Json(v) => query.bind(sqlx::types::Json(v)),
115 BindValue::Null => query.bind(Option::<String>::None),
116 }
117}
118
119#[inline]
120fn apply_persistent_query_as<'q, DB, T>(
121 query: sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>,
122 prepared: bool,
123) -> sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>
124where
125 DB: Database + sqlx::database::HasStatementCache,
126{
127 if prepared {
128 query.persistent(true)
129 } else {
130 query
131 }
132}
133
134#[inline]
135fn apply_persistent_query<'q, DB>(
136 query: sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>,
137 prepared: bool,
138) -> sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>
139where
140 DB: Database + sqlx::database::HasStatementCache,
141{
142 if prepared {
143 query.persistent(true)
144 } else {
145 query
146 }
147}
148
149impl From<String> for BindValue {
150 fn from(value: String) -> Self {
151 Self::String(value)
152 }
153}
154
155impl From<&str> for BindValue {
156 fn from(value: &str) -> Self {
157 Self::String(value.to_string())
158 }
159}
160
161impl From<i32> for BindValue {
162 fn from(value: i32) -> Self {
163 Self::I64(value as i64)
164 }
165}
166
167impl From<i64> for BindValue {
168 fn from(value: i64) -> Self {
169 Self::I64(value)
170 }
171}
172
173impl From<f64> for BindValue {
174 fn from(value: f64) -> Self {
175 Self::F64(value)
176 }
177}
178
179impl From<bool> for BindValue {
180 fn from(value: bool) -> Self {
181 Self::Bool(value)
182 }
183}
184
185impl From<Option<String>> for BindValue {
186 fn from(value: Option<String>) -> Self {
187 match value {
188 Some(v) => Self::String(v),
189 None => Self::Null,
190 }
191 }
192}
193
194impl From<uuid::Uuid> for BindValue {
195 fn from(value: uuid::Uuid) -> Self {
196 Self::Uuid(value)
197 }
198}
199
200impl From<chrono::DateTime<chrono::Utc>> for BindValue {
201 fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
202 Self::DateTime(value)
203 }
204}
205
206impl From<chrono::NaiveDateTime> for BindValue {
207 fn from(value: chrono::NaiveDateTime) -> Self {
208 Self::NaiveDateTime(value)
209 }
210}
211
212impl From<chrono::NaiveDate> for BindValue {
213 fn from(value: chrono::NaiveDate) -> Self {
214 Self::NaiveDate(value)
215 }
216}
217
218impl From<serde_json::Value> for BindValue {
219 fn from(value: serde_json::Value) -> Self {
220 Self::Json(value)
221 }
222}
223
224#[derive(Debug, Clone)]
225pub(crate) enum FilterExpr {
226 Raw(String),
227 Compare {
228 column: ColumnRef,
229 op: FilterOp,
230 values: SmallVec<[BindValue; 2]>,
231 },
232 NullCheck {
233 column: ColumnRef,
234 is_null: bool,
235 },
236}
237
238#[derive(Debug, Clone, Copy, PartialEq, Eq)]
239pub(crate) enum FilterOp {
240 Eq,
241 Ne,
242 Lt,
243 Lte,
244 Gt,
245 Gte,
246 Like,
247 In,
248}
249
250impl FilterOp {
251 fn as_str(self) -> &'static str {
252 match self {
253 FilterOp::Eq => "=",
254 FilterOp::Ne => "!=",
255 FilterOp::Lt => "<",
256 FilterOp::Lte => "<=",
257 FilterOp::Gt => ">",
258 FilterOp::Gte => ">=",
259 FilterOp::Like => "LIKE",
260 FilterOp::In => "IN",
261 }
262 }
263
264 fn is_in(self) -> bool {
265 matches!(self, FilterOp::In)
266 }
267}
268
269#[derive(Debug, Clone, PartialEq, Eq)]
271pub enum ColumnRef {
272 Static(&'static str),
274 Owned(String),
276}
277
278impl ColumnRef {
279 pub const fn static_str(value: &'static str) -> Self {
281 ColumnRef::Static(value)
282 }
283
284 fn as_str(&self) -> &str {
285 match self {
286 ColumnRef::Static(value) => value,
287 ColumnRef::Owned(value) => value,
288 }
289 }
290}
291
292impl From<&str> for ColumnRef {
293 fn from(value: &str) -> Self {
294 ColumnRef::Owned(value.to_string())
295 }
296}
297
298impl From<String> for ColumnRef {
299 fn from(value: String) -> Self {
300 ColumnRef::Owned(value)
301 }
302}
303
304impl From<&String> for ColumnRef {
305 fn from(value: &String) -> Self {
306 ColumnRef::Owned(value.clone())
307 }
308}
309
310pub struct QueryBuilder<'a, T, DB: Database> {
315 executor: Executor<'a, DB>,
316 filters: Vec<FilterExpr>,
317 limit: Option<i32>,
318 offset: Option<i32>,
319 includes: SmallVec<[String; 2]>,
320 include_deleted: bool,
321 allow_unsafe: bool,
322 has_raw_filter: bool,
323 fast_path: bool,
324 unsafe_fast: bool,
325 ultra_fast: bool,
326 prepared: bool,
327 _marker: std::marker::PhantomData<T>,
328}
329
330impl<'a, T, DB: Database> std::fmt::Debug for QueryBuilder<'a, T, DB> {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 f.debug_struct("QueryBuilder")
333 .field("filters", &self.filters)
334 .field("limit", &self.limit)
335 .field("offset", &self.offset)
336 .field("includes", &self.includes)
337 .field("include_deleted", &self.include_deleted)
338 .field("allow_unsafe", &self.allow_unsafe)
339 .field("fast_path", &self.fast_path)
340 .field("unsafe_fast", &self.unsafe_fast)
341 .field("ultra_fast", &self.ultra_fast)
342 .field("prepared", &self.prepared)
343 .finish()
344 }
345}
346
347impl<'a, T, DB> QueryBuilder<'a, T, DB>
348where
349 DB: SqlDialect + sqlx::database::HasStatementCache,
350 T: Model<DB>,
351{
352 pub fn new(executor: Executor<'a, DB>) -> Self {
354 Self {
355 executor,
356 filters: Vec::with_capacity(4), limit: None,
358 offset: None,
359 includes: SmallVec::with_capacity(2), include_deleted: false,
361 allow_unsafe: false,
362 has_raw_filter: false,
363 fast_path: false,
364 unsafe_fast: false,
365 ultra_fast: false,
366 prepared: true,
367 _marker: std::marker::PhantomData,
368 }
369 }
370
371 pub fn filter(mut self, condition: impl Into<String>) -> Self {
376 self.filters.push(FilterExpr::Raw(condition.into()));
377 self.has_raw_filter = true;
378 self
379 }
380
381 pub fn filter_raw(self, condition: impl Into<String>) -> Self {
383 self.filter(condition)
384 }
385
386 pub fn filter_eq(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
388 self.filters.push(FilterExpr::Compare {
389 column: column.into(),
390 op: FilterOp::Eq,
391 values: smallvec![value.into()],
392 });
393 self
394 }
395
396 pub fn filter_ne(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
398 self.filters.push(FilterExpr::Compare {
399 column: column.into(),
400 op: FilterOp::Ne,
401 values: smallvec![value.into()],
402 });
403 self
404 }
405
406 pub fn filter_lt(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
408 self.filters.push(FilterExpr::Compare {
409 column: column.into(),
410 op: FilterOp::Lt,
411 values: smallvec![value.into()],
412 });
413 self
414 }
415
416 pub fn filter_lte(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
418 self.filters.push(FilterExpr::Compare {
419 column: column.into(),
420 op: FilterOp::Lte,
421 values: smallvec![value.into()],
422 });
423 self
424 }
425
426 pub fn filter_gt(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
428 self.filters.push(FilterExpr::Compare {
429 column: column.into(),
430 op: FilterOp::Gt,
431 values: smallvec![value.into()],
432 });
433 self
434 }
435
436 pub fn filter_gte(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
438 self.filters.push(FilterExpr::Compare {
439 column: column.into(),
440 op: FilterOp::Gte,
441 values: smallvec![value.into()],
442 });
443 self
444 }
445
446 pub fn filter_like(
448 mut self,
449 column: impl Into<ColumnRef>,
450 value: impl Into<BindValue>,
451 ) -> Self {
452 self.filters.push(FilterExpr::Compare {
453 column: column.into(),
454 op: FilterOp::Like,
455 values: smallvec![value.into()],
456 });
457 self
458 }
459
460 pub fn filter_is_null(mut self, column: impl Into<ColumnRef>) -> Self {
462 self.filters.push(FilterExpr::NullCheck {
463 column: column.into(),
464 is_null: true,
465 });
466 self
467 }
468
469 pub fn filter_is_not_null(mut self, column: impl Into<ColumnRef>) -> Self {
471 self.filters.push(FilterExpr::NullCheck {
472 column: column.into(),
473 is_null: false,
474 });
475 self
476 }
477
478 pub fn filter_in<I, V>(mut self, column: impl Into<ColumnRef>, values: I) -> Self
480 where
481 I: IntoIterator<Item = V>,
482 V: Into<BindValue>,
483 {
484 let values: SmallVec<[BindValue; 2]> = values.into_iter().map(Into::into).collect();
485 self.filters.push(FilterExpr::Compare {
486 column: column.into(),
487 op: FilterOp::In,
488 values,
489 });
490 self
491 }
492
493 fn format_filters_for_log(&self) -> String {
494 let sensitive_fields = T::sensitive_fields();
495 let mut rendered = String::with_capacity(128);
496 let mut first_clause = true;
497 let mut append_and = |buf: &mut String| {
498 if first_clause {
499 first_clause = false;
500 } else {
501 buf.push_str(" AND ");
502 }
503 };
504 use std::fmt::Write;
505
506 for filter in &self.filters {
507 match filter {
508 FilterExpr::Raw(_) => {
509 append_and(&mut rendered);
510 rendered.push_str("RAW(<redacted>)");
511 }
512 FilterExpr::Compare { column, op, values } => {
513 let column_name = column.as_str();
514 let is_sensitive = sensitive_fields.contains(&column_name);
515 if op.is_in() {
516 if values.is_empty() {
517 append_and(&mut rendered);
518 rendered.push_str("1=0");
519 continue;
520 }
521 append_and(&mut rendered);
522 let _ = write!(rendered, "{} IN (", column_name);
523 for (idx, value) in values.iter().enumerate() {
524 if idx > 0 {
525 rendered.push_str(", ");
526 }
527 if is_sensitive {
528 rendered.push_str("***");
529 } else {
530 rendered.push_str(&value.to_log_string());
531 }
532 }
533 rendered.push(')');
534 } else {
535 append_and(&mut rendered);
536 let _ = write!(rendered, "{} {} ", column_name, op.as_str());
537 if is_sensitive {
538 rendered.push_str("***");
539 } else if let Some(value) = values.first() {
540 rendered.push_str(&value.to_log_string());
541 } else {
542 rendered.push_str("NULL");
543 }
544 }
545 }
546 FilterExpr::NullCheck { column, is_null } => {
547 append_and(&mut rendered);
548 if *is_null {
549 let _ = write!(rendered, "{} IS NULL", column.as_str());
550 } else {
551 let _ = write!(rendered, "{} IS NOT NULL", column.as_str());
552 }
553 }
554 }
555 }
556
557 if T::has_soft_delete() && !self.include_deleted {
558 append_and(&mut rendered);
559 rendered.push_str("deleted_at IS NULL");
560 }
561
562 rendered
563 }
564
565 fn estimate_bind_count(&self) -> usize {
566 let mut count = 0usize;
567 for filter in &self.filters {
568 if let FilterExpr::Compare { op, values, .. } = filter {
569 if op.is_in() {
570 count = count.saturating_add(values.len());
571 } else {
572 count = count.saturating_add(1);
573 }
574 }
575 }
576 count
577 }
578
579 pub fn limit(mut self, limit: i32) -> Self {
582 self.limit = Some(limit);
583 self
584 }
585
586 pub fn offset(mut self, offset: i32) -> Self {
589 self.offset = Some(offset);
590 self
591 }
592
593 pub fn include(mut self, relation: impl Into<String>) -> Self {
595 self.includes.push(relation.into());
596 self
597 }
598
599 pub fn with_deleted(mut self) -> Self {
601 self.include_deleted = true;
602 self
603 }
604
605 pub fn allow_unsafe(mut self) -> Self {
608 self.allow_unsafe = true;
609 self
610 }
611
612 pub fn fast(mut self) -> Self {
614 self.fast_path = true;
615 self
616 }
617
618 pub fn unsafe_fast(mut self) -> Self {
620 self.fast_path = true;
621 self.unsafe_fast = true;
622 self
623 }
624
625 pub fn ultra_fast(mut self) -> Self {
628 self.fast_path = true;
629 self.unsafe_fast = true;
630 self.ultra_fast = true;
631 self
632 }
633
634 pub fn prepared(mut self) -> Self {
636 self.prepared = true;
637 self
638 }
639
640 pub fn unprepared(mut self) -> Self {
642 self.prepared = false;
643 self
644 }
645
646 pub fn to_sql(&self) -> String {
648 let mut sql = String::with_capacity(128); use std::fmt::Write;
650
651 sql.push_str("SELECT * FROM ");
652 sql.push_str(T::table_name());
653
654 let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
655 self.render_where_clause_into(&mut sql, &mut dummy_binds, 1);
656
657 if let Some(limit) = self.limit {
658 let _ = write!(sql, " LIMIT {}", limit);
659 }
660
661 if let Some(offset) = self.offset {
662 let _ = write!(sql, " OFFSET {}", offset);
663 }
664
665 sql
666 }
667
668 pub fn to_update_sql(&self, values: &serde_json::Value) -> Result<String, sqlx::Error> {
670 let obj = values.as_object().ok_or_else(|| {
671 sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
672 })?;
673
674 let mut sql = String::with_capacity(256);
675 use std::fmt::Write;
676
677 let _ = write!(sql, "UPDATE {} SET ", T::table_name());
678
679 let mut i = 1;
680 let mut first = true;
681
682 for k in obj.keys() {
683 if !first {
684 sql.push_str(", ");
685 }
686 let p = DB::placeholder(i);
687 let _ = write!(sql, "{} = {}", DB::quote_identifier(k), p);
688 i += 1;
689 first = false;
690 }
691
692 let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
693 self.render_where_clause_into(&mut sql, &mut dummy_binds, obj.len() + 1);
694 Ok(sql)
695 }
696
697 pub fn to_delete_sql(&self) -> String {
699 let mut sql = String::with_capacity(128);
700 use std::fmt::Write;
701
702 if T::has_soft_delete() {
703 let _ = write!(
704 sql,
705 "UPDATE {} SET {} = {}",
706 T::table_name(),
707 DB::quote_identifier("deleted_at"),
708 DB::current_timestamp_fn()
709 );
710 } else {
711 let _ = write!(sql, "DELETE FROM {}", T::table_name());
712 }
713
714 let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
715 self.render_where_clause_into(&mut sql, &mut dummy_binds, 1);
716 sql
717 }
718
719 #[inline(always)]
721 fn render_where_clause_into(
722 &self,
723 sql: &mut String,
724 binds: &mut SmallVec<[BindValue; 8]>,
725 start_index: usize,
726 ) {
727 let mut idx = start_index;
728 let mut first_clause = true;
729 use std::fmt::Write;
730
731 let mut append_and = |sql: &mut String| {
733 if first_clause {
734 sql.push_str(" WHERE ");
735 first_clause = false;
736 } else {
737 sql.push_str(" AND ");
738 }
739 };
740
741 for filter in &self.filters {
742 match filter {
743 FilterExpr::Raw(condition) => {
744 append_and(sql);
745 sql.push_str(condition);
746 }
747 FilterExpr::Compare { column, op, values } => {
748 if op.is_in() {
749 if values.is_empty() {
750 append_and(sql);
751 sql.push_str("1=0");
752 continue;
753 }
754 append_and(sql);
755 let _ = write!(sql, "{} IN (", DB::quote_identifier(column.as_str()));
756 let placeholders = crate::cached_placeholders_from::<DB>(idx, values.len());
757 sql.push_str(placeholders);
758 sql.push(')');
759 idx = idx.saturating_add(values.len());
760 for v in values {
761 binds.push(v.clone());
762 }
763 } else {
764 append_and(sql);
765 let _ = write!(
766 sql,
767 "{} {} {}",
768 DB::quote_identifier(column.as_str()),
769 op.as_str(),
770 DB::placeholder(idx)
771 );
772 idx += 1;
773 if let Some(v) = values.first() {
774 binds.push(v.clone());
775 }
776 }
777 }
778 FilterExpr::NullCheck { column, is_null } => {
779 append_and(sql);
780 if *is_null {
781 let _ = write!(sql, "{} IS NULL", DB::quote_identifier(column.as_str()));
782 } else {
783 let _ =
784 write!(sql, "{} IS NOT NULL", DB::quote_identifier(column.as_str()));
785 }
786 }
787 }
788 }
789
790 if T::has_soft_delete() && !self.include_deleted {
791 append_and(sql);
792 sql.push_str("deleted_at IS NULL");
793 }
794 }
795}
796
797impl<'a, T, DB> QueryBuilder<'a, T, DB>
798where
799 DB: SqlDialect,
800 T: Model<DB>,
801 for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
802 for<'c> &'c mut <DB as Database>::Connection: sqlx::Executor<'c, Database = DB>,
803 for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
804 DB::Connection: Send,
805 T: Send,
806 String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
807 i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
808 f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
809 bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
810 Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
811 uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
812 chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
813 chrono::NaiveDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
814 chrono::NaiveDate: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
815 sqlx::types::Json<serde_json::Value>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
816{
817 fn ensure_safe_filters(&self) -> Result<(), sqlx::Error> {
818 if self.unsafe_fast {
819 return Ok(());
820 }
821 if self.has_raw_filter && !self.allow_unsafe {
822 return Err(sqlx::Error::Protocol(
823 "Refusing raw filter without allow_unsafe".to_string(),
824 ));
825 }
826 Ok(())
827 }
828
829 #[tracing::instrument(skip(self), fields(table = T::table_name()))]
834 pub async fn all(mut self) -> Result<Vec<T>, sqlx::Error> {
835 self.ensure_safe_filters()?;
836
837 let mut sql = String::with_capacity(128);
838 sql.push_str("SELECT * FROM ");
839 sql.push_str(T::table_name());
840
841 let mut where_binds: SmallVec<[BindValue; 8]> =
842 SmallVec::with_capacity(self.estimate_bind_count());
843 self.render_where_clause_into(&mut sql, &mut where_binds, 1);
844
845 if let Some(limit) = self.limit {
846 use std::fmt::Write;
847 let _ = write!(sql, " LIMIT {}", limit);
848 }
849
850 if let Some(offset) = self.offset {
851 use std::fmt::Write;
852 let _ = write!(sql, " OFFSET {}", offset);
853 }
854
855 #[cfg(debug_assertions)]
856 if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
857 let filters = self.format_filters_for_log();
858 tracing::debug!(
859 operation = "select",
860 sql = %sql,
861 filters = %filters,
862 "premix query"
863 );
864 }
865
866 let start = Instant::now();
867 let mut results: Vec<T> = match &mut self.executor {
868 Executor::Pool(pool) => {
869 let base = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
870 let query = where_binds.into_iter().fold(base, bind_value_query_as);
871 query.fetch_all(*pool).await?
872 }
873 Executor::Conn(conn) => {
874 let base = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
875 let query = where_binds.into_iter().fold(base, bind_value_query_as);
876 query.fetch_all(&mut **conn).await?
877 }
878 };
879 if !self.fast_path {
880 record_query_metrics("select", T::table_name(), start.elapsed());
881 }
882
883 if self.ultra_fast {
884 return Ok(results);
885 }
886
887 for relation in self.includes {
888 match &mut self.executor {
889 Executor::Pool(pool) => {
890 T::eager_load(&mut results, &relation, Executor::Pool(*pool)).await?;
891 }
892 Executor::Conn(conn) => {
893 T::eager_load(&mut results, &relation, Executor::Conn(&mut **conn)).await?;
894 }
895 }
896 }
897
898 Ok(results)
899 }
900
901 #[tracing::instrument(skip(self), fields(table = T::table_name()))]
905 pub fn stream(
906 self,
907 ) -> Result<futures_util::stream::BoxStream<'a, Result<T, sqlx::Error>>, sqlx::Error>
908 where
909 T: 'a,
910 {
911 self.ensure_safe_filters()?;
912
913 let mut sql = String::with_capacity(128);
914 sql.push_str("SELECT * FROM ");
915 sql.push_str(T::table_name());
916
917 let mut where_binds: SmallVec<[BindValue; 8]> =
918 SmallVec::with_capacity(self.estimate_bind_count());
919 self.render_where_clause_into(&mut sql, &mut where_binds, 1);
920
921 if let Some(limit) = self.limit {
922 use std::fmt::Write;
923 let _ = write!(sql, " LIMIT {}", limit);
924 }
925
926 if let Some(offset) = self.offset {
927 use std::fmt::Write;
928 let _ = write!(sql, " OFFSET {}", offset);
929 }
930
931 #[cfg(debug_assertions)]
932 if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
933 let filters = self.format_filters_for_log();
934 tracing::debug!(
935 operation = "stream",
936 sql = %sql,
937 filters = %filters,
938 "premix query"
939 );
940 }
941
942 let executor = self.executor;
943 Ok(Box::pin(async_stream::try_stream! {
944 let mut query = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
945 for bind in where_binds {
946 query = bind_value_query_as(query, bind);
947 }
948 let mut s = executor.fetch_stream(query);
949 while let Some(row) = s.next().await {
950 yield row?;
951 }
952 }))
953 }
954
955 #[inline(never)]
961 #[tracing::instrument(skip(self, values), fields(table = T::table_name()))]
962 pub async fn update(mut self, values: serde_json::Value) -> Result<u64, sqlx::Error>
963 where
964 String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
965 i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
966 f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
967 bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
968 Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
969 uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
970 chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
971 chrono::NaiveDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
972 chrono::NaiveDate: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
973 sqlx::types::Json<serde_json::Value>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
974 {
975 self.ensure_safe_filters()?;
976 if self.filters.is_empty() && !self.allow_unsafe && !self.unsafe_fast {
977 return Err(sqlx::Error::Protocol(
978 "Refusing bulk update without filters".to_string(),
979 ));
980 }
981 let obj = values.as_object().ok_or_else(|| {
982 sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
983 })?;
984
985 let mut sql = String::with_capacity(256);
986 use std::fmt::Write;
987 let _ = write!(sql, "UPDATE {} SET ", T::table_name());
988
989 let mut i = 1;
990 let mut first = true;
991 for k in obj.keys() {
992 if !first {
993 sql.push_str(", ");
994 }
995 let p = DB::placeholder(i);
996 let _ = write!(sql, "{} = {}", DB::quote_identifier(k), p);
997 i += 1;
998 first = false;
999 }
1000
1001 let mut where_binds: SmallVec<[BindValue; 8]> =
1002 SmallVec::with_capacity(self.estimate_bind_count());
1003 self.render_where_clause_into(&mut sql, &mut where_binds, obj.len() + 1);
1004
1005 if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
1006 let filters = self.format_filters_for_log();
1007 tracing::debug!(
1008 operation = "bulk_update",
1009 sql = %sql,
1010 filters = %filters,
1011 "premix query"
1012 );
1013 }
1014 let mut query = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1015 for val in obj.values() {
1016 match val {
1017 serde_json::Value::String(s) => query = query.bind(s.clone()),
1018 serde_json::Value::Number(n) => {
1019 if let Some(v) = n.as_i64() {
1020 query = query.bind(v);
1021 } else if let Some(v) = n.as_f64() {
1022 query = query.bind(v);
1023 }
1024 }
1025 serde_json::Value::Bool(b) => query = query.bind(*b),
1026 serde_json::Value::Null => query = query.bind(Option::<String>::None),
1027 serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
1028 query = query.bind(sqlx::types::Json(val.clone()));
1029 }
1030 }
1031 }
1032 for bind in where_binds {
1033 query = bind_value_query(query, bind);
1034 }
1035
1036 let start = Instant::now();
1037 let result = match &mut self.executor {
1038 Executor::Pool(pool) => {
1039 let res = query.execute(*pool).await?;
1040 Ok(DB::rows_affected(&res))
1041 }
1042 Executor::Conn(conn) => {
1043 let res = query.execute(&mut **conn).await?;
1044 Ok(DB::rows_affected(&res))
1045 }
1046 };
1047 if !self.fast_path {
1048 record_query_metrics("bulk_update", T::table_name(), start.elapsed());
1049 }
1050 result
1051 }
1052
1053 pub async fn update_all(self, values: serde_json::Value) -> Result<u64, sqlx::Error>
1055 where
1056 String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1057 i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1058 f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1059 bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1060 Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1061 uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1062 chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1063 chrono::NaiveDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1064 chrono::NaiveDate: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1065 sqlx::types::Json<serde_json::Value>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1066 {
1067 self.update(values).await
1068 }
1069
1070 #[tracing::instrument(skip(self), fields(table = T::table_name()))]
1072 pub async fn delete(mut self) -> Result<u64, sqlx::Error> {
1073 self.ensure_safe_filters()?;
1074 if self.filters.is_empty() && !self.allow_unsafe && !self.unsafe_fast {
1075 return Err(sqlx::Error::Protocol(
1076 "Refusing bulk delete without filters".to_string(),
1077 ));
1078 }
1079
1080 let mut sql = String::with_capacity(128);
1081 use std::fmt::Write;
1082
1083 if T::has_soft_delete() {
1084 let _ = write!(
1085 sql,
1086 "UPDATE {} SET {} = {}",
1087 T::table_name(),
1088 DB::quote_identifier("deleted_at"),
1089 DB::current_timestamp_fn()
1090 );
1091 } else {
1092 let _ = write!(sql, "DELETE FROM {}", T::table_name());
1093 }
1094
1095 let mut where_binds: SmallVec<[BindValue; 8]> =
1096 SmallVec::with_capacity(self.estimate_bind_count());
1097 self.render_where_clause_into(&mut sql, &mut where_binds, 1);
1098
1099 if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
1100 let filters = self.format_filters_for_log();
1101 tracing::debug!(
1102 operation = "bulk_delete",
1103 sql = %sql,
1104 filters = %filters,
1105 "premix query"
1106 );
1107 }
1108 let start = Instant::now();
1109 let result = match &mut self.executor {
1110 Executor::Pool(pool) => {
1111 let base = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1112 let query = where_binds.into_iter().fold(base, bind_value_query);
1113 let res = query.execute(*pool).await?;
1114 Ok(DB::rows_affected(&res))
1115 }
1116 Executor::Conn(conn) => {
1117 let base = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1118 let query = where_binds.into_iter().fold(base, bind_value_query);
1119 let res = query.execute(&mut **conn).await?;
1120 Ok(DB::rows_affected(&res))
1121 }
1122 };
1123 if !self.fast_path {
1124 record_query_metrics("bulk_delete", T::table_name(), start.elapsed());
1125 }
1126 result
1127 }
1128
1129 pub async fn delete_all(self) -> Result<u64, sqlx::Error> {
1131 self.delete().await
1132 }
1133}
1134
1135#[cfg(test)]
1136mod tests {
1137 use super::*;
1138 use sqlx::Sqlite;
1139
1140 struct DummyModel;
1141
1142 impl Model<Sqlite> for DummyModel {
1143 fn table_name() -> &'static str {
1144 "users"
1145 }
1146 fn create_table_sql() -> String {
1147 String::new()
1148 }
1149 fn list_columns() -> Vec<String> {
1150 vec!["id".to_string()]
1151 }
1152 async fn save<'a, E>(&'a mut self, _e: E) -> Result<(), sqlx::Error>
1153 where
1154 E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1155 {
1156 Ok(())
1157 }
1158 async fn update<'a, E>(
1159 &'a mut self,
1160 _e: E,
1161 ) -> Result<crate::model::UpdateResult, sqlx::Error>
1162 where
1163 E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1164 {
1165 Ok(crate::model::UpdateResult::Success)
1166 }
1167 async fn delete<'a, E>(&'a mut self, _e: E) -> Result<(), sqlx::Error>
1168 where
1169 E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1170 {
1171 Ok(())
1172 }
1173 fn has_soft_delete() -> bool {
1174 false
1175 }
1176 async fn find_by_id<'a, E>(_e: E, _id: i32) -> Result<Option<Self>, sqlx::Error>
1177 where
1178 E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1179 {
1180 Ok(None)
1181 }
1182 }
1183
1184 impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for DummyModel {
1186 fn from_row(_row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
1187 Ok(DummyModel)
1188 }
1189 }
1190
1191 #[tokio::test]
1192 async fn test_sql_injection_mitigation() {
1193 let pool = sqlx::Pool::<Sqlite>::connect_lazy("sqlite::memory:").unwrap();
1194 let qb = DummyModel::find_in_pool(&pool);
1195
1196 let qb = qb.filter_eq("id; DROP TABLE users; --", 1);
1198 let sql = qb.to_sql();
1199 println!("SQL select: {}", sql);
1200
1201 assert!(sql.contains("`id; DROP TABLE users; --` = ?"));
1203 assert!(sql.contains("SELECT * FROM users WHERE"));
1204 }
1205
1206 #[tokio::test]
1207 async fn test_to_update_sql_quoting() {
1208 let pool = sqlx::Pool::<Sqlite>::connect_lazy("sqlite::memory:").unwrap();
1209 let qb = DummyModel::find_in_pool(&pool).filter_eq("id", 1);
1210
1211 let values = serde_json::json!({
1212 "name; DROP TABLE users; --": "admin"
1213 });
1214
1215 let sql = qb.to_update_sql(&values).unwrap();
1216 println!("SQL update: {}", sql);
1217 assert!(sql.contains("`name; DROP TABLE users; --` = ?"));
1218 }
1219
1220 #[tokio::test]
1221 async fn test_stream_api() {
1222 use sqlx::Connection;
1223 let mut conn = sqlx::SqliteConnection::connect("sqlite::memory:")
1224 .await
1225 .unwrap();
1226 sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
1227 .execute(&mut conn)
1228 .await
1229 .unwrap();
1230 sqlx::query("INSERT INTO users (id, name) VALUES (1, 'alice'), (2, 'bob')")
1231 .execute(&mut conn)
1232 .await
1233 .unwrap();
1234
1235 let qb = DummyModel::find_in_tx(&mut conn);
1237
1238 let mut stream = qb.stream().unwrap();
1239 let mut count = 0;
1240 while let Some(row) = stream.next().await {
1241 let _: DummyModel = row.unwrap();
1242 count += 1;
1243 }
1244 assert_eq!(count, 2);
1245 }
1246}