Skip to main content

dbx_core/sql/
builder.rs

1//! Query Builder — Fluent 스타일 API
2//!
3//! Dapper 스타일 파라미터 바인딩 지원:
4//! - Positional: `$1, $2, ...`
5//! - Named: `:name, :age, ...`
6
7use crate::api::FromRow;
8use crate::engine::Database;
9use crate::error::{DbxError, DbxResult};
10use std::marker::PhantomData;
11
12/// 쿼리 파라미터 값
13#[derive(Debug, Clone)]
14pub enum ScalarValue {
15    Null,
16    Int32(i32),
17    Int64(i64),
18    Float64(f64),
19    Utf8(String),
20    Boolean(bool),
21}
22
23impl ScalarValue {
24    /// SQL 리터럴 문자열로 변환 (placeholder 치환용)
25    pub fn to_sql_literal(&self) -> String {
26        match self {
27            ScalarValue::Null => "NULL".to_string(),
28            ScalarValue::Int32(v) => v.to_string(),
29            ScalarValue::Int64(v) => v.to_string(),
30            ScalarValue::Float64(v) => format!("{v}"),
31            ScalarValue::Utf8(v) => format!("'{}'", v.replace('\'', "''")),
32            ScalarValue::Boolean(v) => {
33                if *v {
34                    "TRUE".to_string()
35                } else {
36                    "FALSE".to_string()
37                }
38            }
39        }
40    }
41}
42
43/// Named 파라미터 엔트리
44#[derive(Debug, Clone)]
45struct NamedParam {
46    name: String,
47    value: ScalarValue,
48}
49
50/// Positional `$N` placeholder를 리터럴 값으로 치환
51fn substitute_params(sql: &str, params: &[ScalarValue]) -> String {
52    let mut result = sql.to_string();
53    // 큰 번호부터 치환해야 $10이 $1로 잘못 매치되지 않음
54    for (i, param) in params.iter().enumerate().rev() {
55        let placeholder = format!("${}", i + 1);
56        result = result.replace(&placeholder, &param.to_sql_literal());
57    }
58    result
59}
60
61/// Named `:name` placeholder를 positional `$N`으로 변환하고
62/// 파라미터 순서를 재배열
63fn resolve_named_params(
64    sql: &str,
65    named: &[NamedParam],
66    positional: &[ScalarValue],
67) -> DbxResult<(String, Vec<ScalarValue>)> {
68    if named.is_empty() {
69        // Positional only — 그대로 사용
70        return Ok((sql.to_string(), positional.to_vec()));
71    }
72    if !positional.is_empty() {
73        return Err(DbxError::InvalidOperation {
74            message: "positional과 named 파라미터를 동시에 사용할 수 없습니다".to_string(),
75            context: "resolve_named_params".to_string(),
76        });
77    }
78
79    let mut result_sql = sql.to_string();
80    let mut ordered_params = Vec::new();
81    let mut idx = 1;
82
83    for np in named {
84        let named_placeholder = format!(":{}", np.name);
85        if result_sql.contains(&named_placeholder) {
86            let positional_placeholder = format!("${idx}");
87            result_sql = result_sql.replace(&named_placeholder, &positional_placeholder);
88            ordered_params.push(np.value.clone());
89            idx += 1;
90        }
91    }
92
93    Ok((result_sql, ordered_params))
94}
95
96/// SQL에 파라미터를 적용하여 최종 실행 가능한 SQL 생성
97fn apply_params(sql: &str, positional: &[ScalarValue], named: &[NamedParam]) -> DbxResult<String> {
98    let (resolved_sql, params) = resolve_named_params(sql, named, positional)?;
99    Ok(substitute_params(&resolved_sql, &params))
100}
101
102/// Query Builder — 여러 행 반환
103pub struct Query<'a, T> {
104    db: &'a Database,
105    sql: String,
106    params: Vec<ScalarValue>,
107    named_params: Vec<NamedParam>,
108    _marker: PhantomData<T>,
109}
110
111impl<'a, T: FromRow> Query<'a, T> {
112    pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
113        Self {
114            db,
115            sql: sql.into(),
116            params: Vec::new(),
117            named_params: Vec::new(),
118            _marker: PhantomData,
119        }
120    }
121
122    /// Positional 파라미터 바인딩 ($1, $2, ...)
123    pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
124        self.params.push(value.into_scalar());
125        self
126    }
127
128    /// Named 파라미터 바인딩 (:name, :age, ...)
129    pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
130        self.named_params.push(NamedParam {
131            name: name.to_string(),
132            value: value.into_scalar(),
133        });
134        self
135    }
136
137    /// 모든 행 반환
138    pub fn fetch_all(self) -> DbxResult<Vec<T>> {
139        let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
140        let batches = self.db.execute_sql(&final_sql)?;
141        let mut rows = Vec::new();
142        for batch in &batches {
143            for row_idx in 0..batch.num_rows() {
144                rows.push(T::from_row(batch, row_idx)?);
145            }
146        }
147        Ok(rows)
148    }
149
150    /// 첫 번째 행만 반환 (나머지 무시)
151    pub fn fetch_first(self) -> DbxResult<Option<T>> {
152        let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
153        let batches = self.db.execute_sql(&final_sql)?;
154        for batch in &batches {
155            if batch.num_rows() > 0 {
156                return Ok(Some(T::from_row(batch, 0)?));
157            }
158        }
159        Ok(None)
160    }
161}
162
163/// Query Builder — 단일 행 반환 (없으면 에러)
164pub struct QueryOne<'a, T> {
165    db: &'a Database,
166    sql: String,
167    params: Vec<ScalarValue>,
168    named_params: Vec<NamedParam>,
169    _marker: PhantomData<T>,
170}
171
172impl<'a, T: FromRow> QueryOne<'a, T> {
173    pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
174        Self {
175            db,
176            sql: sql.into(),
177            params: Vec::new(),
178            named_params: Vec::new(),
179            _marker: PhantomData,
180        }
181    }
182
183    pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
184        self.params.push(value.into_scalar());
185        self
186    }
187
188    pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
189        self.named_params.push(NamedParam {
190            name: name.to_string(),
191            value: value.into_scalar(),
192        });
193        self
194    }
195
196    /// 정확히 1개 행 반환 (0개 또는 2개 이상이면 에러)
197    pub fn fetch(self) -> DbxResult<T> {
198        let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
199        let batches = self.db.execute_sql(&final_sql)?;
200        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
201        if total_rows == 0 {
202            return Err(DbxError::KeyNotFound);
203        }
204        if total_rows > 1 {
205            return Err(DbxError::InvalidOperation {
206                message: format!("expected 1 row, got {total_rows}"),
207                context: "QueryOne::fetch".to_string(),
208            });
209        }
210        T::from_row(&batches[0], 0)
211    }
212}
213
214/// Query Builder — 단일 행 반환 (없으면 None)
215pub struct QueryOptional<'a, T> {
216    db: &'a Database,
217    sql: String,
218    params: Vec<ScalarValue>,
219    named_params: Vec<NamedParam>,
220    _marker: PhantomData<T>,
221}
222
223impl<'a, T: FromRow> QueryOptional<'a, T> {
224    pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
225        Self {
226            db,
227            sql: sql.into(),
228            params: Vec::new(),
229            named_params: Vec::new(),
230            _marker: PhantomData,
231        }
232    }
233
234    pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
235        self.params.push(value.into_scalar());
236        self
237    }
238
239    pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
240        self.named_params.push(NamedParam {
241            name: name.to_string(),
242            value: value.into_scalar(),
243        });
244        self
245    }
246
247    pub fn fetch(self) -> DbxResult<Option<T>> {
248        let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
249        let batches = self.db.execute_sql(&final_sql)?;
250        for batch in &batches {
251            if batch.num_rows() > 0 {
252                return Ok(Some(T::from_row(batch, 0)?));
253            }
254        }
255        Ok(None)
256    }
257}
258
259/// Query Builder — Scalar 값 반환
260pub struct QueryScalar<'a, T> {
261    db: &'a Database,
262    sql: String,
263    params: Vec<ScalarValue>,
264    _marker: PhantomData<T>,
265}
266
267impl<'a, T: FromScalar> QueryScalar<'a, T> {
268    pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
269        Self {
270            db,
271            sql: sql.into(),
272            params: Vec::new(),
273            _marker: PhantomData,
274        }
275    }
276
277    pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
278        self.params.push(value.into_scalar());
279        self
280    }
281
282    pub fn fetch(self) -> DbxResult<T> {
283        let final_sql = substitute_params(&self.sql, &self.params);
284        let batches = self.db.execute_sql(&final_sql)?;
285        for batch in &batches {
286            if batch.num_rows() > 0 && batch.num_columns() > 0 {
287                let col = batch.column(0);
288                let sv = crate::storage::columnar::ScalarValue::from_array(col, 0)?;
289                let qsv = scalar_to_query_scalar(&sv);
290                return T::from_scalar(&qsv);
291            }
292        }
293        Err(DbxError::KeyNotFound)
294    }
295}
296
297/// Execute Builder — INSERT/UPDATE/DELETE
298pub struct Execute<'a> {
299    db: &'a Database,
300    sql: String,
301    params: Vec<ScalarValue>,
302    named_params: Vec<NamedParam>,
303}
304
305impl<'a> Execute<'a> {
306    pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
307        Self {
308            db,
309            sql: sql.into(),
310            params: Vec::new(),
311            named_params: Vec::new(),
312        }
313    }
314
315    pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
316        self.params.push(value.into_scalar());
317        self
318    }
319
320    pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
321        self.named_params.push(NamedParam {
322            name: name.to_string(),
323            value: value.into_scalar(),
324        });
325        self
326    }
327
328    /// INSERT/UPDATE/DELETE 실행 → 영향받은 행 수
329    pub fn run(self) -> DbxResult<usize> {
330        let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
331        let batches = self.db.execute_sql(&final_sql)?;
332        Ok(batches.iter().map(|b| b.num_rows()).sum())
333    }
334}
335
336/// 파라미터 변환 트레이트
337pub trait IntoParam {
338    fn into_scalar(self) -> ScalarValue;
339}
340
341/// Scalar 값 추출 트레이트
342pub trait FromScalar: Sized {
343    fn from_scalar(value: &ScalarValue) -> DbxResult<Self>;
344}
345
346/// Convert columnar::ScalarValue to query::ScalarValue
347fn scalar_to_query_scalar(sv: &crate::storage::columnar::ScalarValue) -> ScalarValue {
348    use crate::storage::columnar::ScalarValue as CSV;
349    match sv {
350        CSV::Null => ScalarValue::Null,
351        CSV::Int32(v) => ScalarValue::Int32(*v),
352        CSV::Int64(v) => ScalarValue::Int64(*v),
353        CSV::Float64(v) => ScalarValue::Float64(*v),
354        CSV::Utf8(v) => ScalarValue::Utf8(v.clone()),
355        CSV::Boolean(v) => ScalarValue::Boolean(*v),
356        CSV::Binary(_) => {
357            // Binary type is not supported in query builder ScalarValue
358            // This should not happen in normal query operations
359            ScalarValue::Null
360        }
361    }
362}
363
364// 기본 타입 구현
365impl IntoParam for i32 {
366    fn into_scalar(self) -> ScalarValue {
367        ScalarValue::Int32(self)
368    }
369}
370
371impl IntoParam for i64 {
372    fn into_scalar(self) -> ScalarValue {
373        ScalarValue::Int64(self)
374    }
375}
376
377impl IntoParam for f64 {
378    fn into_scalar(self) -> ScalarValue {
379        ScalarValue::Float64(self)
380    }
381}
382
383impl IntoParam for &str {
384    fn into_scalar(self) -> ScalarValue {
385        ScalarValue::Utf8(self.to_string())
386    }
387}
388
389impl IntoParam for String {
390    fn into_scalar(self) -> ScalarValue {
391        ScalarValue::Utf8(self)
392    }
393}
394
395impl IntoParam for bool {
396    fn into_scalar(self) -> ScalarValue {
397        ScalarValue::Boolean(self)
398    }
399}
400
401impl<T: IntoParam> IntoParam for Option<T> {
402    fn into_scalar(self) -> ScalarValue {
403        match self {
404            Some(v) => v.into_scalar(),
405            None => ScalarValue::Null,
406        }
407    }
408}
409
410// FromScalar 구현
411impl FromScalar for i64 {
412    fn from_scalar(value: &ScalarValue) -> DbxResult<Self> {
413        match value {
414            ScalarValue::Int64(v) => Ok(*v),
415            _ => Err(crate::error::DbxError::TypeMismatch {
416                expected: "Int64".to_string(),
417                actual: format!("{:?}", value),
418            }),
419        }
420    }
421}
422
423impl FromScalar for i32 {
424    fn from_scalar(value: &ScalarValue) -> DbxResult<Self> {
425        match value {
426            ScalarValue::Int32(v) => Ok(*v),
427            _ => Err(crate::error::DbxError::TypeMismatch {
428                expected: "Int32".to_string(),
429                actual: format!("{:?}", value),
430            }),
431        }
432    }
433}
434
435impl FromScalar for f64 {
436    fn from_scalar(value: &ScalarValue) -> DbxResult<Self> {
437        match value {
438            ScalarValue::Float64(v) => Ok(*v),
439            _ => Err(crate::error::DbxError::TypeMismatch {
440                expected: "Float64".to_string(),
441                actual: format!("{:?}", value),
442            }),
443        }
444    }
445}
446
447// Database에 Query Builder 메서드 추가
448impl Database {
449    /// SELECT 쿼리 — 여러 행 반환
450    pub fn query<T: FromRow>(&self, sql: impl Into<String>) -> Query<'_, T> {
451        Query::new(self, sql)
452    }
453
454    /// SELECT 쿼리 — 단일 행 반환 (없으면 에러)
455    pub fn query_one<T: FromRow>(&self, sql: impl Into<String>) -> QueryOne<'_, T> {
456        QueryOne::new(self, sql)
457    }
458
459    /// SELECT 쿼리 — 단일 행 반환 (없으면 None)
460    pub fn query_optional<T: FromRow>(&self, sql: impl Into<String>) -> QueryOptional<'_, T> {
461        QueryOptional::new(self, sql)
462    }
463
464    /// SELECT 쿼리 — 단일 스칼라 값 반환
465    pub fn query_scalar<T: FromScalar>(&self, sql: impl Into<String>) -> QueryScalar<'_, T> {
466        QueryScalar::new(self, sql)
467    }
468
469    /// INSERT/UPDATE/DELETE — 영향받은 행 수 반환
470    pub fn execute(&self, sql: impl Into<String>) -> Execute<'_> {
471        Execute::new(self, sql)
472    }
473}
474
475// ════════════════════════════════════════════
476// DatabaseQuery Trait Implementation
477// ════════════════════════════════════════════
478
479impl crate::traits::DatabaseQuery for Database {
480    // Query Builder methods are already implemented in impl Database block above
481    // No additional implementation needed - Trait is satisfied by existing methods
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_scalar_to_sql_literal() {
490        assert_eq!(ScalarValue::Null.to_sql_literal(), "NULL");
491        assert_eq!(ScalarValue::Int32(42).to_sql_literal(), "42");
492        assert_eq!(ScalarValue::Int64(100).to_sql_literal(), "100");
493        assert_eq!(ScalarValue::Float64(3.1).to_sql_literal(), "3.1");
494        assert_eq!(
495            ScalarValue::Utf8("hello".into()).to_sql_literal(),
496            "'hello'"
497        );
498        assert_eq!(ScalarValue::Boolean(true).to_sql_literal(), "TRUE");
499        assert_eq!(ScalarValue::Boolean(false).to_sql_literal(), "FALSE");
500    }
501
502    #[test]
503    fn test_sql_literal_single_quote_escape() {
504        // SQL injection 방어: single quote → 이스케이프
505        assert_eq!(
506            ScalarValue::Utf8("O'Brien".into()).to_sql_literal(),
507            "'O''Brien'"
508        );
509    }
510
511    #[test]
512    fn test_substitute_params_positional() {
513        let sql = "SELECT * FROM users WHERE id = $1 AND age > $2";
514        let params = vec![ScalarValue::Int32(42), ScalarValue::Int64(18)];
515        let result = substitute_params(sql, &params);
516        assert_eq!(result, "SELECT * FROM users WHERE id = 42 AND age > 18");
517    }
518
519    #[test]
520    fn test_substitute_params_string() {
521        let sql = "SELECT * FROM users WHERE name = $1";
522        let params = vec![ScalarValue::Utf8("Alice".into())];
523        let result = substitute_params(sql, &params);
524        assert_eq!(result, "SELECT * FROM users WHERE name = 'Alice'");
525    }
526
527    #[test]
528    fn test_substitute_params_null_bool() {
529        let sql = "SELECT * FROM t WHERE a = $1 AND b = $2 AND c = $3";
530        let params = vec![
531            ScalarValue::Null,
532            ScalarValue::Boolean(true),
533            ScalarValue::Boolean(false),
534        ];
535        let result = substitute_params(sql, &params);
536        assert_eq!(
537            result,
538            "SELECT * FROM t WHERE a = NULL AND b = TRUE AND c = FALSE"
539        );
540    }
541
542    #[test]
543    fn test_substitute_params_reverse_order_safety() {
544        // $10이 $1로 잘못 매치되지 않아야 함
545        let sql = "SELECT $1, $10";
546        let mut params = vec![ScalarValue::Int32(1)];
547        for i in 2..=10 {
548            params.push(ScalarValue::Int32(i));
549        }
550        let result = substitute_params(sql, &params);
551        assert_eq!(result, "SELECT 1, 10");
552    }
553
554    #[test]
555    fn test_resolve_named_params() {
556        let sql = "SELECT * FROM users WHERE name = :name AND age > :age";
557        let named = vec![
558            NamedParam {
559                name: "name".into(),
560                value: ScalarValue::Utf8("Alice".into()),
561            },
562            NamedParam {
563                name: "age".into(),
564                value: ScalarValue::Int32(18),
565            },
566        ];
567        let (resolved_sql, params) = resolve_named_params(sql, &named, &[]).unwrap();
568        assert_eq!(
569            resolved_sql,
570            "SELECT * FROM users WHERE name = $1 AND age > $2"
571        );
572        assert_eq!(params.len(), 2);
573    }
574
575    #[test]
576    fn test_resolve_named_params_empty() {
577        let sql = "SELECT * FROM users WHERE id = $1";
578        let positional = vec![ScalarValue::Int32(5)];
579        let (resolved_sql, params) = resolve_named_params(sql, &[], &positional).unwrap();
580        assert_eq!(resolved_sql, sql);
581        assert_eq!(params.len(), 1);
582    }
583
584    #[test]
585    fn test_mixed_params_error() {
586        let sql = "SELECT * FROM users";
587        let positional = vec![ScalarValue::Int32(1)];
588        let named = vec![NamedParam {
589            name: "a".into(),
590            value: ScalarValue::Int32(2),
591        }];
592        let result = resolve_named_params(sql, &named, &positional);
593        assert!(result.is_err());
594    }
595
596    #[test]
597    fn test_apply_params_named_full() {
598        let sql = "SELECT * FROM t WHERE x = :x AND y = :y";
599        let named = vec![
600            NamedParam {
601                name: "x".into(),
602                value: ScalarValue::Int32(10),
603            },
604            NamedParam {
605                name: "y".into(),
606                value: ScalarValue::Utf8("hello".into()),
607            },
608        ];
609        let result = apply_params(sql, &[], &named).unwrap();
610        assert_eq!(result, "SELECT * FROM t WHERE x = 10 AND y = 'hello'");
611    }
612
613    #[test]
614    fn test_apply_params_positional_full() {
615        let sql = "INSERT INTO t VALUES ($1, $2, $3)";
616        let params = vec![
617            ScalarValue::Int32(1),
618            ScalarValue::Utf8("test".into()),
619            ScalarValue::Float64(3.1),
620        ];
621        let result = apply_params(sql, &params, &[]).unwrap();
622        assert_eq!(result, "INSERT INTO t VALUES (1, 'test', 3.1)");
623    }
624
625    #[test]
626    fn test_into_param_trait() {
627        assert!(matches!(42i32.into_scalar(), ScalarValue::Int32(42)));
628        assert!(matches!(100i64.into_scalar(), ScalarValue::Int64(100)));
629        assert!(matches!(3.1f64.into_scalar(), ScalarValue::Float64(_)));
630        assert!(matches!("hello".into_scalar(), ScalarValue::Utf8(_)));
631        assert!(matches!(true.into_scalar(), ScalarValue::Boolean(true)));
632        assert!(matches!(
633            Option::<i32>::None.into_scalar(),
634            ScalarValue::Null
635        ));
636        assert!(matches!(Some(10i32).into_scalar(), ScalarValue::Int32(10)));
637    }
638}