1use crate::dialect::SqlDialect;
2use crate::executor::Executor;
3use crate::model::Model;
4use futures_util::StreamExt;
5use smallvec::{SmallVec, smallvec};
6use sqlx::{Database, IntoArguments};
7use std::time::{Duration, Instant};
8
9#[doc(hidden)]
10#[derive(Debug, Clone)]
11pub enum BindValue {
12 String(String),
13 I64(i64),
14 F64(f64),
15 Bool(bool),
16 Uuid(uuid::Uuid),
17 DateTime(chrono::DateTime<chrono::Utc>),
18 NaiveDateTime(chrono::NaiveDateTime),
19 NaiveDate(chrono::NaiveDate),
20 Json(serde_json::Value),
21 Null,
22}
23
24impl BindValue {
25 fn to_log_string(&self) -> String {
26 match self {
27 BindValue::String(v) => v.clone(),
28 BindValue::I64(v) => v.to_string(),
29 BindValue::F64(v) => v.to_string(),
30 BindValue::Bool(v) => v.to_string(),
31 BindValue::Uuid(v) => v.to_string(),
32 BindValue::DateTime(v) => v.to_rfc3339(),
33 BindValue::NaiveDateTime(v) => v.to_string(),
34 BindValue::NaiveDate(v) => v.to_string(),
35 BindValue::Json(v) => v.to_string(),
36 BindValue::Null => "NULL".to_string(),
37 }
38 }
39}
40
41#[cfg(feature = "metrics")]
42fn record_query_metrics(operation: &str, table: &str, elapsed: Duration) {
43 let elapsed_ms = elapsed.as_secs_f64() * 1000.0;
44 let labels = [
45 ("operation", operation.to_string()),
46 ("table", table.to_string()),
47 ];
48 metrics::histogram!("premix.query.duration_ms", &labels).record(elapsed_ms);
49 metrics::counter!("premix.query.count", &labels).increment(1);
50}
51
52#[cfg(not(feature = "metrics"))]
53fn record_query_metrics(_operation: &str, _table: &str, _elapsed: Duration) {}
54
55#[inline(always)]
56fn bind_value_query<'q, DB>(
57 query: sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>,
58 value: BindValue,
59) -> sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>
60where
61 DB: Database,
62 String: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
63 i64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
64 f64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
65 bool: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
66 uuid::Uuid: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
67 chrono::DateTime<chrono::Utc>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
68 chrono::NaiveDateTime: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
69 chrono::NaiveDate: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
70 sqlx::types::Json<serde_json::Value>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
71 Option<String>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
72{
73 match value {
74 BindValue::String(v) => query.bind(v),
75 BindValue::I64(v) => query.bind(v),
76 BindValue::F64(v) => query.bind(v),
77 BindValue::Bool(v) => query.bind(v),
78 BindValue::Uuid(v) => query.bind(v),
79 BindValue::DateTime(v) => query.bind(v),
80 BindValue::NaiveDateTime(v) => query.bind(v),
81 BindValue::NaiveDate(v) => query.bind(v),
82 BindValue::Json(v) => query.bind(sqlx::types::Json(v)),
83 BindValue::Null => query.bind(Option::<String>::None),
84 }
85}
86
87#[inline(always)]
88fn bind_value_query_as<'q, DB, T>(
89 query: sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>,
90 value: BindValue,
91) -> sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>
92where
93 DB: Database,
94 String: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
95 i64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
96 f64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
97 bool: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
98 uuid::Uuid: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
99 chrono::DateTime<chrono::Utc>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
100 chrono::NaiveDateTime: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
101 chrono::NaiveDate: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
102 sqlx::types::Json<serde_json::Value>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
103 Option<String>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
104{
105 match value {
106 BindValue::String(v) => query.bind(v),
107 BindValue::I64(v) => query.bind(v),
108 BindValue::F64(v) => query.bind(v),
109 BindValue::Bool(v) => query.bind(v),
110 BindValue::Uuid(v) => query.bind(v),
111 BindValue::DateTime(v) => query.bind(v),
112 BindValue::NaiveDateTime(v) => query.bind(v),
113 BindValue::NaiveDate(v) => query.bind(v),
114 BindValue::Json(v) => query.bind(sqlx::types::Json(v)),
115 BindValue::Null => query.bind(Option::<String>::None),
116 }
117}
118
119#[inline]
120fn apply_persistent_query_as<'q, DB, T>(
121 query: sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>,
122 prepared: bool,
123) -> sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>
124where
125 DB: Database + sqlx::database::HasStatementCache,
126{
127 if prepared {
128 query.persistent(true)
129 } else {
130 query
131 }
132}
133
134#[inline]
135fn apply_persistent_query<'q, DB>(
136 query: sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>,
137 prepared: bool,
138) -> sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>
139where
140 DB: Database + sqlx::database::HasStatementCache,
141{
142 if prepared {
143 query.persistent(true)
144 } else {
145 query
146 }
147}
148
149impl From<String> for BindValue {
150 fn from(value: String) -> Self {
151 Self::String(value)
152 }
153}
154
155impl From<&str> for BindValue {
156 fn from(value: &str) -> Self {
157 Self::String(value.to_string())
158 }
159}
160
161impl From<i32> for BindValue {
162 fn from(value: i32) -> Self {
163 Self::I64(value as i64)
164 }
165}
166
167impl From<i64> for BindValue {
168 fn from(value: i64) -> Self {
169 Self::I64(value)
170 }
171}
172
173impl From<f64> for BindValue {
174 fn from(value: f64) -> Self {
175 Self::F64(value)
176 }
177}
178
179impl From<bool> for BindValue {
180 fn from(value: bool) -> Self {
181 Self::Bool(value)
182 }
183}
184
185impl From<Option<String>> for BindValue {
186 fn from(value: Option<String>) -> Self {
187 match value {
188 Some(v) => Self::String(v),
189 None => Self::Null,
190 }
191 }
192}
193
194impl From<uuid::Uuid> for BindValue {
195 fn from(value: uuid::Uuid) -> Self {
196 Self::Uuid(value)
197 }
198}
199
200impl From<chrono::DateTime<chrono::Utc>> for BindValue {
201 fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
202 Self::DateTime(value)
203 }
204}
205
206impl From<chrono::NaiveDateTime> for BindValue {
207 fn from(value: chrono::NaiveDateTime) -> Self {
208 Self::NaiveDateTime(value)
209 }
210}
211
212impl From<chrono::NaiveDate> for BindValue {
213 fn from(value: chrono::NaiveDate) -> Self {
214 Self::NaiveDate(value)
215 }
216}
217
218impl From<serde_json::Value> for BindValue {
219 fn from(value: serde_json::Value) -> Self {
220 Self::Json(value)
221 }
222}
223
224#[derive(Debug, Clone)]
225pub(crate) enum FilterExpr {
226 Raw(String),
227 Compare {
228 column: ColumnRef,
229 op: FilterOp,
230 values: SmallVec<[BindValue; 2]>,
231 },
232 NullCheck {
233 column: ColumnRef,
234 is_null: bool,
235 },
236}
237
238#[derive(Debug, Clone, Copy, PartialEq, Eq)]
239pub(crate) enum FilterOp {
240 Eq,
241 Ne,
242 Lt,
243 Lte,
244 Gt,
245 Gte,
246 Like,
247 In,
248}
249
250impl FilterOp {
251 fn as_str(self) -> &'static str {
252 match self {
253 FilterOp::Eq => "=",
254 FilterOp::Ne => "!=",
255 FilterOp::Lt => "<",
256 FilterOp::Lte => "<=",
257 FilterOp::Gt => ">",
258 FilterOp::Gte => ">=",
259 FilterOp::Like => "LIKE",
260 FilterOp::In => "IN",
261 }
262 }
263
264 fn is_in(self) -> bool {
265 matches!(self, FilterOp::In)
266 }
267}
268
269#[derive(Debug, Clone, PartialEq, Eq)]
271pub enum ColumnRef {
272 Static(&'static str),
274 Owned(String),
276}
277
278impl ColumnRef {
279 pub const fn static_str(value: &'static str) -> Self {
281 ColumnRef::Static(value)
282 }
283
284 fn as_str(&self) -> &str {
285 match self {
286 ColumnRef::Static(value) => value,
287 ColumnRef::Owned(value) => value,
288 }
289 }
290}
291
292impl From<&str> for ColumnRef {
293 fn from(value: &str) -> Self {
294 ColumnRef::Owned(value.to_string())
295 }
296}
297
298impl From<String> for ColumnRef {
299 fn from(value: String) -> Self {
300 ColumnRef::Owned(value)
301 }
302}
303
304impl From<&String> for ColumnRef {
305 fn from(value: &String) -> Self {
306 ColumnRef::Owned(value.clone())
307 }
308}
309
310pub struct QueryBuilder<'a, T, DB: Database> {
315 executor: Executor<'a, DB>,
316 filters: Vec<FilterExpr>,
317 limit: Option<i32>,
318 offset: Option<i32>,
319 includes: SmallVec<[String; 2]>,
320 include_deleted: bool,
321 allow_unsafe: bool,
322 has_raw_filter: bool,
323 fast_path: bool,
324 unsafe_fast: bool,
325 ultra_fast: bool,
326 prepared: bool,
327 _marker: std::marker::PhantomData<T>,
328}
329
330impl<'a, T, DB: Database> std::fmt::Debug for QueryBuilder<'a, T, DB> {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 f.debug_struct("QueryBuilder")
333 .field("filters", &self.filters)
334 .field("limit", &self.limit)
335 .field("offset", &self.offset)
336 .field("includes", &self.includes)
337 .field("include_deleted", &self.include_deleted)
338 .field("allow_unsafe", &self.allow_unsafe)
339 .field("fast_path", &self.fast_path)
340 .field("unsafe_fast", &self.unsafe_fast)
341 .field("ultra_fast", &self.ultra_fast)
342 .field("prepared", &self.prepared)
343 .finish()
344 }
345}
346
347impl<'a, T, DB> QueryBuilder<'a, T, DB>
348where
349 DB: SqlDialect + sqlx::database::HasStatementCache,
350 T: Model<DB>,
351{
352 pub fn new(executor: Executor<'a, DB>) -> Self {
354 let default_includes = T::default_includes();
355 let mut includes = SmallVec::with_capacity(default_includes.len().max(2));
356 for relation in default_includes {
357 includes.push((*relation).to_string());
358 }
359 Self {
360 executor,
361 filters: Vec::with_capacity(4), limit: None,
363 offset: None,
364 includes, include_deleted: false,
366 allow_unsafe: false,
367 has_raw_filter: false,
368 fast_path: false,
369 unsafe_fast: false,
370 ultra_fast: false,
371 prepared: true,
372 _marker: std::marker::PhantomData,
373 }
374 }
375
376 pub fn filter(mut self, condition: impl Into<String>) -> Self {
381 self.filters.push(FilterExpr::Raw(condition.into()));
382 self.has_raw_filter = true;
383 self
384 }
385
386 pub fn filter_raw(self, condition: impl Into<String>) -> Self {
388 self.filter(condition)
389 }
390
391 pub fn filter_eq(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
393 self.filters.push(FilterExpr::Compare {
394 column: column.into(),
395 op: FilterOp::Eq,
396 values: smallvec![value.into()],
397 });
398 self
399 }
400
401 pub fn filter_ne(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
403 self.filters.push(FilterExpr::Compare {
404 column: column.into(),
405 op: FilterOp::Ne,
406 values: smallvec![value.into()],
407 });
408 self
409 }
410
411 pub fn filter_lt(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
413 self.filters.push(FilterExpr::Compare {
414 column: column.into(),
415 op: FilterOp::Lt,
416 values: smallvec![value.into()],
417 });
418 self
419 }
420
421 pub fn filter_lte(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
423 self.filters.push(FilterExpr::Compare {
424 column: column.into(),
425 op: FilterOp::Lte,
426 values: smallvec![value.into()],
427 });
428 self
429 }
430
431 pub fn filter_gt(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
433 self.filters.push(FilterExpr::Compare {
434 column: column.into(),
435 op: FilterOp::Gt,
436 values: smallvec![value.into()],
437 });
438 self
439 }
440
441 pub fn filter_gte(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
443 self.filters.push(FilterExpr::Compare {
444 column: column.into(),
445 op: FilterOp::Gte,
446 values: smallvec![value.into()],
447 });
448 self
449 }
450
451 pub fn filter_like(
453 mut self,
454 column: impl Into<ColumnRef>,
455 value: impl Into<BindValue>,
456 ) -> Self {
457 self.filters.push(FilterExpr::Compare {
458 column: column.into(),
459 op: FilterOp::Like,
460 values: smallvec![value.into()],
461 });
462 self
463 }
464
465 pub fn filter_is_null(mut self, column: impl Into<ColumnRef>) -> Self {
467 self.filters.push(FilterExpr::NullCheck {
468 column: column.into(),
469 is_null: true,
470 });
471 self
472 }
473
474 pub fn filter_is_not_null(mut self, column: impl Into<ColumnRef>) -> Self {
476 self.filters.push(FilterExpr::NullCheck {
477 column: column.into(),
478 is_null: false,
479 });
480 self
481 }
482
483 pub fn filter_in<I, V>(mut self, column: impl Into<ColumnRef>, values: I) -> Self
485 where
486 I: IntoIterator<Item = V>,
487 V: Into<BindValue>,
488 {
489 let values: SmallVec<[BindValue; 2]> = values.into_iter().map(Into::into).collect();
490 self.filters.push(FilterExpr::Compare {
491 column: column.into(),
492 op: FilterOp::In,
493 values,
494 });
495 self
496 }
497
498 fn format_filters_for_log(&self) -> String {
499 let sensitive_fields = T::sensitive_fields();
500 let mut rendered = String::with_capacity(128);
501 let mut first_clause = true;
502 let mut append_and = |buf: &mut String| {
503 if first_clause {
504 first_clause = false;
505 } else {
506 buf.push_str(" AND ");
507 }
508 };
509 use std::fmt::Write;
510
511 for filter in &self.filters {
512 match filter {
513 FilterExpr::Raw(_) => {
514 append_and(&mut rendered);
515 rendered.push_str("RAW(<redacted>)");
516 }
517 FilterExpr::Compare { column, op, values } => {
518 let column_name = column.as_str();
519 let is_sensitive = sensitive_fields.contains(&column_name);
520 if op.is_in() {
521 if values.is_empty() {
522 append_and(&mut rendered);
523 rendered.push_str("1=0");
524 continue;
525 }
526 append_and(&mut rendered);
527 let _ = write!(rendered, "{} IN (", column_name);
528 for (idx, value) in values.iter().enumerate() {
529 if idx > 0 {
530 rendered.push_str(", ");
531 }
532 if is_sensitive {
533 rendered.push_str("***");
534 } else {
535 rendered.push_str(&value.to_log_string());
536 }
537 }
538 rendered.push(')');
539 } else {
540 append_and(&mut rendered);
541 let _ = write!(rendered, "{} {} ", column_name, op.as_str());
542 if is_sensitive {
543 rendered.push_str("***");
544 } else if let Some(value) = values.first() {
545 rendered.push_str(&value.to_log_string());
546 } else {
547 rendered.push_str("NULL");
548 }
549 }
550 }
551 FilterExpr::NullCheck { column, is_null } => {
552 append_and(&mut rendered);
553 if *is_null {
554 let _ = write!(rendered, "{} IS NULL", column.as_str());
555 } else {
556 let _ = write!(rendered, "{} IS NOT NULL", column.as_str());
557 }
558 }
559 }
560 }
561
562 if T::has_soft_delete() && !self.include_deleted {
563 append_and(&mut rendered);
564 rendered.push_str("deleted_at IS NULL");
565 }
566
567 rendered
568 }
569
570 fn estimate_bind_count(&self) -> usize {
571 let mut count = 0usize;
572 for filter in &self.filters {
573 if let FilterExpr::Compare { op, values, .. } = filter {
574 if op.is_in() {
575 count = count.saturating_add(values.len());
576 } else {
577 count = count.saturating_add(1);
578 }
579 }
580 }
581 count
582 }
583
584 pub fn limit(mut self, limit: i32) -> Self {
587 self.limit = Some(limit);
588 self
589 }
590
591 pub fn offset(mut self, offset: i32) -> Self {
594 self.offset = Some(offset);
595 self
596 }
597
598 pub fn include(mut self, relation: impl Into<String>) -> Self {
600 self.includes.push(relation.into());
601 self
602 }
603
604 pub fn with_deleted(mut self) -> Self {
606 self.include_deleted = true;
607 self
608 }
609
610 pub fn allow_unsafe(mut self) -> Self {
613 self.allow_unsafe = true;
614 self
615 }
616
617 pub fn fast(mut self) -> Self {
619 self.fast_path = true;
620 self
621 }
622
623 pub fn unsafe_fast(mut self) -> Self {
625 self.fast_path = true;
626 self.unsafe_fast = true;
627 self
628 }
629
630 pub fn ultra_fast(mut self) -> Self {
633 self.fast_path = true;
634 self.unsafe_fast = true;
635 self.ultra_fast = true;
636 self
637 }
638
639 pub fn prepared(mut self) -> Self {
641 self.prepared = true;
642 self
643 }
644
645 pub fn unprepared(mut self) -> Self {
647 self.prepared = false;
648 self
649 }
650
651 pub fn to_sql(&self) -> String {
653 let mut sql = String::with_capacity(128); use std::fmt::Write;
655
656 sql.push_str("SELECT * FROM ");
657 sql.push_str(T::table_name());
658
659 let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
660 self.render_where_clause_into(&mut sql, &mut dummy_binds, 1);
661
662 if let Some(limit) = self.limit {
663 let _ = write!(sql, " LIMIT {}", limit);
664 }
665
666 if let Some(offset) = self.offset {
667 let _ = write!(sql, " OFFSET {}", offset);
668 }
669
670 sql
671 }
672
673 pub fn to_update_sql(&self, values: &serde_json::Value) -> Result<String, sqlx::Error> {
675 let obj = values.as_object().ok_or_else(|| {
676 sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
677 })?;
678
679 let mut sql = String::with_capacity(256);
680 use std::fmt::Write;
681
682 let _ = write!(sql, "UPDATE {} SET ", T::table_name());
683
684 let mut i = 1;
685 let mut first = true;
686
687 for k in obj.keys() {
688 if !first {
689 sql.push_str(", ");
690 }
691 let p = DB::placeholder(i);
692 let _ = write!(sql, "{} = {}", DB::quote_identifier(k), p);
693 i += 1;
694 first = false;
695 }
696
697 let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
698 self.render_where_clause_into(&mut sql, &mut dummy_binds, obj.len() + 1);
699 Ok(sql)
700 }
701
702 pub fn to_delete_sql(&self) -> String {
704 let mut sql = String::with_capacity(128);
705 use std::fmt::Write;
706
707 if T::has_soft_delete() {
708 let _ = write!(
709 sql,
710 "UPDATE {} SET {} = {}",
711 T::table_name(),
712 DB::quote_identifier("deleted_at"),
713 DB::current_timestamp_fn()
714 );
715 } else {
716 let _ = write!(sql, "DELETE FROM {}", T::table_name());
717 }
718
719 let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
720 self.render_where_clause_into(&mut sql, &mut dummy_binds, 1);
721 sql
722 }
723
724 #[inline(always)]
726 fn render_where_clause_into(
727 &self,
728 sql: &mut String,
729 binds: &mut SmallVec<[BindValue; 8]>,
730 start_index: usize,
731 ) {
732 let mut idx = start_index;
733 let mut first_clause = true;
734 use std::fmt::Write;
735
736 let mut append_and = |sql: &mut String| {
738 if first_clause {
739 sql.push_str(" WHERE ");
740 first_clause = false;
741 } else {
742 sql.push_str(" AND ");
743 }
744 };
745
746 for filter in &self.filters {
747 match filter {
748 FilterExpr::Raw(condition) => {
749 append_and(sql);
750 sql.push_str(condition);
751 }
752 FilterExpr::Compare { column, op, values } => {
753 if op.is_in() {
754 if values.is_empty() {
755 append_and(sql);
756 sql.push_str("1=0");
757 continue;
758 }
759 append_and(sql);
760 let _ = write!(sql, "{} IN (", DB::quote_identifier(column.as_str()));
761 let placeholders = crate::cached_placeholders_from::<DB>(idx, values.len());
762 sql.push_str(placeholders);
763 sql.push(')');
764 idx = idx.saturating_add(values.len());
765 for v in values {
766 binds.push(v.clone());
767 }
768 } else {
769 append_and(sql);
770 let _ = write!(
771 sql,
772 "{} {} {}",
773 DB::quote_identifier(column.as_str()),
774 op.as_str(),
775 DB::placeholder(idx)
776 );
777 idx += 1;
778 if let Some(v) = values.first() {
779 binds.push(v.clone());
780 }
781 }
782 }
783 FilterExpr::NullCheck { column, is_null } => {
784 append_and(sql);
785 if *is_null {
786 let _ = write!(sql, "{} IS NULL", DB::quote_identifier(column.as_str()));
787 } else {
788 let _ =
789 write!(sql, "{} IS NOT NULL", DB::quote_identifier(column.as_str()));
790 }
791 }
792 }
793 }
794
795 if T::has_soft_delete() && !self.include_deleted {
796 append_and(sql);
797 sql.push_str("deleted_at IS NULL");
798 }
799 }
800}
801
802impl<'a, T, DB> QueryBuilder<'a, T, DB>
803where
804 DB: SqlDialect,
805 T: Model<DB>,
806 for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
807 for<'c> &'c mut <DB as Database>::Connection: sqlx::Executor<'c, Database = DB>,
808 for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
809 DB::Connection: Send,
810 T: Send,
811 String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
812 i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
813 f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
814 bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
815 Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
816 uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
817 chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
818 chrono::NaiveDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
819 chrono::NaiveDate: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
820 sqlx::types::Json<serde_json::Value>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
821{
822 fn ensure_safe_filters(&self) -> Result<(), sqlx::Error> {
823 if self.unsafe_fast {
824 return Ok(());
825 }
826 if self.has_raw_filter && !self.allow_unsafe {
827 return Err(sqlx::Error::Protocol(
828 "Refusing raw filter without allow_unsafe".to_string(),
829 ));
830 }
831 Ok(())
832 }
833
834 #[tracing::instrument(skip(self), fields(table = T::table_name()))]
839 pub async fn all(mut self) -> Result<Vec<T>, sqlx::Error> {
840 self.ensure_safe_filters()?;
841
842 let mut sql = String::with_capacity(128);
843 sql.push_str("SELECT * FROM ");
844 sql.push_str(T::table_name());
845
846 let mut where_binds: SmallVec<[BindValue; 8]> =
847 SmallVec::with_capacity(self.estimate_bind_count());
848 self.render_where_clause_into(&mut sql, &mut where_binds, 1);
849
850 if let Some(limit) = self.limit {
851 use std::fmt::Write;
852 let _ = write!(sql, " LIMIT {}", limit);
853 }
854
855 if let Some(offset) = self.offset {
856 use std::fmt::Write;
857 let _ = write!(sql, " OFFSET {}", offset);
858 }
859
860 #[cfg(debug_assertions)]
861 if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
862 let filters = self.format_filters_for_log();
863 tracing::debug!(
864 operation = "select",
865 sql = %sql,
866 filters = %filters,
867 "premix query"
868 );
869 }
870
871 let start = Instant::now();
872 let mut results: Vec<T> = match &mut self.executor {
873 Executor::Pool(pool) => {
874 let base = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
875 let query = where_binds.into_iter().fold(base, bind_value_query_as);
876 query.fetch_all(*pool).await?
877 }
878 Executor::Conn(conn) => {
879 let base = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
880 let query = where_binds.into_iter().fold(base, bind_value_query_as);
881 query.fetch_all(&mut **conn).await?
882 }
883 };
884 if !self.fast_path {
885 record_query_metrics("select", T::table_name(), start.elapsed());
886 }
887
888 if self.ultra_fast {
889 return Ok(results);
890 }
891
892 for relation in self.includes {
893 match &mut self.executor {
894 Executor::Pool(pool) => {
895 T::eager_load(&mut results, &relation, Executor::Pool(*pool)).await?;
896 }
897 Executor::Conn(conn) => {
898 T::eager_load(&mut results, &relation, Executor::Conn(&mut **conn)).await?;
899 }
900 }
901 }
902
903 Ok(results)
904 }
905
906 #[tracing::instrument(skip(self), fields(table = T::table_name()))]
910 pub fn stream(
911 self,
912 ) -> Result<futures_util::stream::BoxStream<'a, Result<T, sqlx::Error>>, sqlx::Error>
913 where
914 T: 'a,
915 {
916 self.ensure_safe_filters()?;
917
918 let mut sql = String::with_capacity(128);
919 sql.push_str("SELECT * FROM ");
920 sql.push_str(T::table_name());
921
922 let mut where_binds: SmallVec<[BindValue; 8]> =
923 SmallVec::with_capacity(self.estimate_bind_count());
924 self.render_where_clause_into(&mut sql, &mut where_binds, 1);
925
926 if let Some(limit) = self.limit {
927 use std::fmt::Write;
928 let _ = write!(sql, " LIMIT {}", limit);
929 }
930
931 if let Some(offset) = self.offset {
932 use std::fmt::Write;
933 let _ = write!(sql, " OFFSET {}", offset);
934 }
935
936 #[cfg(debug_assertions)]
937 if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
938 let filters = self.format_filters_for_log();
939 tracing::debug!(
940 operation = "stream",
941 sql = %sql,
942 filters = %filters,
943 "premix query"
944 );
945 }
946
947 let executor = self.executor;
948 Ok(Box::pin(async_stream::try_stream! {
949 let mut query = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
950 for bind in where_binds {
951 query = bind_value_query_as(query, bind);
952 }
953 let mut s = executor.fetch_stream(query);
954 while let Some(row) = s.next().await {
955 yield row?;
956 }
957 }))
958 }
959
960 #[inline(never)]
966 #[tracing::instrument(skip(self, values), fields(table = T::table_name()))]
967 pub async fn update(mut self, values: serde_json::Value) -> Result<u64, sqlx::Error>
968 where
969 String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
970 i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
971 f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
972 bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
973 Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
974 uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
975 chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
976 chrono::NaiveDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
977 chrono::NaiveDate: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
978 sqlx::types::Json<serde_json::Value>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
979 {
980 self.ensure_safe_filters()?;
981 if self.filters.is_empty() && !self.allow_unsafe && !self.unsafe_fast {
982 return Err(sqlx::Error::Protocol(
983 "Refusing bulk update without filters".to_string(),
984 ));
985 }
986 let obj = values.as_object().ok_or_else(|| {
987 sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
988 })?;
989
990 let mut sql = String::with_capacity(256);
991 use std::fmt::Write;
992 let _ = write!(sql, "UPDATE {} SET ", T::table_name());
993
994 let mut i = 1;
995 let mut first = true;
996 for k in obj.keys() {
997 if !first {
998 sql.push_str(", ");
999 }
1000 let p = DB::placeholder(i);
1001 let _ = write!(sql, "{} = {}", DB::quote_identifier(k), p);
1002 i += 1;
1003 first = false;
1004 }
1005
1006 let mut where_binds: SmallVec<[BindValue; 8]> =
1007 SmallVec::with_capacity(self.estimate_bind_count());
1008 self.render_where_clause_into(&mut sql, &mut where_binds, obj.len() + 1);
1009
1010 if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
1011 let filters = self.format_filters_for_log();
1012 tracing::debug!(
1013 operation = "bulk_update",
1014 sql = %sql,
1015 filters = %filters,
1016 "premix query"
1017 );
1018 }
1019 let mut query = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1020 for val in obj.values() {
1021 match val {
1022 serde_json::Value::String(s) => query = query.bind(s.clone()),
1023 serde_json::Value::Number(n) => {
1024 if let Some(v) = n.as_i64() {
1025 query = query.bind(v);
1026 } else if let Some(v) = n.as_f64() {
1027 query = query.bind(v);
1028 }
1029 }
1030 serde_json::Value::Bool(b) => query = query.bind(*b),
1031 serde_json::Value::Null => query = query.bind(Option::<String>::None),
1032 serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
1033 query = query.bind(sqlx::types::Json(val.clone()));
1034 }
1035 }
1036 }
1037 for bind in where_binds {
1038 query = bind_value_query(query, bind);
1039 }
1040
1041 let start = Instant::now();
1042 let result = match &mut self.executor {
1043 Executor::Pool(pool) => {
1044 let res = query.execute(*pool).await?;
1045 Ok(DB::rows_affected(&res))
1046 }
1047 Executor::Conn(conn) => {
1048 let res = query.execute(&mut **conn).await?;
1049 Ok(DB::rows_affected(&res))
1050 }
1051 };
1052 if !self.fast_path {
1053 record_query_metrics("bulk_update", T::table_name(), start.elapsed());
1054 }
1055 result
1056 }
1057
1058 pub async fn update_all(self, values: serde_json::Value) -> Result<u64, sqlx::Error>
1060 where
1061 String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1062 i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1063 f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1064 bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1065 Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1066 uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1067 chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1068 chrono::NaiveDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1069 chrono::NaiveDate: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1070 sqlx::types::Json<serde_json::Value>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1071 {
1072 self.update(values).await
1073 }
1074
1075 #[tracing::instrument(skip(self), fields(table = T::table_name()))]
1077 pub async fn delete(mut self) -> Result<u64, sqlx::Error> {
1078 self.ensure_safe_filters()?;
1079 if self.filters.is_empty() && !self.allow_unsafe && !self.unsafe_fast {
1080 return Err(sqlx::Error::Protocol(
1081 "Refusing bulk delete without filters".to_string(),
1082 ));
1083 }
1084
1085 let mut sql = String::with_capacity(128);
1086 use std::fmt::Write;
1087
1088 if T::has_soft_delete() {
1089 let _ = write!(
1090 sql,
1091 "UPDATE {} SET {} = {}",
1092 T::table_name(),
1093 DB::quote_identifier("deleted_at"),
1094 DB::current_timestamp_fn()
1095 );
1096 } else {
1097 let _ = write!(sql, "DELETE FROM {}", T::table_name());
1098 }
1099
1100 let mut where_binds: SmallVec<[BindValue; 8]> =
1101 SmallVec::with_capacity(self.estimate_bind_count());
1102 self.render_where_clause_into(&mut sql, &mut where_binds, 1);
1103
1104 if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
1105 let filters = self.format_filters_for_log();
1106 tracing::debug!(
1107 operation = "bulk_delete",
1108 sql = %sql,
1109 filters = %filters,
1110 "premix query"
1111 );
1112 }
1113 let start = Instant::now();
1114 let result = match &mut self.executor {
1115 Executor::Pool(pool) => {
1116 let base = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1117 let query = where_binds.into_iter().fold(base, bind_value_query);
1118 let res = query.execute(*pool).await?;
1119 Ok(DB::rows_affected(&res))
1120 }
1121 Executor::Conn(conn) => {
1122 let base = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1123 let query = where_binds.into_iter().fold(base, bind_value_query);
1124 let res = query.execute(&mut **conn).await?;
1125 Ok(DB::rows_affected(&res))
1126 }
1127 };
1128 if !self.fast_path {
1129 record_query_metrics("bulk_delete", T::table_name(), start.elapsed());
1130 }
1131 result
1132 }
1133
1134 pub async fn delete_all(self) -> Result<u64, sqlx::Error> {
1136 self.delete().await
1137 }
1138}
1139
1140#[cfg(test)]
1141mod tests {
1142 use super::*;
1143 use sqlx::Sqlite;
1144
1145 struct DummyModel;
1146
1147 impl Model<Sqlite> for DummyModel {
1148 fn table_name() -> &'static str {
1149 "users"
1150 }
1151 fn create_table_sql() -> String {
1152 String::new()
1153 }
1154 fn list_columns() -> Vec<String> {
1155 vec!["id".to_string()]
1156 }
1157 async fn save<'a, E>(&'a mut self, _e: E) -> Result<(), sqlx::Error>
1158 where
1159 E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1160 {
1161 Ok(())
1162 }
1163 async fn update<'a, E>(
1164 &'a mut self,
1165 _e: E,
1166 ) -> Result<crate::model::UpdateResult, sqlx::Error>
1167 where
1168 E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1169 {
1170 Ok(crate::model::UpdateResult::Success)
1171 }
1172 async fn delete<'a, E>(&'a mut self, _e: E) -> Result<(), sqlx::Error>
1173 where
1174 E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1175 {
1176 Ok(())
1177 }
1178 fn has_soft_delete() -> bool {
1179 false
1180 }
1181 async fn find_by_id<'a, E>(_e: E, _id: i32) -> Result<Option<Self>, sqlx::Error>
1182 where
1183 E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1184 {
1185 Ok(None)
1186 }
1187 }
1188
1189 impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for DummyModel {
1191 fn from_row(_row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
1192 Ok(DummyModel)
1193 }
1194 }
1195
1196 #[tokio::test]
1197 async fn test_sql_injection_mitigation() {
1198 let pool = sqlx::Pool::<Sqlite>::connect_lazy("sqlite::memory:").unwrap();
1199 let qb = DummyModel::find_in_pool(&pool);
1200
1201 let qb = qb.filter_eq("id; DROP TABLE users; --", 1);
1203 let sql = qb.to_sql();
1204 println!("SQL select: {}", sql);
1205
1206 assert!(sql.contains("`id; DROP TABLE users; --` = ?"));
1208 assert!(sql.contains("SELECT * FROM users WHERE"));
1209 }
1210
1211 #[tokio::test]
1212 async fn test_to_update_sql_quoting() {
1213 let pool = sqlx::Pool::<Sqlite>::connect_lazy("sqlite::memory:").unwrap();
1214 let qb = DummyModel::find_in_pool(&pool).filter_eq("id", 1);
1215
1216 let values = serde_json::json!({
1217 "name; DROP TABLE users; --": "admin"
1218 });
1219
1220 let sql = qb.to_update_sql(&values).unwrap();
1221 println!("SQL update: {}", sql);
1222 assert!(sql.contains("`name; DROP TABLE users; --` = ?"));
1223 }
1224
1225 #[tokio::test]
1226 async fn test_stream_api() {
1227 use sqlx::Connection;
1228 let mut conn = sqlx::SqliteConnection::connect("sqlite::memory:")
1229 .await
1230 .unwrap();
1231 sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
1232 .execute(&mut conn)
1233 .await
1234 .unwrap();
1235 sqlx::query("INSERT INTO users (id, name) VALUES (1, 'alice'), (2, 'bob')")
1236 .execute(&mut conn)
1237 .await
1238 .unwrap();
1239
1240 let qb = DummyModel::find_in_tx(&mut conn);
1242
1243 let mut stream = qb.stream().unwrap();
1244 let mut count = 0;
1245 while let Some(row) = stream.next().await {
1246 let _: DummyModel = row.unwrap();
1247 count += 1;
1248 }
1249 assert_eq!(count, 2);
1250 }
1251}