1use crate::api::FromRow;
8use crate::engine::Database;
9use crate::error::{DbxError, DbxResult};
10use std::marker::PhantomData;
11
12#[derive(Debug, Clone)]
14pub enum ScalarValue {
15 Null,
16 Int32(i32),
17 Int64(i64),
18 Float64(f64),
19 Utf8(String),
20 Boolean(bool),
21}
22
23impl ScalarValue {
24 pub fn to_sql_literal(&self) -> String {
26 match self {
27 ScalarValue::Null => "NULL".to_string(),
28 ScalarValue::Int32(v) => v.to_string(),
29 ScalarValue::Int64(v) => v.to_string(),
30 ScalarValue::Float64(v) => format!("{v}"),
31 ScalarValue::Utf8(v) => format!("'{}'", v.replace('\'', "''")),
32 ScalarValue::Boolean(v) => {
33 if *v {
34 "TRUE".to_string()
35 } else {
36 "FALSE".to_string()
37 }
38 }
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45struct NamedParam {
46 name: String,
47 value: ScalarValue,
48}
49
50fn substitute_params(sql: &str, params: &[ScalarValue]) -> String {
52 let mut result = sql.to_string();
53 for (i, param) in params.iter().enumerate().rev() {
55 let placeholder = format!("${}", i + 1);
56 result = result.replace(&placeholder, ¶m.to_sql_literal());
57 }
58 result
59}
60
61fn resolve_named_params(
64 sql: &str,
65 named: &[NamedParam],
66 positional: &[ScalarValue],
67) -> DbxResult<(String, Vec<ScalarValue>)> {
68 if named.is_empty() {
69 return Ok((sql.to_string(), positional.to_vec()));
71 }
72 if !positional.is_empty() {
73 return Err(DbxError::InvalidOperation {
74 message: "positional과 named 파라미터를 동시에 사용할 수 없습니다".to_string(),
75 context: "resolve_named_params".to_string(),
76 });
77 }
78
79 let mut result_sql = sql.to_string();
80 let mut ordered_params = Vec::new();
81 let mut idx = 1;
82
83 for np in named {
84 let named_placeholder = format!(":{}", np.name);
85 if result_sql.contains(&named_placeholder) {
86 let positional_placeholder = format!("${idx}");
87 result_sql = result_sql.replace(&named_placeholder, &positional_placeholder);
88 ordered_params.push(np.value.clone());
89 idx += 1;
90 }
91 }
92
93 Ok((result_sql, ordered_params))
94}
95
96fn apply_params(sql: &str, positional: &[ScalarValue], named: &[NamedParam]) -> DbxResult<String> {
98 let (resolved_sql, params) = resolve_named_params(sql, named, positional)?;
99 Ok(substitute_params(&resolved_sql, ¶ms))
100}
101
102pub struct Query<'a, T> {
104 db: &'a Database,
105 sql: String,
106 params: Vec<ScalarValue>,
107 named_params: Vec<NamedParam>,
108 _marker: PhantomData<T>,
109}
110
111impl<'a, T: FromRow> Query<'a, T> {
112 pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
113 Self {
114 db,
115 sql: sql.into(),
116 params: Vec::new(),
117 named_params: Vec::new(),
118 _marker: PhantomData,
119 }
120 }
121
122 pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
124 self.params.push(value.into_scalar());
125 self
126 }
127
128 pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
130 self.named_params.push(NamedParam {
131 name: name.to_string(),
132 value: value.into_scalar(),
133 });
134 self
135 }
136
137 pub fn fetch_all(self) -> DbxResult<Vec<T>> {
139 let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
140 let batches = self.db.execute_sql(&final_sql)?;
141 let mut rows = Vec::new();
142 for batch in &batches {
143 for row_idx in 0..batch.num_rows() {
144 rows.push(T::from_row(batch, row_idx)?);
145 }
146 }
147 Ok(rows)
148 }
149
150 pub fn fetch_first(self) -> DbxResult<Option<T>> {
152 let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
153 let batches = self.db.execute_sql(&final_sql)?;
154 for batch in &batches {
155 if batch.num_rows() > 0 {
156 return Ok(Some(T::from_row(batch, 0)?));
157 }
158 }
159 Ok(None)
160 }
161}
162
163pub struct QueryOne<'a, T> {
165 db: &'a Database,
166 sql: String,
167 params: Vec<ScalarValue>,
168 named_params: Vec<NamedParam>,
169 _marker: PhantomData<T>,
170}
171
172impl<'a, T: FromRow> QueryOne<'a, T> {
173 pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
174 Self {
175 db,
176 sql: sql.into(),
177 params: Vec::new(),
178 named_params: Vec::new(),
179 _marker: PhantomData,
180 }
181 }
182
183 pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
184 self.params.push(value.into_scalar());
185 self
186 }
187
188 pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
189 self.named_params.push(NamedParam {
190 name: name.to_string(),
191 value: value.into_scalar(),
192 });
193 self
194 }
195
196 pub fn fetch(self) -> DbxResult<T> {
198 let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
199 let batches = self.db.execute_sql(&final_sql)?;
200 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
201 if total_rows == 0 {
202 return Err(DbxError::KeyNotFound);
203 }
204 if total_rows > 1 {
205 return Err(DbxError::InvalidOperation {
206 message: format!("expected 1 row, got {total_rows}"),
207 context: "QueryOne::fetch".to_string(),
208 });
209 }
210 T::from_row(&batches[0], 0)
211 }
212}
213
214pub struct QueryOptional<'a, T> {
216 db: &'a Database,
217 sql: String,
218 params: Vec<ScalarValue>,
219 named_params: Vec<NamedParam>,
220 _marker: PhantomData<T>,
221}
222
223impl<'a, T: FromRow> QueryOptional<'a, T> {
224 pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
225 Self {
226 db,
227 sql: sql.into(),
228 params: Vec::new(),
229 named_params: Vec::new(),
230 _marker: PhantomData,
231 }
232 }
233
234 pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
235 self.params.push(value.into_scalar());
236 self
237 }
238
239 pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
240 self.named_params.push(NamedParam {
241 name: name.to_string(),
242 value: value.into_scalar(),
243 });
244 self
245 }
246
247 pub fn fetch(self) -> DbxResult<Option<T>> {
248 let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
249 let batches = self.db.execute_sql(&final_sql)?;
250 for batch in &batches {
251 if batch.num_rows() > 0 {
252 return Ok(Some(T::from_row(batch, 0)?));
253 }
254 }
255 Ok(None)
256 }
257}
258
259pub struct QueryScalar<'a, T> {
261 db: &'a Database,
262 sql: String,
263 params: Vec<ScalarValue>,
264 _marker: PhantomData<T>,
265}
266
267impl<'a, T: FromScalar> QueryScalar<'a, T> {
268 pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
269 Self {
270 db,
271 sql: sql.into(),
272 params: Vec::new(),
273 _marker: PhantomData,
274 }
275 }
276
277 pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
278 self.params.push(value.into_scalar());
279 self
280 }
281
282 pub fn fetch(self) -> DbxResult<T> {
283 let final_sql = substitute_params(&self.sql, &self.params);
284 let batches = self.db.execute_sql(&final_sql)?;
285 for batch in &batches {
286 if batch.num_rows() > 0 && batch.num_columns() > 0 {
287 let col = batch.column(0);
288 let sv = crate::storage::columnar::ScalarValue::from_array(col, 0)?;
289 let qsv = scalar_to_query_scalar(&sv);
290 return T::from_scalar(&qsv);
291 }
292 }
293 Err(DbxError::KeyNotFound)
294 }
295}
296
297pub struct Execute<'a> {
299 db: &'a Database,
300 sql: String,
301 params: Vec<ScalarValue>,
302 named_params: Vec<NamedParam>,
303}
304
305impl<'a> Execute<'a> {
306 pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
307 Self {
308 db,
309 sql: sql.into(),
310 params: Vec::new(),
311 named_params: Vec::new(),
312 }
313 }
314
315 pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
316 self.params.push(value.into_scalar());
317 self
318 }
319
320 pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
321 self.named_params.push(NamedParam {
322 name: name.to_string(),
323 value: value.into_scalar(),
324 });
325 self
326 }
327
328 pub fn run(self) -> DbxResult<usize> {
330 let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
331 let batches = self.db.execute_sql(&final_sql)?;
332 Ok(batches.iter().map(|b| b.num_rows()).sum())
333 }
334}
335
336pub trait IntoParam {
338 fn into_scalar(self) -> ScalarValue;
339}
340
341pub trait FromScalar: Sized {
343 fn from_scalar(value: &ScalarValue) -> DbxResult<Self>;
344}
345
346fn scalar_to_query_scalar(sv: &crate::storage::columnar::ScalarValue) -> ScalarValue {
348 use crate::storage::columnar::ScalarValue as CSV;
349 match sv {
350 CSV::Null => ScalarValue::Null,
351 CSV::Int32(v) => ScalarValue::Int32(*v),
352 CSV::Int64(v) => ScalarValue::Int64(*v),
353 CSV::Float64(v) => ScalarValue::Float64(*v),
354 CSV::Utf8(v) => ScalarValue::Utf8(v.clone()),
355 CSV::Boolean(v) => ScalarValue::Boolean(*v),
356 CSV::Binary(_) => {
357 ScalarValue::Null
360 }
361 }
362}
363
364impl IntoParam for i32 {
366 fn into_scalar(self) -> ScalarValue {
367 ScalarValue::Int32(self)
368 }
369}
370
371impl IntoParam for i64 {
372 fn into_scalar(self) -> ScalarValue {
373 ScalarValue::Int64(self)
374 }
375}
376
377impl IntoParam for f64 {
378 fn into_scalar(self) -> ScalarValue {
379 ScalarValue::Float64(self)
380 }
381}
382
383impl IntoParam for &str {
384 fn into_scalar(self) -> ScalarValue {
385 ScalarValue::Utf8(self.to_string())
386 }
387}
388
389impl IntoParam for String {
390 fn into_scalar(self) -> ScalarValue {
391 ScalarValue::Utf8(self)
392 }
393}
394
395impl IntoParam for bool {
396 fn into_scalar(self) -> ScalarValue {
397 ScalarValue::Boolean(self)
398 }
399}
400
401impl<T: IntoParam> IntoParam for Option<T> {
402 fn into_scalar(self) -> ScalarValue {
403 match self {
404 Some(v) => v.into_scalar(),
405 None => ScalarValue::Null,
406 }
407 }
408}
409
410impl FromScalar for i64 {
412 fn from_scalar(value: &ScalarValue) -> DbxResult<Self> {
413 match value {
414 ScalarValue::Int64(v) => Ok(*v),
415 _ => Err(crate::error::DbxError::TypeMismatch {
416 expected: "Int64".to_string(),
417 actual: format!("{:?}", value),
418 }),
419 }
420 }
421}
422
423impl FromScalar for i32 {
424 fn from_scalar(value: &ScalarValue) -> DbxResult<Self> {
425 match value {
426 ScalarValue::Int32(v) => Ok(*v),
427 _ => Err(crate::error::DbxError::TypeMismatch {
428 expected: "Int32".to_string(),
429 actual: format!("{:?}", value),
430 }),
431 }
432 }
433}
434
435impl FromScalar for f64 {
436 fn from_scalar(value: &ScalarValue) -> DbxResult<Self> {
437 match value {
438 ScalarValue::Float64(v) => Ok(*v),
439 _ => Err(crate::error::DbxError::TypeMismatch {
440 expected: "Float64".to_string(),
441 actual: format!("{:?}", value),
442 }),
443 }
444 }
445}
446
447impl Database {
449 pub fn query<T: FromRow>(&self, sql: impl Into<String>) -> Query<'_, T> {
451 Query::new(self, sql)
452 }
453
454 pub fn query_one<T: FromRow>(&self, sql: impl Into<String>) -> QueryOne<'_, T> {
456 QueryOne::new(self, sql)
457 }
458
459 pub fn query_optional<T: FromRow>(&self, sql: impl Into<String>) -> QueryOptional<'_, T> {
461 QueryOptional::new(self, sql)
462 }
463
464 pub fn query_scalar<T: FromScalar>(&self, sql: impl Into<String>) -> QueryScalar<'_, T> {
466 QueryScalar::new(self, sql)
467 }
468
469 pub fn execute(&self, sql: impl Into<String>) -> Execute<'_> {
471 Execute::new(self, sql)
472 }
473}
474
475impl crate::traits::DatabaseQuery for Database {
480 }
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn test_scalar_to_sql_literal() {
490 assert_eq!(ScalarValue::Null.to_sql_literal(), "NULL");
491 assert_eq!(ScalarValue::Int32(42).to_sql_literal(), "42");
492 assert_eq!(ScalarValue::Int64(100).to_sql_literal(), "100");
493 assert_eq!(ScalarValue::Float64(3.1).to_sql_literal(), "3.1");
494 assert_eq!(
495 ScalarValue::Utf8("hello".into()).to_sql_literal(),
496 "'hello'"
497 );
498 assert_eq!(ScalarValue::Boolean(true).to_sql_literal(), "TRUE");
499 assert_eq!(ScalarValue::Boolean(false).to_sql_literal(), "FALSE");
500 }
501
502 #[test]
503 fn test_sql_literal_single_quote_escape() {
504 assert_eq!(
506 ScalarValue::Utf8("O'Brien".into()).to_sql_literal(),
507 "'O''Brien'"
508 );
509 }
510
511 #[test]
512 fn test_substitute_params_positional() {
513 let sql = "SELECT * FROM users WHERE id = $1 AND age > $2";
514 let params = vec![ScalarValue::Int32(42), ScalarValue::Int64(18)];
515 let result = substitute_params(sql, ¶ms);
516 assert_eq!(result, "SELECT * FROM users WHERE id = 42 AND age > 18");
517 }
518
519 #[test]
520 fn test_substitute_params_string() {
521 let sql = "SELECT * FROM users WHERE name = $1";
522 let params = vec![ScalarValue::Utf8("Alice".into())];
523 let result = substitute_params(sql, ¶ms);
524 assert_eq!(result, "SELECT * FROM users WHERE name = 'Alice'");
525 }
526
527 #[test]
528 fn test_substitute_params_null_bool() {
529 let sql = "SELECT * FROM t WHERE a = $1 AND b = $2 AND c = $3";
530 let params = vec![
531 ScalarValue::Null,
532 ScalarValue::Boolean(true),
533 ScalarValue::Boolean(false),
534 ];
535 let result = substitute_params(sql, ¶ms);
536 assert_eq!(
537 result,
538 "SELECT * FROM t WHERE a = NULL AND b = TRUE AND c = FALSE"
539 );
540 }
541
542 #[test]
543 fn test_substitute_params_reverse_order_safety() {
544 let sql = "SELECT $1, $10";
546 let mut params = vec![ScalarValue::Int32(1)];
547 for i in 2..=10 {
548 params.push(ScalarValue::Int32(i));
549 }
550 let result = substitute_params(sql, ¶ms);
551 assert_eq!(result, "SELECT 1, 10");
552 }
553
554 #[test]
555 fn test_resolve_named_params() {
556 let sql = "SELECT * FROM users WHERE name = :name AND age > :age";
557 let named = vec![
558 NamedParam {
559 name: "name".into(),
560 value: ScalarValue::Utf8("Alice".into()),
561 },
562 NamedParam {
563 name: "age".into(),
564 value: ScalarValue::Int32(18),
565 },
566 ];
567 let (resolved_sql, params) = resolve_named_params(sql, &named, &[]).unwrap();
568 assert_eq!(
569 resolved_sql,
570 "SELECT * FROM users WHERE name = $1 AND age > $2"
571 );
572 assert_eq!(params.len(), 2);
573 }
574
575 #[test]
576 fn test_resolve_named_params_empty() {
577 let sql = "SELECT * FROM users WHERE id = $1";
578 let positional = vec![ScalarValue::Int32(5)];
579 let (resolved_sql, params) = resolve_named_params(sql, &[], &positional).unwrap();
580 assert_eq!(resolved_sql, sql);
581 assert_eq!(params.len(), 1);
582 }
583
584 #[test]
585 fn test_mixed_params_error() {
586 let sql = "SELECT * FROM users";
587 let positional = vec![ScalarValue::Int32(1)];
588 let named = vec![NamedParam {
589 name: "a".into(),
590 value: ScalarValue::Int32(2),
591 }];
592 let result = resolve_named_params(sql, &named, &positional);
593 assert!(result.is_err());
594 }
595
596 #[test]
597 fn test_apply_params_named_full() {
598 let sql = "SELECT * FROM t WHERE x = :x AND y = :y";
599 let named = vec![
600 NamedParam {
601 name: "x".into(),
602 value: ScalarValue::Int32(10),
603 },
604 NamedParam {
605 name: "y".into(),
606 value: ScalarValue::Utf8("hello".into()),
607 },
608 ];
609 let result = apply_params(sql, &[], &named).unwrap();
610 assert_eq!(result, "SELECT * FROM t WHERE x = 10 AND y = 'hello'");
611 }
612
613 #[test]
614 fn test_apply_params_positional_full() {
615 let sql = "INSERT INTO t VALUES ($1, $2, $3)";
616 let params = vec![
617 ScalarValue::Int32(1),
618 ScalarValue::Utf8("test".into()),
619 ScalarValue::Float64(3.1),
620 ];
621 let result = apply_params(sql, ¶ms, &[]).unwrap();
622 assert_eq!(result, "INSERT INTO t VALUES (1, 'test', 3.1)");
623 }
624
625 #[test]
626 fn test_into_param_trait() {
627 assert!(matches!(42i32.into_scalar(), ScalarValue::Int32(42)));
628 assert!(matches!(100i64.into_scalar(), ScalarValue::Int64(100)));
629 assert!(matches!(3.1f64.into_scalar(), ScalarValue::Float64(_)));
630 assert!(matches!("hello".into_scalar(), ScalarValue::Utf8(_)));
631 assert!(matches!(true.into_scalar(), ScalarValue::Boolean(true)));
632 assert!(matches!(
633 Option::<i32>::None.into_scalar(),
634 ScalarValue::Null
635 ));
636 assert!(matches!(Some(10i32).into_scalar(), ScalarValue::Int32(10)));
637 }
638}