Skip to main content

exarrow_rs/query/
statement.rs

1//! SQL statement handling and execution.
2//!
3//! This module provides the `Statement` type as a pure data container for SQL queries
4//! with parameter binding. Statement execution is handled by Connection.
5
6use crate::error::QueryError;
7
8/// Type of SQL statement.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum StatementType {
11    /// SELECT query
12    Select,
13    /// INSERT statement
14    Insert,
15    /// UPDATE statement
16    Update,
17    /// DELETE statement
18    Delete,
19    /// DDL statement (CREATE, ALTER, DROP)
20    Ddl,
21    /// Transaction control (BEGIN, COMMIT, ROLLBACK)
22    Transaction,
23    /// Unknown or other statement type
24    Other,
25}
26
27impl StatementType {
28    /// Detect statement type from SQL text.
29    pub fn from_sql(sql: &str) -> Self {
30        let trimmed = sql.trim_start().to_uppercase();
31
32        if trimmed.starts_with("SELECT") || trimmed.starts_with("WITH") {
33            Self::Select
34        } else if trimmed.starts_with("INSERT") {
35            Self::Insert
36        } else if trimmed.starts_with("UPDATE") {
37            Self::Update
38        } else if trimmed.starts_with("DELETE") {
39            Self::Delete
40        } else if trimmed.starts_with("CREATE")
41            || trimmed.starts_with("ALTER")
42            || trimmed.starts_with("DROP")
43            || trimmed.starts_with("TRUNCATE")
44        {
45            Self::Ddl
46        } else if trimmed.starts_with("BEGIN")
47            || trimmed.starts_with("COMMIT")
48            || trimmed.starts_with("ROLLBACK")
49        {
50            Self::Transaction
51        } else {
52            Self::Other
53        }
54    }
55
56    /// Check if this statement type returns a result set.
57    pub fn returns_result_set(&self) -> bool {
58        matches!(self, Self::Select)
59    }
60
61    /// Check if this statement type returns a row count.
62    pub fn returns_row_count(&self) -> bool {
63        matches!(self, Self::Insert | Self::Update | Self::Delete)
64    }
65}
66
67/// Parameter value for prepared statements.
68#[derive(Debug, Clone)]
69pub enum Parameter {
70    /// NULL value
71    Null,
72    /// Boolean value
73    Boolean(bool),
74    /// Integer value
75    Integer(i64),
76    /// Float value
77    Float(f64),
78    /// String value
79    String(String),
80    /// Binary data
81    Binary(Vec<u8>),
82}
83
84impl Parameter {
85    /// Convert parameter to SQL literal string.
86    ///
87    /// This is a basic implementation for Phase 1.
88    /// In production, use proper prepared statement protocol.
89    pub fn to_sql_literal(&self) -> Result<String, QueryError> {
90        match self {
91            Parameter::Null => Ok("NULL".to_string()),
92            Parameter::Boolean(b) => Ok(if *b { "TRUE" } else { "FALSE" }.to_string()),
93            Parameter::Integer(i) => Ok(i.to_string()),
94            Parameter::Float(f) => {
95                if f.is_nan() || f.is_infinite() {
96                    Err(QueryError::ParameterBindingError {
97                        index: 0,
98                        message: "NaN and Infinity are not supported".to_string(),
99                    })
100                } else {
101                    Ok(f.to_string())
102                }
103            }
104            Parameter::String(s) => {
105                // Additional check for suspicious patterns (before escaping)
106                if Self::contains_sql_injection_pattern(s) {
107                    return Err(QueryError::SqlInjectionDetected);
108                }
109
110                // Basic SQL injection prevention: escape single quotes
111                let escaped = s.replace('\'', "''");
112
113                Ok(format!("'{}'", escaped))
114            }
115            Parameter::Binary(b) => {
116                // Convert binary to hex string
117                Ok(format!("'{}'", hex::encode(b)))
118            }
119        }
120    }
121
122    /// Basic SQL injection detection.
123    fn contains_sql_injection_pattern(s: &str) -> bool {
124        let upper = s.to_uppercase();
125
126        // Check for common SQL injection patterns
127        let patterns = [
128            "'; DROP",
129            "'; DELETE",
130            "'; UPDATE",
131            "'; INSERT",
132            "' OR '1'='1",
133            "' OR 1=1",
134            "' OR TRUE",
135            "UNION SELECT",
136            "EXEC(",
137            "EXECUTE(",
138        ];
139
140        patterns.iter().any(|pattern| upper.contains(pattern))
141    }
142}
143
144impl From<bool> for Parameter {
145    fn from(value: bool) -> Self {
146        Parameter::Boolean(value)
147    }
148}
149
150impl From<i32> for Parameter {
151    fn from(value: i32) -> Self {
152        Parameter::Integer(value as i64)
153    }
154}
155
156impl From<i64> for Parameter {
157    fn from(value: i64) -> Self {
158        Parameter::Integer(value)
159    }
160}
161
162impl From<f64> for Parameter {
163    fn from(value: f64) -> Self {
164        Parameter::Float(value)
165    }
166}
167
168impl From<String> for Parameter {
169    fn from(value: String) -> Self {
170        Parameter::String(value)
171    }
172}
173
174impl From<&str> for Parameter {
175    fn from(value: &str) -> Self {
176        Parameter::String(value.to_string())
177    }
178}
179
180impl From<Vec<u8>> for Parameter {
181    fn from(value: Vec<u8>) -> Self {
182        Parameter::Binary(value)
183    }
184}
185
186/// SQL statement as a pure data container.
187///
188/// Statement holds SQL text, parameters, timeout, and statement type.
189/// Execution is performed by Connection, not by Statement itself.
190///
191/// # Example
192///
193pub struct Statement {
194    /// SQL text (may contain parameter placeholders)
195    sql: String,
196    /// Bound parameters (indexed by position)
197    parameters: Vec<Option<Parameter>>,
198    /// Query timeout in milliseconds
199    timeout_ms: u64,
200    /// Statement type
201    statement_type: StatementType,
202}
203
204impl Statement {
205    /// Create a new statement.
206    pub fn new(sql: impl Into<String>) -> Self {
207        let sql = sql.into();
208        let statement_type = StatementType::from_sql(&sql);
209
210        Self {
211            sql,
212            parameters: Vec::new(),
213            timeout_ms: 120_000, // 2 minutes default
214            statement_type,
215        }
216    }
217
218    /// Get the SQL text.
219    pub fn sql(&self) -> &str {
220        &self.sql
221    }
222
223    /// Get the statement type.
224    pub fn statement_type(&self) -> StatementType {
225        self.statement_type
226    }
227
228    /// Get the timeout in milliseconds.
229    pub fn timeout_ms(&self) -> u64 {
230        self.timeout_ms
231    }
232
233    /// Set query timeout.
234    pub fn set_timeout(&mut self, timeout_ms: u64) {
235        self.timeout_ms = timeout_ms;
236    }
237
238    /// Bind a parameter at the given index.
239    ///
240    /// # Arguments
241    /// * `index` - Parameter index (0-based)
242    /// * `value` - Parameter value
243    ///
244    /// # Errors
245    /// Returns `QueryError::ParameterBindingError` if binding fails.
246    pub fn bind<T: Into<Parameter>>(&mut self, index: usize, value: T) -> Result<(), QueryError> {
247        // Ensure parameters vector is large enough
248        if index >= self.parameters.len() {
249            self.parameters.resize(index + 1, None);
250        }
251
252        self.parameters[index] = Some(value.into());
253        Ok(())
254    }
255
256    /// Bind multiple parameters.
257    pub fn bind_all<T: Into<Parameter> + Clone>(&mut self, params: &[T]) -> Result<(), QueryError> {
258        for (index, param) in params.iter().enumerate() {
259            self.bind(index, param.clone())?;
260        }
261        Ok(())
262    }
263
264    /// Clear all bound parameters.
265    pub fn clear_parameters(&mut self) {
266        self.parameters.clear();
267    }
268
269    /// Get bound parameters.
270    pub fn parameters(&self) -> &[Option<Parameter>] {
271        &self.parameters
272    }
273
274    /// Build the final SQL with parameters substituted.
275    ///
276    /// This is used internally by Connection when executing statements.
277    pub fn build_sql(&self) -> Result<String, QueryError> {
278        let mut sql = self.sql.clone();
279        let mut param_index = 0;
280
281        // Replace '?' placeholders with parameter values
282        while let Some(pos) = sql.find('?') {
283            if param_index >= self.parameters.len() {
284                return Err(QueryError::ParameterBindingError {
285                    index: param_index,
286                    message: "Not enough parameters bound".to_string(),
287                });
288            }
289
290            let param = self.parameters[param_index].as_ref().ok_or_else(|| {
291                QueryError::ParameterBindingError {
292                    index: param_index,
293                    message: "Parameter not bound".to_string(),
294                }
295            })?;
296
297            let literal = param.to_sql_literal()?;
298            sql.replace_range(pos..pos + 1, &literal);
299            param_index += 1;
300        }
301
302        Ok(sql)
303    }
304}
305
306impl std::fmt::Debug for Statement {
307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        f.debug_struct("Statement")
309            .field("sql", &self.sql)
310            .field("statement_type", &self.statement_type)
311            .field("timeout_ms", &self.timeout_ms)
312            .finish()
313    }
314}
315
316impl std::fmt::Display for Statement {
317    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        write!(f, "Statement({})", self.sql)
319    }
320}
321
322#[cfg(test)]
323#[allow(clippy::approx_constant)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_statement_type_detection() {
329        assert_eq!(
330            StatementType::from_sql("SELECT * FROM users"),
331            StatementType::Select
332        );
333        assert_eq!(
334            StatementType::from_sql("  select * from users"),
335            StatementType::Select
336        );
337        assert_eq!(
338            StatementType::from_sql("WITH cte AS (SELECT 1) SELECT * FROM cte"),
339            StatementType::Select
340        );
341        assert_eq!(
342            StatementType::from_sql("INSERT INTO users VALUES (1)"),
343            StatementType::Insert
344        );
345        assert_eq!(
346            StatementType::from_sql("UPDATE users SET name = 'John'"),
347            StatementType::Update
348        );
349        assert_eq!(
350            StatementType::from_sql("DELETE FROM users WHERE id = 1"),
351            StatementType::Delete
352        );
353        assert_eq!(
354            StatementType::from_sql("CREATE TABLE test (id INT)"),
355            StatementType::Ddl
356        );
357        assert_eq!(
358            StatementType::from_sql("DROP TABLE test"),
359            StatementType::Ddl
360        );
361        assert_eq!(StatementType::from_sql("BEGIN"), StatementType::Transaction);
362        assert_eq!(
363            StatementType::from_sql("COMMIT"),
364            StatementType::Transaction
365        );
366        assert_eq!(
367            StatementType::from_sql("ROLLBACK"),
368            StatementType::Transaction
369        );
370    }
371
372    #[test]
373    fn test_statement_type_returns_result_set() {
374        assert!(StatementType::Select.returns_result_set());
375        assert!(!StatementType::Insert.returns_result_set());
376        assert!(!StatementType::Update.returns_result_set());
377        assert!(!StatementType::Delete.returns_result_set());
378    }
379
380    #[test]
381    fn test_parameter_to_sql_literal() {
382        assert_eq!(Parameter::Null.to_sql_literal().unwrap(), "NULL");
383        assert_eq!(Parameter::Boolean(true).to_sql_literal().unwrap(), "TRUE");
384        assert_eq!(Parameter::Boolean(false).to_sql_literal().unwrap(), "FALSE");
385        assert_eq!(Parameter::Integer(42).to_sql_literal().unwrap(), "42");
386        assert_eq!(Parameter::Float(3.14).to_sql_literal().unwrap(), "3.14");
387        assert_eq!(
388            Parameter::String("hello".to_string())
389                .to_sql_literal()
390                .unwrap(),
391            "'hello'"
392        );
393    }
394
395    #[test]
396    fn test_parameter_string_escaping() {
397        let param = Parameter::String("O'Reilly".to_string());
398        assert_eq!(param.to_sql_literal().unwrap(), "'O''Reilly'");
399    }
400
401    #[test]
402    fn test_parameter_sql_injection_detection() {
403        let dangerous = Parameter::String("'; DROP TABLE users; --".to_string());
404        assert!(dangerous.to_sql_literal().is_err());
405
406        let malicious = Parameter::String("' OR '1'='1".to_string());
407        assert!(malicious.to_sql_literal().is_err());
408
409        let safe = Parameter::String("It's a nice day".to_string());
410        assert!(safe.to_sql_literal().is_ok());
411    }
412
413    #[test]
414    fn test_parameter_conversions() {
415        let _p: Parameter = true.into();
416        let _p: Parameter = 42i32.into();
417        let _p: Parameter = 42i64.into();
418        let _p: Parameter = 3.14f64.into();
419        let _p: Parameter = "test".into();
420        let _p: Parameter = String::from("test").into();
421        let _p: Parameter = vec![1u8, 2, 3].into();
422    }
423
424    #[test]
425    fn test_statement_creation() {
426        let stmt = Statement::new("SELECT * FROM users");
427
428        assert_eq!(stmt.sql(), "SELECT * FROM users");
429        assert_eq!(stmt.statement_type(), StatementType::Select);
430        assert_eq!(stmt.timeout_ms(), 120_000);
431    }
432
433    #[test]
434    fn test_statement_parameter_binding() {
435        let mut stmt = Statement::new("SELECT * FROM users WHERE id = ?");
436
437        stmt.bind(0, 42).unwrap();
438
439        let final_sql = stmt.build_sql().unwrap();
440        assert_eq!(final_sql, "SELECT * FROM users WHERE id = 42");
441    }
442
443    #[test]
444    fn test_statement_multiple_parameters() {
445        let mut stmt = Statement::new("SELECT * FROM users WHERE age > ? AND name = ?");
446
447        stmt.bind(0, 18).unwrap();
448        stmt.bind(1, "John").unwrap();
449
450        let final_sql = stmt.build_sql().unwrap();
451        assert_eq!(
452            final_sql,
453            "SELECT * FROM users WHERE age > 18 AND name = 'John'"
454        );
455    }
456
457    #[test]
458    fn test_statement_set_timeout() {
459        let mut stmt = Statement::new("SELECT * FROM users");
460        stmt.set_timeout(30_000);
461        assert_eq!(stmt.timeout_ms(), 30_000);
462    }
463
464    #[test]
465    fn test_statement_clear_parameters() {
466        let mut stmt = Statement::new("SELECT * FROM users WHERE id = ?");
467        stmt.bind(0, 42).unwrap();
468        stmt.clear_parameters();
469        assert!(stmt.parameters().is_empty());
470    }
471
472    #[test]
473    fn test_statement_display() {
474        let stmt = Statement::new("SELECT 1");
475        let display = format!("{}", stmt);
476        assert!(display.contains("SELECT 1"));
477    }
478}