prax_query/
raw.rs

1//! Raw SQL query execution with type-safe parameter interpolation.
2//!
3//! This module provides a safe way to execute raw SQL queries while still
4//! benefiting from parameterized queries to prevent SQL injection.
5//!
6//! # Creating SQL Queries
7//!
8//! ```rust
9//! use prax_query::Sql;
10//!
11//! // Simple query
12//! let sql = Sql::new("SELECT * FROM users");
13//! assert_eq!(sql.sql(), "SELECT * FROM users");
14//!
15//! // Query with parameters (binding appends placeholder)
16//! let sql = Sql::new("SELECT * FROM users WHERE id = ")
17//!     .bind(42);
18//! assert_eq!(sql.params().len(), 1);
19//! ```
20//!
21//! # Using the raw_query! Macro
22//!
23//! ```rust
24//! use prax_query::raw_query;
25//!
26//! // Simple query
27//! let sql = raw_query!("SELECT 1");
28//!
29//! // Query with one parameter - {} is replaced with $N placeholder
30//! let id = 42;
31//! let sql = raw_query!("SELECT * FROM users WHERE id = {}", id);
32//! assert_eq!(sql.params().len(), 1);
33//! assert!(sql.sql().contains("$1"));
34//!
35//! // Query with multiple parameters
36//! let name = "John";
37//! let age = 25;
38//! let sql = raw_query!("SELECT * FROM users WHERE name = {} AND age > {}", name, age);
39//! assert_eq!(sql.params().len(), 2);
40//! ```
41//!
42//! # Building Queries Incrementally
43//!
44//! ```rust
45//! use prax_query::Sql;
46//!
47//! // Join multiple conditions
48//! let conditions = vec!["active = true", "verified = true"];
49//! let sql = Sql::new("SELECT * FROM users WHERE ")
50//!     .push(conditions.join(" AND "));
51//!
52//! assert!(sql.sql().contains("active = true AND verified = true"));
53//! ```
54//!
55//! # Safety
56//!
57//! All values passed via `raw_query!` are parameterized and never interpolated
58//! directly into the SQL string, preventing SQL injection attacks.
59//!
60//! ```rust
61//! use prax_query::raw_query;
62//!
63//! // This malicious input will NOT cause SQL injection
64//! let malicious = "'; DROP TABLE users; --";
65//! let sql = raw_query!("SELECT * FROM users WHERE name = {}", malicious);
66//!
67//! // The malicious string is safely bound as a parameter
68//! assert_eq!(sql.params().len(), 1);
69//! // The SQL itself doesn't contain the malicious text
70//! assert!(!sql.sql().contains("DROP TABLE"));
71//! ```
72
73use std::marker::PhantomData;
74use tracing::debug;
75
76use crate::error::QueryResult;
77use crate::filter::FilterValue;
78use crate::sql::DatabaseType;
79use crate::traits::{Model, QueryEngine};
80
81/// A raw SQL query with parameterized values.
82#[derive(Debug, Clone)]
83pub struct Sql {
84    /// The SQL string parts (between parameters).
85    parts: Vec<String>,
86    /// The parameter values.
87    params: Vec<FilterValue>,
88    /// The database type for parameter formatting.
89    db_type: DatabaseType,
90}
91
92impl Sql {
93    /// Create a new raw SQL query.
94    pub fn new(sql: impl Into<String>) -> Self {
95        Self {
96            parts: vec![sql.into()],
97            params: Vec::new(),
98            db_type: DatabaseType::PostgreSQL,
99        }
100    }
101
102    /// Create an empty SQL query.
103    pub fn empty() -> Self {
104        Self {
105            parts: Vec::new(),
106            params: Vec::new(),
107            db_type: DatabaseType::PostgreSQL,
108        }
109    }
110
111    /// Set the database type for parameter formatting.
112    pub fn with_db_type(mut self, db_type: DatabaseType) -> Self {
113        self.db_type = db_type;
114        self
115    }
116
117    /// Append a literal SQL string.
118    pub fn push(mut self, sql: impl Into<String>) -> Self {
119        if let Some(last) = self.parts.last_mut() {
120            last.push_str(&sql.into());
121        } else {
122            self.parts.push(sql.into());
123        }
124        self
125    }
126
127    /// Bind a parameter value.
128    pub fn bind(mut self, value: impl Into<FilterValue>) -> Self {
129        let index = self.params.len() + 1;
130        let placeholder = self.db_type.placeholder(index);
131
132        if let Some(last) = self.parts.last_mut() {
133            // push_str accepts &str, which Cow<str> derefs to
134            last.push_str(&placeholder);
135        } else {
136            // Convert to owned string for storage
137            self.parts.push(placeholder.into_owned());
138        }
139
140        self.params.push(value.into());
141        self
142    }
143
144    /// Bind multiple parameter values at once.
145    pub fn bind_many(mut self, values: impl IntoIterator<Item = FilterValue>) -> Self {
146        for value in values {
147            self = self.bind(value);
148        }
149        self
150    }
151
152    /// Append a conditional clause.
153    pub fn push_if(self, condition: bool, sql: impl Into<String>) -> Self {
154        if condition {
155            self.push(sql)
156        } else {
157            self
158        }
159    }
160
161    /// Bind a parameter conditionally.
162    pub fn bind_if(self, condition: bool, value: impl Into<FilterValue>) -> Self {
163        if condition {
164            self.bind(value)
165        } else {
166            self
167        }
168    }
169
170    /// Push SQL and bind a value together.
171    pub fn push_bind(self, sql: impl Into<String>, value: impl Into<FilterValue>) -> Self {
172        self.push(sql).bind(value)
173    }
174
175    /// Push SQL and bind a value conditionally.
176    pub fn push_bind_if(
177        self,
178        condition: bool,
179        sql: impl Into<String>,
180        value: impl Into<FilterValue>,
181    ) -> Self {
182        if condition {
183            self.push(sql).bind(value)
184        } else {
185            self
186        }
187    }
188
189    /// Add a separator between parts if there are previous parts.
190    pub fn separated(self, separator: &str) -> SeparatedSql {
191        SeparatedSql {
192            sql: self,
193            separator: separator.to_string(),
194            first: true,
195        }
196    }
197
198    /// Build the final SQL string and parameters.
199    pub fn build(self) -> (String, Vec<FilterValue>) {
200        let sql = self.parts.join("");
201        debug!(sql_len = sql.len(), param_count = self.params.len(), db_type = ?self.db_type, "Sql::build()");
202        (sql, self.params)
203    }
204
205    /// Get the SQL string (without consuming).
206    pub fn sql(&self) -> String {
207        self.parts.join("")
208    }
209
210    /// Get the parameters (without consuming).
211    pub fn params(&self) -> &[FilterValue] {
212        &self.params
213    }
214
215    /// Get the number of bound parameters.
216    pub fn param_count(&self) -> usize {
217        self.params.len()
218    }
219
220    /// Check if the query is empty.
221    pub fn is_empty(&self) -> bool {
222        self.parts.is_empty() || self.parts.iter().all(|p| p.is_empty())
223    }
224}
225
226impl Default for Sql {
227    fn default() -> Self {
228        Self::empty()
229    }
230}
231
232impl std::fmt::Display for Sql {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        write!(f, "{}", self.parts.join(""))
235    }
236}
237
238/// A helper for building SQL with separators between items.
239#[derive(Debug, Clone)]
240pub struct SeparatedSql {
241    sql: Sql,
242    separator: String,
243    first: bool,
244}
245
246impl SeparatedSql {
247    /// Push a literal SQL string with separator.
248    pub fn push(mut self, sql: impl Into<String>) -> Self {
249        if !self.first {
250            self.sql = self.sql.push(&self.separator);
251        }
252        self.sql = self.sql.push(sql);
253        self.first = false;
254        self
255    }
256
257    /// Push SQL and bind a value with separator.
258    pub fn push_bind(mut self, sql: impl Into<String>, value: impl Into<FilterValue>) -> Self {
259        if !self.first {
260            self.sql = self.sql.push(&self.separator);
261        }
262        self.sql = self.sql.push(sql).bind(value);
263        self.first = false;
264        self
265    }
266
267    /// Push SQL and bind conditionally with separator.
268    pub fn push_bind_if(
269        mut self,
270        condition: bool,
271        sql: impl Into<String>,
272        value: impl Into<FilterValue>,
273    ) -> Self {
274        if condition {
275            if !self.first {
276                self.sql = self.sql.push(&self.separator);
277            }
278            self.sql = self.sql.push(sql).bind(value);
279            self.first = false;
280        }
281        self
282    }
283
284    /// Finish and get the underlying Sql.
285    pub fn finish(self) -> Sql {
286        self.sql
287    }
288
289    /// Build the final SQL string and parameters.
290    pub fn build(self) -> (String, Vec<FilterValue>) {
291        self.sql.build()
292    }
293}
294
295/// Raw query operation for executing typed queries.
296#[derive(Debug)]
297pub struct RawQueryOperation<M, E>
298where
299    M: Model + Send + 'static,
300    E: QueryEngine,
301{
302    _model: PhantomData<M>,
303    engine: E,
304    sql: Sql,
305}
306
307impl<M, E> RawQueryOperation<M, E>
308where
309    M: Model + Send + 'static,
310    E: QueryEngine,
311{
312    /// Create a new raw query operation.
313    pub fn new(engine: E, sql: Sql) -> Self {
314        Self {
315            _model: PhantomData,
316            engine,
317            sql,
318        }
319    }
320
321    /// Execute the query and return all matching records.
322    pub async fn exec(self) -> QueryResult<Vec<M>> {
323        let (sql, params) = self.sql.build();
324        self.engine.query_many(&sql, params).await
325    }
326
327    /// Execute the query and return a single record.
328    pub async fn exec_one(self) -> QueryResult<M> {
329        let (sql, params) = self.sql.build();
330        self.engine.query_one(&sql, params).await
331    }
332
333    /// Execute the query and return an optional record.
334    pub async fn exec_optional(self) -> QueryResult<Option<M>> {
335        let (sql, params) = self.sql.build();
336        self.engine.query_optional(&sql, params).await
337    }
338}
339
340/// Raw execute operation for mutations.
341#[derive(Debug)]
342pub struct RawExecuteOperation<E>
343where
344    E: QueryEngine,
345{
346    engine: E,
347    sql: Sql,
348}
349
350impl<E> RawExecuteOperation<E>
351where
352    E: QueryEngine,
353{
354    /// Create a new raw execute operation.
355    pub fn new(engine: E, sql: Sql) -> Self {
356        Self { engine, sql }
357    }
358
359    /// Execute the mutation and return the number of affected rows.
360    pub async fn exec(self) -> QueryResult<u64> {
361        let (sql, params) = self.sql.build();
362        self.engine.execute_raw(&sql, params).await
363    }
364}
365
366/// Helper function to create a raw SQL query from a string.
367pub fn sql(query: impl Into<String>) -> Sql {
368    Sql::new(query)
369}
370
371/// Helper function to create a raw SQL query from parts.
372///
373/// This is typically used with the `raw_query!` macro.
374pub fn sql_with_params(sql_str: impl Into<String>, params: Vec<FilterValue>) -> Sql {
375    let mut sql = Sql::new(sql_str);
376    sql.params = params;
377    sql
378}
379
380/// A macro for creating raw SQL queries with inline parameter binding.
381///
382/// # Example
383///
384/// ```rust,ignore
385/// let sql = raw_query!("SELECT * FROM users WHERE id = {} AND active = {}", user_id, true);
386/// ```
387///
388/// The `{}` placeholders are replaced with database-specific parameter markers ($1, $2, etc.
389/// for PostgreSQL, ? for MySQL/SQLite) and the values are bound as parameters.
390#[macro_export]
391macro_rules! raw_query {
392    // Base case: just a string, no parameters
393    ($sql:expr) => {
394        $crate::raw::Sql::new($sql)
395    };
396
397    // With parameters
398    ($sql:expr, $($params:expr),+ $(,)?) => {{
399        let parts: Vec<&str> = $sql.split("{}").collect();
400        let param_values: Vec<$crate::filter::FilterValue> = vec![
401            $($params.into()),+
402        ];
403
404        let mut sql = $crate::raw::Sql::empty();
405        let mut param_iter = param_values.into_iter();
406
407        // Interleave parts and parameters
408        for (i, part) in parts.iter().enumerate() {
409            if !part.is_empty() {
410                sql = sql.push(*part);
411            }
412            if i < parts.len() - 1 {
413                if let Some(param) = param_iter.next() {
414                    sql = sql.bind(param);
415                }
416            }
417        }
418
419        sql
420    }};
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn test_sql_new() {
429        let sql = Sql::new("SELECT * FROM users");
430        assert_eq!(sql.sql(), "SELECT * FROM users");
431        assert!(sql.params().is_empty());
432    }
433
434    #[test]
435    fn test_sql_push() {
436        let sql = Sql::new("SELECT * FROM users").push(" WHERE id = 1");
437        assert_eq!(sql.sql(), "SELECT * FROM users WHERE id = 1");
438    }
439
440    #[test]
441    fn test_sql_bind() {
442        let sql = Sql::new("SELECT * FROM users WHERE id = ").bind(42i32);
443        let (query, params) = sql.build();
444        assert_eq!(query, "SELECT * FROM users WHERE id = $1");
445        assert_eq!(params.len(), 1);
446    }
447
448    #[test]
449    fn test_sql_multiple_binds() {
450        let sql = Sql::new("SELECT * FROM users WHERE id = ")
451            .bind(42i32)
452            .push(" AND name = ")
453            .bind("John".to_string());
454        let (query, params) = sql.build();
455        assert_eq!(query, "SELECT * FROM users WHERE id = $1 AND name = $2");
456        assert_eq!(params.len(), 2);
457    }
458
459    #[test]
460    fn test_sql_push_bind() {
461        let sql = Sql::new("SELECT * FROM users WHERE")
462            .push_bind(" id = ", 42i32);
463        let (query, params) = sql.build();
464        assert_eq!(query, "SELECT * FROM users WHERE id = $1");
465        assert_eq!(params.len(), 1);
466    }
467
468    #[test]
469    fn test_sql_push_if() {
470        let include_active = true;
471        let include_deleted = false;
472
473        let sql = Sql::new("SELECT * FROM users")
474            .push_if(include_active, " WHERE active = true")
475            .push_if(include_deleted, " AND deleted = false");
476
477        assert_eq!(sql.sql(), "SELECT * FROM users WHERE active = true");
478    }
479
480    #[test]
481    fn test_sql_bind_if() {
482        let filter_id = Some(42i32);
483        let filter_name: Option<String> = None;
484
485        let sql = Sql::new("SELECT * FROM users WHERE 1=1")
486            .push_bind_if(filter_id.is_some(), " AND id = ", filter_id.unwrap_or(0))
487            .push_bind_if(filter_name.is_some(), " AND name = ", "".to_string());
488
489        let (query, params) = sql.build();
490        assert_eq!(query, "SELECT * FROM users WHERE 1=1 AND id = $1");
491        assert_eq!(params.len(), 1);
492    }
493
494    #[test]
495    fn test_sql_separated() {
496        let columns = vec!["id", "name", "email"];
497
498        let mut sep = Sql::new("SELECT ")
499            .separated(", ");
500
501        for col in columns {
502            sep = sep.push(col);
503        }
504
505        let sql = sep.finish().push(" FROM users");
506        assert_eq!(sql.sql(), "SELECT id, name, email FROM users");
507    }
508
509    #[test]
510    fn test_sql_separated_with_binds() {
511        let filters = vec![("id", 1i32), ("active", 1i32)];
512
513        let mut sep = Sql::new("SELECT * FROM users WHERE ")
514            .separated(" AND ");
515
516        for (col, val) in filters {
517            sep = sep.push_bind(format!("{} = ", col), val);
518        }
519
520        let (query, params) = sep.build();
521        assert_eq!(query, "SELECT * FROM users WHERE id = $1 AND active = $2");
522        assert_eq!(params.len(), 2);
523    }
524
525    #[test]
526    fn test_sql_mysql() {
527        let sql = Sql::new("SELECT * FROM users WHERE id = ")
528            .with_db_type(DatabaseType::MySQL)
529            .bind(42i32);
530        let (query, params) = sql.build();
531        assert_eq!(query, "SELECT * FROM users WHERE id = ?");
532        assert_eq!(params.len(), 1);
533    }
534
535    #[test]
536    fn test_sql_sqlite() {
537        let sql = Sql::new("SELECT * FROM users WHERE id = ")
538            .with_db_type(DatabaseType::SQLite)
539            .bind(42i32);
540        let (query, params) = sql.build();
541        assert_eq!(query, "SELECT * FROM users WHERE id = ?");
542        assert_eq!(params.len(), 1);
543    }
544
545    #[test]
546    fn test_sql_is_empty() {
547        assert!(Sql::empty().is_empty());
548        assert!(!Sql::new("SELECT 1").is_empty());
549    }
550
551    #[test]
552    fn test_sql_display() {
553        let sql = Sql::new("SELECT * FROM users WHERE id = ").bind(42i32);
554        assert_eq!(format!("{}", sql), "SELECT * FROM users WHERE id = $1");
555    }
556
557    #[test]
558    fn test_raw_query_macro_no_params() {
559        let sql = raw_query!("SELECT * FROM users");
560        assert_eq!(sql.sql(), "SELECT * FROM users");
561        assert!(sql.params().is_empty());
562    }
563
564    #[test]
565    fn test_raw_query_macro_with_params() {
566        let sql = raw_query!("SELECT * FROM users WHERE id = {} AND active = {}", 42i32, true);
567        let (query, params) = sql.build();
568        assert_eq!(query, "SELECT * FROM users WHERE id = $1 AND active = $2");
569        assert_eq!(params.len(), 2);
570    }
571
572    #[test]
573    fn test_raw_query_macro_string_params() {
574        let name = "John".to_string();
575        let sql = raw_query!("SELECT * FROM users WHERE name = {}", name);
576        let (query, params) = sql.build();
577        assert_eq!(query, "SELECT * FROM users WHERE name = $1");
578        assert_eq!(params.len(), 1);
579    }
580
581    #[test]
582    fn test_bind_many() {
583        let values: Vec<FilterValue> = vec![
584            FilterValue::Int(1),
585            FilterValue::Int(2),
586            FilterValue::Int(3),
587        ];
588
589        let sql = Sql::new("SELECT * FROM users WHERE id IN (")
590            .bind_many(values)
591            .push(")");
592
593        let (query, params) = sql.build();
594        assert_eq!(query, "SELECT * FROM users WHERE id IN ($1$2$3)");
595        assert_eq!(params.len(), 3);
596    }
597
598    #[test]
599    fn test_build_in_clause() {
600        let ids = vec![1, 2, 3];
601
602        let placeholders: Vec<String> = (1..=ids.len())
603            .map(|i| format!("${}", i))
604            .collect();
605
606        let sql = Sql::new(format!(
607            "SELECT * FROM users WHERE id IN ({})",
608            placeholders.join(", ")
609        ));
610
611        let params: Vec<FilterValue> = ids.into_iter().map(FilterValue::Int).collect();
612        let sql = sql_with_params(sql.sql(), params);
613
614        let (query, params) = sql.build();
615        assert_eq!(query, "SELECT * FROM users WHERE id IN ($1, $2, $3)");
616        assert_eq!(params.len(), 3);
617    }
618}
619