Skip to main content

reinhardt_db/orm/
execution.rs

1//! # Query Execution
2//!
3//! SQLAlchemy-inspired query execution methods.
4//!
5//! This module provides execution methods similar to SQLAlchemy's Query class
6
7use crate::backends::types::QueryValue;
8use crate::orm::Model;
9use reinhardt_query::prelude::{
10	Alias, ColumnRef, Expr, ExprTrait, Func, Query, QueryStatementBuilder, SelectStatement,
11};
12use rust_decimal::prelude::ToPrimitive;
13use std::marker::PhantomData;
14
15/// Query execution result types
16#[derive(Debug)]
17pub enum ExecutionResult<T> {
18	/// Single result
19	One(T),
20	/// Optional single result
21	OneOrNone(Option<T>),
22	/// Multiple results
23	All(Vec<T>),
24	/// Scalar value (for aggregates)
25	Scalar(String),
26	/// No result (for mutations)
27	None,
28}
29
30/// Errors that can occur during query execution
31#[derive(Debug, thiserror::Error)]
32pub enum ExecutionError {
33	/// Database error
34	#[error("Database error: {0}")]
35	Database(#[from] crate::backends::DatabaseError),
36
37	/// No result found (for .one())
38	#[error("No result found")]
39	NoResultFound,
40
41	/// Multiple results found (for .one() and .one_or_none())
42	#[error("Multiple results found (expected 1, got {0})")]
43	MultipleResultsFound(usize),
44
45	/// Deserialization error
46	#[error("Failed to deserialize result: {0}")]
47	Deserialization(#[from] serde_json::Error),
48
49	/// Query building error
50	#[error("Query building error: {0}")]
51	QueryBuild(String),
52
53	/// Generic error from anyhow
54	#[error("Generic error: {0}")]
55	Generic(#[from] anyhow::Error),
56}
57
58/// Convert reinhardt_query Value to QueryValue for parameter binding
59fn convert_value_to_query_value(value: reinhardt_query::value::Value) -> QueryValue {
60	use reinhardt_query::value::Value as SV;
61
62	match value {
63		// Null values
64		SV::Bool(None)
65		| SV::TinyInt(None)
66		| SV::SmallInt(None)
67		| SV::Int(None)
68		| SV::BigInt(None)
69		| SV::TinyUnsigned(None)
70		| SV::SmallUnsigned(None)
71		| SV::Unsigned(None)
72		| SV::BigUnsigned(None)
73		| SV::Float(None)
74		| SV::Double(None)
75		| SV::String(None)
76		| SV::Char(None)
77		| SV::Bytes(None)
78		| SV::ChronoDateTimeUtc(None)
79		| SV::ChronoDateTimeLocal(None)
80		| SV::ChronoDateTimeWithTimeZone(None)
81		| SV::ChronoDate(None)
82		| SV::ChronoTime(None)
83		| SV::ChronoDateTime(None)
84		| SV::Json(None)
85		| SV::Decimal(None)
86		| SV::BigDecimal(None)
87		| SV::Uuid(None) => QueryValue::Null,
88
89		// Boolean
90		SV::Bool(Some(b)) => QueryValue::Bool(b),
91
92		// Signed integers (convert all to i64)
93		SV::TinyInt(Some(v)) => QueryValue::Int(v as i64),
94		SV::SmallInt(Some(v)) => QueryValue::Int(v as i64),
95		SV::Int(Some(v)) => QueryValue::Int(v as i64),
96		SV::BigInt(Some(v)) => QueryValue::Int(v),
97
98		// Unsigned integers (convert to i64 with checked conversion for large values)
99		SV::TinyUnsigned(Some(v)) => QueryValue::Int(v as i64),
100		SV::SmallUnsigned(Some(v)) => QueryValue::Int(v as i64),
101		SV::Unsigned(Some(v)) => QueryValue::Int(v as i64),
102		SV::BigUnsigned(Some(v)) => QueryValue::Int(i64::try_from(v).unwrap_or_else(|_| {
103			tracing::warn!(
104				value = v,
105				"BigUnsigned value {} exceeds i64::MAX, clamping to i64::MAX",
106				v
107			);
108			i64::MAX
109		})),
110
111		// Floating point
112		SV::Float(Some(v)) => QueryValue::Float(v as f64),
113		SV::Double(Some(v)) => QueryValue::Float(v),
114
115		// String and char
116		SV::String(Some(s)) => QueryValue::String(s.to_string()),
117		SV::Char(Some(c)) => QueryValue::String(c.to_string()),
118
119		// Bytes
120		SV::Bytes(Some(b)) => QueryValue::Bytes(b.to_vec()),
121
122		// Chrono datetime types
123		SV::ChronoDateTimeUtc(Some(dt)) => QueryValue::Timestamp(*dt),
124
125		// For other datetime types, convert to UTC if possible
126		SV::ChronoDateTimeLocal(Some(dt)) => {
127			QueryValue::Timestamp((*dt).with_timezone(&chrono::Utc))
128		}
129		SV::ChronoDateTimeWithTimeZone(Some(dt)) => {
130			QueryValue::Timestamp((*dt).with_timezone(&chrono::Utc))
131		}
132
133		// Other datetime types that cannot be easily converted
134		SV::ChronoDate(_) | SV::ChronoTime(_) | SV::ChronoDateTime(_) => {
135			// Convert to string representation as fallback
136			QueryValue::String(format!("{:?}", value))
137		}
138
139		// JSON - convert to string
140		SV::Json(_) => QueryValue::String(format!("{:?}", value)),
141
142		// Decimal - convert to f64 with fallback through string parsing
143		SV::Decimal(Some(d)) => {
144			let f = d.to_f64().unwrap_or_else(|| {
145				tracing::warn!(
146					decimal = %d,
147					"Decimal cannot be directly represented as f64, falling back to string parsing"
148				);
149				d.to_string().parse::<f64>().unwrap_or(0.0)
150			});
151			QueryValue::Float(f)
152		}
153		SV::BigDecimal(Some(d)) => {
154			let f = d.to_string().parse::<f64>().unwrap_or_else(|_| {
155				tracing::warn!(
156					big_decimal = %d,
157					"BigDecimal cannot be represented as f64"
158				);
159				0.0
160			});
161			QueryValue::Float(f)
162		}
163
164		// UUID
165		SV::Uuid(Some(u)) => QueryValue::Uuid(*u),
166
167		// Arrays - convert to string
168		// For reinhardt-query 1.0.0-rc.29+: Array(ArrayType, Option<Box<Vec<Value>>>)
169		SV::Array(_, arr) => QueryValue::String(format!("{:?}", arr)),
170	}
171}
172
173/// Convert reinhardt_query Values (`Vec<Value>`) to `Vec<QueryValue>`
174pub fn convert_values(values: reinhardt_query::prelude::Values) -> Vec<QueryValue> {
175	values
176		.0
177		.into_iter()
178		.map(convert_value_to_query_value)
179		.collect()
180}
181
182/// Query execution methods with both sync builders and async execution
183#[async_trait::async_trait]
184pub trait QueryExecution<T: Model>
185where
186	T: Send + Sync,
187	T::PrimaryKey: Send + Sync,
188{
189	/// Get a single result by primary key (async execution)
190	/// Corresponds to SQLAlchemy's .get()
191	async fn get_async(
192		&self,
193		db: &super::connection::DatabaseConnection,
194		pk: &T::PrimaryKey,
195	) -> Result<T, ExecutionError>
196	where
197		T: for<'de> serde::Deserialize<'de>;
198
199	/// Get a single result by primary key (statement builder)
200	/// Returns a SelectStatement for manual execution
201	fn get(&self, pk: &T::PrimaryKey) -> SelectStatement;
202
203	/// Get all results (async execution)
204	/// Corresponds to SQLAlchemy's .all()
205	async fn all_async(
206		&self,
207		db: &super::connection::DatabaseConnection,
208	) -> Result<Vec<T>, ExecutionError>
209	where
210		T: for<'de> serde::Deserialize<'de>;
211
212	/// Get all results (statement builder)
213	/// Returns a SelectStatement for manual execution
214	fn all(&self) -> SelectStatement;
215
216	/// Get first result or None (async execution)
217	/// Corresponds to SQLAlchemy's .first()
218	async fn first_async(
219		&self,
220		db: &super::connection::DatabaseConnection,
221	) -> Result<Option<T>, ExecutionError>
222	where
223		T: for<'de> serde::Deserialize<'de>;
224
225	/// Get first result or None (statement builder)
226	/// Returns a SelectStatement for manual execution
227	fn first(&self) -> SelectStatement;
228
229	/// Get exactly one result, raise if 0 or >1 (async execution)
230	/// Corresponds to SQLAlchemy's .one()
231	async fn one_async(
232		&self,
233		db: &super::connection::DatabaseConnection,
234	) -> Result<T, ExecutionError>
235	where
236		T: for<'de> serde::Deserialize<'de>;
237
238	/// Get exactly one result (statement builder)
239	/// Returns a SelectStatement for manual execution
240	fn one(&self) -> SelectStatement;
241
242	/// Get one result or None, raise if >1 (async execution)
243	/// Corresponds to SQLAlchemy's .one_or_none()
244	async fn one_or_none_async(
245		&self,
246		db: &super::connection::DatabaseConnection,
247	) -> Result<Option<T>, ExecutionError>
248	where
249		T: for<'de> serde::Deserialize<'de>;
250
251	/// Get one result or None (statement builder)
252	/// Returns a SelectStatement for manual execution
253	fn one_or_none(&self) -> SelectStatement;
254
255	/// Get scalar value (first column of first row) (async execution)
256	/// Corresponds to SQLAlchemy's .scalar()
257	async fn scalar_async<S>(
258		&self,
259		db: &super::connection::DatabaseConnection,
260	) -> Result<Option<S>, ExecutionError>
261	where
262		S: for<'de> serde::Deserialize<'de>;
263
264	/// Get scalar value (statement builder)
265	/// Returns a SelectStatement for manual execution
266	fn scalar(&self) -> SelectStatement;
267
268	/// Count results (async execution)
269	/// Corresponds to SQLAlchemy's .count()
270	async fn count_async(
271		&self,
272		db: &super::connection::DatabaseConnection,
273	) -> Result<i64, ExecutionError>;
274
275	/// Count results (statement builder)
276	/// Returns a SelectStatement for manual execution
277	fn count(&self) -> SelectStatement;
278
279	/// Check if any results exist (async execution)
280	/// Corresponds to SQLAlchemy's .exists()
281	async fn exists_async(
282		&self,
283		db: &super::connection::DatabaseConnection,
284	) -> Result<bool, ExecutionError>;
285
286	/// Check if any results exist (statement builder)
287	/// Returns a SelectStatement for manual execution
288	fn exists(&self) -> SelectStatement;
289}
290
291/// Execution context for SELECT queries
292pub struct SelectExecution<T: Model> {
293	stmt: SelectStatement,
294	_phantom: PhantomData<T>,
295}
296
297impl<T: Model> SelectExecution<T> {
298	/// Create a new query execution context with the given SelectStatement
299	///
300	/// # Examples
301	///
302	/// ```rust,no_run
303	/// use reinhardt_db::orm::execution::SelectExecution;
304	/// use reinhardt_db::orm::Model;
305	/// use reinhardt_query::prelude::{QueryStatementBuilder, Alias, Query};
306	/// use serde::{Serialize, Deserialize};
307	///
308	/// #[derive(Debug, Clone, Serialize, Deserialize)]
309	/// struct User {
310	///     id: Option<i64>,
311	///     name: String,
312	/// }
313	///
314	/// #[derive(Clone)]
315	/// struct UserFields;
316	/// impl reinhardt_db::orm::FieldSelector for UserFields {
317	///     fn with_alias(self, _alias: &str) -> Self { self }
318	/// }
319	///
320	/// impl Model for User {
321	///     type PrimaryKey = i64;
322	///     type Fields = UserFields;
323	///     fn app_label() -> &'static str { "app" }
324	///     fn table_name() -> &'static str { "users" }
325	///     fn new_fields() -> Self::Fields { UserFields }
326	///     fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
327	///     fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
328	///     fn primary_key_field() -> &'static str { "id" }
329	/// }
330	///
331	/// let stmt = Query::select().from(Alias::new("users")).to_owned();
332	/// let exec = SelectExecution::<User>::new(stmt);
333	/// ```
334	pub fn new(stmt: SelectStatement) -> Self {
335		Self {
336			stmt,
337			_phantom: PhantomData,
338		}
339	}
340	/// Get a reference to the underlying SelectStatement
341	///
342	/// # Examples
343	///
344	/// ```rust,ignore
345	/// use reinhardt_db::orm::execution::SelectExecution;
346	/// use reinhardt_db::orm::Model;
347	/// use reinhardt_query::prelude::{QueryStatementBuilder, Alias, Expr, Query};
348	/// use serde::{Serialize, Deserialize};
349	///
350	/// #[derive(Debug, Clone, Serialize, Deserialize)]
351	/// struct User {
352	///     id: Option<i64>,
353	///     name: String,
354	/// }
355	///
356	/// #[derive(Clone)]
357	/// struct UserFields;
358	/// impl reinhardt_db::orm::FieldSelector for UserFields {
359	///     fn with_alias(self, _alias: &str) -> Self { self }
360	/// }
361	///
362	/// impl Model for User {
363	///     type PrimaryKey = i64;
364	///     type Fields = UserFields;
365	///     fn app_label() -> &'static str { "app" }
366	///     fn table_name() -> &'static str { "users" }
367	///     fn new_fields() -> Self::Fields { UserFields }
368	///     fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
369	///     fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
370	///     fn primary_key_field() -> &'static str { "id" }
371	/// }
372	///
373	/// let stmt = Query::select()
374	///     .from(Alias::new("users"))
375	///     .and_where(Expr::col(Alias::new("active")).eq(true))
376	///     .to_owned();
377	/// let exec = SelectExecution::<User>::new(stmt);
378	/// ```
379	pub fn statement(&self) -> &SelectStatement {
380		&self.stmt
381	}
382}
383
384#[async_trait::async_trait]
385impl<T: Model> QueryExecution<T> for SelectExecution<T>
386where
387	T::PrimaryKey: Into<reinhardt_query::value::Value> + Clone + Send + Sync,
388	T: Send + Sync,
389{
390	fn get(&self, pk: &T::PrimaryKey) -> SelectStatement {
391		Query::select()
392			.from(Alias::new(T::table_name()))
393			.column(ColumnRef::Asterisk)
394			.and_where(
395				Expr::col(Alias::new(T::primary_key_field())).eq(Expr::val(pk.clone().into())),
396			)
397			.limit(1)
398			.to_owned()
399	}
400
401	fn all(&self) -> SelectStatement {
402		self.stmt.clone()
403	}
404
405	fn first(&self) -> SelectStatement {
406		let mut stmt = self.stmt.clone();
407		stmt.limit(1);
408		stmt
409	}
410
411	fn one(&self) -> SelectStatement {
412		// Sets LIMIT 2 to detect multiple results
413		// The execution layer should:
414		// - Error if 0 results are returned (NoResultFound)
415		// - Error if 2+ results are returned (MultipleResultsFound)
416		// - Return the single result if exactly 1 is found
417		let mut stmt = self.stmt.clone();
418		stmt.limit(2);
419		stmt
420	}
421
422	fn one_or_none(&self) -> SelectStatement {
423		// Sets LIMIT 2 to detect multiple results
424		// The execution layer should:
425		// - Return None if 0 results
426		// - Error if 2+ results are returned (MultipleResultsFound)
427		// - Return Some(result) if exactly 1 is found
428		let mut stmt = self.stmt.clone();
429		stmt.limit(2);
430		stmt
431	}
432
433	fn scalar(&self) -> SelectStatement {
434		let mut stmt = self.stmt.clone();
435		stmt.limit(1);
436		stmt
437	}
438
439	fn count(&self) -> SelectStatement {
440		// Use the original statement as a subquery and count all rows from it
441		// This preserves all WHERE, JOIN, and other conditions
442		Query::select()
443			.expr(Func::count(Expr::asterisk().into_simple_expr()))
444			.from_subquery(self.stmt.clone(), Alias::new("subquery"))
445			.to_owned()
446	}
447
448	fn exists(&self) -> SelectStatement {
449		Query::select()
450			.expr(Expr::exists(self.stmt.clone()))
451			.to_owned()
452	}
453
454	async fn get_async(
455		&self,
456		db: &super::connection::DatabaseConnection,
457		pk: &T::PrimaryKey,
458	) -> Result<T, ExecutionError>
459	where
460		T: for<'de> serde::Deserialize<'de>,
461	{
462		let stmt = self.get(pk);
463		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
464
465		let query_values = convert_values(values);
466		let row = db.query_one(&sql, query_values).await?;
467		let json = serde_json::to_value(&row)?;
468		let result = serde_json::from_value(json)?;
469		Ok(result)
470	}
471
472	async fn all_async(
473		&self,
474		db: &super::connection::DatabaseConnection,
475	) -> Result<Vec<T>, ExecutionError>
476	where
477		T: for<'de> serde::Deserialize<'de>,
478	{
479		let stmt = self.all();
480		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
481
482		let query_values = convert_values(values);
483		let rows = db.query(&sql, query_values).await?;
484		let mut results = Vec::with_capacity(rows.len());
485		for row in rows {
486			let json = serde_json::to_value(&row)?;
487			let result = serde_json::from_value(json)?;
488			results.push(result);
489		}
490		Ok(results)
491	}
492
493	async fn first_async(
494		&self,
495		db: &super::connection::DatabaseConnection,
496	) -> Result<Option<T>, ExecutionError>
497	where
498		T: for<'de> serde::Deserialize<'de>,
499	{
500		let stmt = self.first();
501		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
502
503		let query_values = convert_values(values);
504		let rows = db.query(&sql, query_values).await?;
505		match rows.first() {
506			Some(row) => {
507				let json = serde_json::to_value(row)?;
508				let result = serde_json::from_value(json)?;
509				Ok(Some(result))
510			}
511			None => Ok(None),
512		}
513	}
514
515	async fn one_async(
516		&self,
517		db: &super::connection::DatabaseConnection,
518	) -> Result<T, ExecutionError>
519	where
520		T: for<'de> serde::Deserialize<'de>,
521	{
522		let stmt = self.one();
523		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
524
525		let query_values = convert_values(values);
526		let rows = db.query(&sql, query_values).await?;
527		match rows.len() {
528			0 => Err(ExecutionError::NoResultFound),
529			1 => {
530				let json = serde_json::to_value(&rows[0])?;
531				let result = serde_json::from_value(json)?;
532				Ok(result)
533			}
534			n => Err(ExecutionError::MultipleResultsFound(n)),
535		}
536	}
537
538	async fn one_or_none_async(
539		&self,
540		db: &super::connection::DatabaseConnection,
541	) -> Result<Option<T>, ExecutionError>
542	where
543		T: for<'de> serde::Deserialize<'de>,
544	{
545		let stmt = self.one_or_none();
546		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
547
548		let query_values = convert_values(values);
549		let rows = db.query(&sql, query_values).await?;
550		match rows.len() {
551			0 => Ok(None),
552			1 => {
553				let json = serde_json::to_value(&rows[0])?;
554				let result = serde_json::from_value(json)?;
555				Ok(Some(result))
556			}
557			n => Err(ExecutionError::MultipleResultsFound(n)),
558		}
559	}
560
561	async fn scalar_async<S>(
562		&self,
563		db: &super::connection::DatabaseConnection,
564	) -> Result<Option<S>, ExecutionError>
565	where
566		S: for<'de> serde::Deserialize<'de>,
567	{
568		let stmt = self.scalar();
569		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
570
571		let query_values = convert_values(values);
572		let rows = db.query(&sql, query_values).await?;
573		match rows.first() {
574			Some(row) => {
575				// Get the first column value
576				let json = serde_json::to_value(row)?;
577				if let Some(obj) = json.as_object()
578					&& let Some((_, value)) = obj.iter().next()
579				{
580					let result = serde_json::from_value(value.clone())?;
581					return Ok(Some(result));
582				}
583				Ok(None)
584			}
585			None => Ok(None),
586		}
587	}
588
589	async fn count_async(
590		&self,
591		db: &super::connection::DatabaseConnection,
592	) -> Result<i64, ExecutionError> {
593		let stmt = self.count();
594		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
595
596		let query_values = convert_values(values);
597		let row = db.query_one(&sql, query_values).await?;
598		let json = serde_json::to_value(&row)?;
599
600		// Extract count from the result (usually the first column)
601		if let Some(obj) = json.as_object()
602			&& let Some((_, value)) = obj.iter().next()
603		{
604			let count: i64 = serde_json::from_value(value.clone())?;
605			return Ok(count);
606		}
607
608		Err(ExecutionError::QueryBuild(
609			"Count query returned unexpected format".to_string(),
610		))
611	}
612
613	async fn exists_async(
614		&self,
615		db: &super::connection::DatabaseConnection,
616	) -> Result<bool, ExecutionError> {
617		let stmt = self.exists();
618		let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
619
620		let query_values = convert_values(values);
621		let row = db.query_one(&sql, query_values).await?;
622		let json = serde_json::to_value(&row)?;
623
624		// Extract exists from the result (usually the first column)
625		if let Some(obj) = json.as_object()
626			&& let Some((_, value)) = obj.iter().next()
627		{
628			let exists: bool = serde_json::from_value(value.clone())?;
629			return Ok(exists);
630		}
631
632		Err(ExecutionError::QueryBuild(
633			"Exists query returned unexpected format".to_string(),
634		))
635	}
636}
637
638/// Loading options for relationships
639/// Corresponds to SQLAlchemy's loader options
640#[derive(Debug, Clone)]
641pub enum LoadOption {
642	/// Eager load with JOIN
643	/// Corresponds to joinedload()
644	JoinedLoad(String),
645
646	/// Eager load with separate SELECT
647	/// Corresponds to selectinload()
648	SelectInLoad(String),
649
650	/// Lazy load on access
651	/// Corresponds to lazyload()
652	LazyLoad(String),
653
654	/// Don't load at all
655	/// Corresponds to noload()
656	NoLoad(String),
657
658	/// Raise error if accessed
659	/// Corresponds to raiseload()
660	RaiseLoad(String),
661
662	/// Defer column loading
663	/// Corresponds to defer()
664	Defer(String),
665
666	/// Undefer column loading
667	/// Corresponds to undefer()
668	Undefer(String),
669
670	/// Load only specified columns
671	/// Corresponds to load_only()
672	LoadOnly(Vec<String>),
673}
674
675impl LoadOption {
676	/// Convert load option to SQL comment for debugging
677	///
678	/// # Examples
679	///
680	/// ```
681	/// use reinhardt_db::orm::execution::LoadOption;
682	///
683	/// let option = LoadOption::JoinedLoad("profile".to_string());
684	/// assert_eq!(option.to_sql_comment(), "/* joinedload(profile) */");
685	///
686	/// let option = LoadOption::Defer("password".to_string());
687	/// assert_eq!(option.to_sql_comment(), "/* defer(password) */");
688	///
689	/// let option = LoadOption::LoadOnly(vec!["id".to_string(), "name".to_string()]);
690	/// assert_eq!(option.to_sql_comment(), "/* load_only(id, name) */");
691	/// ```
692	pub fn to_sql_comment(&self) -> String {
693		match self {
694			LoadOption::JoinedLoad(rel) => format!("/* joinedload({}) */", rel),
695			LoadOption::SelectInLoad(rel) => format!("/* selectinload({}) */", rel),
696			LoadOption::LazyLoad(rel) => format!("/* lazyload({}) */", rel),
697			LoadOption::NoLoad(rel) => format!("/* noload({}) */", rel),
698			LoadOption::RaiseLoad(rel) => format!("/* raiseload({}) */", rel),
699			LoadOption::Defer(col) => format!("/* defer({}) */", col),
700			LoadOption::Undefer(col) => format!("/* undefer({}) */", col),
701			LoadOption::LoadOnly(cols) => format!("/* load_only({}) */", cols.join(", ")),
702		}
703	}
704}
705
706/// Query options container
707pub struct QueryOptions {
708	pub load_options: Vec<LoadOption>,
709}
710
711impl QueryOptions {
712	/// Create a new empty query options container
713	///
714	/// # Examples
715	///
716	/// ```
717	/// use reinhardt_db::orm::execution::QueryOptions;
718	///
719	/// let options = QueryOptions::new();
720	/// assert_eq!(options.to_sql_comments(), "");
721	/// ```
722	pub fn new() -> Self {
723		Self {
724			load_options: Vec::new(),
725		}
726	}
727	/// Add a load option to the query
728	///
729	/// # Examples
730	///
731	/// ```
732	/// use reinhardt_db::orm::execution::{QueryOptions, LoadOption};
733	///
734	/// let options = QueryOptions::new()
735	///     .add_option(LoadOption::JoinedLoad("profile".to_string()))
736	///     .add_option(LoadOption::Defer("password".to_string()));
737	///
738	/// let comments = options.to_sql_comments();
739	/// assert!(comments.contains("joinedload(profile)"));
740	/// assert!(comments.contains("defer(password)"));
741	/// ```
742	pub fn add_option(mut self, option: LoadOption) -> Self {
743		self.load_options.push(option);
744		self
745	}
746	/// Convert all options to SQL comments
747	///
748	/// # Examples
749	///
750	/// ```
751	/// use reinhardt_db::orm::execution::{QueryOptions, LoadOption};
752	///
753	/// let options = QueryOptions::new()
754	///     .add_option(LoadOption::SelectInLoad("posts".to_string()));
755	///
756	/// assert!(options.to_sql_comments().contains("selectinload(posts)"));
757	/// ```
758	pub fn to_sql_comments(&self) -> String {
759		if self.load_options.is_empty() {
760			String::new()
761		} else {
762			format!(
763				" {}",
764				self.load_options
765					.iter()
766					.map(|o| o.to_sql_comment())
767					.collect::<Vec<_>>()
768					.join(" ")
769			)
770		}
771	}
772}
773
774impl Default for QueryOptions {
775	fn default() -> Self {
776		Self::new()
777	}
778}
779
780#[cfg(test)]
781mod tests {
782	use super::*;
783	use reinhardt_core::validators::TableName;
784	use rstest::rstest;
785	use serde::{Deserialize, Serialize};
786
787	#[derive(Debug, Clone, Serialize, Deserialize)]
788	struct User {
789		id: Option<i64>,
790		name: String,
791	}
792
793	#[derive(Clone)]
794	struct UserFields;
795	impl crate::orm::model::FieldSelector for UserFields {
796		fn with_alias(self, _alias: &str) -> Self {
797			self
798		}
799	}
800
801	const USER_TABLE: TableName = TableName::new_const("users");
802
803	impl Model for User {
804		type PrimaryKey = i64;
805		type Fields = UserFields;
806
807		fn table_name() -> &'static str {
808			USER_TABLE.as_str()
809		}
810
811		fn new_fields() -> Self::Fields {
812			UserFields
813		}
814
815		fn primary_key(&self) -> Option<Self::PrimaryKey> {
816			self.id
817		}
818
819		fn set_primary_key(&mut self, value: Self::PrimaryKey) {
820			self.id = Some(value);
821		}
822	}
823
824	#[test]
825	fn test_execution_get() {
826		use reinhardt_query::prelude::{Alias, PostgresQueryBuilder, Query, QueryStatementBuilder};
827
828		let stmt = Query::select()
829			.from(Alias::new("users"))
830			.column(ColumnRef::Asterisk)
831			.to_owned();
832		let exec = SelectExecution::<User>::new(stmt);
833		let result_stmt = exec.get(&123);
834		let sql = result_stmt.to_string(PostgresQueryBuilder);
835		assert!(sql.contains("WHERE"));
836		assert!(sql.contains("LIMIT"));
837	}
838
839	#[test]
840	fn test_all() {
841		use reinhardt_query::prelude::{Alias, PostgresQueryBuilder, Query, QueryStatementBuilder};
842
843		let stmt = Query::select()
844			.from(Alias::new("users"))
845			.column(ColumnRef::Asterisk)
846			.to_owned();
847		let exec = SelectExecution::<User>::new(stmt);
848		let result_stmt = exec.all();
849		let sql = result_stmt.to_string(PostgresQueryBuilder);
850		assert!(sql.contains("SELECT"));
851		assert!(sql.contains("users"));
852	}
853
854	#[test]
855	fn test_first() {
856		use reinhardt_query::prelude::{
857			Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
858		};
859
860		let stmt = Query::select()
861			.from(Alias::new("users"))
862			.column(ColumnRef::Asterisk)
863			.and_where(Expr::col(Alias::new("active")).eq(true))
864			.to_owned();
865		let exec = SelectExecution::<User>::new(stmt);
866		let result_stmt = exec.first();
867		let sql = result_stmt.to_string(PostgresQueryBuilder);
868		assert!(sql.contains("LIMIT"));
869	}
870
871	#[test]
872	fn test_execution_count() {
873		use reinhardt_query::prelude::{
874			Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
875		};
876
877		let stmt = Query::select()
878			.from(Alias::new("users"))
879			.column(ColumnRef::Asterisk)
880			.and_where(Expr::col(Alias::new("active")).eq(true))
881			.to_owned();
882		let exec = SelectExecution::<User>::new(stmt);
883		let result_stmt = exec.count();
884		let sql = result_stmt.to_string(PostgresQueryBuilder);
885		assert!(sql.contains("COUNT"));
886	}
887
888	#[test]
889	fn test_execution_exists() {
890		use reinhardt_query::prelude::{
891			Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
892		};
893
894		let stmt = Query::select()
895			.from(Alias::new("users"))
896			.column(ColumnRef::Asterisk)
897			.and_where(Expr::col(Alias::new("name")).eq("Alice"))
898			.to_owned();
899		let exec = SelectExecution::<User>::new(stmt);
900		let result_stmt = exec.exists();
901		let sql = result_stmt.to_string(PostgresQueryBuilder);
902		assert!(sql.contains("EXISTS"));
903	}
904
905	#[test]
906	fn test_load_options() {
907		let options = QueryOptions::new()
908			.add_option(LoadOption::JoinedLoad("profile".to_string()))
909			.add_option(LoadOption::Defer("password".to_string()));
910
911		let comments = options.to_sql_comments();
912		assert!(comments.contains("joinedload(profile)"));
913		assert!(comments.contains("defer(password)"));
914	}
915
916	#[test]
917	fn test_load_only() {
918		let option = LoadOption::LoadOnly(vec!["id".to_string(), "name".to_string()]);
919		let comment = option.to_sql_comment();
920		assert!(comment.contains("load_only(id, name)"));
921	}
922
923	#[rstest]
924	#[case::zero(0u64, 0i64)]
925	#[case::one(1u64, 1i64)]
926	#[case::i64_max(i64::MAX as u64, i64::MAX)]
927	#[test]
928	fn test_big_unsigned_to_query_value_within_range(#[case] input: u64, #[case] expected: i64) {
929		// Arrange
930		let value = reinhardt_query::value::Value::BigUnsigned(Some(input));
931
932		// Act
933		let result = convert_value_to_query_value(value);
934
935		// Assert
936		assert!(matches!(result, QueryValue::Int(v) if v == expected));
937	}
938
939	#[rstest]
940	#[case::i64_max_plus_one(i64::MAX as u64 + 1)]
941	#[case::u64_max(u64::MAX)]
942	#[test]
943	fn test_big_unsigned_overflow_clamps_to_i64_max(#[case] input: u64) {
944		// Arrange
945		let value = reinhardt_query::value::Value::BigUnsigned(Some(input));
946
947		// Act
948		let result = convert_value_to_query_value(value);
949
950		// Assert: Should clamp to i64::MAX instead of wrapping to negative
951		assert!(matches!(result, QueryValue::Int(v) if v == i64::MAX));
952	}
953
954	#[rstest]
955	#[test]
956	fn test_big_unsigned_none_converts_to_null() {
957		// Arrange
958		let value = reinhardt_query::value::Value::BigUnsigned(None);
959
960		// Act
961		let result = convert_value_to_query_value(value);
962
963		// Assert
964		assert!(matches!(result, QueryValue::Null));
965	}
966}