1use crate::api::FromRow;
8use crate::engine::Database;
9use crate::error::{DbxError, DbxResult};
10use std::marker::PhantomData;
11
12#[derive(Debug, Clone)]
14pub enum ScalarValue {
15 Null,
16 Int32(i32),
17 Int64(i64),
18 Float64(f64),
19 Utf8(String),
20 Boolean(bool),
21}
22
23impl ScalarValue {
24 pub fn to_sql_literal(&self) -> String {
26 match self {
27 ScalarValue::Null => "NULL".to_string(),
28 ScalarValue::Int32(v) => v.to_string(),
29 ScalarValue::Int64(v) => v.to_string(),
30 ScalarValue::Float64(v) => format!("{v}"),
31 ScalarValue::Utf8(v) => format!("'{}'", v.replace('\'', "''")),
32 ScalarValue::Boolean(v) => {
33 if *v {
34 "TRUE".to_string()
35 } else {
36 "FALSE".to_string()
37 }
38 }
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45struct NamedParam {
46 name: String,
47 value: ScalarValue,
48}
49
50fn apply_params(sql: &str, positional: &[ScalarValue], named: &[NamedParam]) -> DbxResult<String> {
53 if !positional.is_empty() && !named.is_empty() {
54 return Err(DbxError::InvalidOperation {
55 message: "positional과 named 파라미터를 동시에 사용할 수 없습니다".to_string(),
56 context: "apply_params".to_string(),
57 });
58 }
59
60 let mut result = String::with_capacity(sql.len() + 64);
61 let mut chars = sql.chars().peekable();
62 let mut in_single_quote = false;
63 let mut in_double_quote = false;
64
65 while let Some(c) = chars.next() {
66 if in_single_quote {
67 result.push(c);
68 if c == '\'' {
69 in_single_quote = false;
70 }
71 continue;
72 } else if in_double_quote {
73 result.push(c);
74 if c == '"' {
75 in_double_quote = false;
76 }
77 continue;
78 }
79
80 match c {
81 '\'' => {
82 in_single_quote = true;
83 result.push(c);
84 }
85 '"' => {
86 in_double_quote = true;
87 result.push(c);
88 }
89 '$' if !positional.is_empty() => {
90 let mut num_str = String::new();
91 while let Some(&next_c) = chars.peek() {
92 if next_c.is_ascii_digit() {
93 num_str.push(next_c);
94 chars.next();
95 } else {
96 break;
97 }
98 }
99 if let Ok(idx) = num_str.parse::<usize>() {
100 if idx > 0 && idx <= positional.len() {
101 result.push_str(&positional[idx - 1].to_sql_literal());
102 } else {
103 result.push('$');
104 result.push_str(&num_str);
105 }
106 } else {
107 result.push('$');
108 result.push_str(&num_str);
109 }
110 }
111 ':' if !named.is_empty() => {
112 let mut name = String::new();
113 while let Some(&next_c) = chars.peek() {
114 if next_c.is_ascii_alphanumeric() || next_c == '_' {
115 name.push(next_c);
116 chars.next();
117 } else {
118 break;
119 }
120 }
121 if !name.is_empty() {
122 let mut found = false;
123 for np in named {
124 if np.name == name {
125 result.push_str(&np.value.to_sql_literal());
126 found = true;
127 break;
128 }
129 }
130 if !found {
131 result.push(':');
132 result.push_str(&name);
133 }
134 } else {
135 result.push(':');
136 }
137 }
138 _ => {
139 result.push(c);
140 }
141 }
142 }
143
144 Ok(result)
145}
146
147pub struct Query<'a, T> {
149 db: &'a Database,
150 sql: String,
151 params: Vec<ScalarValue>,
152 named_params: Vec<NamedParam>,
153 _marker: PhantomData<T>,
154}
155
156impl<'a, T: FromRow> Query<'a, T> {
157 pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
158 Self {
159 db,
160 sql: sql.into(),
161 params: Vec::new(),
162 named_params: Vec::new(),
163 _marker: PhantomData,
164 }
165 }
166
167 pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
169 self.params.push(value.into_scalar());
170 self
171 }
172
173 pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
175 self.named_params.push(NamedParam {
176 name: name.to_string(),
177 value: value.into_scalar(),
178 });
179 self
180 }
181
182 pub fn fetch_all(self) -> DbxResult<Vec<T>> {
184 let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
185 let batches = self.db.execute_sql(&final_sql)?;
186 let mut rows = Vec::new();
187 for batch in &batches {
188 for row_idx in 0..batch.num_rows() {
189 rows.push(T::from_row(batch, row_idx)?);
190 }
191 }
192 Ok(rows)
193 }
194
195 pub fn fetch_first(self) -> DbxResult<Option<T>> {
197 let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
198 let batches = self.db.execute_sql(&final_sql)?;
199 for batch in &batches {
200 if batch.num_rows() > 0 {
201 return Ok(Some(T::from_row(batch, 0)?));
202 }
203 }
204 Ok(None)
205 }
206}
207
208pub struct QueryOne<'a, T> {
210 db: &'a Database,
211 sql: String,
212 params: Vec<ScalarValue>,
213 named_params: Vec<NamedParam>,
214 _marker: PhantomData<T>,
215}
216
217impl<'a, T: FromRow> QueryOne<'a, T> {
218 pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
219 Self {
220 db,
221 sql: sql.into(),
222 params: Vec::new(),
223 named_params: Vec::new(),
224 _marker: PhantomData,
225 }
226 }
227
228 pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
229 self.params.push(value.into_scalar());
230 self
231 }
232
233 pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
234 self.named_params.push(NamedParam {
235 name: name.to_string(),
236 value: value.into_scalar(),
237 });
238 self
239 }
240
241 pub fn fetch(self) -> DbxResult<T> {
243 let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
244 let batches = self.db.execute_sql(&final_sql)?;
245 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
246 if total_rows == 0 {
247 return Err(DbxError::KeyNotFound);
248 }
249 if total_rows > 1 {
250 return Err(DbxError::InvalidOperation {
251 message: format!("expected 1 row, got {total_rows}"),
252 context: "QueryOne::fetch".to_string(),
253 });
254 }
255 T::from_row(&batches[0], 0)
256 }
257}
258
259pub struct QueryOptional<'a, T> {
261 db: &'a Database,
262 sql: String,
263 params: Vec<ScalarValue>,
264 named_params: Vec<NamedParam>,
265 _marker: PhantomData<T>,
266}
267
268impl<'a, T: FromRow> QueryOptional<'a, T> {
269 pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
270 Self {
271 db,
272 sql: sql.into(),
273 params: Vec::new(),
274 named_params: Vec::new(),
275 _marker: PhantomData,
276 }
277 }
278
279 pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
280 self.params.push(value.into_scalar());
281 self
282 }
283
284 pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
285 self.named_params.push(NamedParam {
286 name: name.to_string(),
287 value: value.into_scalar(),
288 });
289 self
290 }
291
292 pub fn fetch(self) -> DbxResult<Option<T>> {
293 let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
294 let batches = self.db.execute_sql(&final_sql)?;
295 for batch in &batches {
296 if batch.num_rows() > 0 {
297 return Ok(Some(T::from_row(batch, 0)?));
298 }
299 }
300 Ok(None)
301 }
302}
303
304pub struct QueryScalar<'a, T> {
306 db: &'a Database,
307 sql: String,
308 params: Vec<ScalarValue>,
309 _marker: PhantomData<T>,
310}
311
312impl<'a, T: FromScalar> QueryScalar<'a, T> {
313 pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
314 Self {
315 db,
316 sql: sql.into(),
317 params: Vec::new(),
318 _marker: PhantomData,
319 }
320 }
321
322 pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
323 self.params.push(value.into_scalar());
324 self
325 }
326
327 pub fn fetch(self) -> DbxResult<T> {
328 let final_sql = apply_params(&self.sql, &self.params, &[])?;
329 let batches = self.db.execute_sql(&final_sql)?;
330 for batch in &batches {
331 if batch.num_rows() > 0 && batch.num_columns() > 0 {
332 let col = batch.column(0);
333 let sv = crate::storage::columnar::ScalarValue::from_array(col, 0)?;
334 let qsv = scalar_to_query_scalar(&sv);
335 return T::from_scalar(&qsv);
336 }
337 }
338 Err(DbxError::KeyNotFound)
339 }
340}
341
342pub struct Execute<'a> {
344 db: &'a Database,
345 sql: String,
346 params: Vec<ScalarValue>,
347 named_params: Vec<NamedParam>,
348}
349
350impl<'a> Execute<'a> {
351 pub fn new(db: &'a Database, sql: impl Into<String>) -> Self {
352 Self {
353 db,
354 sql: sql.into(),
355 params: Vec::new(),
356 named_params: Vec::new(),
357 }
358 }
359
360 pub fn bind<V: IntoParam>(mut self, value: V) -> Self {
361 self.params.push(value.into_scalar());
362 self
363 }
364
365 pub fn param<V: IntoParam>(mut self, name: &str, value: V) -> Self {
366 self.named_params.push(NamedParam {
367 name: name.to_string(),
368 value: value.into_scalar(),
369 });
370 self
371 }
372
373 pub fn run(self) -> DbxResult<usize> {
375 let final_sql = apply_params(&self.sql, &self.params, &self.named_params)?;
376 let batches = self.db.execute_sql(&final_sql)?;
377 Ok(batches.iter().map(|b| b.num_rows()).sum())
378 }
379}
380
381pub trait IntoParam {
383 fn into_scalar(self) -> ScalarValue;
384}
385
386pub trait FromScalar: Sized {
388 fn from_scalar(value: &ScalarValue) -> DbxResult<Self>;
389}
390
391fn scalar_to_query_scalar(sv: &crate::storage::columnar::ScalarValue) -> ScalarValue {
393 use crate::storage::columnar::ScalarValue as CSV;
394 match sv {
395 CSV::Null => ScalarValue::Null,
396 CSV::Int32(v) => ScalarValue::Int32(*v),
397 CSV::Int64(v) => ScalarValue::Int64(*v),
398 CSV::Float64(v) => ScalarValue::Float64(*v),
399 CSV::Utf8(v) => ScalarValue::Utf8(v.clone()),
400 CSV::Boolean(v) => ScalarValue::Boolean(*v),
401 CSV::Binary(_) => {
402 ScalarValue::Null
405 }
406 }
407}
408
409impl IntoParam for i32 {
411 fn into_scalar(self) -> ScalarValue {
412 ScalarValue::Int32(self)
413 }
414}
415
416impl IntoParam for i64 {
417 fn into_scalar(self) -> ScalarValue {
418 ScalarValue::Int64(self)
419 }
420}
421
422impl IntoParam for f64 {
423 fn into_scalar(self) -> ScalarValue {
424 ScalarValue::Float64(self)
425 }
426}
427
428impl IntoParam for &str {
429 fn into_scalar(self) -> ScalarValue {
430 ScalarValue::Utf8(self.to_string())
431 }
432}
433
434impl IntoParam for String {
435 fn into_scalar(self) -> ScalarValue {
436 ScalarValue::Utf8(self)
437 }
438}
439
440impl IntoParam for bool {
441 fn into_scalar(self) -> ScalarValue {
442 ScalarValue::Boolean(self)
443 }
444}
445
446impl<T: IntoParam> IntoParam for Option<T> {
447 fn into_scalar(self) -> ScalarValue {
448 match self {
449 Some(v) => v.into_scalar(),
450 None => ScalarValue::Null,
451 }
452 }
453}
454
455impl FromScalar for i64 {
457 fn from_scalar(value: &ScalarValue) -> DbxResult<Self> {
458 match value {
459 ScalarValue::Int64(v) => Ok(*v),
460 _ => Err(crate::error::DbxError::TypeMismatch {
461 expected: "Int64".to_string(),
462 actual: format!("{:?}", value),
463 }),
464 }
465 }
466}
467
468impl FromScalar for i32 {
469 fn from_scalar(value: &ScalarValue) -> DbxResult<Self> {
470 match value {
471 ScalarValue::Int32(v) => Ok(*v),
472 _ => Err(crate::error::DbxError::TypeMismatch {
473 expected: "Int32".to_string(),
474 actual: format!("{:?}", value),
475 }),
476 }
477 }
478}
479
480impl FromScalar for f64 {
481 fn from_scalar(value: &ScalarValue) -> DbxResult<Self> {
482 match value {
483 ScalarValue::Float64(v) => Ok(*v),
484 _ => Err(crate::error::DbxError::TypeMismatch {
485 expected: "Float64".to_string(),
486 actual: format!("{:?}", value),
487 }),
488 }
489 }
490}
491
492impl Database {
494 pub fn query<T: FromRow>(&self, sql: impl Into<String>) -> Query<'_, T> {
496 Query::new(self, sql)
497 }
498
499 pub fn query_one<T: FromRow>(&self, sql: impl Into<String>) -> QueryOne<'_, T> {
501 QueryOne::new(self, sql)
502 }
503
504 pub fn query_optional<T: FromRow>(&self, sql: impl Into<String>) -> QueryOptional<'_, T> {
506 QueryOptional::new(self, sql)
507 }
508
509 pub fn query_scalar<T: FromScalar>(&self, sql: impl Into<String>) -> QueryScalar<'_, T> {
511 QueryScalar::new(self, sql)
512 }
513
514 pub fn execute(&self, sql: impl Into<String>) -> Execute<'_> {
516 Execute::new(self, sql)
517 }
518}
519
520impl crate::traits::DatabaseQuery for Database {
525 }
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532
533 #[test]
534 fn test_scalar_to_sql_literal() {
535 assert_eq!(ScalarValue::Null.to_sql_literal(), "NULL");
536 assert_eq!(ScalarValue::Int32(42).to_sql_literal(), "42");
537 assert_eq!(ScalarValue::Int64(100).to_sql_literal(), "100");
538 assert_eq!(ScalarValue::Float64(3.1).to_sql_literal(), "3.1");
539 assert_eq!(
540 ScalarValue::Utf8("hello".into()).to_sql_literal(),
541 "'hello'"
542 );
543 assert_eq!(ScalarValue::Boolean(true).to_sql_literal(), "TRUE");
544 assert_eq!(ScalarValue::Boolean(false).to_sql_literal(), "FALSE");
545 }
546
547 #[test]
548 fn test_sql_literal_single_quote_escape() {
549 assert_eq!(
551 ScalarValue::Utf8("O'Brien".into()).to_sql_literal(),
552 "'O''Brien'"
553 );
554 }
555
556 #[test]
557 fn test_apply_params_positional() {
558 let sql = "SELECT * FROM users WHERE id = $1 AND age > $2";
559 let params = vec![ScalarValue::Int32(42), ScalarValue::Int64(18)];
560 let result = apply_params(sql, ¶ms, &[]).unwrap();
561 assert_eq!(result, "SELECT * FROM users WHERE id = 42 AND age > 18");
562 }
563
564 #[test]
565 fn test_apply_params_string() {
566 let sql = "SELECT * FROM users WHERE name = $1";
567 let params = vec![ScalarValue::Utf8("Alice".into())];
568 let result = apply_params(sql, ¶ms, &[]).unwrap();
569 assert_eq!(result, "SELECT * FROM users WHERE name = 'Alice'");
570 }
571
572 #[test]
573 fn test_apply_params_null_bool() {
574 let sql = "SELECT * FROM t WHERE a = $1 AND b = $2 AND c = $3";
575 let params = vec![
576 ScalarValue::Null,
577 ScalarValue::Boolean(true),
578 ScalarValue::Boolean(false),
579 ];
580 let result = apply_params(sql, ¶ms, &[]).unwrap();
581 assert_eq!(
582 result,
583 "SELECT * FROM t WHERE a = NULL AND b = TRUE AND c = FALSE"
584 );
585 }
586
587 #[test]
588 fn test_apply_params_reverse_order_safety() {
589 let sql = "SELECT $1, $10";
591 let mut params = vec![ScalarValue::Int32(1)];
592 for i in 2..=10 {
593 params.push(ScalarValue::Int32(i));
594 }
595 let result = apply_params(sql, ¶ms, &[]).unwrap();
596 assert_eq!(result, "SELECT 1, 10");
597 }
598
599 #[test]
600 fn test_apply_named_params() {
601 let sql = "SELECT * FROM users WHERE name = :name AND age > :age";
602 let named = vec![
603 NamedParam {
604 name: "name".into(),
605 value: ScalarValue::Utf8("Alice".into()),
606 },
607 NamedParam {
608 name: "age".into(),
609 value: ScalarValue::Int32(18),
610 },
611 ];
612 let result = apply_params(sql, &[], &named).unwrap();
613 assert_eq!(
614 result,
615 "SELECT * FROM users WHERE name = 'Alice' AND age > 18"
616 );
617 }
618
619 #[test]
620 fn test_apply_named_params_ignores_strings() {
621 let sql = "SELECT * FROM users WHERE txt = 'cost: $1, name: :name' AND id = $1";
622 let positional = vec![ScalarValue::Int32(5)];
623 let result = apply_params(sql, &positional, &[]).unwrap();
624 assert_eq!(
625 result,
626 "SELECT * FROM users WHERE txt = 'cost: $1, name: :name' AND id = 5"
627 );
628 }
629
630 #[test]
631 fn test_mixed_params_error() {
632 let sql = "SELECT * FROM users";
633 let positional = vec![ScalarValue::Int32(1)];
634 let named = vec![NamedParam {
635 name: "a".into(),
636 value: ScalarValue::Int32(2),
637 }];
638 let result = apply_params(sql, &positional, &named);
639 assert!(result.is_err());
640 }
641
642 #[test]
643 fn test_apply_params_named_full() {
644 let sql = "SELECT * FROM t WHERE x = :x AND y = :y";
645 let named = vec![
646 NamedParam {
647 name: "x".into(),
648 value: ScalarValue::Int32(10),
649 },
650 NamedParam {
651 name: "y".into(),
652 value: ScalarValue::Utf8("hello".into()),
653 },
654 ];
655 let result = apply_params(sql, &[], &named).unwrap();
656 assert_eq!(result, "SELECT * FROM t WHERE x = 10 AND y = 'hello'");
657 }
658
659 #[test]
660 fn test_apply_params_positional_full() {
661 let sql = "INSERT INTO t VALUES ($1, $2, $3)";
662 let params = vec![
663 ScalarValue::Int32(1),
664 ScalarValue::Utf8("test".into()),
665 ScalarValue::Float64(3.1),
666 ];
667 let result = apply_params(sql, ¶ms, &[]).unwrap();
668 assert_eq!(result, "INSERT INTO t VALUES (1, 'test', 3.1)");
669 }
670
671 #[test]
672 fn test_into_param_trait() {
673 assert!(matches!(42i32.into_scalar(), ScalarValue::Int32(42)));
674 assert!(matches!(100i64.into_scalar(), ScalarValue::Int64(100)));
675 assert!(matches!(3.1f64.into_scalar(), ScalarValue::Float64(_)));
676 assert!(matches!("hello".into_scalar(), ScalarValue::Utf8(_)));
677 assert!(matches!(true.into_scalar(), ScalarValue::Boolean(true)));
678 assert!(matches!(
679 Option::<i32>::None.into_scalar(),
680 ScalarValue::Null
681 ));
682 assert!(matches!(Some(10i32).into_scalar(), ScalarValue::Int32(10)));
683 }
684}