Skip to main content

premix_core/
query.rs

1use crate::dialect::SqlDialect;
2use crate::executor::Executor;
3use crate::model::Model;
4use futures_util::StreamExt;
5use smallvec::{SmallVec, smallvec};
6use sqlx::{Database, IntoArguments};
7use std::time::{Duration, Instant};
8
9#[doc(hidden)]
10#[derive(Debug, Clone)]
11pub enum BindValue {
12    String(String),
13    I64(i64),
14    F64(f64),
15    Bool(bool),
16    Uuid(uuid::Uuid),
17    DateTime(chrono::DateTime<chrono::Utc>),
18    NaiveDateTime(chrono::NaiveDateTime),
19    NaiveDate(chrono::NaiveDate),
20    Json(serde_json::Value),
21    Null,
22}
23
24impl BindValue {
25    fn to_log_string(&self) -> String {
26        match self {
27            BindValue::String(v) => v.clone(),
28            BindValue::I64(v) => v.to_string(),
29            BindValue::F64(v) => v.to_string(),
30            BindValue::Bool(v) => v.to_string(),
31            BindValue::Uuid(v) => v.to_string(),
32            BindValue::DateTime(v) => v.to_rfc3339(),
33            BindValue::NaiveDateTime(v) => v.to_string(),
34            BindValue::NaiveDate(v) => v.to_string(),
35            BindValue::Json(v) => v.to_string(),
36            BindValue::Null => "NULL".to_string(),
37        }
38    }
39}
40
41#[cfg(feature = "metrics")]
42fn record_query_metrics(operation: &str, table: &str, elapsed: Duration) {
43    let elapsed_ms = elapsed.as_secs_f64() * 1000.0;
44    let labels = [
45        ("operation", operation.to_string()),
46        ("table", table.to_string()),
47    ];
48    metrics::histogram!("premix.query.duration_ms", &labels).record(elapsed_ms);
49    metrics::counter!("premix.query.count", &labels).increment(1);
50}
51
52#[cfg(not(feature = "metrics"))]
53fn record_query_metrics(_operation: &str, _table: &str, _elapsed: Duration) {}
54
55#[inline(always)]
56fn bind_value_query<'q, DB>(
57    query: sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>,
58    value: BindValue,
59) -> sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>
60where
61    DB: Database,
62    String: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
63    i64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
64    f64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
65    bool: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
66    uuid::Uuid: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
67    chrono::DateTime<chrono::Utc>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
68    chrono::NaiveDateTime: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
69    chrono::NaiveDate: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
70    sqlx::types::Json<serde_json::Value>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
71    Option<String>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
72{
73    match value {
74        BindValue::String(v) => query.bind(v),
75        BindValue::I64(v) => query.bind(v),
76        BindValue::F64(v) => query.bind(v),
77        BindValue::Bool(v) => query.bind(v),
78        BindValue::Uuid(v) => query.bind(v),
79        BindValue::DateTime(v) => query.bind(v),
80        BindValue::NaiveDateTime(v) => query.bind(v),
81        BindValue::NaiveDate(v) => query.bind(v),
82        BindValue::Json(v) => query.bind(sqlx::types::Json(v)),
83        BindValue::Null => query.bind(Option::<String>::None),
84    }
85}
86
87#[inline(always)]
88fn bind_value_query_as<'q, DB, T>(
89    query: sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>,
90    value: BindValue,
91) -> sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>
92where
93    DB: Database,
94    String: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
95    i64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
96    f64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
97    bool: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
98    uuid::Uuid: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
99    chrono::DateTime<chrono::Utc>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
100    chrono::NaiveDateTime: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
101    chrono::NaiveDate: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
102    sqlx::types::Json<serde_json::Value>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
103    Option<String>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
104{
105    match value {
106        BindValue::String(v) => query.bind(v),
107        BindValue::I64(v) => query.bind(v),
108        BindValue::F64(v) => query.bind(v),
109        BindValue::Bool(v) => query.bind(v),
110        BindValue::Uuid(v) => query.bind(v),
111        BindValue::DateTime(v) => query.bind(v),
112        BindValue::NaiveDateTime(v) => query.bind(v),
113        BindValue::NaiveDate(v) => query.bind(v),
114        BindValue::Json(v) => query.bind(sqlx::types::Json(v)),
115        BindValue::Null => query.bind(Option::<String>::None),
116    }
117}
118
119#[inline]
120fn apply_persistent_query_as<'q, DB, T>(
121    query: sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>,
122    prepared: bool,
123) -> sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>
124where
125    DB: Database + sqlx::database::HasStatementCache,
126{
127    if prepared {
128        query.persistent(true)
129    } else {
130        query
131    }
132}
133
134#[inline]
135fn apply_persistent_query<'q, DB>(
136    query: sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>,
137    prepared: bool,
138) -> sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>
139where
140    DB: Database + sqlx::database::HasStatementCache,
141{
142    if prepared {
143        query.persistent(true)
144    } else {
145        query
146    }
147}
148
149impl From<String> for BindValue {
150    fn from(value: String) -> Self {
151        Self::String(value)
152    }
153}
154
155impl From<&str> for BindValue {
156    fn from(value: &str) -> Self {
157        Self::String(value.to_string())
158    }
159}
160
161impl From<i32> for BindValue {
162    fn from(value: i32) -> Self {
163        Self::I64(value as i64)
164    }
165}
166
167impl From<i64> for BindValue {
168    fn from(value: i64) -> Self {
169        Self::I64(value)
170    }
171}
172
173impl From<f64> for BindValue {
174    fn from(value: f64) -> Self {
175        Self::F64(value)
176    }
177}
178
179impl From<bool> for BindValue {
180    fn from(value: bool) -> Self {
181        Self::Bool(value)
182    }
183}
184
185impl From<Option<String>> for BindValue {
186    fn from(value: Option<String>) -> Self {
187        match value {
188            Some(v) => Self::String(v),
189            None => Self::Null,
190        }
191    }
192}
193
194impl From<uuid::Uuid> for BindValue {
195    fn from(value: uuid::Uuid) -> Self {
196        Self::Uuid(value)
197    }
198}
199
200impl From<chrono::DateTime<chrono::Utc>> for BindValue {
201    fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
202        Self::DateTime(value)
203    }
204}
205
206impl From<chrono::NaiveDateTime> for BindValue {
207    fn from(value: chrono::NaiveDateTime) -> Self {
208        Self::NaiveDateTime(value)
209    }
210}
211
212impl From<chrono::NaiveDate> for BindValue {
213    fn from(value: chrono::NaiveDate) -> Self {
214        Self::NaiveDate(value)
215    }
216}
217
218impl From<serde_json::Value> for BindValue {
219    fn from(value: serde_json::Value) -> Self {
220        Self::Json(value)
221    }
222}
223
224#[derive(Debug, Clone)]
225pub(crate) enum FilterExpr {
226    Raw(String),
227    Compare {
228        column: ColumnRef,
229        op: FilterOp,
230        values: SmallVec<[BindValue; 2]>,
231    },
232    NullCheck {
233        column: ColumnRef,
234        is_null: bool,
235    },
236}
237
238#[derive(Debug, Clone, Copy, PartialEq, Eq)]
239pub(crate) enum FilterOp {
240    Eq,
241    Ne,
242    Lt,
243    Lte,
244    Gt,
245    Gte,
246    Like,
247    In,
248}
249
250impl FilterOp {
251    fn as_str(self) -> &'static str {
252        match self {
253            FilterOp::Eq => "=",
254            FilterOp::Ne => "!=",
255            FilterOp::Lt => "<",
256            FilterOp::Lte => "<=",
257            FilterOp::Gt => ">",
258            FilterOp::Gte => ">=",
259            FilterOp::Like => "LIKE",
260            FilterOp::In => "IN",
261        }
262    }
263
264    fn is_in(self) -> bool {
265        matches!(self, FilterOp::In)
266    }
267}
268
269/// Column reference used in filters (static literals or owned names).
270#[derive(Debug, Clone, PartialEq, Eq)]
271pub enum ColumnRef {
272    /// Static column name known at compile time.
273    Static(&'static str),
274    /// Owned column name created at runtime.
275    Owned(String),
276}
277
278impl ColumnRef {
279    /// Builds a static column reference without allocation.
280    pub const fn static_str(value: &'static str) -> Self {
281        ColumnRef::Static(value)
282    }
283
284    fn as_str(&self) -> &str {
285        match self {
286            ColumnRef::Static(value) => value,
287            ColumnRef::Owned(value) => value,
288        }
289    }
290}
291
292impl From<&str> for ColumnRef {
293    fn from(value: &str) -> Self {
294        ColumnRef::Owned(value.to_string())
295    }
296}
297
298impl From<String> for ColumnRef {
299    fn from(value: String) -> Self {
300        ColumnRef::Owned(value)
301    }
302}
303
304impl From<&String> for ColumnRef {
305    fn from(value: &String) -> Self {
306        ColumnRef::Owned(value.clone())
307    }
308}
309
310/// A type-safe SQL query builder.
311///
312/// `QueryBuilder` provides a fluent interface for building SELECT, UPDATE, and DELETE queries
313/// with support for filtering, pagination, eager loading, and soft deletes.
314pub struct QueryBuilder<'a, T, DB: Database> {
315    executor: Executor<'a, DB>,
316    filters: Vec<FilterExpr>,
317    limit: Option<i32>,
318    offset: Option<i32>,
319    includes: SmallVec<[String; 2]>,
320    include_deleted: bool,
321    allow_unsafe: bool,
322    has_raw_filter: bool,
323    fast_path: bool,
324    unsafe_fast: bool,
325    ultra_fast: bool,
326    prepared: bool,
327    _marker: std::marker::PhantomData<T>,
328}
329
330impl<'a, T, DB: Database> std::fmt::Debug for QueryBuilder<'a, T, DB> {
331    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332        f.debug_struct("QueryBuilder")
333            .field("filters", &self.filters)
334            .field("limit", &self.limit)
335            .field("offset", &self.offset)
336            .field("includes", &self.includes)
337            .field("include_deleted", &self.include_deleted)
338            .field("allow_unsafe", &self.allow_unsafe)
339            .field("fast_path", &self.fast_path)
340            .field("unsafe_fast", &self.unsafe_fast)
341            .field("ultra_fast", &self.ultra_fast)
342            .field("prepared", &self.prepared)
343            .finish()
344    }
345}
346
347impl<'a, T, DB> QueryBuilder<'a, T, DB>
348where
349    DB: SqlDialect + sqlx::database::HasStatementCache,
350    T: Model<DB>,
351{
352    /// Creates a new `QueryBuilder` using the provided [`Executor`].
353    pub fn new(executor: Executor<'a, DB>) -> Self {
354        let default_includes = T::default_includes();
355        let mut includes = SmallVec::with_capacity(default_includes.len().max(2));
356        for relation in default_includes {
357            includes.push((*relation).to_string());
358        }
359        Self {
360            executor,
361            filters: Vec::with_capacity(4), // Pre-allocate for typical queries (1-4 filters)
362            limit: None,
363            offset: None,
364            includes, // Include eager defaults
365            include_deleted: false,
366            allow_unsafe: false,
367            has_raw_filter: false,
368            fast_path: false,
369            unsafe_fast: false,
370            ultra_fast: false,
371            prepared: true,
372            _marker: std::marker::PhantomData,
373        }
374    }
375
376    /// Adds a raw SQL filter condition to the query.
377    ///
378    /// # Safety
379    /// This method is potentially unsafe and requires calling [`allow_unsafe`] for the query to execute.
380    pub fn filter(mut self, condition: impl Into<String>) -> Self {
381        self.filters.push(FilterExpr::Raw(condition.into()));
382        self.has_raw_filter = true;
383        self
384    }
385
386    /// Adds a raw SQL filter condition to the query.
387    pub fn filter_raw(self, condition: impl Into<String>) -> Self {
388        self.filter(condition)
389    }
390
391    /// Adds an equality filter (`column = value`).
392    pub fn filter_eq(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
393        self.filters.push(FilterExpr::Compare {
394            column: column.into(),
395            op: FilterOp::Eq,
396            values: smallvec![value.into()],
397        });
398        self
399    }
400
401    /// Adds a not-equal filter (`column != value`).
402    pub fn filter_ne(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
403        self.filters.push(FilterExpr::Compare {
404            column: column.into(),
405            op: FilterOp::Ne,
406            values: smallvec![value.into()],
407        });
408        self
409    }
410
411    /// Adds a less-than filter (`column < value`).
412    pub fn filter_lt(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
413        self.filters.push(FilterExpr::Compare {
414            column: column.into(),
415            op: FilterOp::Lt,
416            values: smallvec![value.into()],
417        });
418        self
419    }
420
421    /// Adds a less-than-or-equal filter (`column <= value`).
422    pub fn filter_lte(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
423        self.filters.push(FilterExpr::Compare {
424            column: column.into(),
425            op: FilterOp::Lte,
426            values: smallvec![value.into()],
427        });
428        self
429    }
430
431    /// Adds a greater-than filter (`column > value`).
432    pub fn filter_gt(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
433        self.filters.push(FilterExpr::Compare {
434            column: column.into(),
435            op: FilterOp::Gt,
436            values: smallvec![value.into()],
437        });
438        self
439    }
440
441    /// Adds a greater-than-or-equal filter (`column >= value`).
442    pub fn filter_gte(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
443        self.filters.push(FilterExpr::Compare {
444            column: column.into(),
445            op: FilterOp::Gte,
446            values: smallvec![value.into()],
447        });
448        self
449    }
450
451    /// Adds a LIKE filter (`column LIKE value`).
452    pub fn filter_like(
453        mut self,
454        column: impl Into<ColumnRef>,
455        value: impl Into<BindValue>,
456    ) -> Self {
457        self.filters.push(FilterExpr::Compare {
458            column: column.into(),
459            op: FilterOp::Like,
460            values: smallvec![value.into()],
461        });
462        self
463    }
464
465    /// Filters rows where the column IS NULL.
466    pub fn filter_is_null(mut self, column: impl Into<ColumnRef>) -> Self {
467        self.filters.push(FilterExpr::NullCheck {
468            column: column.into(),
469            is_null: true,
470        });
471        self
472    }
473
474    /// Filters rows where the column IS NOT NULL.
475    pub fn filter_is_not_null(mut self, column: impl Into<ColumnRef>) -> Self {
476        self.filters.push(FilterExpr::NullCheck {
477            column: column.into(),
478            is_null: false,
479        });
480        self
481    }
482
483    /// Adds an IN filter (`column IN (values...)`).
484    pub fn filter_in<I, V>(mut self, column: impl Into<ColumnRef>, values: I) -> Self
485    where
486        I: IntoIterator<Item = V>,
487        V: Into<BindValue>,
488    {
489        let values: SmallVec<[BindValue; 2]> = values.into_iter().map(Into::into).collect();
490        self.filters.push(FilterExpr::Compare {
491            column: column.into(),
492            op: FilterOp::In,
493            values,
494        });
495        self
496    }
497
498    fn format_filters_for_log(&self) -> String {
499        let sensitive_fields = T::sensitive_fields();
500        let mut rendered = String::with_capacity(128);
501        let mut first_clause = true;
502        let mut append_and = |buf: &mut String| {
503            if first_clause {
504                first_clause = false;
505            } else {
506                buf.push_str(" AND ");
507            }
508        };
509        use std::fmt::Write;
510
511        for filter in &self.filters {
512            match filter {
513                FilterExpr::Raw(_) => {
514                    append_and(&mut rendered);
515                    rendered.push_str("RAW(<redacted>)");
516                }
517                FilterExpr::Compare { column, op, values } => {
518                    let column_name = column.as_str();
519                    let is_sensitive = sensitive_fields.contains(&column_name);
520                    if op.is_in() {
521                        if values.is_empty() {
522                            append_and(&mut rendered);
523                            rendered.push_str("1=0");
524                            continue;
525                        }
526                        append_and(&mut rendered);
527                        let _ = write!(rendered, "{} IN (", column_name);
528                        for (idx, value) in values.iter().enumerate() {
529                            if idx > 0 {
530                                rendered.push_str(", ");
531                            }
532                            if is_sensitive {
533                                rendered.push_str("***");
534                            } else {
535                                rendered.push_str(&value.to_log_string());
536                            }
537                        }
538                        rendered.push(')');
539                    } else {
540                        append_and(&mut rendered);
541                        let _ = write!(rendered, "{} {} ", column_name, op.as_str());
542                        if is_sensitive {
543                            rendered.push_str("***");
544                        } else if let Some(value) = values.first() {
545                            rendered.push_str(&value.to_log_string());
546                        } else {
547                            rendered.push_str("NULL");
548                        }
549                    }
550                }
551                FilterExpr::NullCheck { column, is_null } => {
552                    append_and(&mut rendered);
553                    if *is_null {
554                        let _ = write!(rendered, "{} IS NULL", column.as_str());
555                    } else {
556                        let _ = write!(rendered, "{} IS NOT NULL", column.as_str());
557                    }
558                }
559            }
560        }
561
562        if T::has_soft_delete() && !self.include_deleted {
563            append_and(&mut rendered);
564            rendered.push_str("deleted_at IS NULL");
565        }
566
567        rendered
568    }
569
570    fn estimate_bind_count(&self) -> usize {
571        let mut count = 0usize;
572        for filter in &self.filters {
573            if let FilterExpr::Compare { op, values, .. } = filter {
574                if op.is_in() {
575                    count = count.saturating_add(values.len());
576                } else {
577                    count = count.saturating_add(1);
578                }
579            }
580        }
581        count
582    }
583
584    /// Limits the number of rows returned by the query.
585    /// Sets the maximum number of rows to return.
586    pub fn limit(mut self, limit: i32) -> Self {
587        self.limit = Some(limit);
588        self
589    }
590
591    /// Skips the specified number of rows.
592    /// Sets the number of rows to skip.
593    pub fn offset(mut self, offset: i32) -> Self {
594        self.offset = Some(offset);
595        self
596    }
597
598    /// Eager loads a related model.
599    pub fn include(mut self, relation: impl Into<String>) -> Self {
600        self.includes.push(relation.into());
601        self
602    }
603
604    /// Includes soft-deleted records in the results.
605    pub fn with_deleted(mut self) -> Self {
606        self.include_deleted = true;
607        self
608    }
609
610    /// Explicitly allows potentially unsafe raw filters.
611    /// Enables execution of queries with raw SQL filters.
612    pub fn allow_unsafe(mut self) -> Self {
613        self.allow_unsafe = true;
614        self
615    }
616
617    /// Enables a fast path that skips logging and metrics for hot queries.
618    pub fn fast(mut self) -> Self {
619        self.fast_path = true;
620        self
621    }
622
623    /// Enables an unsafe fast path that skips logging, metrics, and safety guards.
624    pub fn unsafe_fast(mut self) -> Self {
625        self.fast_path = true;
626        self.unsafe_fast = true;
627        self
628    }
629
630    /// Enables the ultra-fast path: skips logging, metrics, safety guards, and eager loading.
631    /// Note: Any configured includes will be ignored.
632    pub fn ultra_fast(mut self) -> Self {
633        self.fast_path = true;
634        self.unsafe_fast = true;
635        self.ultra_fast = true;
636        self
637    }
638
639    /// Enable prepared statement caching for this query (default: enabled).
640    pub fn prepared(mut self) -> Self {
641        self.prepared = true;
642        self
643    }
644
645    /// Disable prepared statement caching for this query.
646    pub fn unprepared(mut self) -> Self {
647        self.prepared = false;
648        self
649    }
650
651    /// Returns the SELECT SQL that would be executed for this query.
652    pub fn to_sql(&self) -> String {
653        let mut sql = String::with_capacity(128); // Pre-allocate reasonable size
654        use std::fmt::Write;
655
656        sql.push_str("SELECT * FROM ");
657        sql.push_str(T::table_name());
658
659        let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
660        self.render_where_clause_into(&mut sql, &mut dummy_binds, 1);
661
662        if let Some(limit) = self.limit {
663            let _ = write!(sql, " LIMIT {}", limit);
664        }
665
666        if let Some(offset) = self.offset {
667            let _ = write!(sql, " OFFSET {}", offset);
668        }
669
670        sql
671    }
672
673    /// Returns the UPDATE SQL that would be executed for this query.
674    pub fn to_update_sql(&self, values: &serde_json::Value) -> Result<String, sqlx::Error> {
675        let obj = values.as_object().ok_or_else(|| {
676            sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
677        })?;
678
679        let mut sql = String::with_capacity(256);
680        use std::fmt::Write;
681
682        let _ = write!(sql, "UPDATE {} SET ", T::table_name());
683
684        let mut i = 1;
685        let mut first = true;
686
687        for k in obj.keys() {
688            if !first {
689                sql.push_str(", ");
690            }
691            let p = DB::placeholder(i);
692            let _ = write!(sql, "{} = {}", DB::quote_identifier(k), p);
693            i += 1;
694            first = false;
695        }
696
697        let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
698        self.render_where_clause_into(&mut sql, &mut dummy_binds, obj.len() + 1);
699        Ok(sql)
700    }
701
702    /// Returns the DELETE (or soft delete) SQL that would be executed for this query.
703    pub fn to_delete_sql(&self) -> String {
704        let mut sql = String::with_capacity(128);
705        use std::fmt::Write;
706
707        if T::has_soft_delete() {
708            let _ = write!(
709                sql,
710                "UPDATE {} SET {} = {}",
711                T::table_name(),
712                DB::quote_identifier("deleted_at"),
713                DB::current_timestamp_fn()
714            );
715        } else {
716            let _ = write!(sql, "DELETE FROM {}", T::table_name());
717        }
718
719        let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
720        self.render_where_clause_into(&mut sql, &mut dummy_binds, 1);
721        sql
722    }
723
724    // Optimized version that writes to buffer
725    #[inline(always)]
726    fn render_where_clause_into(
727        &self,
728        sql: &mut String,
729        binds: &mut SmallVec<[BindValue; 8]>,
730        start_index: usize,
731    ) {
732        let mut idx = start_index;
733        let mut first_clause = true;
734        use std::fmt::Write;
735
736        // Helper to handle AND prefix
737        let mut append_and = |sql: &mut String| {
738            if first_clause {
739                sql.push_str(" WHERE ");
740                first_clause = false;
741            } else {
742                sql.push_str(" AND ");
743            }
744        };
745
746        for filter in &self.filters {
747            match filter {
748                FilterExpr::Raw(condition) => {
749                    append_and(sql);
750                    sql.push_str(condition);
751                }
752                FilterExpr::Compare { column, op, values } => {
753                    if op.is_in() {
754                        if values.is_empty() {
755                            append_and(sql);
756                            sql.push_str("1=0");
757                            continue;
758                        }
759                        append_and(sql);
760                        let _ = write!(sql, "{} IN (", DB::quote_identifier(column.as_str()));
761                        let placeholders = crate::cached_placeholders_from::<DB>(idx, values.len());
762                        sql.push_str(placeholders);
763                        sql.push(')');
764                        idx = idx.saturating_add(values.len());
765                        for v in values {
766                            binds.push(v.clone());
767                        }
768                    } else {
769                        append_and(sql);
770                        let _ = write!(
771                            sql,
772                            "{} {} {}",
773                            DB::quote_identifier(column.as_str()),
774                            op.as_str(),
775                            DB::placeholder(idx)
776                        );
777                        idx += 1;
778                        if let Some(v) = values.first() {
779                            binds.push(v.clone());
780                        }
781                    }
782                }
783                FilterExpr::NullCheck { column, is_null } => {
784                    append_and(sql);
785                    if *is_null {
786                        let _ = write!(sql, "{} IS NULL", DB::quote_identifier(column.as_str()));
787                    } else {
788                        let _ =
789                            write!(sql, "{} IS NOT NULL", DB::quote_identifier(column.as_str()));
790                    }
791                }
792            }
793        }
794
795        if T::has_soft_delete() && !self.include_deleted {
796            append_and(sql);
797            sql.push_str("deleted_at IS NULL");
798        }
799    }
800}
801
802impl<'a, T, DB> QueryBuilder<'a, T, DB>
803where
804    DB: SqlDialect,
805    T: Model<DB>,
806    for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
807    for<'c> &'c mut <DB as Database>::Connection: sqlx::Executor<'c, Database = DB>,
808    for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
809    DB::Connection: Send,
810    T: Send,
811    String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
812    i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
813    f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
814    bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
815    Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
816    uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
817    chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
818    chrono::NaiveDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
819    chrono::NaiveDate: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
820    sqlx::types::Json<serde_json::Value>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
821{
822    fn ensure_safe_filters(&self) -> Result<(), sqlx::Error> {
823        if self.unsafe_fast {
824            return Ok(());
825        }
826        if self.has_raw_filter && !self.allow_unsafe {
827            return Err(sqlx::Error::Protocol(
828                "Refusing raw filter without allow_unsafe".to_string(),
829            ));
830        }
831        Ok(())
832    }
833
834    /// Executes the query and returns a vector of results.
835    ///
836    /// This method will fetch all rows matching the criteria and then perform
837    /// eager loading for any included relations.
838    #[tracing::instrument(skip(self), fields(table = T::table_name()))]
839    pub async fn all(mut self) -> Result<Vec<T>, sqlx::Error> {
840        self.ensure_safe_filters()?;
841
842        let mut sql = String::with_capacity(128);
843        sql.push_str("SELECT * FROM ");
844        sql.push_str(T::table_name());
845
846        let mut where_binds: SmallVec<[BindValue; 8]> =
847            SmallVec::with_capacity(self.estimate_bind_count());
848        self.render_where_clause_into(&mut sql, &mut where_binds, 1);
849
850        if let Some(limit) = self.limit {
851            use std::fmt::Write;
852            let _ = write!(sql, " LIMIT {}", limit);
853        }
854
855        if let Some(offset) = self.offset {
856            use std::fmt::Write;
857            let _ = write!(sql, " OFFSET {}", offset);
858        }
859
860        #[cfg(debug_assertions)]
861        if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
862            let filters = self.format_filters_for_log();
863            tracing::debug!(
864                operation = "select",
865                sql = %sql,
866                filters = %filters,
867                "premix query"
868            );
869        }
870
871        let start = Instant::now();
872        let mut results: Vec<T> = match &mut self.executor {
873            Executor::Pool(pool) => {
874                let base = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
875                let query = where_binds.into_iter().fold(base, bind_value_query_as);
876                query.fetch_all(*pool).await?
877            }
878            Executor::Conn(conn) => {
879                let base = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
880                let query = where_binds.into_iter().fold(base, bind_value_query_as);
881                query.fetch_all(&mut **conn).await?
882            }
883        };
884        if !self.fast_path {
885            record_query_metrics("select", T::table_name(), start.elapsed());
886        }
887
888        if self.ultra_fast {
889            return Ok(results);
890        }
891
892        for relation in self.includes {
893            match &mut self.executor {
894                Executor::Pool(pool) => {
895                    T::eager_load(&mut results, &relation, Executor::Pool(*pool)).await?;
896                }
897                Executor::Conn(conn) => {
898                    T::eager_load(&mut results, &relation, Executor::Conn(&mut **conn)).await?;
899                }
900            }
901        }
902
903        Ok(results)
904    }
905
906    /// Executes the query and returns a stream of results.
907    ///
908    /// This is useful for processing large result sets without loading them all into memory.
909    #[tracing::instrument(skip(self), fields(table = T::table_name()))]
910    pub fn stream(
911        self,
912    ) -> Result<futures_util::stream::BoxStream<'a, Result<T, sqlx::Error>>, sqlx::Error>
913    where
914        T: 'a,
915    {
916        self.ensure_safe_filters()?;
917
918        let mut sql = String::with_capacity(128);
919        sql.push_str("SELECT * FROM ");
920        sql.push_str(T::table_name());
921
922        let mut where_binds: SmallVec<[BindValue; 8]> =
923            SmallVec::with_capacity(self.estimate_bind_count());
924        self.render_where_clause_into(&mut sql, &mut where_binds, 1);
925
926        if let Some(limit) = self.limit {
927            use std::fmt::Write;
928            let _ = write!(sql, " LIMIT {}", limit);
929        }
930
931        if let Some(offset) = self.offset {
932            use std::fmt::Write;
933            let _ = write!(sql, " OFFSET {}", offset);
934        }
935
936        #[cfg(debug_assertions)]
937        if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
938            let filters = self.format_filters_for_log();
939            tracing::debug!(
940                operation = "stream",
941                sql = %sql,
942                filters = %filters,
943                "premix query"
944            );
945        }
946
947        let executor = self.executor;
948        Ok(Box::pin(async_stream::try_stream! {
949            let mut query = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
950            for bind in where_binds {
951                query = bind_value_query_as(query, bind);
952            }
953            let mut s = executor.fetch_stream(query);
954            while let Some(row) = s.next().await {
955                yield row?;
956            }
957        }))
958    }
959
960    /// Executes a bulk update based on the current filters.
961    ///
962    /// # Errors
963    /// Returns an error if no filters are provided (unless `allow_unsafe` is used),
964    /// or if the values cannot be mapped to the database.
965    #[inline(never)]
966    #[tracing::instrument(skip(self, values), fields(table = T::table_name()))]
967    pub async fn update(mut self, values: serde_json::Value) -> Result<u64, sqlx::Error>
968    where
969        String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
970        i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
971        f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
972        bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
973        Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
974        uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
975        chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
976        chrono::NaiveDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
977        chrono::NaiveDate: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
978        sqlx::types::Json<serde_json::Value>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
979    {
980        self.ensure_safe_filters()?;
981        if self.filters.is_empty() && !self.allow_unsafe && !self.unsafe_fast {
982            return Err(sqlx::Error::Protocol(
983                "Refusing bulk update without filters".to_string(),
984            ));
985        }
986        let obj = values.as_object().ok_or_else(|| {
987            sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
988        })?;
989
990        let mut sql = String::with_capacity(256);
991        use std::fmt::Write;
992        let _ = write!(sql, "UPDATE {} SET ", T::table_name());
993
994        let mut i = 1;
995        let mut first = true;
996        for k in obj.keys() {
997            if !first {
998                sql.push_str(", ");
999            }
1000            let p = DB::placeholder(i);
1001            let _ = write!(sql, "{} = {}", DB::quote_identifier(k), p);
1002            i += 1;
1003            first = false;
1004        }
1005
1006        let mut where_binds: SmallVec<[BindValue; 8]> =
1007            SmallVec::with_capacity(self.estimate_bind_count());
1008        self.render_where_clause_into(&mut sql, &mut where_binds, obj.len() + 1);
1009
1010        if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
1011            let filters = self.format_filters_for_log();
1012            tracing::debug!(
1013                operation = "bulk_update",
1014                sql = %sql,
1015                filters = %filters,
1016                "premix query"
1017            );
1018        }
1019        let mut query = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1020        for val in obj.values() {
1021            match val {
1022                serde_json::Value::String(s) => query = query.bind(s.clone()),
1023                serde_json::Value::Number(n) => {
1024                    if let Some(v) = n.as_i64() {
1025                        query = query.bind(v);
1026                    } else if let Some(v) = n.as_f64() {
1027                        query = query.bind(v);
1028                    }
1029                }
1030                serde_json::Value::Bool(b) => query = query.bind(*b),
1031                serde_json::Value::Null => query = query.bind(Option::<String>::None),
1032                serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
1033                    query = query.bind(sqlx::types::Json(val.clone()));
1034                }
1035            }
1036        }
1037        for bind in where_binds {
1038            query = bind_value_query(query, bind);
1039        }
1040
1041        let start = Instant::now();
1042        let result = match &mut self.executor {
1043            Executor::Pool(pool) => {
1044                let res = query.execute(*pool).await?;
1045                Ok(DB::rows_affected(&res))
1046            }
1047            Executor::Conn(conn) => {
1048                let res = query.execute(&mut **conn).await?;
1049                Ok(DB::rows_affected(&res))
1050            }
1051        };
1052        if !self.fast_path {
1053            record_query_metrics("bulk_update", T::table_name(), start.elapsed());
1054        }
1055        result
1056    }
1057
1058    /// Executes a bulk update based on the current filters. Alias for [`update`].
1059    pub async fn update_all(self, values: serde_json::Value) -> Result<u64, sqlx::Error>
1060    where
1061        String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1062        i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1063        f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1064        bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1065        Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1066        uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1067        chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1068        chrono::NaiveDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1069        chrono::NaiveDate: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1070        sqlx::types::Json<serde_json::Value>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1071    {
1072        self.update(values).await
1073    }
1074
1075    /// Executes a bulk delete based on the current filters.
1076    #[tracing::instrument(skip(self), fields(table = T::table_name()))]
1077    pub async fn delete(mut self) -> Result<u64, sqlx::Error> {
1078        self.ensure_safe_filters()?;
1079        if self.filters.is_empty() && !self.allow_unsafe && !self.unsafe_fast {
1080            return Err(sqlx::Error::Protocol(
1081                "Refusing bulk delete without filters".to_string(),
1082            ));
1083        }
1084
1085        let mut sql = String::with_capacity(128);
1086        use std::fmt::Write;
1087
1088        if T::has_soft_delete() {
1089            let _ = write!(
1090                sql,
1091                "UPDATE {} SET {} = {}",
1092                T::table_name(),
1093                DB::quote_identifier("deleted_at"),
1094                DB::current_timestamp_fn()
1095            );
1096        } else {
1097            let _ = write!(sql, "DELETE FROM {}", T::table_name());
1098        }
1099
1100        let mut where_binds: SmallVec<[BindValue; 8]> =
1101            SmallVec::with_capacity(self.estimate_bind_count());
1102        self.render_where_clause_into(&mut sql, &mut where_binds, 1);
1103
1104        if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
1105            let filters = self.format_filters_for_log();
1106            tracing::debug!(
1107                operation = "bulk_delete",
1108                sql = %sql,
1109                filters = %filters,
1110                "premix query"
1111            );
1112        }
1113        let start = Instant::now();
1114        let result = match &mut self.executor {
1115            Executor::Pool(pool) => {
1116                let base = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1117                let query = where_binds.into_iter().fold(base, bind_value_query);
1118                let res = query.execute(*pool).await?;
1119                Ok(DB::rows_affected(&res))
1120            }
1121            Executor::Conn(conn) => {
1122                let base = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1123                let query = where_binds.into_iter().fold(base, bind_value_query);
1124                let res = query.execute(&mut **conn).await?;
1125                Ok(DB::rows_affected(&res))
1126            }
1127        };
1128        if !self.fast_path {
1129            record_query_metrics("bulk_delete", T::table_name(), start.elapsed());
1130        }
1131        result
1132    }
1133
1134    /// Executes a bulk delete based on the current filters. Alias for [`delete`].
1135    pub async fn delete_all(self) -> Result<u64, sqlx::Error> {
1136        self.delete().await
1137    }
1138}
1139
1140#[cfg(test)]
1141mod tests {
1142    use super::*;
1143    use sqlx::Sqlite;
1144
1145    struct DummyModel;
1146
1147    impl Model<Sqlite> for DummyModel {
1148        fn table_name() -> &'static str {
1149            "users"
1150        }
1151        fn create_table_sql() -> String {
1152            String::new()
1153        }
1154        fn list_columns() -> Vec<String> {
1155            vec!["id".to_string()]
1156        }
1157        async fn save<'a, E>(&'a mut self, _e: E) -> Result<(), sqlx::Error>
1158        where
1159            E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1160        {
1161            Ok(())
1162        }
1163        async fn update<'a, E>(
1164            &'a mut self,
1165            _e: E,
1166        ) -> Result<crate::model::UpdateResult, sqlx::Error>
1167        where
1168            E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1169        {
1170            Ok(crate::model::UpdateResult::Success)
1171        }
1172        async fn delete<'a, E>(&'a mut self, _e: E) -> Result<(), sqlx::Error>
1173        where
1174            E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1175        {
1176            Ok(())
1177        }
1178        fn has_soft_delete() -> bool {
1179            false
1180        }
1181        async fn find_by_id<'a, E>(_e: E, _id: i32) -> Result<Option<Self>, sqlx::Error>
1182        where
1183            E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1184        {
1185            Ok(None)
1186        }
1187    }
1188
1189    // Dummy FromRow implementation for Sqlite
1190    impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for DummyModel {
1191        fn from_row(_row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
1192            Ok(DummyModel)
1193        }
1194    }
1195
1196    #[tokio::test]
1197    async fn test_sql_injection_mitigation() {
1198        let pool = sqlx::Pool::<Sqlite>::connect_lazy("sqlite::memory:").unwrap();
1199        let qb = DummyModel::find_in_pool(&pool);
1200
1201        // Malicious column name
1202        let qb = qb.filter_eq("id; DROP TABLE users; --", 1);
1203        let sql = qb.to_sql();
1204        println!("SQL select: {}", sql);
1205
1206        // The column name should be quoted
1207        assert!(sql.contains("`id; DROP TABLE users; --` = ?"));
1208        assert!(sql.contains("SELECT * FROM users WHERE"));
1209    }
1210
1211    #[tokio::test]
1212    async fn test_to_update_sql_quoting() {
1213        let pool = sqlx::Pool::<Sqlite>::connect_lazy("sqlite::memory:").unwrap();
1214        let qb = DummyModel::find_in_pool(&pool).filter_eq("id", 1);
1215
1216        let values = serde_json::json!({
1217            "name; DROP TABLE users; --": "admin"
1218        });
1219
1220        let sql = qb.to_update_sql(&values).unwrap();
1221        println!("SQL update: {}", sql);
1222        assert!(sql.contains("`name; DROP TABLE users; --` = ?"));
1223    }
1224
1225    #[tokio::test]
1226    async fn test_stream_api() {
1227        use sqlx::Connection;
1228        let mut conn = sqlx::SqliteConnection::connect("sqlite::memory:")
1229            .await
1230            .unwrap();
1231        sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
1232            .execute(&mut conn)
1233            .await
1234            .unwrap();
1235        sqlx::query("INSERT INTO users (id, name) VALUES (1, 'alice'), (2, 'bob')")
1236            .execute(&mut conn)
1237            .await
1238            .unwrap();
1239
1240        // Use find_in_tx which is standard in Premix
1241        let qb = DummyModel::find_in_tx(&mut conn);
1242
1243        let mut stream = qb.stream().unwrap();
1244        let mut count = 0;
1245        while let Some(row) = stream.next().await {
1246            let _: DummyModel = row.unwrap();
1247            count += 1;
1248        }
1249        assert_eq!(count, 2);
1250    }
1251}