use crate::backends::types::QueryValue;
use crate::orm::Model;
use reinhardt_query::prelude::{
Alias, ColumnRef, Expr, ExprTrait, Func, Query, QueryStatementBuilder, SelectStatement,
};
use rust_decimal::prelude::ToPrimitive;
use std::marker::PhantomData;
#[derive(Debug)]
pub enum ExecutionResult<T> {
One(T),
OneOrNone(Option<T>),
All(Vec<T>),
Scalar(String),
None,
}
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum ExecutionError {
#[error("Database error: {0}")]
Database(#[from] crate::backends::DatabaseError),
#[error("No result found")]
NoResultFound,
#[error("Multiple results found (expected 1, got {0})")]
MultipleResultsFound(usize),
#[error("Failed to deserialize result: {0}")]
Deserialization(#[from] serde_json::Error),
#[error("Query building error: {0}")]
QueryBuild(String),
#[error("Generic error: {0}")]
Generic(#[from] anyhow::Error),
}
fn convert_value_to_query_value(value: reinhardt_query::value::Value) -> QueryValue {
use reinhardt_query::value::Value as SV;
match value {
SV::Bool(None)
| SV::TinyInt(None)
| SV::SmallInt(None)
| SV::Int(None)
| SV::BigInt(None)
| SV::TinyUnsigned(None)
| SV::SmallUnsigned(None)
| SV::Unsigned(None)
| SV::BigUnsigned(None)
| SV::Float(None)
| SV::Double(None)
| SV::String(None)
| SV::Char(None)
| SV::Bytes(None)
| SV::ChronoDateTimeUtc(None)
| SV::ChronoDateTimeLocal(None)
| SV::ChronoDateTimeWithTimeZone(None)
| SV::ChronoDate(None)
| SV::ChronoTime(None)
| SV::ChronoDateTime(None)
| SV::Json(None)
| SV::Decimal(None)
| SV::BigDecimal(None)
| SV::Uuid(None) => QueryValue::Null,
SV::Bool(Some(b)) => QueryValue::Bool(b),
SV::TinyInt(Some(v)) => QueryValue::Int(v as i64),
SV::SmallInt(Some(v)) => QueryValue::Int(v as i64),
SV::Int(Some(v)) => QueryValue::Int(v as i64),
SV::BigInt(Some(v)) => QueryValue::Int(v),
SV::TinyUnsigned(Some(v)) => QueryValue::Int(v as i64),
SV::SmallUnsigned(Some(v)) => QueryValue::Int(v as i64),
SV::Unsigned(Some(v)) => QueryValue::Int(v as i64),
SV::BigUnsigned(Some(v)) => QueryValue::Int(i64::try_from(v).unwrap_or_else(|_| {
tracing::warn!(
value = v,
"BigUnsigned value {} exceeds i64::MAX, clamping to i64::MAX",
v
);
i64::MAX
})),
SV::Float(Some(v)) => QueryValue::Float(v as f64),
SV::Double(Some(v)) => QueryValue::Float(v),
SV::String(Some(s)) => QueryValue::String(s.to_string()),
SV::Char(Some(c)) => QueryValue::String(c.to_string()),
SV::Bytes(Some(b)) => QueryValue::Bytes(b.to_vec()),
SV::ChronoDateTimeUtc(Some(dt)) => QueryValue::Timestamp(*dt),
SV::ChronoDateTimeLocal(Some(dt)) => {
QueryValue::Timestamp((*dt).with_timezone(&chrono::Utc))
}
SV::ChronoDateTimeWithTimeZone(Some(dt)) => {
QueryValue::Timestamp((*dt).with_timezone(&chrono::Utc))
}
SV::ChronoDate(_) | SV::ChronoTime(_) | SV::ChronoDateTime(_) => {
QueryValue::String(format!("{:?}", value))
}
SV::Json(_) => QueryValue::String(format!("{:?}", value)),
SV::Decimal(Some(d)) => {
let f = d.to_f64().unwrap_or_else(|| {
tracing::warn!(
decimal = %d,
"Decimal cannot be directly represented as f64, falling back to string parsing"
);
d.to_string().parse::<f64>().unwrap_or(0.0)
});
QueryValue::Float(f)
}
SV::BigDecimal(Some(d)) => {
let f = d.to_string().parse::<f64>().unwrap_or_else(|_| {
tracing::warn!(
big_decimal = %d,
"BigDecimal cannot be represented as f64"
);
0.0
});
QueryValue::Float(f)
}
SV::Uuid(Some(u)) => QueryValue::Uuid(*u),
SV::Array(_, arr) => QueryValue::String(format!("{:?}", arr)),
}
}
pub fn convert_values(values: reinhardt_query::prelude::Values) -> Vec<QueryValue> {
values
.0
.into_iter()
.map(convert_value_to_query_value)
.collect()
}
#[async_trait::async_trait]
pub trait QueryExecution<T: Model>
where
T: Send + Sync,
T::PrimaryKey: Send + Sync,
{
async fn get_async(
&self,
db: &super::connection::DatabaseConnection,
pk: &T::PrimaryKey,
) -> Result<T, ExecutionError>
where
T: for<'de> serde::Deserialize<'de>;
fn get(&self, pk: &T::PrimaryKey) -> SelectStatement;
async fn all_async(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<Vec<T>, ExecutionError>
where
T: for<'de> serde::Deserialize<'de>;
fn all(&self) -> SelectStatement;
async fn first_async(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<Option<T>, ExecutionError>
where
T: for<'de> serde::Deserialize<'de>;
fn first(&self) -> SelectStatement;
async fn one_async(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<T, ExecutionError>
where
T: for<'de> serde::Deserialize<'de>;
fn one(&self) -> SelectStatement;
async fn one_or_none_async(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<Option<T>, ExecutionError>
where
T: for<'de> serde::Deserialize<'de>;
fn one_or_none(&self) -> SelectStatement;
async fn scalar_async<S>(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<Option<S>, ExecutionError>
where
S: for<'de> serde::Deserialize<'de>;
fn scalar(&self) -> SelectStatement;
async fn count_async(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<i64, ExecutionError>;
fn count(&self) -> SelectStatement;
async fn exists_async(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<bool, ExecutionError>;
fn exists(&self) -> SelectStatement;
}
pub struct SelectExecution<T: Model> {
stmt: SelectStatement,
_phantom: PhantomData<T>,
}
impl<T: Model> SelectExecution<T> {
pub fn new(stmt: SelectStatement) -> Self {
Self {
stmt,
_phantom: PhantomData,
}
}
pub fn statement(&self) -> &SelectStatement {
&self.stmt
}
}
#[async_trait::async_trait]
impl<T: Model> QueryExecution<T> for SelectExecution<T>
where
T::PrimaryKey: Into<reinhardt_query::value::Value> + Clone + Send + Sync,
T: Send + Sync,
{
fn get(&self, pk: &T::PrimaryKey) -> SelectStatement {
Query::select()
.from(Alias::new(T::table_name()))
.column(ColumnRef::Asterisk)
.and_where(
Expr::col(Alias::new(T::primary_key_field())).eq(Expr::val(pk.clone().into())),
)
.limit(1)
.to_owned()
}
fn all(&self) -> SelectStatement {
self.stmt.clone()
}
fn first(&self) -> SelectStatement {
let mut stmt = self.stmt.clone();
stmt.limit(1);
stmt
}
fn one(&self) -> SelectStatement {
let mut stmt = self.stmt.clone();
stmt.limit(2);
stmt
}
fn one_or_none(&self) -> SelectStatement {
let mut stmt = self.stmt.clone();
stmt.limit(2);
stmt
}
fn scalar(&self) -> SelectStatement {
let mut stmt = self.stmt.clone();
stmt.limit(1);
stmt
}
fn count(&self) -> SelectStatement {
Query::select()
.expr(Func::count(Expr::asterisk().into_simple_expr()))
.from_subquery(self.stmt.clone(), Alias::new("subquery"))
.to_owned()
}
fn exists(&self) -> SelectStatement {
Query::select()
.expr(Expr::exists(self.stmt.clone()))
.to_owned()
}
async fn get_async(
&self,
db: &super::connection::DatabaseConnection,
pk: &T::PrimaryKey,
) -> Result<T, ExecutionError>
where
T: for<'de> serde::Deserialize<'de>,
{
let stmt = self.get(pk);
let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
let query_values = convert_values(values);
let row = db.query_one(&sql, query_values).await?;
let json = serde_json::to_value(&row)?;
let result = serde_json::from_value(json)?;
Ok(result)
}
async fn all_async(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<Vec<T>, ExecutionError>
where
T: for<'de> serde::Deserialize<'de>,
{
let stmt = self.all();
let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
let query_values = convert_values(values);
let rows = db.query(&sql, query_values).await?;
let mut results = Vec::with_capacity(rows.len());
for row in rows {
let json = serde_json::to_value(&row)?;
let result = serde_json::from_value(json)?;
results.push(result);
}
Ok(results)
}
async fn first_async(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<Option<T>, ExecutionError>
where
T: for<'de> serde::Deserialize<'de>,
{
let stmt = self.first();
let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
let query_values = convert_values(values);
let rows = db.query(&sql, query_values).await?;
match rows.first() {
Some(row) => {
let json = serde_json::to_value(row)?;
let result = serde_json::from_value(json)?;
Ok(Some(result))
}
None => Ok(None),
}
}
async fn one_async(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<T, ExecutionError>
where
T: for<'de> serde::Deserialize<'de>,
{
let stmt = self.one();
let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
let query_values = convert_values(values);
let rows = db.query(&sql, query_values).await?;
match rows.len() {
0 => Err(ExecutionError::NoResultFound),
1 => {
let json = serde_json::to_value(&rows[0])?;
let result = serde_json::from_value(json)?;
Ok(result)
}
n => Err(ExecutionError::MultipleResultsFound(n)),
}
}
async fn one_or_none_async(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<Option<T>, ExecutionError>
where
T: for<'de> serde::Deserialize<'de>,
{
let stmt = self.one_or_none();
let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
let query_values = convert_values(values);
let rows = db.query(&sql, query_values).await?;
match rows.len() {
0 => Ok(None),
1 => {
let json = serde_json::to_value(&rows[0])?;
let result = serde_json::from_value(json)?;
Ok(Some(result))
}
n => Err(ExecutionError::MultipleResultsFound(n)),
}
}
async fn scalar_async<S>(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<Option<S>, ExecutionError>
where
S: for<'de> serde::Deserialize<'de>,
{
let stmt = self.scalar();
let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
let query_values = convert_values(values);
let rows = db.query(&sql, query_values).await?;
match rows.first() {
Some(row) => {
let json = serde_json::to_value(row)?;
if let Some(obj) = json.as_object()
&& let Some((_, value)) = obj.iter().next()
{
let result = serde_json::from_value(value.clone())?;
return Ok(Some(result));
}
Ok(None)
}
None => Ok(None),
}
}
async fn count_async(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<i64, ExecutionError> {
let stmt = self.count();
let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
let query_values = convert_values(values);
let row = db.query_one(&sql, query_values).await?;
let json = serde_json::to_value(&row)?;
if let Some(obj) = json.as_object()
&& let Some((_, value)) = obj.iter().next()
{
let count: i64 = serde_json::from_value(value.clone())?;
return Ok(count);
}
Err(ExecutionError::QueryBuild(
"Count query returned unexpected format".to_string(),
))
}
async fn exists_async(
&self,
db: &super::connection::DatabaseConnection,
) -> Result<bool, ExecutionError> {
let stmt = self.exists();
let (sql, values) = stmt.build_any(&reinhardt_query::prelude::PostgresQueryBuilder);
let query_values = convert_values(values);
let row = db.query_one(&sql, query_values).await?;
let json = serde_json::to_value(&row)?;
if let Some(obj) = json.as_object()
&& let Some((_, value)) = obj.iter().next()
{
let exists: bool = serde_json::from_value(value.clone())?;
return Ok(exists);
}
Err(ExecutionError::QueryBuild(
"Exists query returned unexpected format".to_string(),
))
}
}
#[derive(Debug, Clone)]
pub enum LoadOption {
JoinedLoad(String),
SelectInLoad(String),
LazyLoad(String),
NoLoad(String),
RaiseLoad(String),
Defer(String),
Undefer(String),
LoadOnly(Vec<String>),
}
impl LoadOption {
pub fn to_sql_comment(&self) -> String {
match self {
LoadOption::JoinedLoad(rel) => format!("/* joinedload({}) */", rel),
LoadOption::SelectInLoad(rel) => format!("/* selectinload({}) */", rel),
LoadOption::LazyLoad(rel) => format!("/* lazyload({}) */", rel),
LoadOption::NoLoad(rel) => format!("/* noload({}) */", rel),
LoadOption::RaiseLoad(rel) => format!("/* raiseload({}) */", rel),
LoadOption::Defer(col) => format!("/* defer({}) */", col),
LoadOption::Undefer(col) => format!("/* undefer({}) */", col),
LoadOption::LoadOnly(cols) => format!("/* load_only({}) */", cols.join(", ")),
}
}
}
#[non_exhaustive]
pub struct QueryOptions {
pub load_options: Vec<LoadOption>,
}
impl QueryOptions {
pub fn new() -> Self {
Self {
load_options: Vec::new(),
}
}
pub fn add_option(mut self, option: LoadOption) -> Self {
self.load_options.push(option);
self
}
pub fn to_sql_comments(&self) -> String {
if self.load_options.is_empty() {
String::new()
} else {
format!(
" {}",
self.load_options
.iter()
.map(|o| o.to_sql_comment())
.collect::<Vec<_>>()
.join(" ")
)
}
}
}
impl Default for QueryOptions {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use reinhardt_core::validators::TableName;
use rstest::rstest;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct User {
id: Option<i64>,
name: String,
}
#[derive(Clone)]
struct UserFields;
impl crate::orm::model::FieldSelector for UserFields {
fn with_alias(self, _alias: &str) -> Self {
self
}
}
const USER_TABLE: TableName = TableName::new_const("users");
impl Model for User {
type PrimaryKey = i64;
type Fields = UserFields;
fn table_name() -> &'static str {
USER_TABLE.as_str()
}
fn new_fields() -> Self::Fields {
UserFields
}
fn primary_key(&self) -> Option<Self::PrimaryKey> {
self.id
}
fn set_primary_key(&mut self, value: Self::PrimaryKey) {
self.id = Some(value);
}
}
#[test]
fn test_execution_get() {
use reinhardt_query::prelude::{Alias, PostgresQueryBuilder, Query, QueryStatementBuilder};
let stmt = Query::select()
.from(Alias::new("users"))
.column(ColumnRef::Asterisk)
.to_owned();
let exec = SelectExecution::<User>::new(stmt);
let result_stmt = exec.get(&123);
let sql = result_stmt.to_string(PostgresQueryBuilder);
assert!(sql.contains("WHERE"));
assert!(sql.contains("LIMIT"));
}
#[test]
fn test_all() {
use reinhardt_query::prelude::{Alias, PostgresQueryBuilder, Query, QueryStatementBuilder};
let stmt = Query::select()
.from(Alias::new("users"))
.column(ColumnRef::Asterisk)
.to_owned();
let exec = SelectExecution::<User>::new(stmt);
let result_stmt = exec.all();
let sql = result_stmt.to_string(PostgresQueryBuilder);
assert!(sql.contains("SELECT"));
assert!(sql.contains("users"));
}
#[test]
fn test_first() {
use reinhardt_query::prelude::{
Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
};
let stmt = Query::select()
.from(Alias::new("users"))
.column(ColumnRef::Asterisk)
.and_where(Expr::col(Alias::new("active")).eq(true))
.to_owned();
let exec = SelectExecution::<User>::new(stmt);
let result_stmt = exec.first();
let sql = result_stmt.to_string(PostgresQueryBuilder);
assert!(sql.contains("LIMIT"));
}
#[test]
fn test_execution_count() {
use reinhardt_query::prelude::{
Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
};
let stmt = Query::select()
.from(Alias::new("users"))
.column(ColumnRef::Asterisk)
.and_where(Expr::col(Alias::new("active")).eq(true))
.to_owned();
let exec = SelectExecution::<User>::new(stmt);
let result_stmt = exec.count();
let sql = result_stmt.to_string(PostgresQueryBuilder);
assert!(sql.contains("COUNT"));
}
#[test]
fn test_execution_exists() {
use reinhardt_query::prelude::{
Alias, Expr, PostgresQueryBuilder, Query, QueryStatementBuilder,
};
let stmt = Query::select()
.from(Alias::new("users"))
.column(ColumnRef::Asterisk)
.and_where(Expr::col(Alias::new("name")).eq("Alice"))
.to_owned();
let exec = SelectExecution::<User>::new(stmt);
let result_stmt = exec.exists();
let sql = result_stmt.to_string(PostgresQueryBuilder);
assert!(sql.contains("EXISTS"));
}
#[test]
fn test_load_options() {
let options = QueryOptions::new()
.add_option(LoadOption::JoinedLoad("profile".to_string()))
.add_option(LoadOption::Defer("password".to_string()));
let comments = options.to_sql_comments();
assert!(comments.contains("joinedload(profile)"));
assert!(comments.contains("defer(password)"));
}
#[test]
fn test_load_only() {
let option = LoadOption::LoadOnly(vec!["id".to_string(), "name".to_string()]);
let comment = option.to_sql_comment();
assert!(comment.contains("load_only(id, name)"));
}
#[rstest]
#[case::zero(0u64, 0i64)]
#[case::one(1u64, 1i64)]
#[case::i64_max(i64::MAX as u64, i64::MAX)]
#[test]
fn test_big_unsigned_to_query_value_within_range(#[case] input: u64, #[case] expected: i64) {
let value = reinhardt_query::value::Value::BigUnsigned(Some(input));
let result = convert_value_to_query_value(value);
assert!(matches!(result, QueryValue::Int(v) if v == expected));
}
#[rstest]
#[case::i64_max_plus_one(i64::MAX as u64 + 1)]
#[case::u64_max(u64::MAX)]
#[test]
fn test_big_unsigned_overflow_clamps_to_i64_max(#[case] input: u64) {
let value = reinhardt_query::value::Value::BigUnsigned(Some(input));
let result = convert_value_to_query_value(value);
assert!(matches!(result, QueryValue::Int(v) if v == i64::MAX));
}
#[rstest]
#[test]
fn test_big_unsigned_none_converts_to_null() {
let value = reinhardt_query::value::Value::BigUnsigned(None);
let result = convert_value_to_query_value(value);
assert!(matches!(result, QueryValue::Null));
}
}