1use crate::error::QueryError;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum StatementType {
11 Select,
13 Insert,
15 Update,
17 Delete,
19 Ddl,
21 Transaction,
23 Other,
25}
26
27impl StatementType {
28 pub fn from_sql(sql: &str) -> Self {
30 let trimmed = sql.trim_start().to_uppercase();
31
32 if trimmed.starts_with("SELECT") || trimmed.starts_with("WITH") {
33 Self::Select
34 } else if trimmed.starts_with("INSERT") {
35 Self::Insert
36 } else if trimmed.starts_with("UPDATE") {
37 Self::Update
38 } else if trimmed.starts_with("DELETE") {
39 Self::Delete
40 } else if trimmed.starts_with("CREATE")
41 || trimmed.starts_with("ALTER")
42 || trimmed.starts_with("DROP")
43 || trimmed.starts_with("TRUNCATE")
44 {
45 Self::Ddl
46 } else if trimmed.starts_with("BEGIN")
47 || trimmed.starts_with("COMMIT")
48 || trimmed.starts_with("ROLLBACK")
49 {
50 Self::Transaction
51 } else {
52 Self::Other
53 }
54 }
55
56 pub fn returns_result_set(&self) -> bool {
58 matches!(self, Self::Select)
59 }
60
61 pub fn returns_row_count(&self) -> bool {
63 matches!(self, Self::Insert | Self::Update | Self::Delete)
64 }
65}
66
67#[derive(Debug, Clone)]
69pub enum Parameter {
70 Null,
72 Boolean(bool),
74 Integer(i64),
76 Float(f64),
78 String(String),
80 Binary(Vec<u8>),
82}
83
84impl Parameter {
85 pub fn to_sql_literal(&self) -> Result<String, QueryError> {
90 match self {
91 Parameter::Null => Ok("NULL".to_string()),
92 Parameter::Boolean(b) => Ok(if *b { "TRUE" } else { "FALSE" }.to_string()),
93 Parameter::Integer(i) => Ok(i.to_string()),
94 Parameter::Float(f) => {
95 if f.is_nan() || f.is_infinite() {
96 Err(QueryError::ParameterBindingError {
97 index: 0,
98 message: "NaN and Infinity are not supported".to_string(),
99 })
100 } else {
101 Ok(f.to_string())
102 }
103 }
104 Parameter::String(s) => {
105 if Self::contains_sql_injection_pattern(s) {
107 return Err(QueryError::SqlInjectionDetected);
108 }
109
110 let escaped = s.replace('\'', "''");
112
113 Ok(format!("'{}'", escaped))
114 }
115 Parameter::Binary(b) => {
116 Ok(format!("'{}'", hex::encode(b)))
118 }
119 }
120 }
121
122 fn contains_sql_injection_pattern(s: &str) -> bool {
124 let upper = s.to_uppercase();
125
126 let patterns = [
128 "'; DROP",
129 "'; DELETE",
130 "'; UPDATE",
131 "'; INSERT",
132 "' OR '1'='1",
133 "' OR 1=1",
134 "' OR TRUE",
135 "UNION SELECT",
136 "EXEC(",
137 "EXECUTE(",
138 ];
139
140 patterns.iter().any(|pattern| upper.contains(pattern))
141 }
142}
143
144impl From<bool> for Parameter {
145 fn from(value: bool) -> Self {
146 Parameter::Boolean(value)
147 }
148}
149
150impl From<i32> for Parameter {
151 fn from(value: i32) -> Self {
152 Parameter::Integer(value as i64)
153 }
154}
155
156impl From<i64> for Parameter {
157 fn from(value: i64) -> Self {
158 Parameter::Integer(value)
159 }
160}
161
162impl From<f64> for Parameter {
163 fn from(value: f64) -> Self {
164 Parameter::Float(value)
165 }
166}
167
168impl From<String> for Parameter {
169 fn from(value: String) -> Self {
170 Parameter::String(value)
171 }
172}
173
174impl From<&str> for Parameter {
175 fn from(value: &str) -> Self {
176 Parameter::String(value.to_string())
177 }
178}
179
180impl From<Vec<u8>> for Parameter {
181 fn from(value: Vec<u8>) -> Self {
182 Parameter::Binary(value)
183 }
184}
185
186pub struct Statement {
194 sql: String,
196 parameters: Vec<Option<Parameter>>,
198 timeout_ms: u64,
200 statement_type: StatementType,
202}
203
204impl Statement {
205 pub fn new(sql: impl Into<String>) -> Self {
207 let sql = sql.into();
208 let statement_type = StatementType::from_sql(&sql);
209
210 Self {
211 sql,
212 parameters: Vec::new(),
213 timeout_ms: 120_000, statement_type,
215 }
216 }
217
218 pub fn sql(&self) -> &str {
220 &self.sql
221 }
222
223 pub fn statement_type(&self) -> StatementType {
225 self.statement_type
226 }
227
228 pub fn timeout_ms(&self) -> u64 {
230 self.timeout_ms
231 }
232
233 pub fn set_timeout(&mut self, timeout_ms: u64) {
235 self.timeout_ms = timeout_ms;
236 }
237
238 pub fn bind<T: Into<Parameter>>(&mut self, index: usize, value: T) -> Result<(), QueryError> {
247 if index >= self.parameters.len() {
249 self.parameters.resize(index + 1, None);
250 }
251
252 self.parameters[index] = Some(value.into());
253 Ok(())
254 }
255
256 pub fn bind_all<T: Into<Parameter> + Clone>(&mut self, params: &[T]) -> Result<(), QueryError> {
258 for (index, param) in params.iter().enumerate() {
259 self.bind(index, param.clone())?;
260 }
261 Ok(())
262 }
263
264 pub fn clear_parameters(&mut self) {
266 self.parameters.clear();
267 }
268
269 pub fn parameters(&self) -> &[Option<Parameter>] {
271 &self.parameters
272 }
273
274 pub fn build_sql(&self) -> Result<String, QueryError> {
278 let mut sql = self.sql.clone();
279 let mut param_index = 0;
280
281 while let Some(pos) = sql.find('?') {
283 if param_index >= self.parameters.len() {
284 return Err(QueryError::ParameterBindingError {
285 index: param_index,
286 message: "Not enough parameters bound".to_string(),
287 });
288 }
289
290 let param = self.parameters[param_index].as_ref().ok_or_else(|| {
291 QueryError::ParameterBindingError {
292 index: param_index,
293 message: "Parameter not bound".to_string(),
294 }
295 })?;
296
297 let literal = param.to_sql_literal()?;
298 sql.replace_range(pos..pos + 1, &literal);
299 param_index += 1;
300 }
301
302 Ok(sql)
303 }
304}
305
306impl std::fmt::Debug for Statement {
307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 f.debug_struct("Statement")
309 .field("sql", &self.sql)
310 .field("statement_type", &self.statement_type)
311 .field("timeout_ms", &self.timeout_ms)
312 .finish()
313 }
314}
315
316impl std::fmt::Display for Statement {
317 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318 write!(f, "Statement({})", self.sql)
319 }
320}
321
322#[cfg(test)]
323#[allow(clippy::approx_constant)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_statement_type_detection() {
329 assert_eq!(
330 StatementType::from_sql("SELECT * FROM users"),
331 StatementType::Select
332 );
333 assert_eq!(
334 StatementType::from_sql(" select * from users"),
335 StatementType::Select
336 );
337 assert_eq!(
338 StatementType::from_sql("WITH cte AS (SELECT 1) SELECT * FROM cte"),
339 StatementType::Select
340 );
341 assert_eq!(
342 StatementType::from_sql("INSERT INTO users VALUES (1)"),
343 StatementType::Insert
344 );
345 assert_eq!(
346 StatementType::from_sql("UPDATE users SET name = 'John'"),
347 StatementType::Update
348 );
349 assert_eq!(
350 StatementType::from_sql("DELETE FROM users WHERE id = 1"),
351 StatementType::Delete
352 );
353 assert_eq!(
354 StatementType::from_sql("CREATE TABLE test (id INT)"),
355 StatementType::Ddl
356 );
357 assert_eq!(
358 StatementType::from_sql("DROP TABLE test"),
359 StatementType::Ddl
360 );
361 assert_eq!(StatementType::from_sql("BEGIN"), StatementType::Transaction);
362 assert_eq!(
363 StatementType::from_sql("COMMIT"),
364 StatementType::Transaction
365 );
366 assert_eq!(
367 StatementType::from_sql("ROLLBACK"),
368 StatementType::Transaction
369 );
370 }
371
372 #[test]
373 fn test_statement_type_returns_result_set() {
374 assert!(StatementType::Select.returns_result_set());
375 assert!(!StatementType::Insert.returns_result_set());
376 assert!(!StatementType::Update.returns_result_set());
377 assert!(!StatementType::Delete.returns_result_set());
378 }
379
380 #[test]
381 fn test_parameter_to_sql_literal() {
382 assert_eq!(Parameter::Null.to_sql_literal().unwrap(), "NULL");
383 assert_eq!(Parameter::Boolean(true).to_sql_literal().unwrap(), "TRUE");
384 assert_eq!(Parameter::Boolean(false).to_sql_literal().unwrap(), "FALSE");
385 assert_eq!(Parameter::Integer(42).to_sql_literal().unwrap(), "42");
386 assert_eq!(Parameter::Float(3.14).to_sql_literal().unwrap(), "3.14");
387 assert_eq!(
388 Parameter::String("hello".to_string())
389 .to_sql_literal()
390 .unwrap(),
391 "'hello'"
392 );
393 }
394
395 #[test]
396 fn test_parameter_string_escaping() {
397 let param = Parameter::String("O'Reilly".to_string());
398 assert_eq!(param.to_sql_literal().unwrap(), "'O''Reilly'");
399 }
400
401 #[test]
402 fn test_parameter_sql_injection_detection() {
403 let dangerous = Parameter::String("'; DROP TABLE users; --".to_string());
404 assert!(dangerous.to_sql_literal().is_err());
405
406 let malicious = Parameter::String("' OR '1'='1".to_string());
407 assert!(malicious.to_sql_literal().is_err());
408
409 let safe = Parameter::String("It's a nice day".to_string());
410 assert!(safe.to_sql_literal().is_ok());
411 }
412
413 #[test]
414 fn test_parameter_conversions() {
415 let _p: Parameter = true.into();
416 let _p: Parameter = 42i32.into();
417 let _p: Parameter = 42i64.into();
418 let _p: Parameter = 3.14f64.into();
419 let _p: Parameter = "test".into();
420 let _p: Parameter = String::from("test").into();
421 let _p: Parameter = vec![1u8, 2, 3].into();
422 }
423
424 #[test]
425 fn test_statement_creation() {
426 let stmt = Statement::new("SELECT * FROM users");
427
428 assert_eq!(stmt.sql(), "SELECT * FROM users");
429 assert_eq!(stmt.statement_type(), StatementType::Select);
430 assert_eq!(stmt.timeout_ms(), 120_000);
431 }
432
433 #[test]
434 fn test_statement_parameter_binding() {
435 let mut stmt = Statement::new("SELECT * FROM users WHERE id = ?");
436
437 stmt.bind(0, 42).unwrap();
438
439 let final_sql = stmt.build_sql().unwrap();
440 assert_eq!(final_sql, "SELECT * FROM users WHERE id = 42");
441 }
442
443 #[test]
444 fn test_statement_multiple_parameters() {
445 let mut stmt = Statement::new("SELECT * FROM users WHERE age > ? AND name = ?");
446
447 stmt.bind(0, 18).unwrap();
448 stmt.bind(1, "John").unwrap();
449
450 let final_sql = stmt.build_sql().unwrap();
451 assert_eq!(
452 final_sql,
453 "SELECT * FROM users WHERE age > 18 AND name = 'John'"
454 );
455 }
456
457 #[test]
458 fn test_statement_set_timeout() {
459 let mut stmt = Statement::new("SELECT * FROM users");
460 stmt.set_timeout(30_000);
461 assert_eq!(stmt.timeout_ms(), 30_000);
462 }
463
464 #[test]
465 fn test_statement_clear_parameters() {
466 let mut stmt = Statement::new("SELECT * FROM users WHERE id = ?");
467 stmt.bind(0, 42).unwrap();
468 stmt.clear_parameters();
469 assert!(stmt.parameters().is_empty());
470 }
471
472 #[test]
473 fn test_statement_display() {
474 let stmt = Statement::new("SELECT 1");
475 let display = format!("{}", stmt);
476 assert!(display.contains("SELECT 1"));
477 }
478}