Skip to main content

dbx_core/sql/
builder.rs

1//! Query Builder — Fluent 스타일 API
2//!
3//! Dapper 스타일 파라미터 바인딩 지원:
4//! - Positional: `$1, $2, ...`
5//! - Named: `:name, :age, ...`
6
7use crate::api::FromRow;
8use crate::engine::Database;
9use crate::error::{DbxError, DbxResult};
10use std::marker::PhantomData;
11
12/// 쿼리 파라미터 값
13#[derive(Debug, Clone)]
14pub enum ScalarValue {
15    Null,
16    Int32(i32),
17    Int64(i64),
18    Float64(f64),
19    Utf8(String),
20    Boolean(bool),
21}
22
23impl ScalarValue {
24    /// SQL 리터럴 문자열로 변환 (placeholder 치환용)
25    pub fn to_sql_literal(&self) -> String {
26        match self {
27            ScalarValue::Null => "NULL".to_string(),
28            ScalarValue::Int32(v) => v.to_string(),
29            ScalarValue::Int64(v) => v.to_string(),
30            ScalarValue::Float64(v) => format!("{v}"),
31            ScalarValue::Utf8(v) => format!("'{}'", v.replace('\'', "''")),
32            ScalarValue::Boolean(v) => {
33                if *v {
34                    "TRUE".to_string()
35                } else {
36                    "FALSE".to_string()
37                }
38            }
39        }
40    }
41}
42
43/// Named 파라미터 엔트리
44#[derive(Debug, Clone)]
45struct NamedParam {
46    name: String,
47    value: ScalarValue,
48}
49
50/// SQL에 파라미터를 적용하여 최종 실행 가능한 SQL 생성
51/// 문자열 리터럴('...') 및 식별자("...") 내부에 있는 placeholder는 무시합니다.
52fn apply_params(sql: &str, positional: &[ScalarValue], named: &[NamedParam]) -> DbxResult<String> {
53    if !positional.is_empty() && !named.is_empty() {
54        return Err(DbxError::InvalidOperation {
55            message: "positional과 named 파라미터를 동시에 사용할 수 없습니다".to_string(),
56            context: "apply_params".to_string(),
57        });
58    }
59
60    let mut result = String::with_capacity(sql.len() + 64);
61    let mut chars = sql.chars().peekable();
62    let mut in_single_quote = false;
63    let mut in_double_quote = false;
64
65    while let Some(c) = chars.next() {
66        if in_single_quote {
67            result.push(c);
68            if c == '\'' {
69                in_single_quote = false;
70            }
71            continue;
72        } else if in_double_quote {
73            result.push(c);
74            if c == '"' {
75                in_double_quote = false;
76            }
77            continue;
78        }
79
80        match c {
81            '\'' => {
82                in_single_quote = true;
83                result.push(c);
84            }
85            '"' => {
86                in_double_quote = true;
87                result.push(c);
88            }
89            '$' if !positional.is_empty() => {
90                let mut num_str = String::new();
91                while let Some(&next_c) = chars.peek() {
92                    if next_c.is_ascii_digit() {
93                        num_str.push(next_c);
94                        chars.next();
95                    } else {
96                        break;
97                    }
98                }
99                if let Ok(idx) = num_str.parse::<usize>() {
100                    if idx > 0 && idx <= positional.len() {
101                        result.push_str(&positional[idx - 1].to_sql_literal());
102                    } else {
103                        result.push('$');
104                        result.push_str(&num_str);
105                    }
106                } else {
107                    result.push('$');
108                    result.push_str(&num_str);
109                }
110            }
111            ':' if !named.is_empty() => {
112                let mut name = String::new();
113                while let Some(&next_c) = chars.peek() {
114                    if next_c.is_ascii_alphanumeric() || next_c == '_' {
115                        name.push(next_c);
116                        chars.next();
117                    } else {
118                        break;
119                    }
120                }
121                if !name.is_empty() {
122                    let mut found = false;
123                    for np in named {
124                        if np.name == name {
125                            result.push_str(&np.value.to_sql_literal());
126                            found = true;
127                            break;
128                        }
129                    }
130                    if !found {
131                        result.push(':');
132                        result.push_str(&name);
133                    }
134                } else {
135                    result.push(':');
136                }
137            }
138            _ => {
139                result.push(c);
140            }
141        }
142    }
143
144    Ok(result)
145}
146
147/// Query Builder — 여러 행 반환
148pub struct Query<'a, T> {
149    db: &'a Database,
150    sql: String,
151    params: Vec<ScalarValue>,
152    named_params: Vec<NamedParam>,
153    _marker: PhantomData<T>,
154}
155
156impl<'a, T: FromRow> Query<'a, T> {
157    pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
158        Self {
159            db,
160            sql: sql.into(),
161            params: Vec::new(),
162            named_params: Vec::new(),
163            _marker: PhantomData,
164        }
165    }
166
167    /// Positional 파라미터 바인딩 ($1, $2, ...)
168    pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
169        self.params.push(value.into_scalar());
170        self
171    }
172
173    /// Named 파라미터 바인딩 (:name, :age, ...)
174    pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
175        self.named_params.push(NamedParam {
176            name: name.to_string(),
177            value: value.into_scalar(),
178        });
179        self
180    }
181
182    /// 모든 행 반환
183    pub fn fetch_all(self) -> DbxResult<Vec<T>> {
184        let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
185        let batches = self.db.execute_sql(&final_sql)?;
186        let mut rows = Vec::new();
187        for batch in &batches {
188            for row_idx in 0..batch.num_rows() {
189                rows.push(T::from_row(batch, row_idx)?);
190            }
191        }
192        Ok(rows)
193    }
194
195    /// 첫 번째 행만 반환 (나머지 무시)
196    pub fn fetch_first(self) -> DbxResult<Option<T>> {
197        let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
198        let batches = self.db.execute_sql(&final_sql)?;
199        for batch in &batches {
200            if batch.num_rows() > 0 {
201                return Ok(Some(T::from_row(batch, 0)?));
202            }
203        }
204        Ok(None)
205    }
206}
207
208/// Query Builder — 단일 행 반환 (없으면 에러)
209pub struct QueryOne<'a, T> {
210    db: &'a Database,
211    sql: String,
212    params: Vec<ScalarValue>,
213    named_params: Vec<NamedParam>,
214    _marker: PhantomData<T>,
215}
216
217impl<'a, T: FromRow> QueryOne<'a, T> {
218    pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
219        Self {
220            db,
221            sql: sql.into(),
222            params: Vec::new(),
223            named_params: Vec::new(),
224            _marker: PhantomData,
225        }
226    }
227
228    pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
229        self.params.push(value.into_scalar());
230        self
231    }
232
233    pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
234        self.named_params.push(NamedParam {
235            name: name.to_string(),
236            value: value.into_scalar(),
237        });
238        self
239    }
240
241    /// 정확히 1개 행 반환 (0개 또는 2개 이상이면 에러)
242    pub fn fetch(self) -> DbxResult<T> {
243        let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
244        let batches = self.db.execute_sql(&final_sql)?;
245        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
246        if total_rows == 0 {
247            return Err(DbxError::KeyNotFound);
248        }
249        if total_rows > 1 {
250            return Err(DbxError::InvalidOperation {
251                message: format!("expected 1 row, got {total_rows}"),
252                context: "QueryOne::fetch".to_string(),
253            });
254        }
255        T::from_row(&batches[0], 0)
256    }
257}
258
259/// Query Builder — 단일 행 반환 (없으면 None)
260pub struct QueryOptional<'a, T> {
261    db: &'a Database,
262    sql: String,
263    params: Vec<ScalarValue>,
264    named_params: Vec<NamedParam>,
265    _marker: PhantomData<T>,
266}
267
268impl<'a, T: FromRow> QueryOptional<'a, T> {
269    pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
270        Self {
271            db,
272            sql: sql.into(),
273            params: Vec::new(),
274            named_params: Vec::new(),
275            _marker: PhantomData,
276        }
277    }
278
279    pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
280        self.params.push(value.into_scalar());
281        self
282    }
283
284    pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
285        self.named_params.push(NamedParam {
286            name: name.to_string(),
287            value: value.into_scalar(),
288        });
289        self
290    }
291
292    pub fn fetch(self) -> DbxResult<Option<T>> {
293        let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
294        let batches = self.db.execute_sql(&final_sql)?;
295        for batch in &batches {
296            if batch.num_rows() > 0 {
297                return Ok(Some(T::from_row(batch, 0)?));
298            }
299        }
300        Ok(None)
301    }
302}
303
304/// Query Builder — Scalar 값 반환
305pub struct QueryScalar<'a, T> {
306    db: &'a Database,
307    sql: String,
308    params: Vec<ScalarValue>,
309    _marker: PhantomData<T>,
310}
311
312impl<'a, T: FromScalar> QueryScalar<'a, T> {
313    pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
314        Self {
315            db,
316            sql: sql.into(),
317            params: Vec::new(),
318            _marker: PhantomData,
319        }
320    }
321
322    pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
323        self.params.push(value.into_scalar());
324        self
325    }
326
327    pub fn fetch(self) -> DbxResult<T> {
328        let final_sql = apply_params(&self.sql, &self.params, &[])?;
329        let batches = self.db.execute_sql(&final_sql)?;
330        for batch in &batches {
331            if batch.num_rows() > 0 && batch.num_columns() > 0 {
332                let col = batch.column(0);
333                let sv = crate::storage::columnar::ScalarValue::from_array(col, 0)?;
334                let qsv = scalar_to_query_scalar(&sv);
335                return T::from_scalar(&qsv);
336            }
337        }
338        Err(DbxError::KeyNotFound)
339    }
340}
341
342/// Execute Builder — INSERT/UPDATE/DELETE
343pub struct Execute<'a> {
344    db: &'a Database,
345    sql: String,
346    params: Vec<ScalarValue>,
347    named_params: Vec<NamedParam>,
348}
349
350impl<'a> Execute<'a> {
351    pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
352        Self {
353            db,
354            sql: sql.into(),
355            params: Vec::new(),
356            named_params: Vec::new(),
357        }
358    }
359
360    pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
361        self.params.push(value.into_scalar());
362        self
363    }
364
365    pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
366        self.named_params.push(NamedParam {
367            name: name.to_string(),
368            value: value.into_scalar(),
369        });
370        self
371    }
372
373    /// INSERT/UPDATE/DELETE 실행 → 영향받은 행 수
374    pub fn run(self) -> DbxResult<usize> {
375        let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
376        let batches = self.db.execute_sql(&final_sql)?;
377        Ok(batches.iter().map(|b| b.num_rows()).sum())
378    }
379}
380
381/// 파라미터 변환 트레이트
382pub trait IntoParam {
383    fn into_scalar(self) -> ScalarValue;
384}
385
386/// Scalar 값 추출 트레이트
387pub trait FromScalar: Sized {
388    fn from_scalar(value: &ScalarValue) -> DbxResult<Self>;
389}
390
391/// Convert columnar::ScalarValue to query::ScalarValue
392fn scalar_to_query_scalar(sv: &crate::storage::columnar::ScalarValue) -> ScalarValue {
393    use crate::storage::columnar::ScalarValue as CSV;
394    match sv {
395        CSV::Null => ScalarValue::Null,
396        CSV::Int32(v) => ScalarValue::Int32(*v),
397        CSV::Int64(v) => ScalarValue::Int64(*v),
398        CSV::Float64(v) => ScalarValue::Float64(*v),
399        CSV::Utf8(v) => ScalarValue::Utf8(v.clone()),
400        CSV::Boolean(v) => ScalarValue::Boolean(*v),
401        CSV::Binary(_) => {
402            // Binary type is not supported in query builder ScalarValue
403            // This should not happen in normal query operations
404            ScalarValue::Null
405        }
406    }
407}
408
409// 기본 타입 구현
410impl IntoParam for i32 {
411    fn into_scalar(self) -> ScalarValue {
412        ScalarValue::Int32(self)
413    }
414}
415
416impl IntoParam for i64 {
417    fn into_scalar(self) -> ScalarValue {
418        ScalarValue::Int64(self)
419    }
420}
421
422impl IntoParam for f64 {
423    fn into_scalar(self) -> ScalarValue {
424        ScalarValue::Float64(self)
425    }
426}
427
428impl IntoParam for &str {
429    fn into_scalar(self) -> ScalarValue {
430        ScalarValue::Utf8(self.to_string())
431    }
432}
433
434impl IntoParam for String {
435    fn into_scalar(self) -> ScalarValue {
436        ScalarValue::Utf8(self)
437    }
438}
439
440impl IntoParam for bool {
441    fn into_scalar(self) -> ScalarValue {
442        ScalarValue::Boolean(self)
443    }
444}
445
446impl<T: IntoParam> IntoParam for Option<T> {
447    fn into_scalar(self) -> ScalarValue {
448        match self {
449            Some(v) => v.into_scalar(),
450            None => ScalarValue::Null,
451        }
452    }
453}
454
455// FromScalar 구현
456impl FromScalar for i64 {
457    fn from_scalar(value: &ScalarValue) -> DbxResult<Self> {
458        match value {
459            ScalarValue::Int64(v) => Ok(*v),
460            _ => Err(crate::error::DbxError::TypeMismatch {
461                expected: "Int64".to_string(),
462                actual: format!("{:?}", value),
463            }),
464        }
465    }
466}
467
468impl FromScalar for i32 {
469    fn from_scalar(value: &ScalarValue) -> DbxResult<Self> {
470        match value {
471            ScalarValue::Int32(v) => Ok(*v),
472            _ => Err(crate::error::DbxError::TypeMismatch {
473                expected: "Int32".to_string(),
474                actual: format!("{:?}", value),
475            }),
476        }
477    }
478}
479
480impl FromScalar for f64 {
481    fn from_scalar(value: &ScalarValue) -> DbxResult<Self> {
482        match value {
483            ScalarValue::Float64(v) => Ok(*v),
484            _ => Err(crate::error::DbxError::TypeMismatch {
485                expected: "Float64".to_string(),
486                actual: format!("{:?}", value),
487            }),
488        }
489    }
490}
491
492// Database에 Query Builder 메서드 추가
493impl Database {
494    /// SELECT 쿼리 — 여러 행 반환
495    pub fn query<T: FromRow>(&self, sql: impl Into<String>) -> Query<'_, T> {
496        Query::new(self, sql)
497    }
498
499    /// SELECT 쿼리 — 단일 행 반환 (없으면 에러)
500    pub fn query_one<T: FromRow>(&self, sql: impl Into<String>) -> QueryOne<'_, T> {
501        QueryOne::new(self, sql)
502    }
503
504    /// SELECT 쿼리 — 단일 행 반환 (없으면 None)
505    pub fn query_optional<T: FromRow>(&self, sql: impl Into<String>) -> QueryOptional<'_, T> {
506        QueryOptional::new(self, sql)
507    }
508
509    /// SELECT 쿼리 — 단일 스칼라 값 반환
510    pub fn query_scalar<T: FromScalar>(&self, sql: impl Into<String>) -> QueryScalar<'_, T> {
511        QueryScalar::new(self, sql)
512    }
513
514    /// INSERT/UPDATE/DELETE — 영향받은 행 수 반환
515    pub fn execute(&self, sql: impl Into<String>) -> Execute<'_> {
516        Execute::new(self, sql)
517    }
518}
519
520// ════════════════════════════════════════════
521// DatabaseQuery Trait Implementation
522// ════════════════════════════════════════════
523
524impl crate::traits::DatabaseQuery for Database {
525    // Query Builder methods are already implemented in impl Database block above
526    // No additional implementation needed - Trait is satisfied by existing methods
527}
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532
533    #[test]
534    fn test_scalar_to_sql_literal() {
535        assert_eq!(ScalarValue::Null.to_sql_literal(), "NULL");
536        assert_eq!(ScalarValue::Int32(42).to_sql_literal(), "42");
537        assert_eq!(ScalarValue::Int64(100).to_sql_literal(), "100");
538        assert_eq!(ScalarValue::Float64(3.1).to_sql_literal(), "3.1");
539        assert_eq!(
540            ScalarValue::Utf8("hello".into()).to_sql_literal(),
541            "'hello'"
542        );
543        assert_eq!(ScalarValue::Boolean(true).to_sql_literal(), "TRUE");
544        assert_eq!(ScalarValue::Boolean(false).to_sql_literal(), "FALSE");
545    }
546
547    #[test]
548    fn test_sql_literal_single_quote_escape() {
549        // SQL injection 방어: single quote → 이스케이프
550        assert_eq!(
551            ScalarValue::Utf8("O'Brien".into()).to_sql_literal(),
552            "'O''Brien'"
553        );
554    }
555
556    #[test]
557    fn test_apply_params_positional() {
558        let sql = "SELECT * FROM users WHERE id = $1 AND age > $2";
559        let params = vec![ScalarValue::Int32(42), ScalarValue::Int64(18)];
560        let result = apply_params(sql, &params, &[]).unwrap();
561        assert_eq!(result, "SELECT * FROM users WHERE id = 42 AND age > 18");
562    }
563
564    #[test]
565    fn test_apply_params_string() {
566        let sql = "SELECT * FROM users WHERE name = $1";
567        let params = vec![ScalarValue::Utf8("Alice".into())];
568        let result = apply_params(sql, &params, &[]).unwrap();
569        assert_eq!(result, "SELECT * FROM users WHERE name = 'Alice'");
570    }
571
572    #[test]
573    fn test_apply_params_null_bool() {
574        let sql = "SELECT * FROM t WHERE a = $1 AND b = $2 AND c = $3";
575        let params = vec![
576            ScalarValue::Null,
577            ScalarValue::Boolean(true),
578            ScalarValue::Boolean(false),
579        ];
580        let result = apply_params(sql, &params, &[]).unwrap();
581        assert_eq!(
582            result,
583            "SELECT * FROM t WHERE a = NULL AND b = TRUE AND c = FALSE"
584        );
585    }
586
587    #[test]
588    fn test_apply_params_reverse_order_safety() {
589        // $10이 $1로 잘못 매치되지 않아야 함
590        let sql = "SELECT $1, $10";
591        let mut params = vec![ScalarValue::Int32(1)];
592        for i in 2..=10 {
593            params.push(ScalarValue::Int32(i));
594        }
595        let result = apply_params(sql, &params, &[]).unwrap();
596        assert_eq!(result, "SELECT 1, 10");
597    }
598
599    #[test]
600    fn test_apply_named_params() {
601        let sql = "SELECT * FROM users WHERE name = :name AND age > :age";
602        let named = vec![
603            NamedParam {
604                name: "name".into(),
605                value: ScalarValue::Utf8("Alice".into()),
606            },
607            NamedParam {
608                name: "age".into(),
609                value: ScalarValue::Int32(18),
610            },
611        ];
612        let result = apply_params(sql, &[], &named).unwrap();
613        assert_eq!(
614            result,
615            "SELECT * FROM users WHERE name = 'Alice' AND age > 18"
616        );
617    }
618
619    #[test]
620    fn test_apply_named_params_ignores_strings() {
621        let sql = "SELECT * FROM users WHERE txt = 'cost: $1, name: :name' AND id = $1";
622        let positional = vec![ScalarValue::Int32(5)];
623        let result = apply_params(sql, &positional, &[]).unwrap();
624        assert_eq!(
625            result,
626            "SELECT * FROM users WHERE txt = 'cost: $1, name: :name' AND id = 5"
627        );
628    }
629
630    #[test]
631    fn test_mixed_params_error() {
632        let sql = "SELECT * FROM users";
633        let positional = vec![ScalarValue::Int32(1)];
634        let named = vec![NamedParam {
635            name: "a".into(),
636            value: ScalarValue::Int32(2),
637        }];
638        let result = apply_params(sql, &positional, &named);
639        assert!(result.is_err());
640    }
641
642    #[test]
643    fn test_apply_params_named_full() {
644        let sql = "SELECT * FROM t WHERE x = :x AND y = :y";
645        let named = vec![
646            NamedParam {
647                name: "x".into(),
648                value: ScalarValue::Int32(10),
649            },
650            NamedParam {
651                name: "y".into(),
652                value: ScalarValue::Utf8("hello".into()),
653            },
654        ];
655        let result = apply_params(sql, &[], &named).unwrap();
656        assert_eq!(result, "SELECT * FROM t WHERE x = 10 AND y = 'hello'");
657    }
658
659    #[test]
660    fn test_apply_params_positional_full() {
661        let sql = "INSERT INTO t VALUES ($1, $2, $3)";
662        let params = vec![
663            ScalarValue::Int32(1),
664            ScalarValue::Utf8("test".into()),
665            ScalarValue::Float64(3.1),
666        ];
667        let result = apply_params(sql, &params, &[]).unwrap();
668        assert_eq!(result, "INSERT INTO t VALUES (1, 'test', 3.1)");
669    }
670
671    #[test]
672    fn test_into_param_trait() {
673        assert!(matches!(42i32.into_scalar(), ScalarValue::Int32(42)));
674        assert!(matches!(100i64.into_scalar(), ScalarValue::Int64(100)));
675        assert!(matches!(3.1f64.into_scalar(), ScalarValue::Float64(_)));
676        assert!(matches!("hello".into_scalar(), ScalarValue::Utf8(_)));
677        assert!(matches!(true.into_scalar(), ScalarValue::Boolean(true)));
678        assert!(matches!(
679            Option::<i32>::None.into_scalar(),
680            ScalarValue::Null
681        ));
682        assert!(matches!(Some(10i32).into_scalar(), ScalarValue::Int32(10)));
683    }
684}