Skip to main content

reinhardt_db/orm/
execution.rs

1//! # Query Execution
2//!
3//! SQLAlchemy-inspired query execution methods.
4//!
5//! This module provides execution methods similar to SQLAlchemy's Query class
6
7use crate::backends::types::QueryValue;
8use crate::orm::Model;
9use reinhardt_query::prelude::{
10	Alias, ColumnRef, Expr, ExprTrait, Func, Query, QueryStatementBuilder, SelectStatement,
11};
12use rust_decimal::prelude::ToPrimitive;
13use std::marker::PhantomData;
14
15/// Query execution result types
16#[derive(Debug)]
17pub enum ExecutionResult<T> {
18	/// Single result
19	One(T),
20	/// Optional single result
21	OneOrNone(Option<T>),
22	/// Multiple results
23	All(Vec<T>),
24	/// Scalar value (for aggregates)
25	Scalar(String),
26	/// No result (for mutations)
27	None,
28}
29
30/// Errors that can occur during query execution
31#[non_exhaustive]
32#[derive(Debug, thiserror::Error)]
33pub enum ExecutionError {
34	/// Database error
35	#[error("Database error: {0}")]
36	Database(#[from] crate::backends::DatabaseError),
37
38	/// No result found (for .one())
39	#[error("No result found")]
40	NoResultFound,
41
42	/// Multiple results found (for .one() and .one_or_none())
43	#[error("Multiple results found (expected 1, got {0})")]
44	MultipleResultsFound(usize),
45
46	/// Deserialization error
47	#[error("Failed to deserialize result: {0}")]
48	Deserialization(#[from] serde_json::Error),
49
50	/// Query building error
51	#[error("Query building error: {0}")]
52	QueryBuild(String),
53
54	/// Generic error from anyhow
55	#[error("Generic error: {0}")]
56	Generic(#[from] anyhow::Error),
57}
58
59/// Convert reinhardt_query Value to QueryValue for parameter binding
60fn convert_value_to_query_value(value: reinhardt_query::value::Value) -> QueryValue {
61	use reinhardt_query::value::Value as SV;
62
63	match value {
64		// Null values
65		SV::Bool(None)
66		| SV::TinyInt(None)
67		| SV::SmallInt(None)
68		| SV::Int(None)
69		| SV::BigInt(None)
70		| SV::TinyUnsigned(None)
71		| SV::SmallUnsigned(None)
72		| SV::Unsigned(None)
73		| SV::BigUnsigned(None)
74		| SV::Float(None)
75		| SV::Double(None)
76		| SV::String(None)
77		| SV::Char(None)
78		| SV::Bytes(None)
79		| SV::ChronoDateTimeUtc(None)
80		| SV::ChronoDateTimeLocal(None)
81		| SV::ChronoDateTimeWithTimeZone(None)
82		| SV::ChronoDate(None)
83		| SV::ChronoTime(None)
84		| SV::ChronoDateTime(None)
85		| SV::Json(None)
86		| SV::Decimal(None)
87		| SV::BigDecimal(None)
88		| SV::Uuid(None) => QueryValue::Null,
89
90		// Boolean
91		SV::Bool(Some(b)) => QueryValue::Bool(b),
92
93		// Signed integers (convert all to i64)
94		SV::TinyInt(Some(v)) => QueryValue::Int(v as i64),
95		SV::SmallInt(Some(v)) => QueryValue::Int(v as i64),
96		SV::Int(Some(v)) => QueryValue::Int(v as i64),
97		SV::BigInt(Some(v)) => QueryValue::Int(v),
98
99		// Unsigned integers (convert to i64 with checked conversion for large values)
100		SV::TinyUnsigned(Some(v)) => QueryValue::Int(v as i64),
101		SV::SmallUnsigned(Some(v)) => QueryValue::Int(v as i64),
102		SV::Unsigned(Some(v)) => QueryValue::Int(v as i64),
103		SV::BigUnsigned(Some(v)) => QueryValue::Int(i64::try_from(v).unwrap_or_else(|_| {
104			tracing::warn!(
105				value = v,
106				"BigUnsigned value {} exceeds i64::MAX, clamping to i64::MAX",
107				v
108			);
109			i64::MAX
110		})),
111
112		// Floating point
113		SV::Float(Some(v)) => QueryValue::Float(v as f64),
114		SV::Double(Some(v)) => QueryValue::Float(v),
115
116		// String and char
117		SV::String(Some(s)) => QueryValue::String(s.to_string()),
118		SV::Char(Some(c)) => QueryValue::String(c.to_string()),
119
120		// Bytes
121		SV::Bytes(Some(b)) => QueryValue::Bytes(b.to_vec()),
122
123		// Chrono datetime types
124		SV::ChronoDateTimeUtc(Some(dt)) => QueryValue::Timestamp(*dt),
125
126		// For other datetime types, convert to UTC if possible
127		SV::ChronoDateTimeLocal(Some(dt)) => {
128			QueryValue::Timestamp((*dt).with_timezone(&chrono::Utc))
129		}
130		SV::ChronoDateTimeWithTimeZone(Some(dt)) => {
131			QueryValue::Timestamp((*dt).with_timezone(&chrono::Utc))
132		}
133
134		// Other datetime types that cannot be easily converted
135		SV::ChronoDate(_) | SV::ChronoTime(_) | SV::ChronoDateTime(_) => {
136			// Convert to string representation as fallback
137			QueryValue::String(format!("{:?}", value))
138		}
139
140		// JSON - convert to string
141		SV::Json(_) => QueryValue::String(format!("{:?}", value)),
142
143		// Decimal - convert to f64 with fallback through string parsing
144		SV::Decimal(Some(d)) => {
145			let f = d.to_f64().unwrap_or_else(|| {
146				tracing::warn!(
147					decimal = %d,
148					"Decimal cannot be directly represented as f64, falling back to string parsing"
149				);
150				d.to_string().parse::<f64>().unwrap_or(0.0)
151			});
152			QueryValue::Float(f)
153		}
154		SV::BigDecimal(Some(d)) => {
155			let f = d.to_string().parse::<f64>().unwrap_or_else(|_| {
156				tracing::warn!(
157					big_decimal = %d,
158					"BigDecimal cannot be represented as f64"
159				);
160				0.0
161			});
162			QueryValue::Float(f)
163		}
164
165		// UUID
166		SV::Uuid(Some(u)) => QueryValue::Uuid(*u),
167
168		// Arrays - convert to string
169		// For reinhardt-query 1.0.0-rc.29+: Array(ArrayType, Option<Box<Vec<Value>>>)
170		SV::Array(_, arr) => QueryValue::String(format!("{:?}", arr)),
171	}
172}
173
174/// Convert reinhardt_query Values (`Vec<Value>`) to `Vec<QueryValue>`
175pub fn convert_values(values: reinhardt_query::prelude::Values) -> Vec<QueryValue> {
176	values
177		.0
178		.into_iter()
179		.map(convert_value_to_query_value)
180		.collect()
181}
182
183/// Query execution methods with both sync builders and async execution
184#[async_trait::async_trait]
185pub trait QueryExecution<T: Model>
186where
187	T: Send + Sync,
188	T::PrimaryKey: Send + Sync,
189{
190	/// Get a single result by primary key (async execution)
191	/// Corresponds to SQLAlchemy's .get()
192	async fn get_async(
193		&self,
194		db: &super::connection::DatabaseConnection,
195		pk: &T::PrimaryKey,
196	) -> Result<T, ExecutionError>
197	where
198		T: for<'de> serde::Deserialize<'de>;
199
200	/// Get a single result by primary key (statement builder)
201	/// Returns a SelectStatement for manual execution
202	fn get(&self, pk: &T::PrimaryKey) -> SelectStatement;
203
204	/// Get all results (async execution)
205	/// Corresponds to SQLAlchemy's .all()
206	async fn all_async(
207		&self,
208		db: &super::connection::DatabaseConnection,
209	) -> Result<Vec<T>, ExecutionError>
210	where
211		T: for<'de> serde::Deserialize<'de>;
212
213	/// Get all results (statement builder)
214	/// Returns a SelectStatement for manual execution
215	fn all(&self) -> SelectStatement;
216
217	/// Get first result or None (async execution)
218	/// Corresponds to SQLAlchemy's .first()
219	async fn first_async(
220		&self,
221		db: &super::connection::DatabaseConnection,
222	) -> Result<Option<T>, ExecutionError>
223	where
224		T: for<'de> serde::Deserialize<'de>;
225
226	/// Get first result or None (statement builder)
227	/// Returns a SelectStatement for manual execution
228	fn first(&self) -> SelectStatement;
229
230	/// Get exactly one result, raise if 0 or >1 (async execution)
231	/// Corresponds to SQLAlchemy's .one()
232	async fn one_async(
233		&self,
234		db: &super::connection::DatabaseConnection,
235	) -> Result<T, ExecutionError>
236	where
237		T: for<'de> serde::Deserialize<'de>;
238
239	/// Get exactly one result (statement builder)
240	/// Returns a SelectStatement for manual execution
241	fn one(&self) -> SelectStatement;
242
243	/// Get one result or None, raise if >1 (async execution)
244	/// Corresponds to SQLAlchemy's .one_or_none()
245	async fn one_or_none_async(
246		&self,
247		db: &super::connection::DatabaseConnection,
248	) -> Result<Option<T>, ExecutionError>
249	where
250		T: for<'de> serde::Deserialize<'de>;
251
252	/// Get one result or None (statement builder)
253	/// Returns a SelectStatement for manual execution
254	fn one_or_none(&self) -> SelectStatement;
255
256	/// Get scalar value (first column of first row) (async execution)
257	/// Corresponds to SQLAlchemy's .scalar()
258	async fn scalar_async<S>(
259		&self,
260		db: &super::connection::DatabaseConnection,
261	) -> Result<Option<S>, ExecutionError>
262	where
263		S: for<'de> serde::Deserialize<'de>;
264
265	/// Get scalar value (statement builder)
266	/// Returns a SelectStatement for manual execution
267	fn scalar(&self) -> SelectStatement;
268
269	/// Count results (async execution)
270	/// Corresponds to SQLAlchemy's .count()
271	async fn count_async(
272		&self,
273		db: &super::connection::DatabaseConnection,
274	) -> Result<i64, ExecutionError>;
275
276	/// Count results (statement builder)
277	/// Returns a SelectStatement for manual execution
278	fn count(&self) -> SelectStatement;
279
280	/// Check if any results exist (async execution)
281	/// Corresponds to SQLAlchemy's .exists()
282	async fn exists_async(
283		&self,
284		db: &super::connection::DatabaseConnection,
285	) -> Result<bool, ExecutionError>;
286
287	/// Check if any results exist (statement builder)
288	/// Returns a SelectStatement for manual execution
289	fn exists(&self) -> SelectStatement;
290}
291
292/// Execution context for SELECT queries
293pub struct SelectExecution<T: Model> {
294	stmt: SelectStatement,
295	_phantom: PhantomData<T>,
296}
297
298impl<T: Model> SelectExecution<T> {
299	/// Create a new query execution context with the given SelectStatement
300	///
301	/// # Examples
302	///
303	/// ```rust,no_run
304	/// use reinhardt_db::orm::execution::SelectExecution;
305	/// use reinhardt_db::orm::Model;
306	/// use reinhardt_query::prelude::{QueryStatementBuilder, Alias, Query};
307	/// use serde::{Serialize, Deserialize};
308	///
309	/// #[derive(Debug, Clone, Serialize, Deserialize)]
310	/// struct User {
311	///     id: Option<i64>,
312	///     name: String,
313	/// }
314	///
315	/// #[derive(Clone)]
316	/// struct UserFields;
317	/// impl reinhardt_db::orm::FieldSelector for UserFields {
318	///     fn with_alias(self, _alias: &str) -> Self { self }
319	/// }
320	///
321	/// impl Model for User {
322	///     type PrimaryKey = i64;
323	///     type Fields = UserFields;
324	///     fn app_label() -> &'static str { "app" }
325	///     fn table_name() -> &'static str { "users" }
326	///     fn new_fields() -> Self::Fields { UserFields }
327	///     fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
328	///     fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
329	///     fn primary_key_field() -> &'static str { "id" }
330	/// }
331	///
332	/// let stmt = Query::select().from(Alias::new("users")).to_owned();
333	/// let exec = SelectExecution::<User>::new(stmt);
334	/// ```
335	pub fn new(stmt: SelectStatement) -> Self {
336		Self {
337			stmt,
338			_phantom: PhantomData,
339		}
340	}
341	/// Get a reference to the underlying SelectStatement
342	///
343	/// # Examples
344	///
345	/// ```rust,ignore
346	/// use reinhardt_db::orm::execution::SelectExecution;
347	/// use reinhardt_db::orm::Model;
348	/// use reinhardt_query::prelude::{QueryStatementBuilder, Alias, Expr, Query};
349	/// use serde::{Serialize, Deserialize};
350	///
351	/// #[derive(Debug, Clone, Serialize, Deserialize)]
352	/// struct User {
353	///     id: Option<i64>,
354	///     name: String,
355	/// }
356	///
357	/// #[derive(Clone)]
358	/// struct UserFields;
359	/// impl reinhardt_db::orm::FieldSelector for UserFields {
360	///     fn with_alias(self, _alias: &str) -> Self { self }
361	/// }
362	///
363	/// impl Model for User {
364	///     type PrimaryKey = i64;
365	///     type Fields = UserFields;
366	///     fn app_label() -> &'static str { "app" }
367	///     fn table_name() -> &'static str { "users" }
368	///     fn new_fields() -> Self::Fields { UserFields }
369	///     fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
370	///     fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
371	///     fn primary_key_field() -> &'static str { "id" }
372	/// }
373	///
374	/// let stmt = Query::select()
375	///     .from(Alias::new("users"))
376	///     .and_where(Expr::col(Alias::new("active")).eq(true))
377	///     .to_owned();
378	/// let exec = SelectExecution::<User>::new(stmt);
379	/// ```
380	pub fn statement(&self) -> &SelectStatement {
381		&self.stmt
382	}
383}
384
385#[async_trait::async_trait]
386impl<T: Model> QueryExecution<T> for SelectExecution<T>
387where
388	T::PrimaryKey: Into<reinhardt_query::value::Value> + Clone + Send + Sync,
389	T: Send + Sync,
390{
391	fn get(&self, pk: &T::PrimaryKey) -> SelectStatement {
392		Query::select()
393			.from(Alias::new(T::table_name()))
394			.column(ColumnRef::Asterisk)
395			.and_where(
396				Expr::col(Alias::new(T::primary_key_field())).eq(Expr::val(pk.clone().into())),
397			)
398			.limit(1)
399			.to_owned()
400	}
401
402	fn all(&self) -> SelectStatement {
403		self.stmt.clone()
404	}
405
406	fn first(&self) -> SelectStatement {
407		let mut stmt = self.stmt.clone();
408		stmt.limit(1);
409		stmt
410	}
411
412	fn one(&self) -> SelectStatement {
413		// Sets LIMIT 2 to detect multiple results
414		// The execution layer should:
415		// - Error if 0 results are returned (NoResultFound)
416		// - Error if 2+ results are returned (MultipleResultsFound)
417		// - Return the single result if exactly 1 is found
418		let mut stmt = self.stmt.clone();
419		stmt.limit(2);
420		stmt
421	}
422
423	fn one_or_none(&self) -> SelectStatement {
424		// Sets LIMIT 2 to detect multiple results
425		// The execution layer should:
426		// - Return None if 0 results
427		// - Error if 2+ results are returned (MultipleResultsFound)
428		// - Return Some(result) if exactly 1 is found
429		let mut stmt = self.stmt.clone();
430		stmt.limit(2);
431		stmt
432	}
433
434	fn scalar(&self) -> SelectStatement {
435		let mut stmt = self.stmt.clone();
436		stmt.limit(1);
437		stmt
438	}
439
440	fn count(&self) -> SelectStatement {
441		// Use the original statement as a subquery and count all rows from it
442		// This preserves all WHERE, JOIN, and other conditions
443		Query::select()
444			.expr(Func::count(Expr::asterisk().into_simple_expr()))
445			.from_subquery(self.stmt.clone(), Alias::new("subquery"))
446			.to_owned()
447	}
448
449	fn exists(&self) -> SelectStatement {
450		Query::select()
451			.expr(Expr::exists(self.stmt.clone()))
452			.to_owned()
453	}
454
455	async fn get_async(
456		&self,
457		db: &super::connection::DatabaseConnection,
458		pk: &T::PrimaryKey,
459	) -> Result<T, ExecutionError>
460	where
461		T: for<'de> serde::Deserialize<'de>,
462	{
463		let stmt = self.get(pk);
464		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
465
466		let query_values = convert_values(values);
467		let row = db.query_one(&sql, query_values).await?;
468		let json = serde_json::to_value(&row)?;
469		let result = serde_json::from_value(json)?;
470		Ok(result)
471	}
472
473	async fn all_async(
474		&self,
475		db: &super::connection::DatabaseConnection,
476	) -> Result<Vec<T>, ExecutionError>
477	where
478		T: for<'de> serde::Deserialize<'de>,
479	{
480		let stmt = self.all();
481		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
482
483		let query_values = convert_values(values);
484		let rows = db.query(&sql, query_values).await?;
485		let mut results = Vec::with_capacity(rows.len());
486		for row in rows {
487			let json = serde_json::to_value(&row)?;
488			let result = serde_json::from_value(json)?;
489			results.push(result);
490		}
491		Ok(results)
492	}
493
494	async fn first_async(
495		&self,
496		db: &super::connection::DatabaseConnection,
497	) -> Result<Option<T>, ExecutionError>
498	where
499		T: for<'de> serde::Deserialize<'de>,
500	{
501		let stmt = self.first();
502		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
503
504		let query_values = convert_values(values);
505		let rows = db.query(&sql, query_values).await?;
506		match rows.first() {
507			Some(row) => {
508				let json = serde_json::to_value(row)?;
509				let result = serde_json::from_value(json)?;
510				Ok(Some(result))
511			}
512			None => Ok(None),
513		}
514	}
515
516	async fn one_async(
517		&self,
518		db: &super::connection::DatabaseConnection,
519	) -> Result<T, ExecutionError>
520	where
521		T: for<'de> serde::Deserialize<'de>,
522	{
523		let stmt = self.one();
524		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
525
526		let query_values = convert_values(values);
527		let rows = db.query(&sql, query_values).await?;
528		match rows.len() {
529			0 => Err(ExecutionError::NoResultFound),
530			1 => {
531				let json = serde_json::to_value(&rows[0])?;
532				let result = serde_json::from_value(json)?;
533				Ok(result)
534			}
535			n => Err(ExecutionError::MultipleResultsFound(n)),
536		}
537	}
538
539	async fn one_or_none_async(
540		&self,
541		db: &super::connection::DatabaseConnection,
542	) -> Result<Option<T>, ExecutionError>
543	where
544		T: for<'de> serde::Deserialize<'de>,
545	{
546		let stmt = self.one_or_none();
547		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
548
549		let query_values = convert_values(values);
550		let rows = db.query(&sql, query_values).await?;
551		match rows.len() {
552			0 => Ok(None),
553			1 => {
554				let json = serde_json::to_value(&rows[0])?;
555				let result = serde_json::from_value(json)?;
556				Ok(Some(result))
557			}
558			n => Err(ExecutionError::MultipleResultsFound(n)),
559		}
560	}
561
562	async fn scalar_async<S>(
563		&self,
564		db: &super::connection::DatabaseConnection,
565	) -> Result<Option<S>, ExecutionError>
566	where
567		S: for<'de> serde::Deserialize<'de>,
568	{
569		let stmt = self.scalar();
570		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
571
572		let query_values = convert_values(values);
573		let rows = db.query(&sql, query_values).await?;
574		match rows.first() {
575			Some(row) => {
576				// Get the first column value
577				let json = serde_json::to_value(row)?;
578				if let Some(obj) = json.as_object()
579					&& let Some((_, value)) = obj.iter().next()
580				{
581					let result = serde_json::from_value(value.clone())?;
582					return Ok(Some(result));
583				}
584				Ok(None)
585			}
586			None => Ok(None),
587		}
588	}
589
590	async fn count_async(
591		&self,
592		db: &super::connection::DatabaseConnection,
593	) -> Result<i64, ExecutionError> {
594		let stmt = self.count();
595		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
596
597		let query_values = convert_values(values);
598		let row = db.query_one(&sql, query_values).await?;
599		let json = serde_json::to_value(&row)?;
600
601		// Extract count from the result (usually the first column)
602		if let Some(obj) = json.as_object()
603			&& let Some((_, value)) = obj.iter().next()
604		{
605			let count: i64 = serde_json::from_value(value.clone())?;
606			return Ok(count);
607		}
608
609		Err(ExecutionError::QueryBuild(
610			"Count query returned unexpected format".to_string(),
611		))
612	}
613
614	async fn exists_async(
615		&self,
616		db: &super::connection::DatabaseConnection,
617	) -> Result<bool, ExecutionError> {
618		let stmt = self.exists();
619		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
620
621		let query_values = convert_values(values);
622		let row = db.query_one(&sql, query_values).await?;
623		let json = serde_json::to_value(&row)?;
624
625		// Extract exists from the result (usually the first column)
626		if let Some(obj) = json.as_object()
627			&& let Some((_, value)) = obj.iter().next()
628		{
629			let exists: bool = serde_json::from_value(value.clone())?;
630			return Ok(exists);
631		}
632
633		Err(ExecutionError::QueryBuild(
634			"Exists query returned unexpected format".to_string(),
635		))
636	}
637}
638
639/// Loading options for relationships
640/// Corresponds to SQLAlchemy's loader options
641#[derive(Debug, Clone)]
642pub enum LoadOption {
643	/// Eager load with JOIN
644	/// Corresponds to joinedload()
645	JoinedLoad(String),
646
647	/// Eager load with separate SELECT
648	/// Corresponds to selectinload()
649	SelectInLoad(String),
650
651	/// Lazy load on access
652	/// Corresponds to lazyload()
653	LazyLoad(String),
654
655	/// Don't load at all
656	/// Corresponds to noload()
657	NoLoad(String),
658
659	/// Raise error if accessed
660	/// Corresponds to raiseload()
661	RaiseLoad(String),
662
663	/// Defer column loading
664	/// Corresponds to defer()
665	Defer(String),
666
667	/// Undefer column loading
668	/// Corresponds to undefer()
669	Undefer(String),
670
671	/// Load only specified columns
672	/// Corresponds to load_only()
673	LoadOnly(Vec<String>),
674}
675
676impl LoadOption {
677	/// Convert load option to SQL comment for debugging
678	///
679	/// # Examples
680	///
681	/// ```
682	/// use reinhardt_db::orm::execution::LoadOption;
683	///
684	/// let option = LoadOption::JoinedLoad("profile".to_string());
685	/// assert_eq!(option.to_sql_comment(), "/* joinedload(profile) */");
686	///
687	/// let option = LoadOption::Defer("password".to_string());
688	/// assert_eq!(option.to_sql_comment(), "/* defer(password) */");
689	///
690	/// let option = LoadOption::LoadOnly(vec!["id".to_string(), "name".to_string()]);
691	/// assert_eq!(option.to_sql_comment(), "/* load_only(id, name) */");
692	/// ```
693	pub fn to_sql_comment(&self) -> String {
694		match self {
695			LoadOption::JoinedLoad(rel) => format!("/* joinedload({}) */", rel),
696			LoadOption::SelectInLoad(rel) => format!("/* selectinload({}) */", rel),
697			LoadOption::LazyLoad(rel) => format!("/* lazyload({}) */", rel),
698			LoadOption::NoLoad(rel) => format!("/* noload({}) */", rel),
699			LoadOption::RaiseLoad(rel) => format!("/* raiseload({}) */", rel),
700			LoadOption::Defer(col) => format!("/* defer({}) */", col),
701			LoadOption::Undefer(col) => format!("/* undefer({}) */", col),
702			LoadOption::LoadOnly(cols) => format!("/* load_only({}) */", cols.join(", ")),
703		}
704	}
705}
706
707/// Query options container
708#[non_exhaustive]
709pub struct QueryOptions {
710	/// The load options.
711	pub load_options: Vec<LoadOption>,
712}
713
714impl QueryOptions {
715	/// Create a new empty query options container
716	///
717	/// # Examples
718	///
719	/// ```
720	/// use reinhardt_db::orm::execution::QueryOptions;
721	///
722	/// let options = QueryOptions::new();
723	/// assert_eq!(options.to_sql_comments(), "");
724	/// ```
725	pub fn new() -> Self {
726		Self {
727			load_options: Vec::new(),
728		}
729	}
730	/// Add a load option to the query
731	///
732	/// # Examples
733	///
734	/// ```
735	/// use reinhardt_db::orm::execution::{QueryOptions, LoadOption};
736	///
737	/// let options = QueryOptions::new()
738	///     .add_option(LoadOption::JoinedLoad("profile".to_string()))
739	///     .add_option(LoadOption::Defer("password".to_string()));
740	///
741	/// let comments = options.to_sql_comments();
742	/// assert!(comments.contains("joinedload(profile)"));
743	/// assert!(comments.contains("defer(password)"));
744	/// ```
745	pub fn add_option(mut self, option: LoadOption) -> Self {
746		self.load_options.push(option);
747		self
748	}
749	/// Convert all options to SQL comments
750	///
751	/// # Examples
752	///
753	/// ```
754	/// use reinhardt_db::orm::execution::{QueryOptions, LoadOption};
755	///
756	/// let options = QueryOptions::new()
757	///     .add_option(LoadOption::SelectInLoad("posts".to_string()));
758	///
759	/// assert!(options.to_sql_comments().contains("selectinload(posts)"));
760	/// ```
761	pub fn to_sql_comments(&self) -> String {
762		if self.load_options.is_empty() {
763			String::new()
764		} else {
765			format!(
766				" {}",
767				self.load_options
768					.iter()
769					.map(|o| o.to_sql_comment())
770					.collect::<Vec<_>>()
771					.join(" ")
772			)
773		}
774	}
775}
776
777impl Default for QueryOptions {
778	fn default() -> Self {
779		Self::new()
780	}
781}
782
783#[cfg(test)]
784mod tests {
785	use super::*;
786	use reinhardt_core::validators::TableName;
787	use rstest::rstest;
788	use serde::{Deserialize, Serialize};
789
790	#[derive(Debug, Clone, Serialize, Deserialize)]
791	struct User {
792		id: Option<i64>,
793		name: String,
794	}
795
796	#[derive(Clone)]
797	struct UserFields;
798	impl crate::orm::model::FieldSelector for UserFields {
799		fn with_alias(self, _alias: &str) -> Self {
800			self
801		}
802	}
803
804	const USER_TABLE: TableName = TableName::new_const("users");
805
806	impl Model for User {
807		type PrimaryKey = i64;
808		type Fields = UserFields;
809
810		fn table_name() -> &'static str {
811			USER_TABLE.as_str()
812		}
813
814		fn new_fields() -> Self::Fields {
815			UserFields
816		}
817
818		fn primary_key(&self) -> Option<Self::PrimaryKey> {
819			self.id
820		}
821
822		fn set_primary_key(&mut self, value: Self::PrimaryKey) {
823			self.id = Some(value);
824		}
825	}
826
827	#[test]
828	fn test_execution_get() {
829		use reinhardt_query::prelude::{Alias, PostgresQueryBuilder, Query, QueryStatementBuilder};
830
831		let stmt = Query::select()
832			.from(Alias::new("users"))
833			.column(ColumnRef::Asterisk)
834			.to_owned();
835		let exec = SelectExecution::<User>::new(stmt);
836		let result_stmt = exec.get(&123);
837		let sql = result_stmt.to_string(PostgresQueryBuilder);
838		assert!(sql.contains("WHERE"));
839		assert!(sql.contains("LIMIT"));
840	}
841
842	#[test]
843	fn test_all() {
844		use reinhardt_query::prelude::{Alias, PostgresQueryBuilder, Query, QueryStatementBuilder};
845
846		let stmt = Query::select()
847			.from(Alias::new("users"))
848			.column(ColumnRef::Asterisk)
849			.to_owned();
850		let exec = SelectExecution::<User>::new(stmt);
851		let result_stmt = exec.all();
852		let sql = result_stmt.to_string(PostgresQueryBuilder);
853		assert!(sql.contains("SELECT"));
854		assert!(sql.contains("users"));
855	}
856
857	#[test]
858	fn test_first() {
859		use reinhardt_query::prelude::{
860			Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
861		};
862
863		let stmt = Query::select()
864			.from(Alias::new("users"))
865			.column(ColumnRef::Asterisk)
866			.and_where(Expr::col(Alias::new("active")).eq(true))
867			.to_owned();
868		let exec = SelectExecution::<User>::new(stmt);
869		let result_stmt = exec.first();
870		let sql = result_stmt.to_string(PostgresQueryBuilder);
871		assert!(sql.contains("LIMIT"));
872	}
873
874	#[test]
875	fn test_execution_count() {
876		use reinhardt_query::prelude::{
877			Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
878		};
879
880		let stmt = Query::select()
881			.from(Alias::new("users"))
882			.column(ColumnRef::Asterisk)
883			.and_where(Expr::col(Alias::new("active")).eq(true))
884			.to_owned();
885		let exec = SelectExecution::<User>::new(stmt);
886		let result_stmt = exec.count();
887		let sql = result_stmt.to_string(PostgresQueryBuilder);
888		assert!(sql.contains("COUNT"));
889	}
890
891	#[test]
892	fn test_execution_exists() {
893		use reinhardt_query::prelude::{
894			Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
895		};
896
897		let stmt = Query::select()
898			.from(Alias::new("users"))
899			.column(ColumnRef::Asterisk)
900			.and_where(Expr::col(Alias::new("name")).eq("Alice"))
901			.to_owned();
902		let exec = SelectExecution::<User>::new(stmt);
903		let result_stmt = exec.exists();
904		let sql = result_stmt.to_string(PostgresQueryBuilder);
905		assert!(sql.contains("EXISTS"));
906	}
907
908	#[test]
909	fn test_load_options() {
910		let options = QueryOptions::new()
911			.add_option(LoadOption::JoinedLoad("profile".to_string()))
912			.add_option(LoadOption::Defer("password".to_string()));
913
914		let comments = options.to_sql_comments();
915		assert!(comments.contains("joinedload(profile)"));
916		assert!(comments.contains("defer(password)"));
917	}
918
919	#[test]
920	fn test_load_only() {
921		let option = LoadOption::LoadOnly(vec!["id".to_string(), "name".to_string()]);
922		let comment = option.to_sql_comment();
923		assert!(comment.contains("load_only(id, name)"));
924	}
925
926	#[rstest]
927	#[case::zero(0u64, 0i64)]
928	#[case::one(1u64, 1i64)]
929	#[case::i64_max(i64::MAX as u64, i64::MAX)]
930	#[test]
931	fn test_big_unsigned_to_query_value_within_range(#[case] input: u64, #[case] expected: i64) {
932		// Arrange
933		let value = reinhardt_query::value::Value::BigUnsigned(Some(input));
934
935		// Act
936		let result = convert_value_to_query_value(value);
937
938		// Assert
939		assert!(matches!(result, QueryValue::Int(v) if v == expected));
940	}
941
942	#[rstest]
943	#[case::i64_max_plus_one(i64::MAX as u64 + 1)]
944	#[case::u64_max(u64::MAX)]
945	#[test]
946	fn test_big_unsigned_overflow_clamps_to_i64_max(#[case] input: u64) {
947		// Arrange
948		let value = reinhardt_query::value::Value::BigUnsigned(Some(input));
949
950		// Act
951		let result = convert_value_to_query_value(value);
952
953		// Assert: Should clamp to i64::MAX instead of wrapping to negative
954		assert!(matches!(result, QueryValue::Int(v) if v == i64::MAX));
955	}
956
957	#[rstest]
958	#[test]
959	fn test_big_unsigned_none_converts_to_null() {
960		// Arrange
961		let value = reinhardt_query::value::Value::BigUnsigned(None);
962
963		// Act
964		let result = convert_value_to_query_value(value);
965
966		// Assert
967		assert!(matches!(result, QueryValue::Null));
968	}
969}