1use crate::backends::types::QueryValue;
8use crate::orm::Model;
9use reinhardt_query::prelude::{
10 Alias, ColumnRef, Expr, ExprTrait, Func, Query, QueryStatementBuilder, SelectStatement,
11};
12use rust_decimal::prelude::ToPrimitive;
13use std::marker::PhantomData;
14
15#[derive(Debug)]
17pub enum ExecutionResult<T> {
18 One(T),
20 OneOrNone(Option<T>),
22 All(Vec<T>),
24 Scalar(String),
26 None,
28}
29
30#[non_exhaustive]
32#[derive(Debug, thiserror::Error)]
33pub enum ExecutionError {
34 #[error("Database error: {0}")]
36 Database(#[from] crate::backends::DatabaseError),
37
38 #[error("No result found")]
40 NoResultFound,
41
42 #[error("Multiple results found (expected 1, got {0})")]
44 MultipleResultsFound(usize),
45
46 #[error("Failed to deserialize result: {0}")]
48 Deserialization(#[from] serde_json::Error),
49
50 #[error("Query building error: {0}")]
52 QueryBuild(String),
53
54 #[error("Generic error: {0}")]
56 Generic(#[from] anyhow::Error),
57}
58
59fn convert_value_to_query_value(value: reinhardt_query::value::Value) -> QueryValue {
61 use reinhardt_query::value::Value as SV;
62
63 match value {
64 SV::Bool(None)
66 | SV::TinyInt(None)
67 | SV::SmallInt(None)
68 | SV::Int(None)
69 | SV::BigInt(None)
70 | SV::TinyUnsigned(None)
71 | SV::SmallUnsigned(None)
72 | SV::Unsigned(None)
73 | SV::BigUnsigned(None)
74 | SV::Float(None)
75 | SV::Double(None)
76 | SV::String(None)
77 | SV::Char(None)
78 | SV::Bytes(None)
79 | SV::ChronoDateTimeUtc(None)
80 | SV::ChronoDateTimeLocal(None)
81 | SV::ChronoDateTimeWithTimeZone(None)
82 | SV::ChronoDate(None)
83 | SV::ChronoTime(None)
84 | SV::ChronoDateTime(None)
85 | SV::Json(None)
86 | SV::Decimal(None)
87 | SV::BigDecimal(None)
88 | SV::Uuid(None) => QueryValue::Null,
89
90 SV::Bool(Some(b)) => QueryValue::Bool(b),
92
93 SV::TinyInt(Some(v)) => QueryValue::Int(v as i64),
95 SV::SmallInt(Some(v)) => QueryValue::Int(v as i64),
96 SV::Int(Some(v)) => QueryValue::Int(v as i64),
97 SV::BigInt(Some(v)) => QueryValue::Int(v),
98
99 SV::TinyUnsigned(Some(v)) => QueryValue::Int(v as i64),
101 SV::SmallUnsigned(Some(v)) => QueryValue::Int(v as i64),
102 SV::Unsigned(Some(v)) => QueryValue::Int(v as i64),
103 SV::BigUnsigned(Some(v)) => QueryValue::Int(i64::try_from(v).unwrap_or_else(|_| {
104 tracing::warn!(
105 value = v,
106 "BigUnsigned value {} exceeds i64::MAX, clamping to i64::MAX",
107 v
108 );
109 i64::MAX
110 })),
111
112 SV::Float(Some(v)) => QueryValue::Float(v as f64),
114 SV::Double(Some(v)) => QueryValue::Float(v),
115
116 SV::String(Some(s)) => QueryValue::String(s.to_string()),
118 SV::Char(Some(c)) => QueryValue::String(c.to_string()),
119
120 SV::Bytes(Some(b)) => QueryValue::Bytes(b.to_vec()),
122
123 SV::ChronoDateTimeUtc(Some(dt)) => QueryValue::Timestamp(*dt),
125
126 SV::ChronoDateTimeLocal(Some(dt)) => {
128 QueryValue::Timestamp((*dt).with_timezone(&chrono::Utc))
129 }
130 SV::ChronoDateTimeWithTimeZone(Some(dt)) => {
131 QueryValue::Timestamp((*dt).with_timezone(&chrono::Utc))
132 }
133
134 SV::ChronoDate(_) | SV::ChronoTime(_) | SV::ChronoDateTime(_) => {
136 QueryValue::String(format!("{:?}", value))
138 }
139
140 SV::Json(_) => QueryValue::String(format!("{:?}", value)),
142
143 SV::Decimal(Some(d)) => {
145 let f = d.to_f64().unwrap_or_else(|| {
146 tracing::warn!(
147 decimal = %d,
148 "Decimal cannot be directly represented as f64, falling back to string parsing"
149 );
150 d.to_string().parse::<f64>().unwrap_or(0.0)
151 });
152 QueryValue::Float(f)
153 }
154 SV::BigDecimal(Some(d)) => {
155 let f = d.to_string().parse::<f64>().unwrap_or_else(|_| {
156 tracing::warn!(
157 big_decimal = %d,
158 "BigDecimal cannot be represented as f64"
159 );
160 0.0
161 });
162 QueryValue::Float(f)
163 }
164
165 SV::Uuid(Some(u)) => QueryValue::Uuid(*u),
167
168 SV::Array(_, arr) => QueryValue::String(format!("{:?}", arr)),
171 }
172}
173
174pub fn convert_values(values: reinhardt_query::prelude::Values) -> Vec<QueryValue> {
176 values
177 .0
178 .into_iter()
179 .map(convert_value_to_query_value)
180 .collect()
181}
182
183#[async_trait::async_trait]
185pub trait QueryExecution<T: Model>
186where
187 T: Send + Sync,
188 T::PrimaryKey: Send + Sync,
189{
190 async fn get_async(
193 &self,
194 db: &super::connection::DatabaseConnection,
195 pk: &T::PrimaryKey,
196 ) -> Result<T, ExecutionError>
197 where
198 T: for<'de> serde::Deserialize<'de>;
199
200 fn get(&self, pk: &T::PrimaryKey) -> SelectStatement;
203
204 async fn all_async(
207 &self,
208 db: &super::connection::DatabaseConnection,
209 ) -> Result<Vec<T>, ExecutionError>
210 where
211 T: for<'de> serde::Deserialize<'de>;
212
213 fn all(&self) -> SelectStatement;
216
217 async fn first_async(
220 &self,
221 db: &super::connection::DatabaseConnection,
222 ) -> Result<Option<T>, ExecutionError>
223 where
224 T: for<'de> serde::Deserialize<'de>;
225
226 fn first(&self) -> SelectStatement;
229
230 async fn one_async(
233 &self,
234 db: &super::connection::DatabaseConnection,
235 ) -> Result<T, ExecutionError>
236 where
237 T: for<'de> serde::Deserialize<'de>;
238
239 fn one(&self) -> SelectStatement;
242
243 async fn one_or_none_async(
246 &self,
247 db: &super::connection::DatabaseConnection,
248 ) -> Result<Option<T>, ExecutionError>
249 where
250 T: for<'de> serde::Deserialize<'de>;
251
252 fn one_or_none(&self) -> SelectStatement;
255
256 async fn scalar_async<S>(
259 &self,
260 db: &super::connection::DatabaseConnection,
261 ) -> Result<Option<S>, ExecutionError>
262 where
263 S: for<'de> serde::Deserialize<'de>;
264
265 fn scalar(&self) -> SelectStatement;
268
269 async fn count_async(
272 &self,
273 db: &super::connection::DatabaseConnection,
274 ) -> Result<i64, ExecutionError>;
275
276 fn count(&self) -> SelectStatement;
279
280 async fn exists_async(
283 &self,
284 db: &super::connection::DatabaseConnection,
285 ) -> Result<bool, ExecutionError>;
286
287 fn exists(&self) -> SelectStatement;
290}
291
292pub struct SelectExecution<T: Model> {
294 stmt: SelectStatement,
295 _phantom: PhantomData<T>,
296}
297
298impl<T: Model> SelectExecution<T> {
299 pub fn new(stmt: SelectStatement) -> Self {
336 Self {
337 stmt,
338 _phantom: PhantomData,
339 }
340 }
341 pub fn statement(&self) -> &SelectStatement {
381 &self.stmt
382 }
383}
384
385#[async_trait::async_trait]
386impl<T: Model> QueryExecution<T> for SelectExecution<T>
387where
388 T::PrimaryKey: Into<reinhardt_query::value::Value> + Clone + Send + Sync,
389 T: Send + Sync,
390{
391 fn get(&self, pk: &T::PrimaryKey) -> SelectStatement {
392 Query::select()
393 .from(Alias::new(T::table_name()))
394 .column(ColumnRef::Asterisk)
395 .and_where(
396 Expr::col(Alias::new(T::primary_key_field())).eq(Expr::val(pk.clone().into())),
397 )
398 .limit(1)
399 .to_owned()
400 }
401
402 fn all(&self) -> SelectStatement {
403 self.stmt.clone()
404 }
405
406 fn first(&self) -> SelectStatement {
407 let mut stmt = self.stmt.clone();
408 stmt.limit(1);
409 stmt
410 }
411
412 fn one(&self) -> SelectStatement {
413 let mut stmt = self.stmt.clone();
419 stmt.limit(2);
420 stmt
421 }
422
423 fn one_or_none(&self) -> SelectStatement {
424 let mut stmt = self.stmt.clone();
430 stmt.limit(2);
431 stmt
432 }
433
434 fn scalar(&self) -> SelectStatement {
435 let mut stmt = self.stmt.clone();
436 stmt.limit(1);
437 stmt
438 }
439
440 fn count(&self) -> SelectStatement {
441 Query::select()
444 .expr(Func::count(Expr::asterisk().into_simple_expr()))
445 .from_subquery(self.stmt.clone(), Alias::new("subquery"))
446 .to_owned()
447 }
448
449 fn exists(&self) -> SelectStatement {
450 Query::select()
451 .expr(Expr::exists(self.stmt.clone()))
452 .to_owned()
453 }
454
455 async fn get_async(
456 &self,
457 db: &super::connection::DatabaseConnection,
458 pk: &T::PrimaryKey,
459 ) -> Result<T, ExecutionError>
460 where
461 T: for<'de> serde::Deserialize<'de>,
462 {
463 let stmt = self.get(pk);
464 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
465
466 let query_values = convert_values(values);
467 let row = db.query_one(&sql, query_values).await?;
468 let json = serde_json::to_value(&row)?;
469 let result = serde_json::from_value(json)?;
470 Ok(result)
471 }
472
473 async fn all_async(
474 &self,
475 db: &super::connection::DatabaseConnection,
476 ) -> Result<Vec<T>, ExecutionError>
477 where
478 T: for<'de> serde::Deserialize<'de>,
479 {
480 let stmt = self.all();
481 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
482
483 let query_values = convert_values(values);
484 let rows = db.query(&sql, query_values).await?;
485 let mut results = Vec::with_capacity(rows.len());
486 for row in rows {
487 let json = serde_json::to_value(&row)?;
488 let result = serde_json::from_value(json)?;
489 results.push(result);
490 }
491 Ok(results)
492 }
493
494 async fn first_async(
495 &self,
496 db: &super::connection::DatabaseConnection,
497 ) -> Result<Option<T>, ExecutionError>
498 where
499 T: for<'de> serde::Deserialize<'de>,
500 {
501 let stmt = self.first();
502 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
503
504 let query_values = convert_values(values);
505 let rows = db.query(&sql, query_values).await?;
506 match rows.first() {
507 Some(row) => {
508 let json = serde_json::to_value(row)?;
509 let result = serde_json::from_value(json)?;
510 Ok(Some(result))
511 }
512 None => Ok(None),
513 }
514 }
515
516 async fn one_async(
517 &self,
518 db: &super::connection::DatabaseConnection,
519 ) -> Result<T, ExecutionError>
520 where
521 T: for<'de> serde::Deserialize<'de>,
522 {
523 let stmt = self.one();
524 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
525
526 let query_values = convert_values(values);
527 let rows = db.query(&sql, query_values).await?;
528 match rows.len() {
529 0 => Err(ExecutionError::NoResultFound),
530 1 => {
531 let json = serde_json::to_value(&rows[0])?;
532 let result = serde_json::from_value(json)?;
533 Ok(result)
534 }
535 n => Err(ExecutionError::MultipleResultsFound(n)),
536 }
537 }
538
539 async fn one_or_none_async(
540 &self,
541 db: &super::connection::DatabaseConnection,
542 ) -> Result<Option<T>, ExecutionError>
543 where
544 T: for<'de> serde::Deserialize<'de>,
545 {
546 let stmt = self.one_or_none();
547 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
548
549 let query_values = convert_values(values);
550 let rows = db.query(&sql, query_values).await?;
551 match rows.len() {
552 0 => Ok(None),
553 1 => {
554 let json = serde_json::to_value(&rows[0])?;
555 let result = serde_json::from_value(json)?;
556 Ok(Some(result))
557 }
558 n => Err(ExecutionError::MultipleResultsFound(n)),
559 }
560 }
561
562 async fn scalar_async<S>(
563 &self,
564 db: &super::connection::DatabaseConnection,
565 ) -> Result<Option<S>, ExecutionError>
566 where
567 S: for<'de> serde::Deserialize<'de>,
568 {
569 let stmt = self.scalar();
570 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
571
572 let query_values = convert_values(values);
573 let rows = db.query(&sql, query_values).await?;
574 match rows.first() {
575 Some(row) => {
576 let json = serde_json::to_value(row)?;
578 if let Some(obj) = json.as_object()
579 && let Some((_, value)) = obj.iter().next()
580 {
581 let result = serde_json::from_value(value.clone())?;
582 return Ok(Some(result));
583 }
584 Ok(None)
585 }
586 None => Ok(None),
587 }
588 }
589
590 async fn count_async(
591 &self,
592 db: &super::connection::DatabaseConnection,
593 ) -> Result<i64, ExecutionError> {
594 let stmt = self.count();
595 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
596
597 let query_values = convert_values(values);
598 let row = db.query_one(&sql, query_values).await?;
599 let json = serde_json::to_value(&row)?;
600
601 if let Some(obj) = json.as_object()
603 && let Some((_, value)) = obj.iter().next()
604 {
605 let count: i64 = serde_json::from_value(value.clone())?;
606 return Ok(count);
607 }
608
609 Err(ExecutionError::QueryBuild(
610 "Count query returned unexpected format".to_string(),
611 ))
612 }
613
614 async fn exists_async(
615 &self,
616 db: &super::connection::DatabaseConnection,
617 ) -> Result<bool, ExecutionError> {
618 let stmt = self.exists();
619 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
620
621 let query_values = convert_values(values);
622 let row = db.query_one(&sql, query_values).await?;
623 let json = serde_json::to_value(&row)?;
624
625 if let Some(obj) = json.as_object()
627 && let Some((_, value)) = obj.iter().next()
628 {
629 let exists: bool = serde_json::from_value(value.clone())?;
630 return Ok(exists);
631 }
632
633 Err(ExecutionError::QueryBuild(
634 "Exists query returned unexpected format".to_string(),
635 ))
636 }
637}
638
639#[derive(Debug, Clone)]
642pub enum LoadOption {
643 JoinedLoad(String),
646
647 SelectInLoad(String),
650
651 LazyLoad(String),
654
655 NoLoad(String),
658
659 RaiseLoad(String),
662
663 Defer(String),
666
667 Undefer(String),
670
671 LoadOnly(Vec<String>),
674}
675
676impl LoadOption {
677 pub fn to_sql_comment(&self) -> String {
694 match self {
695 LoadOption::JoinedLoad(rel) => format!("/* joinedload({}) */", rel),
696 LoadOption::SelectInLoad(rel) => format!("/* selectinload({}) */", rel),
697 LoadOption::LazyLoad(rel) => format!("/* lazyload({}) */", rel),
698 LoadOption::NoLoad(rel) => format!("/* noload({}) */", rel),
699 LoadOption::RaiseLoad(rel) => format!("/* raiseload({}) */", rel),
700 LoadOption::Defer(col) => format!("/* defer({}) */", col),
701 LoadOption::Undefer(col) => format!("/* undefer({}) */", col),
702 LoadOption::LoadOnly(cols) => format!("/* load_only({}) */", cols.join(", ")),
703 }
704 }
705}
706
707#[non_exhaustive]
709pub struct QueryOptions {
710 pub load_options: Vec<LoadOption>,
712}
713
714impl QueryOptions {
715 pub fn new() -> Self {
726 Self {
727 load_options: Vec::new(),
728 }
729 }
730 pub fn add_option(mut self, option: LoadOption) -> Self {
746 self.load_options.push(option);
747 self
748 }
749 pub fn to_sql_comments(&self) -> String {
762 if self.load_options.is_empty() {
763 String::new()
764 } else {
765 format!(
766 " {}",
767 self.load_options
768 .iter()
769 .map(|o| o.to_sql_comment())
770 .collect::<Vec<_>>()
771 .join(" ")
772 )
773 }
774 }
775}
776
777impl Default for QueryOptions {
778 fn default() -> Self {
779 Self::new()
780 }
781}
782
783#[cfg(test)]
784mod tests {
785 use super::*;
786 use reinhardt_core::validators::TableName;
787 use rstest::rstest;
788 use serde::{Deserialize, Serialize};
789
790 #[derive(Debug, Clone, Serialize, Deserialize)]
791 struct User {
792 id: Option<i64>,
793 name: String,
794 }
795
796 #[derive(Clone)]
797 struct UserFields;
798 impl crate::orm::model::FieldSelector for UserFields {
799 fn with_alias(self, _alias: &str) -> Self {
800 self
801 }
802 }
803
804 const USER_TABLE: TableName = TableName::new_const("users");
805
806 impl Model for User {
807 type PrimaryKey = i64;
808 type Fields = UserFields;
809
810 fn table_name() -> &'static str {
811 USER_TABLE.as_str()
812 }
813
814 fn new_fields() -> Self::Fields {
815 UserFields
816 }
817
818 fn primary_key(&self) -> Option<Self::PrimaryKey> {
819 self.id
820 }
821
822 fn set_primary_key(&mut self, value: Self::PrimaryKey) {
823 self.id = Some(value);
824 }
825 }
826
827 #[test]
828 fn test_execution_get() {
829 use reinhardt_query::prelude::{Alias, PostgresQueryBuilder, Query, QueryStatementBuilder};
830
831 let stmt = Query::select()
832 .from(Alias::new("users"))
833 .column(ColumnRef::Asterisk)
834 .to_owned();
835 let exec = SelectExecution::<User>::new(stmt);
836 let result_stmt = exec.get(&123);
837 let sql = result_stmt.to_string(PostgresQueryBuilder);
838 assert!(sql.contains("WHERE"));
839 assert!(sql.contains("LIMIT"));
840 }
841
842 #[test]
843 fn test_all() {
844 use reinhardt_query::prelude::{Alias, PostgresQueryBuilder, Query, QueryStatementBuilder};
845
846 let stmt = Query::select()
847 .from(Alias::new("users"))
848 .column(ColumnRef::Asterisk)
849 .to_owned();
850 let exec = SelectExecution::<User>::new(stmt);
851 let result_stmt = exec.all();
852 let sql = result_stmt.to_string(PostgresQueryBuilder);
853 assert!(sql.contains("SELECT"));
854 assert!(sql.contains("users"));
855 }
856
857 #[test]
858 fn test_first() {
859 use reinhardt_query::prelude::{
860 Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
861 };
862
863 let stmt = Query::select()
864 .from(Alias::new("users"))
865 .column(ColumnRef::Asterisk)
866 .and_where(Expr::col(Alias::new("active")).eq(true))
867 .to_owned();
868 let exec = SelectExecution::<User>::new(stmt);
869 let result_stmt = exec.first();
870 let sql = result_stmt.to_string(PostgresQueryBuilder);
871 assert!(sql.contains("LIMIT"));
872 }
873
874 #[test]
875 fn test_execution_count() {
876 use reinhardt_query::prelude::{
877 Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
878 };
879
880 let stmt = Query::select()
881 .from(Alias::new("users"))
882 .column(ColumnRef::Asterisk)
883 .and_where(Expr::col(Alias::new("active")).eq(true))
884 .to_owned();
885 let exec = SelectExecution::<User>::new(stmt);
886 let result_stmt = exec.count();
887 let sql = result_stmt.to_string(PostgresQueryBuilder);
888 assert!(sql.contains("COUNT"));
889 }
890
891 #[test]
892 fn test_execution_exists() {
893 use reinhardt_query::prelude::{
894 Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
895 };
896
897 let stmt = Query::select()
898 .from(Alias::new("users"))
899 .column(ColumnRef::Asterisk)
900 .and_where(Expr::col(Alias::new("name")).eq("Alice"))
901 .to_owned();
902 let exec = SelectExecution::<User>::new(stmt);
903 let result_stmt = exec.exists();
904 let sql = result_stmt.to_string(PostgresQueryBuilder);
905 assert!(sql.contains("EXISTS"));
906 }
907
908 #[test]
909 fn test_load_options() {
910 let options = QueryOptions::new()
911 .add_option(LoadOption::JoinedLoad("profile".to_string()))
912 .add_option(LoadOption::Defer("password".to_string()));
913
914 let comments = options.to_sql_comments();
915 assert!(comments.contains("joinedload(profile)"));
916 assert!(comments.contains("defer(password)"));
917 }
918
919 #[test]
920 fn test_load_only() {
921 let option = LoadOption::LoadOnly(vec!["id".to_string(), "name".to_string()]);
922 let comment = option.to_sql_comment();
923 assert!(comment.contains("load_only(id, name)"));
924 }
925
926 #[rstest]
927 #[case::zero(0u64, 0i64)]
928 #[case::one(1u64, 1i64)]
929 #[case::i64_max(i64::MAX as u64, i64::MAX)]
930 #[test]
931 fn test_big_unsigned_to_query_value_within_range(#[case] input: u64, #[case] expected: i64) {
932 let value = reinhardt_query::value::Value::BigUnsigned(Some(input));
934
935 let result = convert_value_to_query_value(value);
937
938 assert!(matches!(result, QueryValue::Int(v) if v == expected));
940 }
941
942 #[rstest]
943 #[case::i64_max_plus_one(i64::MAX as u64 + 1)]
944 #[case::u64_max(u64::MAX)]
945 #[test]
946 fn test_big_unsigned_overflow_clamps_to_i64_max(#[case] input: u64) {
947 let value = reinhardt_query::value::Value::BigUnsigned(Some(input));
949
950 let result = convert_value_to_query_value(value);
952
953 assert!(matches!(result, QueryValue::Int(v) if v == i64::MAX));
955 }
956
957 #[rstest]
958 #[test]
959 fn test_big_unsigned_none_converts_to_null() {
960 let value = reinhardt_query::value::Value::BigUnsigned(None);
962
963 let result = convert_value_to_query_value(value);
965
966 assert!(matches!(result, QueryValue::Null));
968 }
969}