1use crate::backends::types::QueryValue;
8use crate::orm::Model;
9use reinhardt_query::prelude::{
10 Alias, ColumnRef, Expr, ExprTrait, Func, Query, QueryStatementBuilder, SelectStatement,
11};
12use rust_decimal::prelude::ToPrimitive;
13use std::marker::PhantomData;
14
15#[derive(Debug)]
17pub enum ExecutionResult<T> {
18 One(T),
20 OneOrNone(Option<T>),
22 All(Vec<T>),
24 Scalar(String),
26 None,
28}
29
30#[derive(Debug, thiserror::Error)]
32pub enum ExecutionError {
33 #[error("Database error: {0}")]
35 Database(#[from] crate::backends::DatabaseError),
36
37 #[error("No result found")]
39 NoResultFound,
40
41 #[error("Multiple results found (expected 1, got {0})")]
43 MultipleResultsFound(usize),
44
45 #[error("Failed to deserialize result: {0}")]
47 Deserialization(#[from] serde_json::Error),
48
49 #[error("Query building error: {0}")]
51 QueryBuild(String),
52
53 #[error("Generic error: {0}")]
55 Generic(#[from] anyhow::Error),
56}
57
58fn convert_value_to_query_value(value: reinhardt_query::value::Value) -> QueryValue {
60 use reinhardt_query::value::Value as SV;
61
62 match value {
63 SV::Bool(None)
65 | SV::TinyInt(None)
66 | SV::SmallInt(None)
67 | SV::Int(None)
68 | SV::BigInt(None)
69 | SV::TinyUnsigned(None)
70 | SV::SmallUnsigned(None)
71 | SV::Unsigned(None)
72 | SV::BigUnsigned(None)
73 | SV::Float(None)
74 | SV::Double(None)
75 | SV::String(None)
76 | SV::Char(None)
77 | SV::Bytes(None)
78 | SV::ChronoDateTimeUtc(None)
79 | SV::ChronoDateTimeLocal(None)
80 | SV::ChronoDateTimeWithTimeZone(None)
81 | SV::ChronoDate(None)
82 | SV::ChronoTime(None)
83 | SV::ChronoDateTime(None)
84 | SV::Json(None)
85 | SV::Decimal(None)
86 | SV::BigDecimal(None)
87 | SV::Uuid(None) => QueryValue::Null,
88
89 SV::Bool(Some(b)) => QueryValue::Bool(b),
91
92 SV::TinyInt(Some(v)) => QueryValue::Int(v as i64),
94 SV::SmallInt(Some(v)) => QueryValue::Int(v as i64),
95 SV::Int(Some(v)) => QueryValue::Int(v as i64),
96 SV::BigInt(Some(v)) => QueryValue::Int(v),
97
98 SV::TinyUnsigned(Some(v)) => QueryValue::Int(v as i64),
100 SV::SmallUnsigned(Some(v)) => QueryValue::Int(v as i64),
101 SV::Unsigned(Some(v)) => QueryValue::Int(v as i64),
102 SV::BigUnsigned(Some(v)) => QueryValue::Int(i64::try_from(v).unwrap_or_else(|_| {
103 tracing::warn!(
104 value = v,
105 "BigUnsigned value {} exceeds i64::MAX, clamping to i64::MAX",
106 v
107 );
108 i64::MAX
109 })),
110
111 SV::Float(Some(v)) => QueryValue::Float(v as f64),
113 SV::Double(Some(v)) => QueryValue::Float(v),
114
115 SV::String(Some(s)) => QueryValue::String(s.to_string()),
117 SV::Char(Some(c)) => QueryValue::String(c.to_string()),
118
119 SV::Bytes(Some(b)) => QueryValue::Bytes(b.to_vec()),
121
122 SV::ChronoDateTimeUtc(Some(dt)) => QueryValue::Timestamp(*dt),
124
125 SV::ChronoDateTimeLocal(Some(dt)) => {
127 QueryValue::Timestamp((*dt).with_timezone(&chrono::Utc))
128 }
129 SV::ChronoDateTimeWithTimeZone(Some(dt)) => {
130 QueryValue::Timestamp((*dt).with_timezone(&chrono::Utc))
131 }
132
133 SV::ChronoDate(_) | SV::ChronoTime(_) | SV::ChronoDateTime(_) => {
135 QueryValue::String(format!("{:?}", value))
137 }
138
139 SV::Json(_) => QueryValue::String(format!("{:?}", value)),
141
142 SV::Decimal(Some(d)) => {
144 let f = d.to_f64().unwrap_or_else(|| {
145 tracing::warn!(
146 decimal = %d,
147 "Decimal cannot be directly represented as f64, falling back to string parsing"
148 );
149 d.to_string().parse::<f64>().unwrap_or(0.0)
150 });
151 QueryValue::Float(f)
152 }
153 SV::BigDecimal(Some(d)) => {
154 let f = d.to_string().parse::<f64>().unwrap_or_else(|_| {
155 tracing::warn!(
156 big_decimal = %d,
157 "BigDecimal cannot be represented as f64"
158 );
159 0.0
160 });
161 QueryValue::Float(f)
162 }
163
164 SV::Uuid(Some(u)) => QueryValue::Uuid(*u),
166
167 SV::Array(_, arr) => QueryValue::String(format!("{:?}", arr)),
170 }
171}
172
173pub fn convert_values(values: reinhardt_query::prelude::Values) -> Vec<QueryValue> {
175 values
176 .0
177 .into_iter()
178 .map(convert_value_to_query_value)
179 .collect()
180}
181
182#[async_trait::async_trait]
184pub trait QueryExecution<T: Model>
185where
186 T: Send + Sync,
187 T::PrimaryKey: Send + Sync,
188{
189 async fn get_async(
192 &self,
193 db: &super::connection::DatabaseConnection,
194 pk: &T::PrimaryKey,
195 ) -> Result<T, ExecutionError>
196 where
197 T: for<'de> serde::Deserialize<'de>;
198
199 fn get(&self, pk: &T::PrimaryKey) -> SelectStatement;
202
203 async fn all_async(
206 &self,
207 db: &super::connection::DatabaseConnection,
208 ) -> Result<Vec<T>, ExecutionError>
209 where
210 T: for<'de> serde::Deserialize<'de>;
211
212 fn all(&self) -> SelectStatement;
215
216 async fn first_async(
219 &self,
220 db: &super::connection::DatabaseConnection,
221 ) -> Result<Option<T>, ExecutionError>
222 where
223 T: for<'de> serde::Deserialize<'de>;
224
225 fn first(&self) -> SelectStatement;
228
229 async fn one_async(
232 &self,
233 db: &super::connection::DatabaseConnection,
234 ) -> Result<T, ExecutionError>
235 where
236 T: for<'de> serde::Deserialize<'de>;
237
238 fn one(&self) -> SelectStatement;
241
242 async fn one_or_none_async(
245 &self,
246 db: &super::connection::DatabaseConnection,
247 ) -> Result<Option<T>, ExecutionError>
248 where
249 T: for<'de> serde::Deserialize<'de>;
250
251 fn one_or_none(&self) -> SelectStatement;
254
255 async fn scalar_async<S>(
258 &self,
259 db: &super::connection::DatabaseConnection,
260 ) -> Result<Option<S>, ExecutionError>
261 where
262 S: for<'de> serde::Deserialize<'de>;
263
264 fn scalar(&self) -> SelectStatement;
267
268 async fn count_async(
271 &self,
272 db: &super::connection::DatabaseConnection,
273 ) -> Result<i64, ExecutionError>;
274
275 fn count(&self) -> SelectStatement;
278
279 async fn exists_async(
282 &self,
283 db: &super::connection::DatabaseConnection,
284 ) -> Result<bool, ExecutionError>;
285
286 fn exists(&self) -> SelectStatement;
289}
290
291pub struct SelectExecution<T: Model> {
293 stmt: SelectStatement,
294 _phantom: PhantomData<T>,
295}
296
297impl<T: Model> SelectExecution<T> {
298 pub fn new(stmt: SelectStatement) -> Self {
335 Self {
336 stmt,
337 _phantom: PhantomData,
338 }
339 }
340 pub fn statement(&self) -> &SelectStatement {
380 &self.stmt
381 }
382}
383
384#[async_trait::async_trait]
385impl<T: Model> QueryExecution<T> for SelectExecution<T>
386where
387 T::PrimaryKey: Into<reinhardt_query::value::Value> + Clone + Send + Sync,
388 T: Send + Sync,
389{
390 fn get(&self, pk: &T::PrimaryKey) -> SelectStatement {
391 Query::select()
392 .from(Alias::new(T::table_name()))
393 .column(ColumnRef::Asterisk)
394 .and_where(
395 Expr::col(Alias::new(T::primary_key_field())).eq(Expr::val(pk.clone().into())),
396 )
397 .limit(1)
398 .to_owned()
399 }
400
401 fn all(&self) -> SelectStatement {
402 self.stmt.clone()
403 }
404
405 fn first(&self) -> SelectStatement {
406 let mut stmt = self.stmt.clone();
407 stmt.limit(1);
408 stmt
409 }
410
411 fn one(&self) -> SelectStatement {
412 let mut stmt = self.stmt.clone();
418 stmt.limit(2);
419 stmt
420 }
421
422 fn one_or_none(&self) -> SelectStatement {
423 let mut stmt = self.stmt.clone();
429 stmt.limit(2);
430 stmt
431 }
432
433 fn scalar(&self) -> SelectStatement {
434 let mut stmt = self.stmt.clone();
435 stmt.limit(1);
436 stmt
437 }
438
439 fn count(&self) -> SelectStatement {
440 Query::select()
443 .expr(Func::count(Expr::asterisk().into_simple_expr()))
444 .from_subquery(self.stmt.clone(), Alias::new("subquery"))
445 .to_owned()
446 }
447
448 fn exists(&self) -> SelectStatement {
449 Query::select()
450 .expr(Expr::exists(self.stmt.clone()))
451 .to_owned()
452 }
453
454 async fn get_async(
455 &self,
456 db: &super::connection::DatabaseConnection,
457 pk: &T::PrimaryKey,
458 ) -> Result<T, ExecutionError>
459 where
460 T: for<'de> serde::Deserialize<'de>,
461 {
462 let stmt = self.get(pk);
463 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
464
465 let query_values = convert_values(values);
466 let row = db.query_one(&sql, query_values).await?;
467 let json = serde_json::to_value(&row)?;
468 let result = serde_json::from_value(json)?;
469 Ok(result)
470 }
471
472 async fn all_async(
473 &self,
474 db: &super::connection::DatabaseConnection,
475 ) -> Result<Vec<T>, ExecutionError>
476 where
477 T: for<'de> serde::Deserialize<'de>,
478 {
479 let stmt = self.all();
480 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
481
482 let query_values = convert_values(values);
483 let rows = db.query(&sql, query_values).await?;
484 let mut results = Vec::with_capacity(rows.len());
485 for row in rows {
486 let json = serde_json::to_value(&row)?;
487 let result = serde_json::from_value(json)?;
488 results.push(result);
489 }
490 Ok(results)
491 }
492
493 async fn first_async(
494 &self,
495 db: &super::connection::DatabaseConnection,
496 ) -> Result<Option<T>, ExecutionError>
497 where
498 T: for<'de> serde::Deserialize<'de>,
499 {
500 let stmt = self.first();
501 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
502
503 let query_values = convert_values(values);
504 let rows = db.query(&sql, query_values).await?;
505 match rows.first() {
506 Some(row) => {
507 let json = serde_json::to_value(row)?;
508 let result = serde_json::from_value(json)?;
509 Ok(Some(result))
510 }
511 None => Ok(None),
512 }
513 }
514
515 async fn one_async(
516 &self,
517 db: &super::connection::DatabaseConnection,
518 ) -> Result<T, ExecutionError>
519 where
520 T: for<'de> serde::Deserialize<'de>,
521 {
522 let stmt = self.one();
523 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
524
525 let query_values = convert_values(values);
526 let rows = db.query(&sql, query_values).await?;
527 match rows.len() {
528 0 => Err(ExecutionError::NoResultFound),
529 1 => {
530 let json = serde_json::to_value(&rows[0])?;
531 let result = serde_json::from_value(json)?;
532 Ok(result)
533 }
534 n => Err(ExecutionError::MultipleResultsFound(n)),
535 }
536 }
537
538 async fn one_or_none_async(
539 &self,
540 db: &super::connection::DatabaseConnection,
541 ) -> Result<Option<T>, ExecutionError>
542 where
543 T: for<'de> serde::Deserialize<'de>,
544 {
545 let stmt = self.one_or_none();
546 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
547
548 let query_values = convert_values(values);
549 let rows = db.query(&sql, query_values).await?;
550 match rows.len() {
551 0 => Ok(None),
552 1 => {
553 let json = serde_json::to_value(&rows[0])?;
554 let result = serde_json::from_value(json)?;
555 Ok(Some(result))
556 }
557 n => Err(ExecutionError::MultipleResultsFound(n)),
558 }
559 }
560
561 async fn scalar_async<S>(
562 &self,
563 db: &super::connection::DatabaseConnection,
564 ) -> Result<Option<S>, ExecutionError>
565 where
566 S: for<'de> serde::Deserialize<'de>,
567 {
568 let stmt = self.scalar();
569 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
570
571 let query_values = convert_values(values);
572 let rows = db.query(&sql, query_values).await?;
573 match rows.first() {
574 Some(row) => {
575 let json = serde_json::to_value(row)?;
577 if let Some(obj) = json.as_object()
578 && let Some((_, value)) = obj.iter().next()
579 {
580 let result = serde_json::from_value(value.clone())?;
581 return Ok(Some(result));
582 }
583 Ok(None)
584 }
585 None => Ok(None),
586 }
587 }
588
589 async fn count_async(
590 &self,
591 db: &super::connection::DatabaseConnection,
592 ) -> Result<i64, ExecutionError> {
593 let stmt = self.count();
594 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
595
596 let query_values = convert_values(values);
597 let row = db.query_one(&sql, query_values).await?;
598 let json = serde_json::to_value(&row)?;
599
600 if let Some(obj) = json.as_object()
602 && let Some((_, value)) = obj.iter().next()
603 {
604 let count: i64 = serde_json::from_value(value.clone())?;
605 return Ok(count);
606 }
607
608 Err(ExecutionError::QueryBuild(
609 "Count query returned unexpected format".to_string(),
610 ))
611 }
612
613 async fn exists_async(
614 &self,
615 db: &super::connection::DatabaseConnection,
616 ) -> Result<bool, ExecutionError> {
617 let stmt = self.exists();
618 let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
619
620 let query_values = convert_values(values);
621 let row = db.query_one(&sql, query_values).await?;
622 let json = serde_json::to_value(&row)?;
623
624 if let Some(obj) = json.as_object()
626 && let Some((_, value)) = obj.iter().next()
627 {
628 let exists: bool = serde_json::from_value(value.clone())?;
629 return Ok(exists);
630 }
631
632 Err(ExecutionError::QueryBuild(
633 "Exists query returned unexpected format".to_string(),
634 ))
635 }
636}
637
638#[derive(Debug, Clone)]
641pub enum LoadOption {
642 JoinedLoad(String),
645
646 SelectInLoad(String),
649
650 LazyLoad(String),
653
654 NoLoad(String),
657
658 RaiseLoad(String),
661
662 Defer(String),
665
666 Undefer(String),
669
670 LoadOnly(Vec<String>),
673}
674
675impl LoadOption {
676 pub fn to_sql_comment(&self) -> String {
693 match self {
694 LoadOption::JoinedLoad(rel) => format!("/* joinedload({}) */", rel),
695 LoadOption::SelectInLoad(rel) => format!("/* selectinload({}) */", rel),
696 LoadOption::LazyLoad(rel) => format!("/* lazyload({}) */", rel),
697 LoadOption::NoLoad(rel) => format!("/* noload({}) */", rel),
698 LoadOption::RaiseLoad(rel) => format!("/* raiseload({}) */", rel),
699 LoadOption::Defer(col) => format!("/* defer({}) */", col),
700 LoadOption::Undefer(col) => format!("/* undefer({}) */", col),
701 LoadOption::LoadOnly(cols) => format!("/* load_only({}) */", cols.join(", ")),
702 }
703 }
704}
705
706pub struct QueryOptions {
708 pub load_options: Vec<LoadOption>,
709}
710
711impl QueryOptions {
712 pub fn new() -> Self {
723 Self {
724 load_options: Vec::new(),
725 }
726 }
727 pub fn add_option(mut self, option: LoadOption) -> Self {
743 self.load_options.push(option);
744 self
745 }
746 pub fn to_sql_comments(&self) -> String {
759 if self.load_options.is_empty() {
760 String::new()
761 } else {
762 format!(
763 " {}",
764 self.load_options
765 .iter()
766 .map(|o| o.to_sql_comment())
767 .collect::<Vec<_>>()
768 .join(" ")
769 )
770 }
771 }
772}
773
774impl Default for QueryOptions {
775 fn default() -> Self {
776 Self::new()
777 }
778}
779
780#[cfg(test)]
781mod tests {
782 use super::*;
783 use reinhardt_core::validators::TableName;
784 use rstest::rstest;
785 use serde::{Deserialize, Serialize};
786
787 #[derive(Debug, Clone, Serialize, Deserialize)]
788 struct User {
789 id: Option<i64>,
790 name: String,
791 }
792
793 #[derive(Clone)]
794 struct UserFields;
795 impl crate::orm::model::FieldSelector for UserFields {
796 fn with_alias(self, _alias: &str) -> Self {
797 self
798 }
799 }
800
801 const USER_TABLE: TableName = TableName::new_const("users");
802
803 impl Model for User {
804 type PrimaryKey = i64;
805 type Fields = UserFields;
806
807 fn table_name() -> &'static str {
808 USER_TABLE.as_str()
809 }
810
811 fn new_fields() -> Self::Fields {
812 UserFields
813 }
814
815 fn primary_key(&self) -> Option<Self::PrimaryKey> {
816 self.id
817 }
818
819 fn set_primary_key(&mut self, value: Self::PrimaryKey) {
820 self.id = Some(value);
821 }
822 }
823
824 #[test]
825 fn test_execution_get() {
826 use reinhardt_query::prelude::{Alias, PostgresQueryBuilder, Query, QueryStatementBuilder};
827
828 let stmt = Query::select()
829 .from(Alias::new("users"))
830 .column(ColumnRef::Asterisk)
831 .to_owned();
832 let exec = SelectExecution::<User>::new(stmt);
833 let result_stmt = exec.get(&123);
834 let sql = result_stmt.to_string(PostgresQueryBuilder);
835 assert!(sql.contains("WHERE"));
836 assert!(sql.contains("LIMIT"));
837 }
838
839 #[test]
840 fn test_all() {
841 use reinhardt_query::prelude::{Alias, PostgresQueryBuilder, Query, QueryStatementBuilder};
842
843 let stmt = Query::select()
844 .from(Alias::new("users"))
845 .column(ColumnRef::Asterisk)
846 .to_owned();
847 let exec = SelectExecution::<User>::new(stmt);
848 let result_stmt = exec.all();
849 let sql = result_stmt.to_string(PostgresQueryBuilder);
850 assert!(sql.contains("SELECT"));
851 assert!(sql.contains("users"));
852 }
853
854 #[test]
855 fn test_first() {
856 use reinhardt_query::prelude::{
857 Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
858 };
859
860 let stmt = Query::select()
861 .from(Alias::new("users"))
862 .column(ColumnRef::Asterisk)
863 .and_where(Expr::col(Alias::new("active")).eq(true))
864 .to_owned();
865 let exec = SelectExecution::<User>::new(stmt);
866 let result_stmt = exec.first();
867 let sql = result_stmt.to_string(PostgresQueryBuilder);
868 assert!(sql.contains("LIMIT"));
869 }
870
871 #[test]
872 fn test_execution_count() {
873 use reinhardt_query::prelude::{
874 Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
875 };
876
877 let stmt = Query::select()
878 .from(Alias::new("users"))
879 .column(ColumnRef::Asterisk)
880 .and_where(Expr::col(Alias::new("active")).eq(true))
881 .to_owned();
882 let exec = SelectExecution::<User>::new(stmt);
883 let result_stmt = exec.count();
884 let sql = result_stmt.to_string(PostgresQueryBuilder);
885 assert!(sql.contains("COUNT"));
886 }
887
888 #[test]
889 fn test_execution_exists() {
890 use reinhardt_query::prelude::{
891 Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
892 };
893
894 let stmt = Query::select()
895 .from(Alias::new("users"))
896 .column(ColumnRef::Asterisk)
897 .and_where(Expr::col(Alias::new("name")).eq("Alice"))
898 .to_owned();
899 let exec = SelectExecution::<User>::new(stmt);
900 let result_stmt = exec.exists();
901 let sql = result_stmt.to_string(PostgresQueryBuilder);
902 assert!(sql.contains("EXISTS"));
903 }
904
905 #[test]
906 fn test_load_options() {
907 let options = QueryOptions::new()
908 .add_option(LoadOption::JoinedLoad("profile".to_string()))
909 .add_option(LoadOption::Defer("password".to_string()));
910
911 let comments = options.to_sql_comments();
912 assert!(comments.contains("joinedload(profile)"));
913 assert!(comments.contains("defer(password)"));
914 }
915
916 #[test]
917 fn test_load_only() {
918 let option = LoadOption::LoadOnly(vec!["id".to_string(), "name".to_string()]);
919 let comment = option.to_sql_comment();
920 assert!(comment.contains("load_only(id, name)"));
921 }
922
923 #[rstest]
924 #[case::zero(0u64, 0i64)]
925 #[case::one(1u64, 1i64)]
926 #[case::i64_max(i64::MAX as u64, i64::MAX)]
927 #[test]
928 fn test_big_unsigned_to_query_value_within_range(#[case] input: u64, #[case] expected: i64) {
929 let value = reinhardt_query::value::Value::BigUnsigned(Some(input));
931
932 let result = convert_value_to_query_value(value);
934
935 assert!(matches!(result, QueryValue::Int(v) if v == expected));
937 }
938
939 #[rstest]
940 #[case::i64_max_plus_one(i64::MAX as u64 + 1)]
941 #[case::u64_max(u64::MAX)]
942 #[test]
943 fn test_big_unsigned_overflow_clamps_to_i64_max(#[case] input: u64) {
944 let value = reinhardt_query::value::Value::BigUnsigned(Some(input));
946
947 let result = convert_value_to_query_value(value);
949
950 assert!(matches!(result, QueryValue::Int(v) if v == i64::MAX));
952 }
953
954 #[rstest]
955 #[test]
956 fn test_big_unsigned_none_converts_to_null() {
957 let value = reinhardt_query::value::Value::BigUnsigned(None);
959
960 let result = convert_value_to_query_value(value);
962
963 assert!(matches!(result, QueryValue::Null));
965 }
966}