prax_query/
raw.rs

1//! Raw SQL query execution with type-safe parameter interpolation.
2//!
3//! This module provides a safe way to execute raw SQL queries while still
4//! benefiting from parameterized queries to prevent SQL injection.
5//!
6//! # Creating SQL Queries
7//!
8//! ```rust
9//! use prax_query::Sql;
10//!
11//! // Simple query
12//! let sql = Sql::new("SELECT * FROM users");
13//! assert_eq!(sql.sql(), "SELECT * FROM users");
14//!
15//! // Query with parameters (binding appends placeholder)
16//! let sql = Sql::new("SELECT * FROM users WHERE id = ")
17//!     .bind(42);
18//! assert_eq!(sql.params().len(), 1);
19//! ```
20//!
21//! # Using the raw_query! Macro
22//!
23//! ```rust
24//! use prax_query::raw_query;
25//!
26//! // Simple query
27//! let sql = raw_query!("SELECT 1");
28//!
29//! // Query with one parameter - {} is replaced with $N placeholder
30//! let id = 42;
31//! let sql = raw_query!("SELECT * FROM users WHERE id = {}", id);
32//! assert_eq!(sql.params().len(), 1);
33//! assert!(sql.sql().contains("$1"));
34//!
35//! // Query with multiple parameters
36//! let name = "John";
37//! let age = 25;
38//! let sql = raw_query!("SELECT * FROM users WHERE name = {} AND age > {}", name, age);
39//! assert_eq!(sql.params().len(), 2);
40//! ```
41//!
42//! # Building Queries Incrementally
43//!
44//! ```rust
45//! use prax_query::Sql;
46//!
47//! // Join multiple conditions
48//! let conditions = vec!["active = true", "verified = true"];
49//! let sql = Sql::new("SELECT * FROM users WHERE ")
50//!     .push(conditions.join(" AND "));
51//!
52//! assert!(sql.sql().contains("active = true AND verified = true"));
53//! ```
54//!
55//! # Safety
56//!
57//! All values passed via `raw_query!` are parameterized and never interpolated
58//! directly into the SQL string, preventing SQL injection attacks.
59//!
60//! ```rust
61//! use prax_query::raw_query;
62//!
63//! // This malicious input will NOT cause SQL injection
64//! let malicious = "'; DROP TABLE users; --";
65//! let sql = raw_query!("SELECT * FROM users WHERE name = {}", malicious);
66//!
67//! // The malicious string is safely bound as a parameter
68//! assert_eq!(sql.params().len(), 1);
69//! // The SQL itself doesn't contain the malicious text
70//! assert!(!sql.sql().contains("DROP TABLE"));
71//! ```
72
73use std::marker::PhantomData;
74use tracing::debug;
75
76use crate::error::QueryResult;
77use crate::filter::FilterValue;
78use crate::sql::DatabaseType;
79use crate::traits::{Model, QueryEngine};
80
81/// A raw SQL query with parameterized values.
82#[derive(Debug, Clone)]
83pub struct Sql {
84    /// The SQL string parts (between parameters).
85    parts: Vec<String>,
86    /// The parameter values.
87    params: Vec<FilterValue>,
88    /// The database type for parameter formatting.
89    db_type: DatabaseType,
90}
91
92impl Sql {
93    /// Create a new raw SQL query.
94    pub fn new(sql: impl Into<String>) -> Self {
95        Self {
96            parts: vec![sql.into()],
97            params: Vec::new(),
98            db_type: DatabaseType::PostgreSQL,
99        }
100    }
101
102    /// Create an empty SQL query.
103    pub fn empty() -> Self {
104        Self {
105            parts: Vec::new(),
106            params: Vec::new(),
107            db_type: DatabaseType::PostgreSQL,
108        }
109    }
110
111    /// Set the database type for parameter formatting.
112    pub fn with_db_type(mut self, db_type: DatabaseType) -> Self {
113        self.db_type = db_type;
114        self
115    }
116
117    /// Append a literal SQL string.
118    pub fn push(mut self, sql: impl Into<String>) -> Self {
119        if let Some(last) = self.parts.last_mut() {
120            last.push_str(&sql.into());
121        } else {
122            self.parts.push(sql.into());
123        }
124        self
125    }
126
127    /// Bind a parameter value.
128    pub fn bind(mut self, value: impl Into<FilterValue>) -> Self {
129        let index = self.params.len() + 1;
130        let placeholder = self.db_type.placeholder(index);
131
132        if let Some(last) = self.parts.last_mut() {
133            // push_str accepts &str, which Cow<str> derefs to
134            last.push_str(&placeholder);
135        } else {
136            // Convert to owned string for storage
137            self.parts.push(placeholder.into_owned());
138        }
139
140        self.params.push(value.into());
141        self
142    }
143
144    /// Bind multiple parameter values at once.
145    pub fn bind_many(mut self, values: impl IntoIterator<Item = FilterValue>) -> Self {
146        for value in values {
147            self = self.bind(value);
148        }
149        self
150    }
151
152    /// Append a conditional clause.
153    pub fn push_if(self, condition: bool, sql: impl Into<String>) -> Self {
154        if condition { self.push(sql) } else { self }
155    }
156
157    /// Bind a parameter conditionally.
158    pub fn bind_if(self, condition: bool, value: impl Into<FilterValue>) -> Self {
159        if condition { self.bind(value) } else { self }
160    }
161
162    /// Push SQL and bind a value together.
163    pub fn push_bind(self, sql: impl Into<String>, value: impl Into<FilterValue>) -> Self {
164        self.push(sql).bind(value)
165    }
166
167    /// Push SQL and bind a value conditionally.
168    pub fn push_bind_if(
169        self,
170        condition: bool,
171        sql: impl Into<String>,
172        value: impl Into<FilterValue>,
173    ) -> Self {
174        if condition {
175            self.push(sql).bind(value)
176        } else {
177            self
178        }
179    }
180
181    /// Add a separator between parts if there are previous parts.
182    pub fn separated(self, separator: &str) -> SeparatedSql {
183        SeparatedSql {
184            sql: self,
185            separator: separator.to_string(),
186            first: true,
187        }
188    }
189
190    /// Build the final SQL string and parameters.
191    pub fn build(self) -> (String, Vec<FilterValue>) {
192        let sql = self.parts.join("");
193        debug!(sql_len = sql.len(), param_count = self.params.len(), db_type = ?self.db_type, "Sql::build()");
194        (sql, self.params)
195    }
196
197    /// Get the SQL string (without consuming).
198    pub fn sql(&self) -> String {
199        self.parts.join("")
200    }
201
202    /// Get the parameters (without consuming).
203    pub fn params(&self) -> &[FilterValue] {
204        &self.params
205    }
206
207    /// Get the number of bound parameters.
208    pub fn param_count(&self) -> usize {
209        self.params.len()
210    }
211
212    /// Check if the query is empty.
213    pub fn is_empty(&self) -> bool {
214        self.parts.is_empty() || self.parts.iter().all(|p| p.is_empty())
215    }
216}
217
218impl Default for Sql {
219    fn default() -> Self {
220        Self::empty()
221    }
222}
223
224impl std::fmt::Display for Sql {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        write!(f, "{}", self.parts.join(""))
227    }
228}
229
230/// A helper for building SQL with separators between items.
231#[derive(Debug, Clone)]
232pub struct SeparatedSql {
233    sql: Sql,
234    separator: String,
235    first: bool,
236}
237
238impl SeparatedSql {
239    /// Push a literal SQL string with separator.
240    pub fn push(mut self, sql: impl Into<String>) -> Self {
241        if !self.first {
242            self.sql = self.sql.push(&self.separator);
243        }
244        self.sql = self.sql.push(sql);
245        self.first = false;
246        self
247    }
248
249    /// Push SQL and bind a value with separator.
250    pub fn push_bind(mut self, sql: impl Into<String>, value: impl Into<FilterValue>) -> Self {
251        if !self.first {
252            self.sql = self.sql.push(&self.separator);
253        }
254        self.sql = self.sql.push(sql).bind(value);
255        self.first = false;
256        self
257    }
258
259    /// Push SQL and bind conditionally with separator.
260    pub fn push_bind_if(
261        mut self,
262        condition: bool,
263        sql: impl Into<String>,
264        value: impl Into<FilterValue>,
265    ) -> Self {
266        if condition {
267            if !self.first {
268                self.sql = self.sql.push(&self.separator);
269            }
270            self.sql = self.sql.push(sql).bind(value);
271            self.first = false;
272        }
273        self
274    }
275
276    /// Finish and get the underlying Sql.
277    pub fn finish(self) -> Sql {
278        self.sql
279    }
280
281    /// Build the final SQL string and parameters.
282    pub fn build(self) -> (String, Vec<FilterValue>) {
283        self.sql.build()
284    }
285}
286
287/// Raw query operation for executing typed queries.
288#[derive(Debug)]
289pub struct RawQueryOperation<M, E>
290where
291    M: Model + Send + 'static,
292    E: QueryEngine,
293{
294    _model: PhantomData<M>,
295    engine: E,
296    sql: Sql,
297}
298
299impl<M, E> RawQueryOperation<M, E>
300where
301    M: Model + Send + 'static,
302    E: QueryEngine,
303{
304    /// Create a new raw query operation.
305    pub fn new(engine: E, sql: Sql) -> Self {
306        Self {
307            _model: PhantomData,
308            engine,
309            sql,
310        }
311    }
312
313    /// Execute the query and return all matching records.
314    pub async fn exec(self) -> QueryResult<Vec<M>> {
315        let (sql, params) = self.sql.build();
316        self.engine.query_many(&sql, params).await
317    }
318
319    /// Execute the query and return a single record.
320    pub async fn exec_one(self) -> QueryResult<M> {
321        let (sql, params) = self.sql.build();
322        self.engine.query_one(&sql, params).await
323    }
324
325    /// Execute the query and return an optional record.
326    pub async fn exec_optional(self) -> QueryResult<Option<M>> {
327        let (sql, params) = self.sql.build();
328        self.engine.query_optional(&sql, params).await
329    }
330}
331
332/// Raw execute operation for mutations.
333#[derive(Debug)]
334pub struct RawExecuteOperation<E>
335where
336    E: QueryEngine,
337{
338    engine: E,
339    sql: Sql,
340}
341
342impl<E> RawExecuteOperation<E>
343where
344    E: QueryEngine,
345{
346    /// Create a new raw execute operation.
347    pub fn new(engine: E, sql: Sql) -> Self {
348        Self { engine, sql }
349    }
350
351    /// Execute the mutation and return the number of affected rows.
352    pub async fn exec(self) -> QueryResult<u64> {
353        let (sql, params) = self.sql.build();
354        self.engine.execute_raw(&sql, params).await
355    }
356}
357
358/// Helper function to create a raw SQL query from a string.
359pub fn sql(query: impl Into<String>) -> Sql {
360    Sql::new(query)
361}
362
363/// Helper function to create a raw SQL query from parts.
364///
365/// This is typically used with the `raw_query!` macro.
366pub fn sql_with_params(sql_str: impl Into<String>, params: Vec<FilterValue>) -> Sql {
367    let mut sql = Sql::new(sql_str);
368    sql.params = params;
369    sql
370}
371
372/// A macro for creating raw SQL queries with inline parameter binding.
373///
374/// # Example
375///
376/// ```rust,ignore
377/// let sql = raw_query!("SELECT * FROM users WHERE id = {} AND active = {}", user_id, true);
378/// ```
379///
380/// The `{}` placeholders are replaced with database-specific parameter markers ($1, $2, etc.
381/// for PostgreSQL, ? for MySQL/SQLite) and the values are bound as parameters.
382#[macro_export]
383macro_rules! raw_query {
384    // Base case: just a string, no parameters
385    ($sql:expr) => {
386        $crate::raw::Sql::new($sql)
387    };
388
389    // With parameters
390    ($sql:expr, $($params:expr),+ $(,)?) => {{
391        let parts: Vec<&str> = $sql.split("{}").collect();
392        let param_values: Vec<$crate::filter::FilterValue> = vec![
393            $($params.into()),+
394        ];
395
396        let mut sql = $crate::raw::Sql::empty();
397        let mut param_iter = param_values.into_iter();
398
399        // Interleave parts and parameters
400        for (i, part) in parts.iter().enumerate() {
401            if !part.is_empty() {
402                sql = sql.push(*part);
403            }
404            if i < parts.len() - 1 {
405                if let Some(param) = param_iter.next() {
406                    sql = sql.bind(param);
407                }
408            }
409        }
410
411        sql
412    }};
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    #[test]
420    fn test_sql_new() {
421        let sql = Sql::new("SELECT * FROM users");
422        assert_eq!(sql.sql(), "SELECT * FROM users");
423        assert!(sql.params().is_empty());
424    }
425
426    #[test]
427    fn test_sql_push() {
428        let sql = Sql::new("SELECT * FROM users").push(" WHERE id = 1");
429        assert_eq!(sql.sql(), "SELECT * FROM users WHERE id = 1");
430    }
431
432    #[test]
433    fn test_sql_bind() {
434        let sql = Sql::new("SELECT * FROM users WHERE id = ").bind(42i32);
435        let (query, params) = sql.build();
436        assert_eq!(query, "SELECT * FROM users WHERE id = $1");
437        assert_eq!(params.len(), 1);
438    }
439
440    #[test]
441    fn test_sql_multiple_binds() {
442        let sql = Sql::new("SELECT * FROM users WHERE id = ")
443            .bind(42i32)
444            .push(" AND name = ")
445            .bind("John".to_string());
446        let (query, params) = sql.build();
447        assert_eq!(query, "SELECT * FROM users WHERE id = $1 AND name = $2");
448        assert_eq!(params.len(), 2);
449    }
450
451    #[test]
452    fn test_sql_push_bind() {
453        let sql = Sql::new("SELECT * FROM users WHERE").push_bind(" id = ", 42i32);
454        let (query, params) = sql.build();
455        assert_eq!(query, "SELECT * FROM users WHERE id = $1");
456        assert_eq!(params.len(), 1);
457    }
458
459    #[test]
460    fn test_sql_push_if() {
461        let include_active = true;
462        let include_deleted = false;
463
464        let sql = Sql::new("SELECT * FROM users")
465            .push_if(include_active, " WHERE active = true")
466            .push_if(include_deleted, " AND deleted = false");
467
468        assert_eq!(sql.sql(), "SELECT * FROM users WHERE active = true");
469    }
470
471    #[test]
472    fn test_sql_bind_if() {
473        let filter_id = Some(42i32);
474        let filter_name: Option<String> = None;
475
476        let sql = Sql::new("SELECT * FROM users WHERE 1=1")
477            .push_bind_if(filter_id.is_some(), " AND id = ", filter_id.unwrap_or(0))
478            .push_bind_if(filter_name.is_some(), " AND name = ", "".to_string());
479
480        let (query, params) = sql.build();
481        assert_eq!(query, "SELECT * FROM users WHERE 1=1 AND id = $1");
482        assert_eq!(params.len(), 1);
483    }
484
485    #[test]
486    fn test_sql_separated() {
487        let columns = vec!["id", "name", "email"];
488
489        let mut sep = Sql::new("SELECT ").separated(", ");
490
491        for col in columns {
492            sep = sep.push(col);
493        }
494
495        let sql = sep.finish().push(" FROM users");
496        assert_eq!(sql.sql(), "SELECT id, name, email FROM users");
497    }
498
499    #[test]
500    fn test_sql_separated_with_binds() {
501        let filters = vec![("id", 1i32), ("active", 1i32)];
502
503        let mut sep = Sql::new("SELECT * FROM users WHERE ").separated(" AND ");
504
505        for (col, val) in filters {
506            sep = sep.push_bind(format!("{} = ", col), val);
507        }
508
509        let (query, params) = sep.build();
510        assert_eq!(query, "SELECT * FROM users WHERE id = $1 AND active = $2");
511        assert_eq!(params.len(), 2);
512    }
513
514    #[test]
515    fn test_sql_mysql() {
516        let sql = Sql::new("SELECT * FROM users WHERE id = ")
517            .with_db_type(DatabaseType::MySQL)
518            .bind(42i32);
519        let (query, params) = sql.build();
520        assert_eq!(query, "SELECT * FROM users WHERE id = ?");
521        assert_eq!(params.len(), 1);
522    }
523
524    #[test]
525    fn test_sql_sqlite() {
526        let sql = Sql::new("SELECT * FROM users WHERE id = ")
527            .with_db_type(DatabaseType::SQLite)
528            .bind(42i32);
529        let (query, params) = sql.build();
530        assert_eq!(query, "SELECT * FROM users WHERE id = ?");
531        assert_eq!(params.len(), 1);
532    }
533
534    #[test]
535    fn test_sql_is_empty() {
536        assert!(Sql::empty().is_empty());
537        assert!(!Sql::new("SELECT 1").is_empty());
538    }
539
540    #[test]
541    fn test_sql_display() {
542        let sql = Sql::new("SELECT * FROM users WHERE id = ").bind(42i32);
543        assert_eq!(format!("{}", sql), "SELECT * FROM users WHERE id = $1");
544    }
545
546    #[test]
547    fn test_raw_query_macro_no_params() {
548        let sql = raw_query!("SELECT * FROM users");
549        assert_eq!(sql.sql(), "SELECT * FROM users");
550        assert!(sql.params().is_empty());
551    }
552
553    #[test]
554    fn test_raw_query_macro_with_params() {
555        let sql = raw_query!(
556            "SELECT * FROM users WHERE id = {} AND active = {}",
557            42i32,
558            true
559        );
560        let (query, params) = sql.build();
561        assert_eq!(query, "SELECT * FROM users WHERE id = $1 AND active = $2");
562        assert_eq!(params.len(), 2);
563    }
564
565    #[test]
566    fn test_raw_query_macro_string_params() {
567        let name = "John".to_string();
568        let sql = raw_query!("SELECT * FROM users WHERE name = {}", name);
569        let (query, params) = sql.build();
570        assert_eq!(query, "SELECT * FROM users WHERE name = $1");
571        assert_eq!(params.len(), 1);
572    }
573
574    #[test]
575    fn test_bind_many() {
576        let values: Vec<FilterValue> = vec![
577            FilterValue::Int(1),
578            FilterValue::Int(2),
579            FilterValue::Int(3),
580        ];
581
582        let sql = Sql::new("SELECT * FROM users WHERE id IN (")
583            .bind_many(values)
584            .push(")");
585
586        let (query, params) = sql.build();
587        assert_eq!(query, "SELECT * FROM users WHERE id IN ($1$2$3)");
588        assert_eq!(params.len(), 3);
589    }
590
591    #[test]
592    fn test_build_in_clause() {
593        let ids = vec![1, 2, 3];
594
595        let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}", i)).collect();
596
597        let sql = Sql::new(format!(
598            "SELECT * FROM users WHERE id IN ({})",
599            placeholders.join(", ")
600        ));
601
602        let params: Vec<FilterValue> = ids.into_iter().map(FilterValue::Int).collect();
603        let sql = sql_with_params(sql.sql(), params);
604
605        let (query, params) = sql.build();
606        assert_eq!(query, "SELECT * FROM users WHERE id IN ($1, $2, $3)");
607        assert_eq!(params.len(), 3);
608    }
609}