Skip to main content

premix_core/
query.rs

1use crate::dialect::SqlDialect;
2use crate::executor::Executor;
3use crate::model::Model;
4use futures_util::StreamExt;
5use smallvec::{SmallVec, smallvec};
6use sqlx::{Database, IntoArguments};
7use std::time::{Duration, Instant};
8
9#[doc(hidden)]
10#[derive(Debug, Clone)]
11pub enum BindValue {
12    String(String),
13    I64(i64),
14    F64(f64),
15    Bool(bool),
16    Uuid(uuid::Uuid),
17    DateTime(chrono::DateTime<chrono::Utc>),
18    NaiveDateTime(chrono::NaiveDateTime),
19    NaiveDate(chrono::NaiveDate),
20    Json(serde_json::Value),
21    Null,
22}
23
24impl BindValue {
25    fn to_log_string(&self) -> String {
26        match self {
27            BindValue::String(v) => v.clone(),
28            BindValue::I64(v) => v.to_string(),
29            BindValue::F64(v) => v.to_string(),
30            BindValue::Bool(v) => v.to_string(),
31            BindValue::Uuid(v) => v.to_string(),
32            BindValue::DateTime(v) => v.to_rfc3339(),
33            BindValue::NaiveDateTime(v) => v.to_string(),
34            BindValue::NaiveDate(v) => v.to_string(),
35            BindValue::Json(v) => v.to_string(),
36            BindValue::Null => "NULL".to_string(),
37        }
38    }
39}
40
41#[cfg(feature = "metrics")]
42fn record_query_metrics(operation: &str, table: &str, elapsed: Duration) {
43    let elapsed_ms = elapsed.as_secs_f64() * 1000.0;
44    let labels = [
45        ("operation", operation.to_string()),
46        ("table", table.to_string()),
47    ];
48    metrics::histogram!("premix.query.duration_ms", &labels).record(elapsed_ms);
49    metrics::counter!("premix.query.count", &labels).increment(1);
50}
51
52#[cfg(not(feature = "metrics"))]
53fn record_query_metrics(_operation: &str, _table: &str, _elapsed: Duration) {}
54
55#[inline(always)]
56fn bind_value_query<'q, DB>(
57    query: sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>,
58    value: BindValue,
59) -> sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>
60where
61    DB: Database,
62    String: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
63    i64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
64    f64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
65    bool: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
66    uuid::Uuid: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
67    chrono::DateTime<chrono::Utc>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
68    chrono::NaiveDateTime: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
69    chrono::NaiveDate: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
70    sqlx::types::Json<serde_json::Value>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
71    Option<String>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
72{
73    match value {
74        BindValue::String(v) => query.bind(v),
75        BindValue::I64(v) => query.bind(v),
76        BindValue::F64(v) => query.bind(v),
77        BindValue::Bool(v) => query.bind(v),
78        BindValue::Uuid(v) => query.bind(v),
79        BindValue::DateTime(v) => query.bind(v),
80        BindValue::NaiveDateTime(v) => query.bind(v),
81        BindValue::NaiveDate(v) => query.bind(v),
82        BindValue::Json(v) => query.bind(sqlx::types::Json(v)),
83        BindValue::Null => query.bind(Option::<String>::None),
84    }
85}
86
87#[inline(always)]
88fn bind_value_query_as<'q, DB, T>(
89    query: sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>,
90    value: BindValue,
91) -> sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>
92where
93    DB: Database,
94    String: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
95    i64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
96    f64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
97    bool: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
98    uuid::Uuid: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
99    chrono::DateTime<chrono::Utc>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
100    chrono::NaiveDateTime: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
101    chrono::NaiveDate: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
102    sqlx::types::Json<serde_json::Value>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
103    Option<String>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
104{
105    match value {
106        BindValue::String(v) => query.bind(v),
107        BindValue::I64(v) => query.bind(v),
108        BindValue::F64(v) => query.bind(v),
109        BindValue::Bool(v) => query.bind(v),
110        BindValue::Uuid(v) => query.bind(v),
111        BindValue::DateTime(v) => query.bind(v),
112        BindValue::NaiveDateTime(v) => query.bind(v),
113        BindValue::NaiveDate(v) => query.bind(v),
114        BindValue::Json(v) => query.bind(sqlx::types::Json(v)),
115        BindValue::Null => query.bind(Option::<String>::None),
116    }
117}
118
119#[inline]
120fn apply_persistent_query_as<'q, DB, T>(
121    query: sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>,
122    prepared: bool,
123) -> sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>
124where
125    DB: Database + sqlx::database::HasStatementCache,
126{
127    if prepared {
128        query.persistent(true)
129    } else {
130        query
131    }
132}
133
134#[inline]
135fn apply_persistent_query<'q, DB>(
136    query: sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>,
137    prepared: bool,
138) -> sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>
139where
140    DB: Database + sqlx::database::HasStatementCache,
141{
142    if prepared {
143        query.persistent(true)
144    } else {
145        query
146    }
147}
148
149impl From<String> for BindValue {
150    fn from(value: String) -> Self {
151        Self::String(value)
152    }
153}
154
155impl From<&str> for BindValue {
156    fn from(value: &str) -> Self {
157        Self::String(value.to_string())
158    }
159}
160
161impl From<i32> for BindValue {
162    fn from(value: i32) -> Self {
163        Self::I64(value as i64)
164    }
165}
166
167impl From<i64> for BindValue {
168    fn from(value: i64) -> Self {
169        Self::I64(value)
170    }
171}
172
173impl From<f64> for BindValue {
174    fn from(value: f64) -> Self {
175        Self::F64(value)
176    }
177}
178
179impl From<bool> for BindValue {
180    fn from(value: bool) -> Self {
181        Self::Bool(value)
182    }
183}
184
185impl From<Option<String>> for BindValue {
186    fn from(value: Option<String>) -> Self {
187        match value {
188            Some(v) => Self::String(v),
189            None => Self::Null,
190        }
191    }
192}
193
194impl From<uuid::Uuid> for BindValue {
195    fn from(value: uuid::Uuid) -> Self {
196        Self::Uuid(value)
197    }
198}
199
200impl From<chrono::DateTime<chrono::Utc>> for BindValue {
201    fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
202        Self::DateTime(value)
203    }
204}
205
206impl From<chrono::NaiveDateTime> for BindValue {
207    fn from(value: chrono::NaiveDateTime) -> Self {
208        Self::NaiveDateTime(value)
209    }
210}
211
212impl From<chrono::NaiveDate> for BindValue {
213    fn from(value: chrono::NaiveDate) -> Self {
214        Self::NaiveDate(value)
215    }
216}
217
218impl From<serde_json::Value> for BindValue {
219    fn from(value: serde_json::Value) -> Self {
220        Self::Json(value)
221    }
222}
223
224#[derive(Debug, Clone)]
225pub(crate) enum FilterExpr {
226    Raw(String),
227    Compare {
228        column: ColumnRef,
229        op: FilterOp,
230        values: SmallVec<[BindValue; 2]>,
231    },
232    NullCheck {
233        column: ColumnRef,
234        is_null: bool,
235    },
236}
237
238#[derive(Debug, Clone, Copy, PartialEq, Eq)]
239pub(crate) enum FilterOp {
240    Eq,
241    Ne,
242    Lt,
243    Lte,
244    Gt,
245    Gte,
246    Like,
247    In,
248}
249
250impl FilterOp {
251    fn as_str(self) -> &'static str {
252        match self {
253            FilterOp::Eq => "=",
254            FilterOp::Ne => "!=",
255            FilterOp::Lt => "<",
256            FilterOp::Lte => "<=",
257            FilterOp::Gt => ">",
258            FilterOp::Gte => ">=",
259            FilterOp::Like => "LIKE",
260            FilterOp::In => "IN",
261        }
262    }
263
264    fn is_in(self) -> bool {
265        matches!(self, FilterOp::In)
266    }
267}
268
269/// Column reference used in filters (static literals or owned names).
270#[derive(Debug, Clone, PartialEq, Eq)]
271pub enum ColumnRef {
272    /// Static column name known at compile time.
273    Static(&'static str),
274    /// Owned column name created at runtime.
275    Owned(String),
276}
277
278impl ColumnRef {
279    /// Builds a static column reference without allocation.
280    pub const fn static_str(value: &'static str) -> Self {
281        ColumnRef::Static(value)
282    }
283
284    fn as_str(&self) -> &str {
285        match self {
286            ColumnRef::Static(value) => value,
287            ColumnRef::Owned(value) => value,
288        }
289    }
290}
291
292impl From<&str> for ColumnRef {
293    fn from(value: &str) -> Self {
294        ColumnRef::Owned(value.to_string())
295    }
296}
297
298impl From<String> for ColumnRef {
299    fn from(value: String) -> Self {
300        ColumnRef::Owned(value)
301    }
302}
303
304impl From<&String> for ColumnRef {
305    fn from(value: &String) -> Self {
306        ColumnRef::Owned(value.clone())
307    }
308}
309
310/// A type-safe SQL query builder.
311///
312/// `QueryBuilder` provides a fluent interface for building SELECT, UPDATE, and DELETE queries
313/// with support for filtering, pagination, eager loading, and soft deletes.
314pub struct QueryBuilder<'a, T, DB: Database> {
315    executor: Executor<'a, DB>,
316    filters: Vec<FilterExpr>,
317    limit: Option<i32>,
318    offset: Option<i32>,
319    includes: SmallVec<[String; 2]>,
320    include_deleted: bool,
321    allow_unsafe: bool,
322    has_raw_filter: bool,
323    fast_path: bool,
324    unsafe_fast: bool,
325    ultra_fast: bool,
326    prepared: bool,
327    _marker: std::marker::PhantomData<T>,
328}
329
330impl<'a, T, DB: Database> std::fmt::Debug for QueryBuilder<'a, T, DB> {
331    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332        f.debug_struct("QueryBuilder")
333            .field("filters", &self.filters)
334            .field("limit", &self.limit)
335            .field("offset", &self.offset)
336            .field("includes", &self.includes)
337            .field("include_deleted", &self.include_deleted)
338            .field("allow_unsafe", &self.allow_unsafe)
339            .field("fast_path", &self.fast_path)
340            .field("unsafe_fast", &self.unsafe_fast)
341            .field("ultra_fast", &self.ultra_fast)
342            .field("prepared", &self.prepared)
343            .finish()
344    }
345}
346
347impl<'a, T, DB> QueryBuilder<'a, T, DB>
348where
349    DB: SqlDialect + sqlx::database::HasStatementCache,
350    T: Model<DB>,
351{
352    /// Creates a new `QueryBuilder` using the provided [`Executor`].
353    pub fn new(executor: Executor<'a, DB>) -> Self {
354        Self {
355            executor,
356            filters: Vec::with_capacity(4), // Pre-allocate for typical queries (1-4 filters)
357            limit: None,
358            offset: None,
359            includes: SmallVec::with_capacity(2), // Pre-allocate for typical queries (1-2 includes)
360            include_deleted: false,
361            allow_unsafe: false,
362            has_raw_filter: false,
363            fast_path: false,
364            unsafe_fast: false,
365            ultra_fast: false,
366            prepared: true,
367            _marker: std::marker::PhantomData,
368        }
369    }
370
371    /// Adds a raw SQL filter condition to the query.
372    ///
373    /// # Safety
374    /// This method is potentially unsafe and requires calling [`allow_unsafe`] for the query to execute.
375    pub fn filter(mut self, condition: impl Into<String>) -> Self {
376        self.filters.push(FilterExpr::Raw(condition.into()));
377        self.has_raw_filter = true;
378        self
379    }
380
381    /// Adds a raw SQL filter condition to the query.
382    pub fn filter_raw(self, condition: impl Into<String>) -> Self {
383        self.filter(condition)
384    }
385
386    /// Adds an equality filter (`column = value`).
387    pub fn filter_eq(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
388        self.filters.push(FilterExpr::Compare {
389            column: column.into(),
390            op: FilterOp::Eq,
391            values: smallvec![value.into()],
392        });
393        self
394    }
395
396    /// Adds a not-equal filter (`column != value`).
397    pub fn filter_ne(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
398        self.filters.push(FilterExpr::Compare {
399            column: column.into(),
400            op: FilterOp::Ne,
401            values: smallvec![value.into()],
402        });
403        self
404    }
405
406    /// Adds a less-than filter (`column < value`).
407    pub fn filter_lt(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
408        self.filters.push(FilterExpr::Compare {
409            column: column.into(),
410            op: FilterOp::Lt,
411            values: smallvec![value.into()],
412        });
413        self
414    }
415
416    /// Adds a less-than-or-equal filter (`column <= value`).
417    pub fn filter_lte(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
418        self.filters.push(FilterExpr::Compare {
419            column: column.into(),
420            op: FilterOp::Lte,
421            values: smallvec![value.into()],
422        });
423        self
424    }
425
426    /// Adds a greater-than filter (`column > value`).
427    pub fn filter_gt(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
428        self.filters.push(FilterExpr::Compare {
429            column: column.into(),
430            op: FilterOp::Gt,
431            values: smallvec![value.into()],
432        });
433        self
434    }
435
436    /// Adds a greater-than-or-equal filter (`column >= value`).
437    pub fn filter_gte(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
438        self.filters.push(FilterExpr::Compare {
439            column: column.into(),
440            op: FilterOp::Gte,
441            values: smallvec![value.into()],
442        });
443        self
444    }
445
446    /// Adds a LIKE filter (`column LIKE value`).
447    pub fn filter_like(
448        mut self,
449        column: impl Into<ColumnRef>,
450        value: impl Into<BindValue>,
451    ) -> Self {
452        self.filters.push(FilterExpr::Compare {
453            column: column.into(),
454            op: FilterOp::Like,
455            values: smallvec![value.into()],
456        });
457        self
458    }
459
460    /// Filters rows where the column IS NULL.
461    pub fn filter_is_null(mut self, column: impl Into<ColumnRef>) -> Self {
462        self.filters.push(FilterExpr::NullCheck {
463            column: column.into(),
464            is_null: true,
465        });
466        self
467    }
468
469    /// Filters rows where the column IS NOT NULL.
470    pub fn filter_is_not_null(mut self, column: impl Into<ColumnRef>) -> Self {
471        self.filters.push(FilterExpr::NullCheck {
472            column: column.into(),
473            is_null: false,
474        });
475        self
476    }
477
478    /// Adds an IN filter (`column IN (values...)`).
479    pub fn filter_in<I, V>(mut self, column: impl Into<ColumnRef>, values: I) -> Self
480    where
481        I: IntoIterator<Item = V>,
482        V: Into<BindValue>,
483    {
484        let values: SmallVec<[BindValue; 2]> = values.into_iter().map(Into::into).collect();
485        self.filters.push(FilterExpr::Compare {
486            column: column.into(),
487            op: FilterOp::In,
488            values,
489        });
490        self
491    }
492
493    fn format_filters_for_log(&self) -> String {
494        let sensitive_fields = T::sensitive_fields();
495        let mut rendered = String::with_capacity(128);
496        let mut first_clause = true;
497        let mut append_and = |buf: &mut String| {
498            if first_clause {
499                first_clause = false;
500            } else {
501                buf.push_str(" AND ");
502            }
503        };
504        use std::fmt::Write;
505
506        for filter in &self.filters {
507            match filter {
508                FilterExpr::Raw(_) => {
509                    append_and(&mut rendered);
510                    rendered.push_str("RAW(<redacted>)");
511                }
512                FilterExpr::Compare { column, op, values } => {
513                    let column_name = column.as_str();
514                    let is_sensitive = sensitive_fields.contains(&column_name);
515                    if op.is_in() {
516                        if values.is_empty() {
517                            append_and(&mut rendered);
518                            rendered.push_str("1=0");
519                            continue;
520                        }
521                        append_and(&mut rendered);
522                        let _ = write!(rendered, "{} IN (", column_name);
523                        for (idx, value) in values.iter().enumerate() {
524                            if idx > 0 {
525                                rendered.push_str(", ");
526                            }
527                            if is_sensitive {
528                                rendered.push_str("***");
529                            } else {
530                                rendered.push_str(&value.to_log_string());
531                            }
532                        }
533                        rendered.push(')');
534                    } else {
535                        append_and(&mut rendered);
536                        let _ = write!(rendered, "{} {} ", column_name, op.as_str());
537                        if is_sensitive {
538                            rendered.push_str("***");
539                        } else if let Some(value) = values.first() {
540                            rendered.push_str(&value.to_log_string());
541                        } else {
542                            rendered.push_str("NULL");
543                        }
544                    }
545                }
546                FilterExpr::NullCheck { column, is_null } => {
547                    append_and(&mut rendered);
548                    if *is_null {
549                        let _ = write!(rendered, "{} IS NULL", column.as_str());
550                    } else {
551                        let _ = write!(rendered, "{} IS NOT NULL", column.as_str());
552                    }
553                }
554            }
555        }
556
557        if T::has_soft_delete() && !self.include_deleted {
558            append_and(&mut rendered);
559            rendered.push_str("deleted_at IS NULL");
560        }
561
562        rendered
563    }
564
565    fn estimate_bind_count(&self) -> usize {
566        let mut count = 0usize;
567        for filter in &self.filters {
568            if let FilterExpr::Compare { op, values, .. } = filter {
569                if op.is_in() {
570                    count = count.saturating_add(values.len());
571                } else {
572                    count = count.saturating_add(1);
573                }
574            }
575        }
576        count
577    }
578
579    /// Limits the number of rows returned by the query.
580    /// Sets the maximum number of rows to return.
581    pub fn limit(mut self, limit: i32) -> Self {
582        self.limit = Some(limit);
583        self
584    }
585
586    /// Skips the specified number of rows.
587    /// Sets the number of rows to skip.
588    pub fn offset(mut self, offset: i32) -> Self {
589        self.offset = Some(offset);
590        self
591    }
592
593    /// Eager loads a related model.
594    pub fn include(mut self, relation: impl Into<String>) -> Self {
595        self.includes.push(relation.into());
596        self
597    }
598
599    /// Includes soft-deleted records in the results.
600    pub fn with_deleted(mut self) -> Self {
601        self.include_deleted = true;
602        self
603    }
604
605    /// Explicitly allows potentially unsafe raw filters.
606    /// Enables execution of queries with raw SQL filters.
607    pub fn allow_unsafe(mut self) -> Self {
608        self.allow_unsafe = true;
609        self
610    }
611
612    /// Enables a fast path that skips logging and metrics for hot queries.
613    pub fn fast(mut self) -> Self {
614        self.fast_path = true;
615        self
616    }
617
618    /// Enables an unsafe fast path that skips logging, metrics, and safety guards.
619    pub fn unsafe_fast(mut self) -> Self {
620        self.fast_path = true;
621        self.unsafe_fast = true;
622        self
623    }
624
625    /// Enables the ultra-fast path: skips logging, metrics, safety guards, and eager loading.
626    /// Note: Any configured includes will be ignored.
627    pub fn ultra_fast(mut self) -> Self {
628        self.fast_path = true;
629        self.unsafe_fast = true;
630        self.ultra_fast = true;
631        self
632    }
633
634    /// Enable prepared statement caching for this query (default: enabled).
635    pub fn prepared(mut self) -> Self {
636        self.prepared = true;
637        self
638    }
639
640    /// Disable prepared statement caching for this query.
641    pub fn unprepared(mut self) -> Self {
642        self.prepared = false;
643        self
644    }
645
646    /// Returns the SELECT SQL that would be executed for this query.
647    pub fn to_sql(&self) -> String {
648        let mut sql = String::with_capacity(128); // Pre-allocate reasonable size
649        use std::fmt::Write;
650
651        sql.push_str("SELECT * FROM ");
652        sql.push_str(T::table_name());
653
654        let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
655        self.render_where_clause_into(&mut sql, &mut dummy_binds, 1);
656
657        if let Some(limit) = self.limit {
658            let _ = write!(sql, " LIMIT {}", limit);
659        }
660
661        if let Some(offset) = self.offset {
662            let _ = write!(sql, " OFFSET {}", offset);
663        }
664
665        sql
666    }
667
668    /// Returns the UPDATE SQL that would be executed for this query.
669    pub fn to_update_sql(&self, values: &serde_json::Value) -> Result<String, sqlx::Error> {
670        let obj = values.as_object().ok_or_else(|| {
671            sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
672        })?;
673
674        let mut sql = String::with_capacity(256);
675        use std::fmt::Write;
676
677        let _ = write!(sql, "UPDATE {} SET ", T::table_name());
678
679        let mut i = 1;
680        let mut first = true;
681
682        for k in obj.keys() {
683            if !first {
684                sql.push_str(", ");
685            }
686            let p = DB::placeholder(i);
687            let _ = write!(sql, "{} = {}", DB::quote_identifier(k), p);
688            i += 1;
689            first = false;
690        }
691
692        let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
693        self.render_where_clause_into(&mut sql, &mut dummy_binds, obj.len() + 1);
694        Ok(sql)
695    }
696
697    /// Returns the DELETE (or soft delete) SQL that would be executed for this query.
698    pub fn to_delete_sql(&self) -> String {
699        let mut sql = String::with_capacity(128);
700        use std::fmt::Write;
701
702        if T::has_soft_delete() {
703            let _ = write!(
704                sql,
705                "UPDATE {} SET {} = {}",
706                T::table_name(),
707                DB::quote_identifier("deleted_at"),
708                DB::current_timestamp_fn()
709            );
710        } else {
711            let _ = write!(sql, "DELETE FROM {}", T::table_name());
712        }
713
714        let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
715        self.render_where_clause_into(&mut sql, &mut dummy_binds, 1);
716        sql
717    }
718
719    // Optimized version that writes to buffer
720    #[inline(always)]
721    fn render_where_clause_into(
722        &self,
723        sql: &mut String,
724        binds: &mut SmallVec<[BindValue; 8]>,
725        start_index: usize,
726    ) {
727        let mut idx = start_index;
728        let mut first_clause = true;
729        use std::fmt::Write;
730
731        // Helper to handle AND prefix
732        let mut append_and = |sql: &mut String| {
733            if first_clause {
734                sql.push_str(" WHERE ");
735                first_clause = false;
736            } else {
737                sql.push_str(" AND ");
738            }
739        };
740
741        for filter in &self.filters {
742            match filter {
743                FilterExpr::Raw(condition) => {
744                    append_and(sql);
745                    sql.push_str(condition);
746                }
747                FilterExpr::Compare { column, op, values } => {
748                    if op.is_in() {
749                        if values.is_empty() {
750                            append_and(sql);
751                            sql.push_str("1=0");
752                            continue;
753                        }
754                        append_and(sql);
755                        let _ = write!(sql, "{} IN (", DB::quote_identifier(column.as_str()));
756                        let placeholders = crate::cached_placeholders_from::<DB>(idx, values.len());
757                        sql.push_str(placeholders);
758                        sql.push(')');
759                        idx = idx.saturating_add(values.len());
760                        for v in values {
761                            binds.push(v.clone());
762                        }
763                    } else {
764                        append_and(sql);
765                        let _ = write!(
766                            sql,
767                            "{} {} {}",
768                            DB::quote_identifier(column.as_str()),
769                            op.as_str(),
770                            DB::placeholder(idx)
771                        );
772                        idx += 1;
773                        if let Some(v) = values.first() {
774                            binds.push(v.clone());
775                        }
776                    }
777                }
778                FilterExpr::NullCheck { column, is_null } => {
779                    append_and(sql);
780                    if *is_null {
781                        let _ = write!(sql, "{} IS NULL", DB::quote_identifier(column.as_str()));
782                    } else {
783                        let _ =
784                            write!(sql, "{} IS NOT NULL", DB::quote_identifier(column.as_str()));
785                    }
786                }
787            }
788        }
789
790        if T::has_soft_delete() && !self.include_deleted {
791            append_and(sql);
792            sql.push_str("deleted_at IS NULL");
793        }
794    }
795}
796
797impl<'a, T, DB> QueryBuilder<'a, T, DB>
798where
799    DB: SqlDialect,
800    T: Model<DB>,
801    for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
802    for<'c> &'c mut <DB as Database>::Connection: sqlx::Executor<'c, Database = DB>,
803    for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
804    DB::Connection: Send,
805    T: Send,
806    String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
807    i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
808    f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
809    bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
810    Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
811    uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
812    chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
813    chrono::NaiveDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
814    chrono::NaiveDate: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
815    sqlx::types::Json<serde_json::Value>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
816{
817    fn ensure_safe_filters(&self) -> Result<(), sqlx::Error> {
818        if self.unsafe_fast {
819            return Ok(());
820        }
821        if self.has_raw_filter && !self.allow_unsafe {
822            return Err(sqlx::Error::Protocol(
823                "Refusing raw filter without allow_unsafe".to_string(),
824            ));
825        }
826        Ok(())
827    }
828
829    /// Executes the query and returns a vector of results.
830    ///
831    /// This method will fetch all rows matching the criteria and then perform
832    /// eager loading for any included relations.
833    #[tracing::instrument(skip(self), fields(table = T::table_name()))]
834    pub async fn all(mut self) -> Result<Vec<T>, sqlx::Error> {
835        self.ensure_safe_filters()?;
836
837        let mut sql = String::with_capacity(128);
838        sql.push_str("SELECT * FROM ");
839        sql.push_str(T::table_name());
840
841        let mut where_binds: SmallVec<[BindValue; 8]> =
842            SmallVec::with_capacity(self.estimate_bind_count());
843        self.render_where_clause_into(&mut sql, &mut where_binds, 1);
844
845        if let Some(limit) = self.limit {
846            use std::fmt::Write;
847            let _ = write!(sql, " LIMIT {}", limit);
848        }
849
850        if let Some(offset) = self.offset {
851            use std::fmt::Write;
852            let _ = write!(sql, " OFFSET {}", offset);
853        }
854
855        #[cfg(debug_assertions)]
856        if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
857            let filters = self.format_filters_for_log();
858            tracing::debug!(
859                operation = "select",
860                sql = %sql,
861                filters = %filters,
862                "premix query"
863            );
864        }
865
866        let start = Instant::now();
867        let mut results: Vec<T> = match &mut self.executor {
868            Executor::Pool(pool) => {
869                let base = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
870                let query = where_binds.into_iter().fold(base, bind_value_query_as);
871                query.fetch_all(*pool).await?
872            }
873            Executor::Conn(conn) => {
874                let base = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
875                let query = where_binds.into_iter().fold(base, bind_value_query_as);
876                query.fetch_all(&mut **conn).await?
877            }
878        };
879        if !self.fast_path {
880            record_query_metrics("select", T::table_name(), start.elapsed());
881        }
882
883        if self.ultra_fast {
884            return Ok(results);
885        }
886
887        for relation in self.includes {
888            match &mut self.executor {
889                Executor::Pool(pool) => {
890                    T::eager_load(&mut results, &relation, Executor::Pool(*pool)).await?;
891                }
892                Executor::Conn(conn) => {
893                    T::eager_load(&mut results, &relation, Executor::Conn(&mut **conn)).await?;
894                }
895            }
896        }
897
898        Ok(results)
899    }
900
901    /// Executes the query and returns a stream of results.
902    ///
903    /// This is useful for processing large result sets without loading them all into memory.
904    #[tracing::instrument(skip(self), fields(table = T::table_name()))]
905    pub fn stream(
906        self,
907    ) -> Result<futures_util::stream::BoxStream<'a, Result<T, sqlx::Error>>, sqlx::Error>
908    where
909        T: 'a,
910    {
911        self.ensure_safe_filters()?;
912
913        let mut sql = String::with_capacity(128);
914        sql.push_str("SELECT * FROM ");
915        sql.push_str(T::table_name());
916
917        let mut where_binds: SmallVec<[BindValue; 8]> =
918            SmallVec::with_capacity(self.estimate_bind_count());
919        self.render_where_clause_into(&mut sql, &mut where_binds, 1);
920
921        if let Some(limit) = self.limit {
922            use std::fmt::Write;
923            let _ = write!(sql, " LIMIT {}", limit);
924        }
925
926        if let Some(offset) = self.offset {
927            use std::fmt::Write;
928            let _ = write!(sql, " OFFSET {}", offset);
929        }
930
931        #[cfg(debug_assertions)]
932        if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
933            let filters = self.format_filters_for_log();
934            tracing::debug!(
935                operation = "stream",
936                sql = %sql,
937                filters = %filters,
938                "premix query"
939            );
940        }
941
942        let executor = self.executor;
943        Ok(Box::pin(async_stream::try_stream! {
944            let mut query = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
945            for bind in where_binds {
946                query = bind_value_query_as(query, bind);
947            }
948            let mut s = executor.fetch_stream(query);
949            while let Some(row) = s.next().await {
950                yield row?;
951            }
952        }))
953    }
954
955    /// Executes a bulk update based on the current filters.
956    ///
957    /// # Errors
958    /// Returns an error if no filters are provided (unless `allow_unsafe` is used),
959    /// or if the values cannot be mapped to the database.
960    #[inline(never)]
961    #[tracing::instrument(skip(self, values), fields(table = T::table_name()))]
962    pub async fn update(mut self, values: serde_json::Value) -> Result<u64, sqlx::Error>
963    where
964        String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
965        i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
966        f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
967        bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
968        Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
969        uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
970        chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
971        chrono::NaiveDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
972        chrono::NaiveDate: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
973        sqlx::types::Json<serde_json::Value>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
974    {
975        self.ensure_safe_filters()?;
976        if self.filters.is_empty() && !self.allow_unsafe && !self.unsafe_fast {
977            return Err(sqlx::Error::Protocol(
978                "Refusing bulk update without filters".to_string(),
979            ));
980        }
981        let obj = values.as_object().ok_or_else(|| {
982            sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
983        })?;
984
985        let mut sql = String::with_capacity(256);
986        use std::fmt::Write;
987        let _ = write!(sql, "UPDATE {} SET ", T::table_name());
988
989        let mut i = 1;
990        let mut first = true;
991        for k in obj.keys() {
992            if !first {
993                sql.push_str(", ");
994            }
995            let p = DB::placeholder(i);
996            let _ = write!(sql, "{} = {}", DB::quote_identifier(k), p);
997            i += 1;
998            first = false;
999        }
1000
1001        let mut where_binds: SmallVec<[BindValue; 8]> =
1002            SmallVec::with_capacity(self.estimate_bind_count());
1003        self.render_where_clause_into(&mut sql, &mut where_binds, obj.len() + 1);
1004
1005        if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
1006            let filters = self.format_filters_for_log();
1007            tracing::debug!(
1008                operation = "bulk_update",
1009                sql = %sql,
1010                filters = %filters,
1011                "premix query"
1012            );
1013        }
1014        let mut query = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1015        for val in obj.values() {
1016            match val {
1017                serde_json::Value::String(s) => query = query.bind(s.clone()),
1018                serde_json::Value::Number(n) => {
1019                    if let Some(v) = n.as_i64() {
1020                        query = query.bind(v);
1021                    } else if let Some(v) = n.as_f64() {
1022                        query = query.bind(v);
1023                    }
1024                }
1025                serde_json::Value::Bool(b) => query = query.bind(*b),
1026                serde_json::Value::Null => query = query.bind(Option::<String>::None),
1027                serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
1028                    query = query.bind(sqlx::types::Json(val.clone()));
1029                }
1030            }
1031        }
1032        for bind in where_binds {
1033            query = bind_value_query(query, bind);
1034        }
1035
1036        let start = Instant::now();
1037        let result = match &mut self.executor {
1038            Executor::Pool(pool) => {
1039                let res = query.execute(*pool).await?;
1040                Ok(DB::rows_affected(&res))
1041            }
1042            Executor::Conn(conn) => {
1043                let res = query.execute(&mut **conn).await?;
1044                Ok(DB::rows_affected(&res))
1045            }
1046        };
1047        if !self.fast_path {
1048            record_query_metrics("bulk_update", T::table_name(), start.elapsed());
1049        }
1050        result
1051    }
1052
1053    /// Executes a bulk update based on the current filters. Alias for [`update`].
1054    pub async fn update_all(self, values: serde_json::Value) -> Result<u64, sqlx::Error>
1055    where
1056        String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1057        i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1058        f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1059        bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1060        Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1061        uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1062        chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1063        chrono::NaiveDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1064        chrono::NaiveDate: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1065        sqlx::types::Json<serde_json::Value>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1066    {
1067        self.update(values).await
1068    }
1069
1070    /// Executes a bulk delete based on the current filters.
1071    #[tracing::instrument(skip(self), fields(table = T::table_name()))]
1072    pub async fn delete(mut self) -> Result<u64, sqlx::Error> {
1073        self.ensure_safe_filters()?;
1074        if self.filters.is_empty() && !self.allow_unsafe && !self.unsafe_fast {
1075            return Err(sqlx::Error::Protocol(
1076                "Refusing bulk delete without filters".to_string(),
1077            ));
1078        }
1079
1080        let mut sql = String::with_capacity(128);
1081        use std::fmt::Write;
1082
1083        if T::has_soft_delete() {
1084            let _ = write!(
1085                sql,
1086                "UPDATE {} SET {} = {}",
1087                T::table_name(),
1088                DB::quote_identifier("deleted_at"),
1089                DB::current_timestamp_fn()
1090            );
1091        } else {
1092            let _ = write!(sql, "DELETE FROM {}", T::table_name());
1093        }
1094
1095        let mut where_binds: SmallVec<[BindValue; 8]> =
1096            SmallVec::with_capacity(self.estimate_bind_count());
1097        self.render_where_clause_into(&mut sql, &mut where_binds, 1);
1098
1099        if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
1100            let filters = self.format_filters_for_log();
1101            tracing::debug!(
1102                operation = "bulk_delete",
1103                sql = %sql,
1104                filters = %filters,
1105                "premix query"
1106            );
1107        }
1108        let start = Instant::now();
1109        let result = match &mut self.executor {
1110            Executor::Pool(pool) => {
1111                let base = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1112                let query = where_binds.into_iter().fold(base, bind_value_query);
1113                let res = query.execute(*pool).await?;
1114                Ok(DB::rows_affected(&res))
1115            }
1116            Executor::Conn(conn) => {
1117                let base = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1118                let query = where_binds.into_iter().fold(base, bind_value_query);
1119                let res = query.execute(&mut **conn).await?;
1120                Ok(DB::rows_affected(&res))
1121            }
1122        };
1123        if !self.fast_path {
1124            record_query_metrics("bulk_delete", T::table_name(), start.elapsed());
1125        }
1126        result
1127    }
1128
1129    /// Executes a bulk delete based on the current filters. Alias for [`delete`].
1130    pub async fn delete_all(self) -> Result<u64, sqlx::Error> {
1131        self.delete().await
1132    }
1133}
1134
1135#[cfg(test)]
1136mod tests {
1137    use super::*;
1138    use sqlx::Sqlite;
1139
1140    struct DummyModel;
1141
1142    impl Model<Sqlite> for DummyModel {
1143        fn table_name() -> &'static str {
1144            "users"
1145        }
1146        fn create_table_sql() -> String {
1147            String::new()
1148        }
1149        fn list_columns() -> Vec<String> {
1150            vec!["id".to_string()]
1151        }
1152        async fn save<'a, E>(&'a mut self, _e: E) -> Result<(), sqlx::Error>
1153        where
1154            E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1155        {
1156            Ok(())
1157        }
1158        async fn update<'a, E>(
1159            &'a mut self,
1160            _e: E,
1161        ) -> Result<crate::model::UpdateResult, sqlx::Error>
1162        where
1163            E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1164        {
1165            Ok(crate::model::UpdateResult::Success)
1166        }
1167        async fn delete<'a, E>(&'a mut self, _e: E) -> Result<(), sqlx::Error>
1168        where
1169            E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1170        {
1171            Ok(())
1172        }
1173        fn has_soft_delete() -> bool {
1174            false
1175        }
1176        async fn find_by_id<'a, E>(_e: E, _id: i32) -> Result<Option<Self>, sqlx::Error>
1177        where
1178            E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1179        {
1180            Ok(None)
1181        }
1182    }
1183
1184    // Dummy FromRow implementation for Sqlite
1185    impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for DummyModel {
1186        fn from_row(_row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
1187            Ok(DummyModel)
1188        }
1189    }
1190
1191    #[tokio::test]
1192    async fn test_sql_injection_mitigation() {
1193        let pool = sqlx::Pool::<Sqlite>::connect_lazy("sqlite::memory:").unwrap();
1194        let qb = DummyModel::find_in_pool(&pool);
1195
1196        // Malicious column name
1197        let qb = qb.filter_eq("id; DROP TABLE users; --", 1);
1198        let sql = qb.to_sql();
1199        println!("SQL select: {}", sql);
1200
1201        // The column name should be quoted
1202        assert!(sql.contains("`id; DROP TABLE users; --` = ?"));
1203        assert!(sql.contains("SELECT * FROM users WHERE"));
1204    }
1205
1206    #[tokio::test]
1207    async fn test_to_update_sql_quoting() {
1208        let pool = sqlx::Pool::<Sqlite>::connect_lazy("sqlite::memory:").unwrap();
1209        let qb = DummyModel::find_in_pool(&pool).filter_eq("id", 1);
1210
1211        let values = serde_json::json!({
1212            "name; DROP TABLE users; --": "admin"
1213        });
1214
1215        let sql = qb.to_update_sql(&values).unwrap();
1216        println!("SQL update: {}", sql);
1217        assert!(sql.contains("`name; DROP TABLE users; --` = ?"));
1218    }
1219
1220    #[tokio::test]
1221    async fn test_stream_api() {
1222        use sqlx::Connection;
1223        let mut conn = sqlx::SqliteConnection::connect("sqlite::memory:")
1224            .await
1225            .unwrap();
1226        sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
1227            .execute(&mut conn)
1228            .await
1229            .unwrap();
1230        sqlx::query("INSERT INTO users (id, name) VALUES (1, 'alice'), (2, 'bob')")
1231            .execute(&mut conn)
1232            .await
1233            .unwrap();
1234
1235        // Use find_in_tx which is standard in Premix
1236        let qb = DummyModel::find_in_tx(&mut conn);
1237
1238        let mut stream = qb.stream().unwrap();
1239        let mut count = 0;
1240        while let Some(row) = stream.next().await {
1241            let _: DummyModel = row.unwrap();
1242            count += 1;
1243        }
1244        assert_eq!(count, 2);
1245    }
1246}