Skip to main content

premix_core/
query.rs

1use crate::dialect::SqlDialect;
2use crate::executor::Executor;
3use crate::model::Model;
4use futures_util::StreamExt;
5use smallvec::{SmallVec, smallvec};
6use sqlx::{Database, IntoArguments};
7use std::time::{Duration, Instant};
8
9#[doc(hidden)]
10#[derive(Debug, Clone)]
11pub enum BindValue {
12    String(String),
13    I64(i64),
14    F64(f64),
15    Bool(bool),
16    Uuid(uuid::Uuid),
17    DateTime(chrono::DateTime<chrono::Utc>),
18    Null,
19}
20
21impl BindValue {
22    fn to_log_string(&self) -> String {
23        match self {
24            BindValue::String(v) => v.clone(),
25            BindValue::I64(v) => v.to_string(),
26            BindValue::F64(v) => v.to_string(),
27            BindValue::Bool(v) => v.to_string(),
28            BindValue::Uuid(v) => v.to_string(),
29            BindValue::DateTime(v) => v.to_rfc3339(),
30            BindValue::Null => "NULL".to_string(),
31        }
32    }
33}
34
35#[cfg(feature = "metrics")]
36fn record_query_metrics(operation: &str, table: &str, elapsed: Duration) {
37    let elapsed_ms = elapsed.as_secs_f64() * 1000.0;
38    let labels = [
39        ("operation", operation.to_string()),
40        ("table", table.to_string()),
41    ];
42    metrics::histogram!("premix.query.duration_ms", &labels).record(elapsed_ms);
43    metrics::counter!("premix.query.count", &labels).increment(1);
44}
45
46#[cfg(not(feature = "metrics"))]
47fn record_query_metrics(_operation: &str, _table: &str, _elapsed: Duration) {}
48
49#[inline(always)]
50fn bind_value_query<'q, DB>(
51    query: sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>,
52    value: BindValue,
53) -> sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>
54where
55    DB: Database,
56    String: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
57    i64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
58    f64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
59    bool: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
60    uuid::Uuid: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
61    chrono::DateTime<chrono::Utc>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
62    Option<String>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
63{
64    match value {
65        BindValue::String(v) => query.bind(v),
66        BindValue::I64(v) => query.bind(v),
67        BindValue::F64(v) => query.bind(v),
68        BindValue::Bool(v) => query.bind(v),
69        BindValue::Uuid(v) => query.bind(v),
70        BindValue::DateTime(v) => query.bind(v),
71        BindValue::Null => query.bind(Option::<String>::None),
72    }
73}
74
75#[inline(always)]
76fn bind_value_query_as<'q, DB, T>(
77    query: sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>,
78    value: BindValue,
79) -> sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>
80where
81    DB: Database,
82    String: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
83    i64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
84    f64: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
85    bool: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
86    uuid::Uuid: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
87    chrono::DateTime<chrono::Utc>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
88    Option<String>: sqlx::Encode<'q, DB> + sqlx::Type<DB>,
89{
90    match value {
91        BindValue::String(v) => query.bind(v),
92        BindValue::I64(v) => query.bind(v),
93        BindValue::F64(v) => query.bind(v),
94        BindValue::Bool(v) => query.bind(v),
95        BindValue::Uuid(v) => query.bind(v),
96        BindValue::DateTime(v) => query.bind(v),
97        BindValue::Null => query.bind(Option::<String>::None),
98    }
99}
100
101#[inline]
102fn apply_persistent_query_as<'q, DB, T>(
103    query: sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>,
104    prepared: bool,
105) -> sqlx::query::QueryAs<'q, DB, T, <DB as Database>::Arguments<'q>>
106where
107    DB: Database + sqlx::database::HasStatementCache,
108{
109    if prepared {
110        query.persistent(true)
111    } else {
112        query
113    }
114}
115
116#[inline]
117fn apply_persistent_query<'q, DB>(
118    query: sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>,
119    prepared: bool,
120) -> sqlx::query::Query<'q, DB, <DB as Database>::Arguments<'q>>
121where
122    DB: Database + sqlx::database::HasStatementCache,
123{
124    if prepared {
125        query.persistent(true)
126    } else {
127        query
128    }
129}
130
131impl From<String> for BindValue {
132    fn from(value: String) -> Self {
133        Self::String(value)
134    }
135}
136
137impl From<&str> for BindValue {
138    fn from(value: &str) -> Self {
139        Self::String(value.to_string())
140    }
141}
142
143impl From<i32> for BindValue {
144    fn from(value: i32) -> Self {
145        Self::I64(value as i64)
146    }
147}
148
149impl From<i64> for BindValue {
150    fn from(value: i64) -> Self {
151        Self::I64(value)
152    }
153}
154
155impl From<f64> for BindValue {
156    fn from(value: f64) -> Self {
157        Self::F64(value)
158    }
159}
160
161impl From<bool> for BindValue {
162    fn from(value: bool) -> Self {
163        Self::Bool(value)
164    }
165}
166
167impl From<Option<String>> for BindValue {
168    fn from(value: Option<String>) -> Self {
169        match value {
170            Some(v) => Self::String(v),
171            None => Self::Null,
172        }
173    }
174}
175
176impl From<uuid::Uuid> for BindValue {
177    fn from(value: uuid::Uuid) -> Self {
178        Self::Uuid(value)
179    }
180}
181
182impl From<chrono::DateTime<chrono::Utc>> for BindValue {
183    fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
184        Self::DateTime(value)
185    }
186}
187
188#[derive(Debug, Clone)]
189pub(crate) enum FilterExpr {
190    Raw(String),
191    Compare {
192        column: ColumnRef,
193        op: FilterOp,
194        values: SmallVec<[BindValue; 2]>,
195    },
196    NullCheck {
197        column: ColumnRef,
198        is_null: bool,
199    },
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq)]
203pub(crate) enum FilterOp {
204    Eq,
205    Ne,
206    Lt,
207    Lte,
208    Gt,
209    Gte,
210    Like,
211    In,
212}
213
214impl FilterOp {
215    fn as_str(self) -> &'static str {
216        match self {
217            FilterOp::Eq => "=",
218            FilterOp::Ne => "!=",
219            FilterOp::Lt => "<",
220            FilterOp::Lte => "<=",
221            FilterOp::Gt => ">",
222            FilterOp::Gte => ">=",
223            FilterOp::Like => "LIKE",
224            FilterOp::In => "IN",
225        }
226    }
227
228    fn is_in(self) -> bool {
229        matches!(self, FilterOp::In)
230    }
231}
232
233/// Column reference used in filters (static literals or owned names).
234#[derive(Debug, Clone, PartialEq, Eq)]
235pub enum ColumnRef {
236    /// Static column name known at compile time.
237    Static(&'static str),
238    /// Owned column name created at runtime.
239    Owned(String),
240}
241
242impl ColumnRef {
243    /// Builds a static column reference without allocation.
244    pub const fn static_str(value: &'static str) -> Self {
245        ColumnRef::Static(value)
246    }
247
248    fn as_str(&self) -> &str {
249        match self {
250            ColumnRef::Static(value) => value,
251            ColumnRef::Owned(value) => value,
252        }
253    }
254}
255
256impl From<&str> for ColumnRef {
257    fn from(value: &str) -> Self {
258        ColumnRef::Owned(value.to_string())
259    }
260}
261
262impl From<String> for ColumnRef {
263    fn from(value: String) -> Self {
264        ColumnRef::Owned(value)
265    }
266}
267
268impl From<&String> for ColumnRef {
269    fn from(value: &String) -> Self {
270        ColumnRef::Owned(value.clone())
271    }
272}
273
274/// A type-safe SQL query builder.
275///
276/// `QueryBuilder` provides a fluent interface for building SELECT, UPDATE, and DELETE queries
277/// with support for filtering, pagination, eager loading, and soft deletes.
278pub struct QueryBuilder<'a, T, DB: Database> {
279    executor: Executor<'a, DB>,
280    filters: Vec<FilterExpr>,
281    limit: Option<i32>,
282    offset: Option<i32>,
283    includes: SmallVec<[String; 2]>,
284    include_deleted: bool,
285    allow_unsafe: bool,
286    has_raw_filter: bool,
287    fast_path: bool,
288    unsafe_fast: bool,
289    ultra_fast: bool,
290    prepared: bool,
291    _marker: std::marker::PhantomData<T>,
292}
293
294impl<'a, T, DB: Database> std::fmt::Debug for QueryBuilder<'a, T, DB> {
295    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296        f.debug_struct("QueryBuilder")
297            .field("filters", &self.filters)
298            .field("limit", &self.limit)
299            .field("offset", &self.offset)
300            .field("includes", &self.includes)
301            .field("include_deleted", &self.include_deleted)
302            .field("allow_unsafe", &self.allow_unsafe)
303            .field("fast_path", &self.fast_path)
304            .field("unsafe_fast", &self.unsafe_fast)
305            .field("ultra_fast", &self.ultra_fast)
306            .field("prepared", &self.prepared)
307            .finish()
308    }
309}
310
311impl<'a, T, DB> QueryBuilder<'a, T, DB>
312where
313    DB: SqlDialect + sqlx::database::HasStatementCache,
314    T: Model<DB>,
315{
316    /// Creates a new `QueryBuilder` using the provided [`Executor`].
317    pub fn new(executor: Executor<'a, DB>) -> Self {
318        Self {
319            executor,
320            filters: Vec::with_capacity(4), // Pre-allocate for typical queries (1-4 filters)
321            limit: None,
322            offset: None,
323            includes: SmallVec::with_capacity(2), // Pre-allocate for typical queries (1-2 includes)
324            include_deleted: false,
325            allow_unsafe: false,
326            has_raw_filter: false,
327            fast_path: false,
328            unsafe_fast: false,
329            ultra_fast: false,
330            prepared: true,
331            _marker: std::marker::PhantomData,
332        }
333    }
334
335    /// Adds a raw SQL filter condition to the query.
336    ///
337    /// # Safety
338    /// This method is potentially unsafe and requires calling [`allow_unsafe`] for the query to execute.
339    pub fn filter(mut self, condition: impl Into<String>) -> Self {
340        self.filters.push(FilterExpr::Raw(condition.into()));
341        self.has_raw_filter = true;
342        self
343    }
344
345    /// Adds a raw SQL filter condition to the query.
346    pub fn filter_raw(self, condition: impl Into<String>) -> Self {
347        self.filter(condition)
348    }
349
350    /// Adds an equality filter (`column = value`).
351    pub fn filter_eq(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
352        self.filters.push(FilterExpr::Compare {
353            column: column.into(),
354            op: FilterOp::Eq,
355            values: smallvec![value.into()],
356        });
357        self
358    }
359
360    /// Adds a not-equal filter (`column != value`).
361    pub fn filter_ne(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
362        self.filters.push(FilterExpr::Compare {
363            column: column.into(),
364            op: FilterOp::Ne,
365            values: smallvec![value.into()],
366        });
367        self
368    }
369
370    /// Adds a less-than filter (`column < value`).
371    pub fn filter_lt(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
372        self.filters.push(FilterExpr::Compare {
373            column: column.into(),
374            op: FilterOp::Lt,
375            values: smallvec![value.into()],
376        });
377        self
378    }
379
380    /// Adds a less-than-or-equal filter (`column <= value`).
381    pub fn filter_lte(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
382        self.filters.push(FilterExpr::Compare {
383            column: column.into(),
384            op: FilterOp::Lte,
385            values: smallvec![value.into()],
386        });
387        self
388    }
389
390    /// Adds a greater-than filter (`column > value`).
391    pub fn filter_gt(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
392        self.filters.push(FilterExpr::Compare {
393            column: column.into(),
394            op: FilterOp::Gt,
395            values: smallvec![value.into()],
396        });
397        self
398    }
399
400    /// Adds a greater-than-or-equal filter (`column >= value`).
401    pub fn filter_gte(mut self, column: impl Into<ColumnRef>, value: impl Into<BindValue>) -> Self {
402        self.filters.push(FilterExpr::Compare {
403            column: column.into(),
404            op: FilterOp::Gte,
405            values: smallvec![value.into()],
406        });
407        self
408    }
409
410    /// Adds a LIKE filter (`column LIKE value`).
411    pub fn filter_like(
412        mut self,
413        column: impl Into<ColumnRef>,
414        value: impl Into<BindValue>,
415    ) -> Self {
416        self.filters.push(FilterExpr::Compare {
417            column: column.into(),
418            op: FilterOp::Like,
419            values: smallvec![value.into()],
420        });
421        self
422    }
423
424    /// Filters rows where the column IS NULL.
425    pub fn filter_is_null(mut self, column: impl Into<ColumnRef>) -> Self {
426        self.filters.push(FilterExpr::NullCheck {
427            column: column.into(),
428            is_null: true,
429        });
430        self
431    }
432
433    /// Filters rows where the column IS NOT NULL.
434    pub fn filter_is_not_null(mut self, column: impl Into<ColumnRef>) -> Self {
435        self.filters.push(FilterExpr::NullCheck {
436            column: column.into(),
437            is_null: false,
438        });
439        self
440    }
441
442    /// Adds an IN filter (`column IN (values...)`).
443    pub fn filter_in<I, V>(mut self, column: impl Into<ColumnRef>, values: I) -> Self
444    where
445        I: IntoIterator<Item = V>,
446        V: Into<BindValue>,
447    {
448        let values: SmallVec<[BindValue; 2]> = values.into_iter().map(Into::into).collect();
449        self.filters.push(FilterExpr::Compare {
450            column: column.into(),
451            op: FilterOp::In,
452            values,
453        });
454        self
455    }
456
457    fn format_filters_for_log(&self) -> String {
458        let sensitive_fields = T::sensitive_fields();
459        let mut rendered = String::with_capacity(128);
460        let mut first_clause = true;
461        let mut append_and = |buf: &mut String| {
462            if first_clause {
463                first_clause = false;
464            } else {
465                buf.push_str(" AND ");
466            }
467        };
468        use std::fmt::Write;
469
470        for filter in &self.filters {
471            match filter {
472                FilterExpr::Raw(_) => {
473                    append_and(&mut rendered);
474                    rendered.push_str("RAW(<redacted>)");
475                }
476                FilterExpr::Compare { column, op, values } => {
477                    let column_name = column.as_str();
478                    let is_sensitive = sensitive_fields.contains(&column_name);
479                    if op.is_in() {
480                        if values.is_empty() {
481                            append_and(&mut rendered);
482                            rendered.push_str("1=0");
483                            continue;
484                        }
485                        append_and(&mut rendered);
486                        let _ = write!(rendered, "{} IN (", column_name);
487                        for (idx, value) in values.iter().enumerate() {
488                            if idx > 0 {
489                                rendered.push_str(", ");
490                            }
491                            if is_sensitive {
492                                rendered.push_str("***");
493                            } else {
494                                rendered.push_str(&value.to_log_string());
495                            }
496                        }
497                        rendered.push(')');
498                    } else {
499                        append_and(&mut rendered);
500                        let _ = write!(rendered, "{} {} ", column_name, op.as_str());
501                        if is_sensitive {
502                            rendered.push_str("***");
503                        } else if let Some(value) = values.first() {
504                            rendered.push_str(&value.to_log_string());
505                        } else {
506                            rendered.push_str("NULL");
507                        }
508                    }
509                }
510                FilterExpr::NullCheck { column, is_null } => {
511                    append_and(&mut rendered);
512                    if *is_null {
513                        let _ = write!(rendered, "{} IS NULL", column.as_str());
514                    } else {
515                        let _ = write!(rendered, "{} IS NOT NULL", column.as_str());
516                    }
517                }
518            }
519        }
520
521        if T::has_soft_delete() && !self.include_deleted {
522            append_and(&mut rendered);
523            rendered.push_str("deleted_at IS NULL");
524        }
525
526        rendered
527    }
528
529    fn estimate_bind_count(&self) -> usize {
530        let mut count = 0usize;
531        for filter in &self.filters {
532            if let FilterExpr::Compare { op, values, .. } = filter {
533                if op.is_in() {
534                    count = count.saturating_add(values.len());
535                } else {
536                    count = count.saturating_add(1);
537                }
538            }
539        }
540        count
541    }
542
543    /// Limits the number of rows returned by the query.
544    /// Sets the maximum number of rows to return.
545    pub fn limit(mut self, limit: i32) -> Self {
546        self.limit = Some(limit);
547        self
548    }
549
550    /// Skips the specified number of rows.
551    /// Sets the number of rows to skip.
552    pub fn offset(mut self, offset: i32) -> Self {
553        self.offset = Some(offset);
554        self
555    }
556
557    /// Eager loads a related model.
558    pub fn include(mut self, relation: impl Into<String>) -> Self {
559        self.includes.push(relation.into());
560        self
561    }
562
563    /// Includes soft-deleted records in the results.
564    pub fn with_deleted(mut self) -> Self {
565        self.include_deleted = true;
566        self
567    }
568
569    /// Explicitly allows potentially unsafe raw filters.
570    /// Enables execution of queries with raw SQL filters.
571    pub fn allow_unsafe(mut self) -> Self {
572        self.allow_unsafe = true;
573        self
574    }
575
576    /// Enables a fast path that skips logging and metrics for hot queries.
577    pub fn fast(mut self) -> Self {
578        self.fast_path = true;
579        self
580    }
581
582    /// Enables an unsafe fast path that skips logging, metrics, and safety guards.
583    pub fn unsafe_fast(mut self) -> Self {
584        self.fast_path = true;
585        self.unsafe_fast = true;
586        self.allow_unsafe = true;
587        self
588    }
589
590    /// Enables the ultra-fast path: skips logging, metrics, safety guards, and eager loading.
591    /// Note: Any configured includes will be ignored.
592    pub fn ultra_fast(mut self) -> Self {
593        self.fast_path = true;
594        self.unsafe_fast = true;
595        self.allow_unsafe = true;
596        self.ultra_fast = true;
597        self
598    }
599
600    /// Enable prepared statement caching for this query (default: enabled).
601    pub fn prepared(mut self) -> Self {
602        self.prepared = true;
603        self
604    }
605
606    /// Disable prepared statement caching for this query.
607    pub fn unprepared(mut self) -> Self {
608        self.prepared = false;
609        self
610    }
611
612    /// Returns the SELECT SQL that would be executed for this query.
613    pub fn to_sql(&self) -> String {
614        let mut sql = String::with_capacity(128); // Pre-allocate reasonable size
615        use std::fmt::Write;
616
617        sql.push_str("SELECT * FROM ");
618        sql.push_str(T::table_name());
619
620        let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
621        self.render_where_clause_into(&mut sql, &mut dummy_binds, 1);
622
623        if let Some(limit) = self.limit {
624            let _ = write!(sql, " LIMIT {}", limit);
625        }
626
627        if let Some(offset) = self.offset {
628            let _ = write!(sql, " OFFSET {}", offset);
629        }
630
631        sql
632    }
633
634    /// Returns the UPDATE SQL that would be executed for this query.
635    pub fn to_update_sql(&self, values: &serde_json::Value) -> Result<String, sqlx::Error> {
636        let obj = values.as_object().ok_or_else(|| {
637            sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
638        })?;
639
640        let mut sql = String::with_capacity(256);
641        use std::fmt::Write;
642
643        let _ = write!(sql, "UPDATE {} SET ", T::table_name());
644
645        let mut i = 1;
646        let mut first = true;
647
648        for k in obj.keys() {
649            if !first {
650                sql.push_str(", ");
651            }
652            let p = DB::placeholder(i);
653            let _ = write!(sql, "{} = {}", DB::quote_identifier(k), p);
654            i += 1;
655            first = false;
656        }
657
658        let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
659        self.render_where_clause_into(&mut sql, &mut dummy_binds, obj.len() + 1);
660        Ok(sql)
661    }
662
663    /// Returns the DELETE (or soft delete) SQL that would be executed for this query.
664    pub fn to_delete_sql(&self) -> String {
665        let mut sql = String::with_capacity(128);
666        use std::fmt::Write;
667
668        if T::has_soft_delete() {
669            let _ = write!(
670                sql,
671                "UPDATE {} SET {} = {}",
672                T::table_name(),
673                DB::quote_identifier("deleted_at"),
674                DB::current_timestamp_fn()
675            );
676        } else {
677            let _ = write!(sql, "DELETE FROM {}", T::table_name());
678        }
679
680        let mut dummy_binds: SmallVec<[BindValue; 8]> = SmallVec::new();
681        self.render_where_clause_into(&mut sql, &mut dummy_binds, 1);
682        sql
683    }
684
685    // Optimized version that writes to buffer
686    #[inline(always)]
687    fn render_where_clause_into(
688        &self,
689        sql: &mut String,
690        binds: &mut SmallVec<[BindValue; 8]>,
691        start_index: usize,
692    ) {
693        let mut idx = start_index;
694        let mut first_clause = true;
695        use std::fmt::Write;
696
697        // Helper to handle AND prefix
698        let mut append_and = |sql: &mut String| {
699            if first_clause {
700                sql.push_str(" WHERE ");
701                first_clause = false;
702            } else {
703                sql.push_str(" AND ");
704            }
705        };
706
707        for filter in &self.filters {
708            match filter {
709                FilterExpr::Raw(condition) => {
710                    append_and(sql);
711                    sql.push_str(condition);
712                }
713                FilterExpr::Compare { column, op, values } => {
714                    if op.is_in() {
715                        if values.is_empty() {
716                            append_and(sql);
717                            sql.push_str("1=0");
718                            continue;
719                        }
720                        append_and(sql);
721                        let _ = write!(sql, "{} IN (", DB::quote_identifier(column.as_str()));
722                        let placeholders = crate::cached_placeholders_from::<DB>(idx, values.len());
723                        sql.push_str(placeholders);
724                        sql.push(')');
725                        idx = idx.saturating_add(values.len());
726                        for v in values {
727                            binds.push(v.clone());
728                        }
729                    } else {
730                        append_and(sql);
731                        let _ = write!(
732                            sql,
733                            "{} {} {}",
734                            DB::quote_identifier(column.as_str()),
735                            op.as_str(),
736                            DB::placeholder(idx)
737                        );
738                        idx += 1;
739                        if let Some(v) = values.first() {
740                            binds.push(v.clone());
741                        }
742                    }
743                }
744                FilterExpr::NullCheck { column, is_null } => {
745                    append_and(sql);
746                    if *is_null {
747                        let _ = write!(sql, "{} IS NULL", DB::quote_identifier(column.as_str()));
748                    } else {
749                        let _ =
750                            write!(sql, "{} IS NOT NULL", DB::quote_identifier(column.as_str()));
751                    }
752                }
753            }
754        }
755
756        if T::has_soft_delete() && !self.include_deleted {
757            append_and(sql);
758            sql.push_str("deleted_at IS NULL");
759        }
760    }
761}
762
763impl<'a, T, DB> QueryBuilder<'a, T, DB>
764where
765    DB: SqlDialect,
766    T: Model<DB>,
767    for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
768    for<'c> &'c mut <DB as Database>::Connection: sqlx::Executor<'c, Database = DB>,
769    for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
770    DB::Connection: Send,
771    T: Send,
772    String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
773    i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
774    f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
775    bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
776    Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
777    uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
778    chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
779{
780    fn ensure_safe_filters(&self) -> Result<(), sqlx::Error> {
781        if self.unsafe_fast {
782            return Ok(());
783        }
784        if self.has_raw_filter && !self.allow_unsafe {
785            return Err(sqlx::Error::Protocol(
786                "Refusing raw filter without allow_unsafe".to_string(),
787            ));
788        }
789        Ok(())
790    }
791
792    /// Executes the query and returns a vector of results.
793    ///
794    /// This method will fetch all rows matching the criteria and then perform
795    /// eager loading for any included relations.
796    #[tracing::instrument(skip(self), fields(table = T::table_name()))]
797    pub async fn all(mut self) -> Result<Vec<T>, sqlx::Error> {
798        self.ensure_safe_filters()?;
799
800        let mut sql = String::with_capacity(128);
801        sql.push_str("SELECT * FROM ");
802        sql.push_str(T::table_name());
803
804        let mut where_binds: SmallVec<[BindValue; 8]> =
805            SmallVec::with_capacity(self.estimate_bind_count());
806        self.render_where_clause_into(&mut sql, &mut where_binds, 1);
807
808        if let Some(limit) = self.limit {
809            use std::fmt::Write;
810            let _ = write!(sql, " LIMIT {}", limit);
811        }
812
813        if let Some(offset) = self.offset {
814            use std::fmt::Write;
815            let _ = write!(sql, " OFFSET {}", offset);
816        }
817
818        #[cfg(debug_assertions)]
819        if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
820            let filters = self.format_filters_for_log();
821            tracing::debug!(
822                operation = "select",
823                sql = %sql,
824                filters = %filters,
825                "premix query"
826            );
827        }
828
829        let start = Instant::now();
830        let mut results: Vec<T> = match &mut self.executor {
831            Executor::Pool(pool) => {
832                let base = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
833                let query = where_binds.into_iter().fold(base, bind_value_query_as);
834                query.fetch_all(*pool).await?
835            }
836            Executor::Conn(conn) => {
837                let base = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
838                let query = where_binds.into_iter().fold(base, bind_value_query_as);
839                query.fetch_all(&mut **conn).await?
840            }
841        };
842        if !self.fast_path {
843            record_query_metrics("select", T::table_name(), start.elapsed());
844        }
845
846        if self.ultra_fast {
847            return Ok(results);
848        }
849
850        for relation in self.includes {
851            match &mut self.executor {
852                Executor::Pool(pool) => {
853                    T::eager_load(&mut results, &relation, Executor::Pool(*pool)).await?;
854                }
855                Executor::Conn(conn) => {
856                    T::eager_load(&mut results, &relation, Executor::Conn(&mut **conn)).await?;
857                }
858            }
859        }
860
861        Ok(results)
862    }
863
864    /// Executes the query and returns a stream of results.
865    ///
866    /// This is useful for processing large result sets without loading them all into memory.
867    #[tracing::instrument(skip(self), fields(table = T::table_name()))]
868    pub fn stream(
869        self,
870    ) -> Result<futures_util::stream::BoxStream<'a, Result<T, sqlx::Error>>, sqlx::Error>
871    where
872        T: 'a,
873    {
874        self.ensure_safe_filters()?;
875
876        let mut sql = String::with_capacity(128);
877        sql.push_str("SELECT * FROM ");
878        sql.push_str(T::table_name());
879
880        let mut where_binds: SmallVec<[BindValue; 8]> =
881            SmallVec::with_capacity(self.estimate_bind_count());
882        self.render_where_clause_into(&mut sql, &mut where_binds, 1);
883
884        if let Some(limit) = self.limit {
885            use std::fmt::Write;
886            let _ = write!(sql, " LIMIT {}", limit);
887        }
888
889        if let Some(offset) = self.offset {
890            use std::fmt::Write;
891            let _ = write!(sql, " OFFSET {}", offset);
892        }
893
894        #[cfg(debug_assertions)]
895        if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
896            let filters = self.format_filters_for_log();
897            tracing::debug!(
898                operation = "stream",
899                sql = %sql,
900                filters = %filters,
901                "premix query"
902            );
903        }
904
905        let executor = self.executor;
906        Ok(Box::pin(async_stream::try_stream! {
907            let mut query = apply_persistent_query_as(sqlx::query_as::<DB, T>(&sql), self.prepared);
908            for bind in where_binds {
909                query = bind_value_query_as(query, bind);
910            }
911            let mut s = executor.fetch_stream(query);
912            while let Some(row) = s.next().await {
913                yield row?;
914            }
915        }))
916    }
917
918    /// Executes a bulk update based on the current filters.
919    ///
920    /// # Errors
921    /// Returns an error if no filters are provided (unless `allow_unsafe` is used),
922    /// or if the values cannot be mapped to the database.
923    #[inline(never)]
924    #[tracing::instrument(skip(self, values), fields(table = T::table_name()))]
925    pub async fn update(mut self, values: serde_json::Value) -> Result<u64, sqlx::Error>
926    where
927        String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
928        i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
929        f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
930        bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
931        Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
932        uuid::Uuid: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
933        chrono::DateTime<chrono::Utc>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
934    {
935        self.ensure_safe_filters()?;
936        if self.filters.is_empty() && !self.allow_unsafe && !self.unsafe_fast {
937            return Err(sqlx::Error::Protocol(
938                "Refusing bulk update without filters".to_string(),
939            ));
940        }
941        let obj = values.as_object().ok_or_else(|| {
942            sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
943        })?;
944
945        let mut sql = String::with_capacity(256);
946        use std::fmt::Write;
947        let _ = write!(sql, "UPDATE {} SET ", T::table_name());
948
949        let mut i = 1;
950        let mut first = true;
951        for k in obj.keys() {
952            if !first {
953                sql.push_str(", ");
954            }
955            let p = DB::placeholder(i);
956            let _ = write!(sql, "{} = {}", DB::quote_identifier(k), p);
957            i += 1;
958            first = false;
959        }
960
961        let mut where_binds: SmallVec<[BindValue; 8]> =
962            SmallVec::with_capacity(self.estimate_bind_count());
963        self.render_where_clause_into(&mut sql, &mut where_binds, obj.len() + 1);
964
965        if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
966            let filters = self.format_filters_for_log();
967            tracing::debug!(
968                operation = "bulk_update",
969                sql = %sql,
970                filters = %filters,
971                "premix query"
972            );
973        }
974        let mut query = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
975        for val in obj.values() {
976            match val {
977                serde_json::Value::String(s) => query = query.bind(s.clone()),
978                serde_json::Value::Number(n) => {
979                    if let Some(v) = n.as_i64() {
980                        query = query.bind(v);
981                    } else if let Some(v) = n.as_f64() {
982                        query = query.bind(v);
983                    }
984                }
985                serde_json::Value::Bool(b) => query = query.bind(*b),
986                serde_json::Value::Null => query = query.bind(Option::<String>::None),
987                _ => {
988                    return Err(sqlx::Error::Protocol(
989                        "Unsupported type in bulk update".to_string(),
990                    ));
991                }
992            }
993        }
994        for bind in where_binds {
995            query = bind_value_query(query, bind);
996        }
997
998        let start = Instant::now();
999        let result = match &mut self.executor {
1000            Executor::Pool(pool) => {
1001                let res = query.execute(*pool).await?;
1002                Ok(DB::rows_affected(&res))
1003            }
1004            Executor::Conn(conn) => {
1005                let res = query.execute(&mut **conn).await?;
1006                Ok(DB::rows_affected(&res))
1007            }
1008        };
1009        if !self.fast_path {
1010            record_query_metrics("bulk_update", T::table_name(), start.elapsed());
1011        }
1012        result
1013    }
1014
1015    /// Executes a bulk update based on the current filters. Alias for [`update`].
1016    pub async fn update_all(self, values: serde_json::Value) -> Result<u64, sqlx::Error>
1017    where
1018        String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1019        i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1020        f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1021        bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1022        Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
1023    {
1024        self.update(values).await
1025    }
1026
1027    /// Executes a bulk delete based on the current filters.
1028    #[tracing::instrument(skip(self), fields(table = T::table_name()))]
1029    pub async fn delete(mut self) -> Result<u64, sqlx::Error> {
1030        self.ensure_safe_filters()?;
1031        if self.filters.is_empty() && !self.allow_unsafe && !self.unsafe_fast {
1032            return Err(sqlx::Error::Protocol(
1033                "Refusing bulk delete without filters".to_string(),
1034            ));
1035        }
1036
1037        let mut sql = String::with_capacity(128);
1038        use std::fmt::Write;
1039
1040        if T::has_soft_delete() {
1041            let _ = write!(
1042                sql,
1043                "UPDATE {} SET {} = {}",
1044                T::table_name(),
1045                DB::quote_identifier("deleted_at"),
1046                DB::current_timestamp_fn()
1047            );
1048        } else {
1049            let _ = write!(sql, "DELETE FROM {}", T::table_name());
1050        }
1051
1052        let mut where_binds: SmallVec<[BindValue; 8]> =
1053            SmallVec::with_capacity(self.estimate_bind_count());
1054        self.render_where_clause_into(&mut sql, &mut where_binds, 1);
1055
1056        if !self.fast_path && tracing::enabled!(tracing::Level::DEBUG) {
1057            let filters = self.format_filters_for_log();
1058            tracing::debug!(
1059                operation = "bulk_delete",
1060                sql = %sql,
1061                filters = %filters,
1062                "premix query"
1063            );
1064        }
1065        let start = Instant::now();
1066        let result = match &mut self.executor {
1067            Executor::Pool(pool) => {
1068                let base = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1069                let query = where_binds.into_iter().fold(base, bind_value_query);
1070                let res = query.execute(*pool).await?;
1071                Ok(DB::rows_affected(&res))
1072            }
1073            Executor::Conn(conn) => {
1074                let base = apply_persistent_query(sqlx::query::<DB>(&sql), self.prepared);
1075                let query = where_binds.into_iter().fold(base, bind_value_query);
1076                let res = query.execute(&mut **conn).await?;
1077                Ok(DB::rows_affected(&res))
1078            }
1079        };
1080        if !self.fast_path {
1081            record_query_metrics("bulk_delete", T::table_name(), start.elapsed());
1082        }
1083        result
1084    }
1085
1086    /// Executes a bulk delete based on the current filters. Alias for [`delete`].
1087    pub async fn delete_all(self) -> Result<u64, sqlx::Error> {
1088        self.delete().await
1089    }
1090}
1091
1092#[cfg(test)]
1093mod tests {
1094    use super::*;
1095    use sqlx::Sqlite;
1096
1097    struct DummyModel;
1098
1099    impl Model<Sqlite> for DummyModel {
1100        fn table_name() -> &'static str {
1101            "users"
1102        }
1103        fn create_table_sql() -> String {
1104            String::new()
1105        }
1106        fn list_columns() -> Vec<String> {
1107            vec!["id".to_string()]
1108        }
1109        async fn save<'a, E>(&'a mut self, _e: E) -> Result<(), sqlx::Error>
1110        where
1111            E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1112        {
1113            Ok(())
1114        }
1115        async fn update<'a, E>(
1116            &'a mut self,
1117            _e: E,
1118        ) -> Result<crate::model::UpdateResult, sqlx::Error>
1119        where
1120            E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1121        {
1122            Ok(crate::model::UpdateResult::Success)
1123        }
1124        async fn delete<'a, E>(&'a mut self, _e: E) -> Result<(), sqlx::Error>
1125        where
1126            E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1127        {
1128            Ok(())
1129        }
1130        fn has_soft_delete() -> bool {
1131            false
1132        }
1133        async fn find_by_id<'a, E>(_e: E, _id: i32) -> Result<Option<Self>, sqlx::Error>
1134        where
1135            E: crate::executor::IntoExecutor<'a, DB = Sqlite>,
1136        {
1137            Ok(None)
1138        }
1139    }
1140
1141    // Dummy FromRow implementation for Sqlite
1142    impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for DummyModel {
1143        fn from_row(_row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
1144            Ok(DummyModel)
1145        }
1146    }
1147
1148    #[tokio::test]
1149    async fn test_sql_injection_mitigation() {
1150        let pool = sqlx::Pool::<Sqlite>::connect_lazy("sqlite::memory:").unwrap();
1151        let qb = DummyModel::find_in_pool(&pool);
1152
1153        // Malicious column name
1154        let qb = qb.filter_eq("id; DROP TABLE users; --", 1);
1155        let sql = qb.to_sql();
1156        println!("SQL select: {}", sql);
1157
1158        // The column name should be quoted
1159        assert!(sql.contains("`id; DROP TABLE users; --` = ?"));
1160        assert!(sql.contains("SELECT * FROM users WHERE"));
1161    }
1162
1163    #[tokio::test]
1164    async fn test_to_update_sql_quoting() {
1165        let pool = sqlx::Pool::<Sqlite>::connect_lazy("sqlite::memory:").unwrap();
1166        let qb = DummyModel::find_in_pool(&pool).filter_eq("id", 1);
1167
1168        let values = serde_json::json!({
1169            "name; DROP TABLE users; --": "admin"
1170        });
1171
1172        let sql = qb.to_update_sql(&values).unwrap();
1173        println!("SQL update: {}", sql);
1174        assert!(sql.contains("`name; DROP TABLE users; --` = ?"));
1175    }
1176
1177    #[tokio::test]
1178    async fn test_stream_api() {
1179        use sqlx::Connection;
1180        let mut conn = sqlx::SqliteConnection::connect("sqlite::memory:")
1181            .await
1182            .unwrap();
1183        sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
1184            .execute(&mut conn)
1185            .await
1186            .unwrap();
1187        sqlx::query("INSERT INTO users (id, name) VALUES (1, 'alice'), (2, 'bob')")
1188            .execute(&mut conn)
1189            .await
1190            .unwrap();
1191
1192        // Use find_in_tx which is standard in Premix
1193        let qb = DummyModel::find_in_tx(&mut conn);
1194
1195        let mut stream = qb.stream().unwrap();
1196        let mut count = 0;
1197        while let Some(row) = stream.next().await {
1198            let _: DummyModel = row.unwrap();
1199            count += 1;
1200        }
1201        assert_eq!(count, 2);
1202    }
1203}