Skip to main content

prax_query/
raw.rs

1//! Raw SQL query execution with type-safe parameter interpolation.
2//!
3//! This module provides a safe way to execute raw SQL queries while still
4//! benefiting from parameterized queries to prevent SQL injection.
5//!
6//! # Creating SQL Queries
7//!
8//! ```rust
9//! use prax_query::Sql;
10//!
11//! // Simple query
12//! let sql = Sql::new("SELECT * FROM users");
13//! assert_eq!(sql.sql(), "SELECT * FROM users");
14//!
15//! // Query with parameters (binding appends placeholder)
16//! let sql = Sql::new("SELECT * FROM users WHERE id = ")
17//!     .bind(42);
18//! assert_eq!(sql.params().len(), 1);
19//! ```
20//!
21//! # Using the raw_query! Macro
22//!
23//! ```rust
24//! use prax_query::raw_query;
25//!
26//! // Simple query
27//! let sql = raw_query!("SELECT 1");
28//!
29//! // Query with one parameter - {} is replaced with $N placeholder
30//! let id = 42;
31//! let sql = raw_query!("SELECT * FROM users WHERE id = {}", id);
32//! assert_eq!(sql.params().len(), 1);
33//! assert!(sql.sql().contains("$1"));
34//!
35//! // Query with multiple parameters
36//! let name = "John";
37//! let age = 25;
38//! let sql = raw_query!("SELECT * FROM users WHERE name = {} AND age > {}", name, age);
39//! assert_eq!(sql.params().len(), 2);
40//! ```
41//!
42//! # Building Queries Incrementally
43//!
44//! ```rust
45//! use prax_query::Sql;
46//!
47//! // Join multiple conditions
48//! let conditions = vec!["active = true", "verified = true"];
49//! let sql = Sql::new("SELECT * FROM users WHERE ")
50//!     .push(conditions.join(" AND "));
51//!
52//! assert!(sql.sql().contains("active = true AND verified = true"));
53//! ```
54//!
55//! # Safety
56//!
57//! All values passed via `raw_query!` are parameterized and never interpolated
58//! directly into the SQL string, preventing SQL injection attacks.
59//!
60//! ```rust
61//! use prax_query::raw_query;
62//!
63//! // This malicious input will NOT cause SQL injection
64//! let malicious = "'; DROP TABLE users; --";
65//! let sql = raw_query!("SELECT * FROM users WHERE name = {}", malicious);
66//!
67//! // The malicious string is safely bound as a parameter
68//! assert_eq!(sql.params().len(), 1);
69//! // The SQL itself doesn't contain the malicious text
70//! assert!(!sql.sql().contains("DROP TABLE"));
71//! ```
72
73use std::marker::PhantomData;
74use tracing::debug;
75
76use crate::error::QueryResult;
77use crate::filter::FilterValue;
78use crate::sql::DatabaseType;
79use crate::traits::{Model, QueryEngine};
80
81/// A raw SQL query with parameterized values.
82#[derive(Debug, Clone)]
83pub struct Sql {
84    /// The SQL string parts (between parameters).
85    parts: Vec<String>,
86    /// The parameter values.
87    params: Vec<FilterValue>,
88    /// The database type for parameter formatting.
89    db_type: DatabaseType,
90}
91
92impl Sql {
93    /// Create a new raw SQL query.
94    pub fn new(sql: impl Into<String>) -> Self {
95        Self {
96            parts: vec![sql.into()],
97            params: Vec::new(),
98            db_type: DatabaseType::PostgreSQL,
99        }
100    }
101
102    /// Create an empty SQL query.
103    pub fn empty() -> Self {
104        Self {
105            parts: Vec::new(),
106            params: Vec::new(),
107            db_type: DatabaseType::PostgreSQL,
108        }
109    }
110
111    /// Set the database type for parameter formatting.
112    pub fn with_db_type(mut self, db_type: DatabaseType) -> Self {
113        self.db_type = db_type;
114        self
115    }
116
117    /// Append a literal SQL string.
118    pub fn push(mut self, sql: impl Into<String>) -> Self {
119        if let Some(last) = self.parts.last_mut() {
120            last.push_str(&sql.into());
121        } else {
122            self.parts.push(sql.into());
123        }
124        self
125    }
126
127    /// Bind a parameter value.
128    pub fn bind(mut self, value: impl Into<FilterValue>) -> Self {
129        let index = self.params.len() + 1;
130        let placeholder = self.db_type.placeholder(index);
131
132        if let Some(last) = self.parts.last_mut() {
133            // push_str accepts &str, which Cow<str> derefs to
134            last.push_str(&placeholder);
135        } else {
136            // Convert to owned string for storage
137            self.parts.push(placeholder.into_owned());
138        }
139
140        self.params.push(value.into());
141        self
142    }
143
144    /// Bind multiple parameter values at once.
145    pub fn bind_many(mut self, values: impl IntoIterator<Item = FilterValue>) -> Self {
146        for value in values {
147            self = self.bind(value);
148        }
149        self
150    }
151
152    /// Append a conditional clause.
153    pub fn push_if(self, condition: bool, sql: impl Into<String>) -> Self {
154        if condition { self.push(sql) } else { self }
155    }
156
157    /// Bind a parameter conditionally.
158    pub fn bind_if(self, condition: bool, value: impl Into<FilterValue>) -> Self {
159        if condition { self.bind(value) } else { self }
160    }
161
162    /// Push SQL and bind a value together.
163    pub fn push_bind(self, sql: impl Into<String>, value: impl Into<FilterValue>) -> Self {
164        self.push(sql).bind(value)
165    }
166
167    /// Push SQL and bind a value conditionally.
168    pub fn push_bind_if(
169        self,
170        condition: bool,
171        sql: impl Into<String>,
172        value: impl Into<FilterValue>,
173    ) -> Self {
174        if condition {
175            self.push(sql).bind(value)
176        } else {
177            self
178        }
179    }
180
181    /// Add a separator between parts if there are previous parts.
182    pub fn separated(self, separator: &str) -> SeparatedSql {
183        SeparatedSql {
184            sql: self,
185            separator: separator.to_string(),
186            first: true,
187        }
188    }
189
190    /// Build the final SQL string and parameters.
191    pub fn build(self) -> (String, Vec<FilterValue>) {
192        let sql = self.parts.join("");
193        debug!(sql_len = sql.len(), param_count = self.params.len(), db_type = ?self.db_type, "Sql::build()");
194        (sql, self.params)
195    }
196
197    /// Get the SQL string (without consuming).
198    pub fn sql(&self) -> String {
199        self.parts.join("")
200    }
201
202    /// Get the parameters (without consuming).
203    pub fn params(&self) -> &[FilterValue] {
204        &self.params
205    }
206
207    /// Get the number of bound parameters.
208    pub fn param_count(&self) -> usize {
209        self.params.len()
210    }
211
212    /// Check if the query is empty.
213    pub fn is_empty(&self) -> bool {
214        self.parts.is_empty() || self.parts.iter().all(|p| p.is_empty())
215    }
216}
217
218impl Default for Sql {
219    fn default() -> Self {
220        Self::empty()
221    }
222}
223
224impl std::fmt::Display for Sql {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        write!(f, "{}", self.parts.join(""))
227    }
228}
229
230/// A helper for building SQL with separators between items.
231#[derive(Debug, Clone)]
232pub struct SeparatedSql {
233    sql: Sql,
234    separator: String,
235    first: bool,
236}
237
238impl SeparatedSql {
239    /// Push a literal SQL string with separator.
240    pub fn push(mut self, sql: impl Into<String>) -> Self {
241        if !self.first {
242            self.sql = self.sql.push(&self.separator);
243        }
244        self.sql = self.sql.push(sql);
245        self.first = false;
246        self
247    }
248
249    /// Push SQL and bind a value with separator.
250    pub fn push_bind(mut self, sql: impl Into<String>, value: impl Into<FilterValue>) -> Self {
251        if !self.first {
252            self.sql = self.sql.push(&self.separator);
253        }
254        self.sql = self.sql.push(sql).bind(value);
255        self.first = false;
256        self
257    }
258
259    /// Push SQL and bind conditionally with separator.
260    pub fn push_bind_if(
261        mut self,
262        condition: bool,
263        sql: impl Into<String>,
264        value: impl Into<FilterValue>,
265    ) -> Self {
266        if condition {
267            if !self.first {
268                self.sql = self.sql.push(&self.separator);
269            }
270            self.sql = self.sql.push(sql).bind(value);
271            self.first = false;
272        }
273        self
274    }
275
276    /// Finish and get the underlying Sql.
277    pub fn finish(self) -> Sql {
278        self.sql
279    }
280
281    /// Build the final SQL string and parameters.
282    pub fn build(self) -> (String, Vec<FilterValue>) {
283        self.sql.build()
284    }
285}
286
287/// Raw query operation for executing typed queries.
288#[derive(Debug)]
289pub struct RawQueryOperation<M, E>
290where
291    M: Model + Send + 'static,
292    E: QueryEngine,
293{
294    _model: PhantomData<M>,
295    engine: E,
296    sql: Sql,
297}
298
299impl<M, E> RawQueryOperation<M, E>
300where
301    M: Model + crate::row::FromRow + Send + 'static,
302    E: QueryEngine,
303{
304    /// Create a new raw query operation.
305    pub fn new(engine: E, sql: Sql) -> Self {
306        Self {
307            _model: PhantomData,
308            engine,
309            sql,
310        }
311    }
312
313    /// Execute the query and return all matching records.
314    pub async fn exec(self) -> QueryResult<Vec<M>> {
315        let (sql, params) = self.sql.build();
316        self.engine.query_many(&sql, params).await
317    }
318
319    /// Execute the query and return a single record.
320    pub async fn exec_one(self) -> QueryResult<M> {
321        let (sql, params) = self.sql.build();
322        self.engine.query_one(&sql, params).await
323    }
324
325    /// Execute the query and return an optional record.
326    pub async fn exec_optional(self) -> QueryResult<Option<M>> {
327        let (sql, params) = self.sql.build();
328        self.engine.query_optional(&sql, params).await
329    }
330}
331
332/// Raw execute operation for mutations.
333#[derive(Debug)]
334pub struct RawExecuteOperation<E>
335where
336    E: QueryEngine,
337{
338    engine: E,
339    sql: Sql,
340}
341
342impl<E> RawExecuteOperation<E>
343where
344    E: QueryEngine,
345{
346    /// Create a new raw execute operation.
347    pub fn new(engine: E, sql: Sql) -> Self {
348        Self { engine, sql }
349    }
350
351    /// Execute the mutation and return the number of affected rows.
352    pub async fn exec(self) -> QueryResult<u64> {
353        let (sql, params) = self.sql.build();
354        self.engine.execute_raw(&sql, params).await
355    }
356}
357
358/// Helper function to create a raw SQL query from a string.
359pub fn sql(query: impl Into<String>) -> Sql {
360    Sql::new(query)
361}
362
363/// Helper function to create a raw SQL query from parts.
364///
365/// This is typically used with the `raw_query!` macro.
366pub fn sql_with_params(sql_str: impl Into<String>, params: Vec<FilterValue>) -> Sql {
367    let mut sql = Sql::new(sql_str);
368    sql.params = params;
369    sql
370}
371
372/// A macro for creating raw SQL queries with inline parameter binding.
373///
374/// # Example
375///
376/// ```rust,ignore
377/// let sql = raw_query!("SELECT * FROM users WHERE id = {} AND active = {}", user_id, true);
378/// ```
379///
380/// The `{}` placeholders are replaced with database-specific parameter markers ($1, $2, etc.
381/// for PostgreSQL, ? for MySQL/SQLite) and the values are bound as parameters.
382#[macro_export]
383macro_rules! raw_query {
384    // Base case: just a string, no parameters
385    ($sql:expr) => {
386        $crate::raw::Sql::new($sql)
387    };
388
389    // With parameters
390    ($sql:expr, $($params:expr),+ $(,)?) => {{
391        let parts: Vec<&str> = $sql.split("{}").collect();
392        let param_values: Vec<$crate::filter::FilterValue> = vec![
393            $($params.into()),+
394        ];
395
396        let mut sql = $crate::raw::Sql::empty();
397        let mut param_iter = param_values.into_iter();
398
399        // Interleave parts and parameters
400        for (i, part) in parts.iter().enumerate() {
401            if !part.is_empty() {
402                sql = sql.push(*part);
403            }
404            if i < parts.len() - 1 {
405                if let Some(param) = param_iter.next() {
406                    sql = sql.bind(param);
407                }
408            }
409        }
410
411        sql
412    }};
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    #[test]
420    fn test_sql_new() {
421        let sql = Sql::new("SELECT * FROM users");
422        assert_eq!(sql.sql(), "SELECT * FROM users");
423        assert!(sql.params().is_empty());
424    }
425
426    #[test]
427    fn test_sql_push() {
428        let sql = Sql::new("SELECT * FROM users").push(" WHERE id = 1");
429        assert_eq!(sql.sql(), "SELECT * FROM users WHERE id = 1");
430    }
431
432    #[test]
433    fn test_sql_bind() {
434        let sql = Sql::new("SELECT * FROM users WHERE id = ").bind(42i32);
435        let (query, params) = sql.build();
436        assert_eq!(query, "SELECT * FROM users WHERE id = $1");
437        assert_eq!(params.len(), 1);
438    }
439
440    #[test]
441    fn test_sql_multiple_binds() {
442        let sql = Sql::new("SELECT * FROM users WHERE id = ")
443            .bind(42i32)
444            .push(" AND name = ")
445            .bind("John".to_string());
446        let (query, params) = sql.build();
447        assert_eq!(query, "SELECT * FROM users WHERE id = $1 AND name = $2");
448        assert_eq!(params.len(), 2);
449    }
450
451    #[test]
452    fn test_sql_push_bind() {
453        let sql = Sql::new("SELECT * FROM users WHERE").push_bind(" id = ", 42i32);
454        let (query, params) = sql.build();
455        assert_eq!(query, "SELECT * FROM users WHERE id = $1");
456        assert_eq!(params.len(), 1);
457    }
458
459    #[test]
460    fn test_sql_push_if() {
461        let include_active = true;
462        let include_deleted = false;
463
464        let sql = Sql::new("SELECT * FROM users")
465            .push_if(include_active, " WHERE active = true")
466            .push_if(include_deleted, " AND deleted = false");
467
468        assert_eq!(sql.sql(), "SELECT * FROM users WHERE active = true");
469    }
470
471    #[test]
472    #[allow(clippy::unnecessary_literal_unwrap)]
473    fn test_sql_bind_if() {
474        let filter_id = Some(42i32);
475        let filter_name: Option<String> = None;
476
477        let sql = Sql::new("SELECT * FROM users WHERE 1=1")
478            .push_bind_if(filter_id.is_some(), " AND id = ", filter_id.unwrap_or(0))
479            .push_bind_if(filter_name.is_some(), " AND name = ", String::new());
480
481        let (query, params) = sql.build();
482        assert_eq!(query, "SELECT * FROM users WHERE 1=1 AND id = $1");
483        assert_eq!(params.len(), 1);
484    }
485
486    #[test]
487    fn test_sql_separated() {
488        let columns = vec!["id", "name", "email"];
489
490        let mut sep = Sql::new("SELECT ").separated(", ");
491
492        for col in columns {
493            sep = sep.push(col);
494        }
495
496        let sql = sep.finish().push(" FROM users");
497        assert_eq!(sql.sql(), "SELECT id, name, email FROM users");
498    }
499
500    #[test]
501    fn test_sql_separated_with_binds() {
502        let filters = vec![("id", 1i32), ("active", 1i32)];
503
504        let mut sep = Sql::new("SELECT * FROM users WHERE ").separated(" AND ");
505
506        for (col, val) in filters {
507            sep = sep.push_bind(format!("{} = ", col), val);
508        }
509
510        let (query, params) = sep.build();
511        assert_eq!(query, "SELECT * FROM users WHERE id = $1 AND active = $2");
512        assert_eq!(params.len(), 2);
513    }
514
515    #[test]
516    fn test_sql_mysql() {
517        let sql = Sql::new("SELECT * FROM users WHERE id = ")
518            .with_db_type(DatabaseType::MySQL)
519            .bind(42i32);
520        let (query, params) = sql.build();
521        assert_eq!(query, "SELECT * FROM users WHERE id = ?");
522        assert_eq!(params.len(), 1);
523    }
524
525    #[test]
526    fn test_sql_sqlite() {
527        let sql = Sql::new("SELECT * FROM users WHERE id = ")
528            .with_db_type(DatabaseType::SQLite)
529            .bind(42i32);
530        let (query, params) = sql.build();
531        assert_eq!(query, "SELECT * FROM users WHERE id = ?");
532        assert_eq!(params.len(), 1);
533    }
534
535    #[test]
536    fn test_sql_is_empty() {
537        assert!(Sql::empty().is_empty());
538        assert!(!Sql::new("SELECT 1").is_empty());
539    }
540
541    #[test]
542    fn test_sql_display() {
543        let sql = Sql::new("SELECT * FROM users WHERE id = ").bind(42i32);
544        assert_eq!(format!("{}", sql), "SELECT * FROM users WHERE id = $1");
545    }
546
547    #[test]
548    fn test_raw_query_macro_no_params() {
549        let sql = raw_query!("SELECT * FROM users");
550        assert_eq!(sql.sql(), "SELECT * FROM users");
551        assert!(sql.params().is_empty());
552    }
553
554    #[test]
555    fn test_raw_query_macro_with_params() {
556        let sql = raw_query!(
557            "SELECT * FROM users WHERE id = {} AND active = {}",
558            42i32,
559            true
560        );
561        let (query, params) = sql.build();
562        assert_eq!(query, "SELECT * FROM users WHERE id = $1 AND active = $2");
563        assert_eq!(params.len(), 2);
564    }
565
566    #[test]
567    fn test_raw_query_macro_string_params() {
568        let name = "John".to_string();
569        let sql = raw_query!("SELECT * FROM users WHERE name = {}", name);
570        let (query, params) = sql.build();
571        assert_eq!(query, "SELECT * FROM users WHERE name = $1");
572        assert_eq!(params.len(), 1);
573    }
574
575    #[test]
576    fn test_bind_many() {
577        let values: Vec<FilterValue> = vec![
578            FilterValue::Int(1),
579            FilterValue::Int(2),
580            FilterValue::Int(3),
581        ];
582
583        let sql = Sql::new("SELECT * FROM users WHERE id IN (")
584            .bind_many(values)
585            .push(")");
586
587        let (query, params) = sql.build();
588        assert_eq!(query, "SELECT * FROM users WHERE id IN ($1$2$3)");
589        assert_eq!(params.len(), 3);
590    }
591
592    #[test]
593    fn test_build_in_clause() {
594        let ids = vec![1, 2, 3];
595
596        let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}", i)).collect();
597
598        let sql = Sql::new(format!(
599            "SELECT * FROM users WHERE id IN ({})",
600            placeholders.join(", ")
601        ));
602
603        let params: Vec<FilterValue> = ids.into_iter().map(FilterValue::Int).collect();
604        let sql = sql_with_params(sql.sql(), params);
605
606        let (query, params) = sql.build();
607        assert_eq!(query, "SELECT * FROM users WHERE id IN ($1, $2, $3)");
608        assert_eq!(params.len(), 3);
609    }
610}