use crate::types::{AdminError, AdminResult};
use async_trait::async_trait;
use reinhardt_core::macros::injectable;
use reinhardt_db::migrations::FieldType as DbFieldType;
use reinhardt_db::orm::execution::convert_values;
use reinhardt_db::orm::{
DatabaseConnection, Filter, FilterCondition, FilterOperator, FilterValue, Model,
};
use reinhardt_di::{DiResult, Injectable, InjectionContext};
use reinhardt_query::prelude::{
Alias, CaseStatement, ColumnRef, Condition, Expr, ExprTrait, IntoValue, Order,
PostgresQueryBuilder, Query, QueryStatementBuilder, SimpleExpr, Value,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
fn json_to_sea_value(value: serde_json::Value) -> Value {
match value {
serde_json::Value::String(s) => {
if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(&s) {
Value::ChronoDateTimeUtc(Some(Box::new(dt.with_timezone(&chrono::Utc))))
} else if let Ok(dt) =
chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S%.fZ")
{
Value::ChronoDateTimeUtc(Some(Box::new(dt.and_utc())))
} else if s.len() == 10 {
if let Ok(d) = chrono::NaiveDate::parse_from_str(&s, "%Y-%m-%d") {
return Value::ChronoDate(Some(Box::new(d)));
}
Value::String(Some(Box::new(s)))
} else if s.len() == 8 && s.chars().filter(|c| *c == ':').count() == 2 {
if let Ok(t) = chrono::NaiveTime::parse_from_str(&s, "%H:%M:%S") {
return Value::ChronoTime(Some(Box::new(t)));
}
Value::String(Some(Box::new(s)))
} else if s.len() == 36
&& s.chars().enumerate().all(|(i, c)| {
matches!(i, 8 | 13 | 18 | 23) && c == '-' || c.is_ascii_hexdigit()
}) {
if let Ok(uuid) = uuid::Uuid::parse_str(&s) {
return Value::Uuid(Some(Box::new(uuid)));
}
Value::String(Some(Box::new(s)))
} else {
Value::String(Some(Box::new(s)))
}
}
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Value::BigInt(Some(i))
} else if let Some(f) = n.as_f64() {
Value::Double(Some(f))
} else {
Value::String(Some(Box::new(n.to_string())))
}
}
serde_json::Value::Bool(b) => Value::Bool(Some(b)),
serde_json::Value::Null => Value::Int(None),
_ => Value::String(Some(Box::new(value.to_string()))),
}
}
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdminRecord {
pub id: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct AdminRecordFields {
pub id: reinhardt_db::orm::query_fields::Field<AdminRecord, Option<i64>>,
}
impl Default for AdminRecordFields {
fn default() -> Self {
Self::new()
}
}
impl AdminRecordFields {
pub fn new() -> Self {
Self {
id: reinhardt_db::orm::query_fields::Field::new(vec!["id".to_string()]),
}
}
}
impl reinhardt_db::orm::FieldSelector for AdminRecordFields {
fn with_alias(mut self, alias: &str) -> Self {
self.id = self.id.with_alias(alias);
self
}
}
impl Model for AdminRecord {
type PrimaryKey = i64;
type Fields = AdminRecordFields;
fn table_name() -> &'static str {
"admin_records"
}
fn new_fields() -> Self::Fields {
AdminRecordFields::new()
}
fn primary_key(&self) -> Option<Self::PrimaryKey> {
self.id
}
fn set_primary_key(&mut self, pk: Self::PrimaryKey) {
self.id = Some(pk);
}
}
fn parse_pk_value(table_name: &str, pk_field: &str, id: &str) -> Value {
if let Some(field_meta) =
crate::server::type_inference::get_field_metadata(table_name, pk_field)
{
match field_meta.field_type {
DbFieldType::Uuid => {
if let Ok(uuid) = uuid::Uuid::parse_str(id) {
return Value::Uuid(Some(Box::new(uuid)));
}
}
DbFieldType::BigInteger => {
if let Ok(num) = id.parse::<i64>() {
return Value::BigInt(Some(num));
}
}
DbFieldType::Integer
| DbFieldType::SmallInteger
| DbFieldType::TinyInt
| DbFieldType::MediumInt => {
if let Ok(num) = id.parse::<i32>() {
return Value::Int(Some(num));
}
}
_ => {}
}
}
if let Ok(num_id) = id.parse::<i64>() {
Value::BigInt(Some(num_id))
} else {
Value::String(Some(Box::new(id.to_string())))
}
}
fn parse_pk_values(table_name: &str, pk_field: &str, ids: &[String]) -> Vec<Value> {
ids.iter()
.map(|id| parse_pk_value(table_name, pk_field, id))
.collect()
}
#[doc(hidden)]
pub fn filter_value_to_sea_value(v: &FilterValue) -> Value {
match v {
FilterValue::String(s) => s.clone().into(),
FilterValue::Integer(i) | FilterValue::Int(i) => (*i).into(),
FilterValue::Float(f) => (*f).into(),
FilterValue::Boolean(b) | FilterValue::Bool(b) => (*b).into(),
FilterValue::Null => Value::Int(None),
FilterValue::Array(_) => Value::String(None),
FilterValue::FieldRef(f) => {
Value::String(Some(Box::new(f.field.clone())))
}
FilterValue::Expression(expr) => {
Value::String(Some(Box::new(expr.to_sql())))
}
FilterValue::OuterRef(outer) => {
Value::String(Some(Box::new(outer.field.clone())))
}
}
}
fn annotation_value_to_safe_expr(
val: &reinhardt_db::orm::annotation::AnnotationValue,
) -> SimpleExpr {
use reinhardt_db::orm::annotation::AnnotationValue;
match val {
AnnotationValue::Value(v) => {
use reinhardt_db::orm::annotation::Value as AnnotValue;
match v {
AnnotValue::String(s) => Expr::val(s.as_str()).into(),
AnnotValue::Int(i) => Expr::val(*i).into(),
AnnotValue::Float(f) => Expr::val(*f).into(),
AnnotValue::Bool(b) => Expr::val(*b).into(),
AnnotValue::Null => Expr::val(Option::<String>::None).into(),
}
}
AnnotationValue::Field(f) => Expr::col(Alias::new(&f.field)).into(),
AnnotationValue::Expression(e) => annotation_expr_to_safe_expr(e),
AnnotationValue::Aggregate(a) => aggregate_to_safe_expr(a),
AnnotationValue::Subquery(_)
| AnnotationValue::ArrayAgg(_)
| AnnotationValue::StringAgg(_)
| AnnotationValue::JsonbAgg(_)
| AnnotationValue::JsonbBuildObject(_)
| AnnotationValue::TsRank(_) => Expr::cust(val.to_sql()).into(),
}
}
fn aggregate_to_safe_expr(agg: &reinhardt_db::orm::aggregation::Aggregate) -> SimpleExpr {
use reinhardt_db::orm::aggregation::AggregateFunc;
let func_name = match agg.func {
AggregateFunc::Count | AggregateFunc::CountDistinct => "COUNT",
AggregateFunc::Sum => "SUM",
AggregateFunc::Avg => "AVG",
AggregateFunc::Max => "MAX",
AggregateFunc::Min => "MIN",
};
if let Some(field) = &agg.field {
let col_expr: SimpleExpr = Expr::col(Alias::new(field)).into();
let is_distinct = agg.distinct || matches!(agg.func, AggregateFunc::CountDistinct);
if is_distinct {
Expr::cust_with_values(format!("{func_name}(DISTINCT ?)"), [col_expr]).into()
} else {
Expr::cust_with_values(format!("{func_name}(?)"), [col_expr]).into()
}
} else {
Expr::cust(format!("{func_name}(*)")).into()
}
}
fn annotation_expr_to_safe_expr(expr: &reinhardt_db::orm::annotation::Expression) -> SimpleExpr {
use reinhardt_db::orm::annotation::Expression as AnnotExpr;
match expr {
AnnotExpr::Add(left, right) => {
let left_expr = annotation_value_to_safe_expr(left);
let right_expr = annotation_value_to_safe_expr(right);
Expr::cust_with_values("(? + ?)", [left_expr, right_expr]).into()
}
AnnotExpr::Subtract(left, right) => {
let left_expr = annotation_value_to_safe_expr(left);
let right_expr = annotation_value_to_safe_expr(right);
Expr::cust_with_values("(? - ?)", [left_expr, right_expr]).into()
}
AnnotExpr::Multiply(left, right) => {
let left_expr = annotation_value_to_safe_expr(left);
let right_expr = annotation_value_to_safe_expr(right);
Expr::cust_with_values("(? * ?)", [left_expr, right_expr]).into()
}
AnnotExpr::Divide(left, right) => {
let left_expr = annotation_value_to_safe_expr(left);
let right_expr = annotation_value_to_safe_expr(right);
Expr::cust_with_values("(? / ?)", [left_expr, right_expr]).into()
}
AnnotExpr::Case { whens, default } => {
let mut case = CaseStatement::new();
for when in whens {
let cond_expr: SimpleExpr = Expr::cust(when.condition.to_sql()).into();
let then_expr = annotation_value_to_safe_expr(&when.then);
case = case.when(cond_expr, then_expr);
}
if let Some(default_val) = default {
case = case.else_result(annotation_value_to_safe_expr(default_val));
}
SimpleExpr::from(case)
}
AnnotExpr::Coalesce(values) => {
let exprs: Vec<SimpleExpr> = values.iter().map(annotation_value_to_safe_expr).collect();
if exprs.is_empty() {
Expr::val(Option::<String>::None).into()
} else {
let placeholders = vec!["?"; exprs.len()].join(", ");
Expr::cust_with_values(format!("COALESCE({placeholders})"), exprs).into()
}
}
}
}
fn escape_like_pattern(input: &str) -> String {
input
.replace('\\', "\\\\")
.replace('%', "\\%")
.replace('_', "\\_")
}
#[doc(hidden)]
pub fn build_single_filter_expr(filter: &Filter) -> Option<SimpleExpr> {
let col = Expr::col(Alias::new(&filter.field));
let expr = match (&filter.operator, &filter.value) {
(FilterOperator::Eq, FilterValue::Null) => col.is_null(),
(FilterOperator::Ne, FilterValue::Null) => col.is_not_null(),
(FilterOperator::Eq, FilterValue::FieldRef(f)) => col.eq(Expr::col(Alias::new(&f.field))),
(FilterOperator::Ne, FilterValue::FieldRef(f)) => col.ne(Expr::col(Alias::new(&f.field))),
(FilterOperator::Gt, FilterValue::FieldRef(f)) => col.gt(Expr::col(Alias::new(&f.field))),
(FilterOperator::Gte, FilterValue::FieldRef(f)) => col.gte(Expr::col(Alias::new(&f.field))),
(FilterOperator::Lt, FilterValue::FieldRef(f)) => col.lt(Expr::col(Alias::new(&f.field))),
(FilterOperator::Lte, FilterValue::FieldRef(f)) => col.lte(Expr::col(Alias::new(&f.field))),
(FilterOperator::Eq, FilterValue::OuterRef(outer)) => {
col.eq(Expr::col(Alias::new(&outer.field)))
}
(FilterOperator::Ne, FilterValue::OuterRef(outer)) => {
col.ne(Expr::col(Alias::new(&outer.field)))
}
(FilterOperator::Gt, FilterValue::OuterRef(outer)) => {
col.gt(Expr::col(Alias::new(&outer.field)))
}
(FilterOperator::Gte, FilterValue::OuterRef(outer)) => {
col.gte(Expr::col(Alias::new(&outer.field)))
}
(FilterOperator::Lt, FilterValue::OuterRef(outer)) => {
col.lt(Expr::col(Alias::new(&outer.field)))
}
(FilterOperator::Lte, FilterValue::OuterRef(outer)) => {
col.lte(Expr::col(Alias::new(&outer.field)))
}
(FilterOperator::Eq, FilterValue::Expression(expr)) => {
col.eq(annotation_expr_to_safe_expr(expr))
}
(FilterOperator::Ne, FilterValue::Expression(expr)) => {
col.ne(annotation_expr_to_safe_expr(expr))
}
(FilterOperator::Gt, FilterValue::Expression(expr)) => {
col.gt(annotation_expr_to_safe_expr(expr))
}
(FilterOperator::Gte, FilterValue::Expression(expr)) => {
col.gte(annotation_expr_to_safe_expr(expr))
}
(FilterOperator::Lt, FilterValue::Expression(expr)) => {
col.lt(annotation_expr_to_safe_expr(expr))
}
(FilterOperator::Lte, FilterValue::Expression(expr)) => {
col.lte(annotation_expr_to_safe_expr(expr))
}
(FilterOperator::Eq, v) => col.eq(filter_value_to_sea_value(v)),
(FilterOperator::Ne, v) => col.ne(filter_value_to_sea_value(v)),
(FilterOperator::Gt, v) => col.gt(filter_value_to_sea_value(v)),
(FilterOperator::Gte, v) => col.gte(filter_value_to_sea_value(v)),
(FilterOperator::Lt, v) => col.lt(filter_value_to_sea_value(v)),
(FilterOperator::Lte, v) => col.lte(filter_value_to_sea_value(v)),
(FilterOperator::Contains, FilterValue::String(s)) => {
col.like(format!("%{}%", escape_like_pattern(s)))
}
(FilterOperator::StartsWith, FilterValue::String(s)) => {
col.like(format!("{}%", escape_like_pattern(s)))
}
(FilterOperator::EndsWith, FilterValue::String(s)) => {
col.like(format!("%{}", escape_like_pattern(s)))
}
(FilterOperator::In, FilterValue::Array(arr)) => {
if arr.is_empty() {
return None;
}
let values: Vec<Value> = arr.iter().map(|v| v.as_str().into_value()).collect();
col.is_in(values)
}
(FilterOperator::NotIn, FilterValue::Array(arr)) => {
if arr.is_empty() {
return None;
}
let values: Vec<Value> = arr.iter().map(|v| v.as_str().into_value()).collect();
col.is_not_in(values)
}
(FilterOperator::In, FilterValue::String(s)) => {
let values: Vec<Value> = s.split(',').map(|v| v.trim().into_value()).collect();
col.is_in(values)
}
(FilterOperator::NotIn, FilterValue::String(s)) => {
let values: Vec<Value> = s.split(',').map(|v| v.trim().into_value()).collect();
col.is_not_in(values)
}
_ => return None,
};
Some(expr)
}
#[doc(hidden)]
pub fn build_filter_condition(filters: &[Filter]) -> Option<Condition> {
if filters.is_empty() {
return None;
}
let mut condition = Condition::all();
let mut added = false;
for filter in filters {
if let Some(expr) = build_single_filter_expr(filter) {
condition = condition.add(expr);
added = true;
}
}
if added { Some(condition) } else { None }
}
#[doc(hidden)]
pub const MAX_FILTER_DEPTH: usize = 100;
#[doc(hidden)]
pub fn build_composite_filter_condition(
filter_condition: &FilterCondition,
) -> AdminResult<Option<Condition>> {
build_composite_filter_condition_with_depth(filter_condition, 0)
}
#[doc(hidden)]
pub fn build_composite_filter_condition_with_depth(
filter_condition: &FilterCondition,
depth: usize,
) -> AdminResult<Option<Condition>> {
if depth >= MAX_FILTER_DEPTH {
return Err(AdminError::ValidationError(format!(
"Filter condition exceeded maximum depth of {} levels",
MAX_FILTER_DEPTH
)));
}
match filter_condition {
FilterCondition::Single(filter) => {
Ok(build_single_filter_expr(filter).map(|expr| Condition::all().add(expr)))
}
FilterCondition::And(conditions) => {
if conditions.is_empty() {
return Ok(None);
}
let mut and_condition = Condition::all();
let mut added = false;
for cond in conditions {
if let Some(sub_cond) =
build_composite_filter_condition_with_depth(cond, depth + 1)?
{
and_condition = and_condition.add(sub_cond);
added = true;
}
}
if added {
Ok(Some(and_condition))
} else {
Ok(None)
}
}
FilterCondition::Or(conditions) => {
if conditions.is_empty() {
return Ok(None);
}
let mut or_condition = Condition::any();
let mut added = false;
for cond in conditions {
if let Some(sub_cond) =
build_composite_filter_condition_with_depth(cond, depth + 1)?
{
or_condition = or_condition.add(sub_cond);
added = true;
}
}
if added {
Ok(Some(or_condition))
} else {
Ok(None)
}
}
FilterCondition::Not(inner) => Ok(build_composite_filter_condition_with_depth(
inner,
depth + 1,
)?
.map(|inner_cond| inner_cond.not())),
}
}
#[injectable(scope = Singleton, prebuilt = true)]
#[derive(Clone)]
pub struct AdminDatabase {
connection: Arc<DatabaseConnection>,
}
impl AdminDatabase {
pub fn new(connection: DatabaseConnection) -> Self {
Self {
connection: Arc::new(connection),
}
}
pub fn from_arc(connection: Arc<DatabaseConnection>) -> Self {
Self { connection }
}
pub fn connection(&self) -> &DatabaseConnection {
&self.connection
}
pub fn connection_arc(&self) -> Arc<DatabaseConnection> {
Arc::clone(&self.connection)
}
pub async fn list<M: Model>(
&self,
table_name: &str,
filters: Vec<Filter>,
offset: u64,
limit: u64,
) -> AdminResult<Vec<HashMap<String, serde_json::Value>>> {
let mut query = Query::select()
.from(Alias::new(table_name))
.column(ColumnRef::Asterisk)
.to_owned();
if let Some(condition) = build_filter_condition(&filters) {
query.cond_where(condition);
}
query.limit(limit).offset(offset);
let (sql, values) = query.build(PostgresQueryBuilder);
let params = convert_values(values);
let rows = self
.connection
.query(&sql, params)
.await
.map_err(|e| AdminError::DatabaseError(e.to_string()))?;
Ok(rows
.into_iter()
.filter_map(|row| {
if let serde_json::Value::Object(map) = row.data {
Some(
map.into_iter()
.collect::<HashMap<String, serde_json::Value>>(),
)
} else {
None
}
})
.collect())
}
pub async fn list_with_condition<M: Model>(
&self,
table_name: &str,
filter_condition: Option<&FilterCondition>,
additional_filters: Vec<Filter>,
sort_by: Option<&str>,
offset: u64,
limit: u64,
) -> AdminResult<Vec<HashMap<String, serde_json::Value>>> {
let mut query = Query::select()
.from(Alias::new(table_name))
.column(ColumnRef::Asterisk)
.to_owned();
let mut combined = Condition::all();
if let Some(fc) = filter_condition
&& let Some(cond) = build_composite_filter_condition(fc)?
{
combined = combined.add(cond);
}
if let Some(simple_cond) = build_filter_condition(&additional_filters) {
combined = combined.add(simple_cond);
}
if !additional_filters.is_empty() || filter_condition.is_some() {
query.cond_where(combined);
}
if let Some(sort_str) = sort_by {
let (field, is_desc) = if let Some(stripped) = sort_str.strip_prefix('-') {
(stripped, true)
} else {
(sort_str, false)
};
let col = Alias::new(field);
if is_desc {
query.order_by(col, Order::Desc);
} else {
query.order_by(col, Order::Asc);
}
}
query.limit(limit).offset(offset);
let (sql, values) = query.build(PostgresQueryBuilder);
let params = convert_values(values);
let rows = self
.connection
.query(&sql, params)
.await
.map_err(|e| AdminError::DatabaseError(e.to_string()))?;
const SENSITIVE_FIELDS: &[&str] = &["password_hash", "password_salt"];
Ok(rows
.into_iter()
.filter_map(|row| {
if let serde_json::Value::Object(map) = row.data {
Some(
map.into_iter()
.filter(|(key, _)| !SENSITIVE_FIELDS.contains(&key.as_str()))
.collect::<HashMap<String, serde_json::Value>>(),
)
} else {
None
}
})
.collect())
}
pub async fn count_with_condition<M: Model>(
&self,
table_name: &str,
filter_condition: Option<&FilterCondition>,
additional_filters: Vec<Filter>,
) -> AdminResult<u64> {
let mut query = Query::select()
.from(Alias::new(table_name))
.expr(Expr::cust("COUNT(*) AS count"))
.to_owned();
let mut combined = Condition::all();
if let Some(fc) = filter_condition
&& let Some(cond) = build_composite_filter_condition(fc)?
{
combined = combined.add(cond);
}
if let Some(simple_cond) = build_filter_condition(&additional_filters) {
combined = combined.add(simple_cond);
}
if !additional_filters.is_empty() || filter_condition.is_some() {
query.cond_where(combined);
}
let (sql, values) = query.build(PostgresQueryBuilder);
let params = convert_values(values);
let row = self
.connection
.query_one(&sql, params)
.await
.map_err(|e| AdminError::DatabaseError(e.to_string()))?;
let count = extract_count_from_row(&row.data)?;
Ok(count)
}
pub async fn get<M: Model>(
&self,
table_name: &str,
pk_field: &str,
id: &str,
) -> AdminResult<Option<HashMap<String, serde_json::Value>>> {
let pk_value = parse_pk_value(table_name, pk_field, id);
let query = Query::select()
.from(Alias::new(table_name))
.column(ColumnRef::Asterisk)
.and_where(Expr::col(Alias::new(pk_field)).eq(pk_value))
.to_owned();
let (sql, values) = query.build(PostgresQueryBuilder);
let params = convert_values(values);
let row = self
.connection
.query_optional(&sql, params)
.await
.map_err(|e| AdminError::DatabaseError(e.to_string()))?;
Ok(row.and_then(|r| {
if let serde_json::Value::Object(map) = r.data {
Some(
map.into_iter()
.collect::<HashMap<String, serde_json::Value>>(),
)
} else {
None
}
}))
}
pub async fn create<M: Model>(
&self,
table_name: &str,
pk_field: Option<&str>,
data: HashMap<String, serde_json::Value>,
) -> AdminResult<u64> {
let pk_field = pk_field.unwrap_or("id");
let mut query = Query::insert()
.into_table(Alias::new(table_name))
.to_owned();
let mut sorted_keys: Vec<String> = data.keys().cloned().collect();
sorted_keys.sort();
let mut columns = Vec::new();
let mut values = Vec::new();
for key in sorted_keys {
let value = data.get(&key).cloned().unwrap_or(serde_json::Value::Null);
columns.push(Alias::new(&key));
let sea_value = json_to_sea_value(value);
values.push(sea_value);
}
query.columns(columns).values(values).map_err(|e| {
AdminError::DatabaseError(format!("column/value count mismatch: {}", e))
})?;
query.returning([Alias::new(pk_field)]);
let (sql, values) = query.build(PostgresQueryBuilder);
let params = convert_values(values);
let row = self
.connection
.query_one(&sql, params)
.await
.map_err(|e| AdminError::DatabaseError(e.to_string()))?;
match row.data.get(pk_field) {
Some(serde_json::Value::Number(n)) => n.as_u64().ok_or_else(|| {
AdminError::DatabaseError(format!(
"RETURNING clause for '{}' returned non-unsigned-integer: {}",
pk_field, n
))
}),
Some(serde_json::Value::String(_)) => {
Ok(1)
}
_ => Err(AdminError::DatabaseError(format!(
"RETURNING clause did not return expected primary key field '{}'",
pk_field
))),
}
}
pub async fn update<M: Model>(
&self,
table_name: &str,
pk_field: &str,
id: &str,
data: HashMap<String, serde_json::Value>,
) -> AdminResult<u64> {
let mut query = Query::update().table(Alias::new(table_name)).to_owned();
let mut sorted_keys: Vec<String> = data.keys().cloned().collect();
sorted_keys.sort();
for key in sorted_keys {
let value = data.get(&key).cloned().unwrap_or(serde_json::Value::Null);
let sea_value = json_to_sea_value(value);
query.value(Alias::new(&key), sea_value);
}
let pk_value = parse_pk_value(table_name, pk_field, id);
query.and_where(Expr::col(Alias::new(pk_field)).eq(pk_value));
let (sql, values) = query.build(PostgresQueryBuilder);
let params = convert_values(values);
let affected = self
.connection
.execute(&sql, params)
.await
.map_err(|e| AdminError::DatabaseError(e.to_string()))?;
Ok(affected)
}
pub async fn delete<M: Model>(
&self,
table_name: &str,
pk_field: &str,
id: &str,
) -> AdminResult<u64> {
let pk_value = parse_pk_value(table_name, pk_field, id);
let query = Query::delete()
.from_table(Alias::new(table_name))
.and_where(Expr::col(Alias::new(pk_field)).eq(pk_value))
.to_owned();
let (sql, values) = query.build(PostgresQueryBuilder);
let params = convert_values(values);
let affected = self
.connection
.execute(&sql, params)
.await
.map_err(|e| AdminError::DatabaseError(e.to_string()))?;
Ok(affected)
}
pub async fn bulk_delete<M: Model>(
&self,
table_name: &str,
pk_field: &str,
ids: Vec<String>,
) -> AdminResult<u64> {
self.bulk_delete_by_table(table_name, pk_field, ids).await
}
pub async fn bulk_delete_by_table(
&self,
table_name: &str,
pk_field: &str,
ids: Vec<String>,
) -> AdminResult<u64> {
if ids.is_empty() {
return Ok(0);
}
let pk_values = parse_pk_values(table_name, pk_field, &ids);
let query = Query::delete()
.from_table(Alias::new(table_name))
.and_where(Expr::col(Alias::new(pk_field)).is_in(pk_values))
.to_owned();
let (sql, values) = query.build(PostgresQueryBuilder);
let params = convert_values(values);
let affected = self
.connection
.execute(&sql, params)
.await
.map_err(|e| AdminError::DatabaseError(e.to_string()))?;
Ok(affected)
}
pub async fn count<M: Model>(
&self,
table_name: &str,
filters: Vec<Filter>,
) -> AdminResult<u64> {
let mut query = Query::select()
.from(Alias::new(table_name))
.expr(Expr::cust("COUNT(*) AS count"))
.to_owned();
if let Some(condition) = build_filter_condition(&filters) {
query.cond_where(condition);
}
let (sql, values) = query.build(PostgresQueryBuilder);
let params = convert_values(values);
let row = self
.connection
.query_one(&sql, params)
.await
.map_err(|e| AdminError::DatabaseError(e.to_string()))?;
let count = extract_count_from_row(&row.data)?;
Ok(count)
}
}
#[doc(hidden)]
pub fn extract_count_from_row(data: &serde_json::Value) -> AdminResult<u64> {
if let Some(count_value) = data.get("count") {
return count_value.as_i64().map(|v| v as u64).ok_or_else(|| {
AdminError::DatabaseError(format!(
"COUNT query returned non-integer value: {}",
count_value
))
});
}
if let Some(obj) = data.as_object() {
let available_keys: Vec<&String> = obj.keys().collect();
return Err(AdminError::DatabaseError(format!(
"COUNT query result missing 'count' key, available keys: {:?}",
available_keys
)));
}
Err(AdminError::DatabaseError(format!(
"COUNT query returned unexpected data format: {}",
data
)))
}
#[async_trait]
impl Injectable for AdminDatabase {
async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
if let Some(db) = ctx.get_singleton::<Self>() {
return Ok((*db).clone());
}
let conn = ctx.get_singleton::<DatabaseConnection>().ok_or_else(|| {
reinhardt_di::DiError::NotRegistered {
type_name: "AdminDatabase".into(),
hint: "DatabaseConnection must be registered as a singleton. \
Use InjectionContextBuilder::singleton(db_connection) during setup."
.into(),
}
})?;
let db = AdminDatabase::from_arc(conn);
ctx.set_singleton(db.clone());
Ok(db)
}
}
fn __register_admin_database(registry: &reinhardt_di::DependencyRegistry) {
registry.register::<AdminDatabase>(
reinhardt_di::DependencyScope::Singleton,
reinhardt_di::InjectableFactory::<AdminDatabase>::new(),
);
}
reinhardt_di::inventory::submit! {
reinhardt_di::InjectableRegistration::new(
__register_admin_database
)
}
#[cfg(test)]
mod tests {
use super::*;
use reinhardt_db::orm::annotation::Expression;
use reinhardt_db::orm::expressions::{F, OuterRef};
use rstest::rstest;
#[rstest]
fn test_escape_like_pattern_percent() {
let input = "100%";
let result = escape_like_pattern(input);
assert_eq!(result, "100\\%");
}
#[rstest]
fn test_escape_like_pattern_underscore() {
let input = "user_name";
let result = escape_like_pattern(input);
assert_eq!(result, "user\\_name");
}
#[rstest]
fn test_escape_like_pattern_backslash() {
let input = "path\\to";
let result = escape_like_pattern(input);
assert_eq!(result, "path\\\\to");
}
#[rstest]
fn test_escape_like_pattern_combined() {
let input = "100%_done";
let result = escape_like_pattern(input);
assert_eq!(result, "100\\%\\_done");
}
#[rstest]
fn test_escape_like_pattern_no_special_chars() {
let input = "normal text";
let result = escape_like_pattern(input);
assert_eq!(result, "normal text");
}
#[rstest]
#[case("%wildcard%", "\\%wildcard\\%")]
#[case("under_score", "under\\_score")]
#[case("back\\slash", "back\\\\slash")]
#[case("%_%", "\\%\\_\\%")]
fn test_escape_like_pattern_sanitizes_special_chars(
#[case] input: &str,
#[case] expected: &str,
) {
let escaped = escape_like_pattern(input);
assert_eq!(
escaped, expected,
"input={input:?} was not correctly escaped"
);
}
#[test]
fn test_build_composite_single_condition() {
let filter = Filter::new(
"name".to_string(),
FilterOperator::Eq,
FilterValue::String("Alice".to_string()),
);
let condition = FilterCondition::Single(filter);
let result = build_composite_filter_condition(&condition);
assert!(result.is_ok());
let result = result.unwrap();
assert!(result.is_some());
let cond = result.unwrap();
let query = Query::select()
.from(Alias::new("users"))
.column(ColumnRef::Asterisk)
.cond_where(cond)
.to_string(PostgresQueryBuilder);
assert!(query.contains("\"name\""));
assert!(query.contains("'Alice'"));
}
#[test]
fn test_build_composite_or_condition() {
let filter1 = Filter::new(
"name".to_string(),
FilterOperator::Contains,
FilterValue::String("Alice".to_string()),
);
let filter2 = Filter::new(
"email".to_string(),
FilterOperator::Contains,
FilterValue::String("alice".to_string()),
);
let condition = FilterCondition::Or(vec![
FilterCondition::Single(filter1),
FilterCondition::Single(filter2),
]);
let result = build_composite_filter_condition(&condition);
assert!(result.is_ok());
let result = result.unwrap();
assert!(result.is_some());
let cond = result.unwrap();
let query = Query::select()
.from(Alias::new("users"))
.column(ColumnRef::Asterisk)
.cond_where(cond)
.to_string(PostgresQueryBuilder);
assert!(query.contains("\"name\""));
assert!(query.contains("\"email\""));
assert!(query.contains("OR"));
}
#[test]
fn test_build_composite_and_condition() {
let filter1 = Filter::new(
"is_active".to_string(),
FilterOperator::Eq,
FilterValue::Boolean(true),
);
let filter2 = Filter::new(
"is_staff".to_string(),
FilterOperator::Eq,
FilterValue::Boolean(true),
);
let condition = FilterCondition::And(vec![
FilterCondition::Single(filter1),
FilterCondition::Single(filter2),
]);
let result = build_composite_filter_condition(&condition);
assert!(result.is_ok());
let result = result.unwrap();
assert!(result.is_some());
let cond = result.unwrap();
let query = Query::select()
.from(Alias::new("users"))
.column(ColumnRef::Asterisk)
.cond_where(cond)
.to_string(PostgresQueryBuilder);
assert!(query.contains("\"is_active\""));
assert!(query.contains("\"is_staff\""));
assert!(query.contains("AND"));
}
#[test]
fn test_build_composite_nested_condition() {
let filter_name = Filter::new(
"name".to_string(),
FilterOperator::Contains,
FilterValue::String("Alice".to_string()),
);
let filter_email = Filter::new(
"email".to_string(),
FilterOperator::Contains,
FilterValue::String("alice".to_string()),
);
let filter_active = Filter::new(
"is_active".to_string(),
FilterOperator::Eq,
FilterValue::Boolean(true),
);
let or_condition = FilterCondition::Or(vec![
FilterCondition::Single(filter_name),
FilterCondition::Single(filter_email),
]);
let and_condition =
FilterCondition::And(vec![or_condition, FilterCondition::Single(filter_active)]);
let result = build_composite_filter_condition(&and_condition);
assert!(result.is_ok());
let result = result.unwrap();
assert!(result.is_some());
let cond = result.unwrap();
let query = Query::select()
.from(Alias::new("users"))
.column(ColumnRef::Asterisk)
.cond_where(cond)
.to_string(PostgresQueryBuilder);
assert!(query.contains("\"name\""));
assert!(query.contains("\"email\""));
assert!(query.contains("\"is_active\""));
assert!(query.contains("OR"));
assert!(query.contains("AND"));
}
#[test]
fn test_build_composite_empty_or() {
let condition = FilterCondition::Or(vec![]);
let result = build_composite_filter_condition(&condition);
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_build_composite_empty_and() {
let condition = FilterCondition::And(vec![]);
let result = build_composite_filter_condition(&condition);
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_build_composite_depth_overflow_returns_error() {
let base_filter = Filter::new(
"name".to_string(),
FilterOperator::Eq,
FilterValue::String("Alice".to_string()),
);
let mut condition = FilterCondition::Single(base_filter);
for _ in 0..=MAX_FILTER_DEPTH {
condition = FilterCondition::And(vec![condition]);
}
let result = build_composite_filter_condition(&condition);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, AdminError::ValidationError(_)));
let err_msg = err.to_string();
assert!(
err_msg.contains("exceeded maximum depth"),
"Error message should mention exceeded depth, got: {}",
err_msg
);
}
#[test]
fn test_build_single_filter_expr_field_ref_eq() {
let filter = Filter::new(
"price".to_string(),
FilterOperator::Eq,
FilterValue::FieldRef(F::new("discount_price")),
);
let result = build_single_filter_expr(&filter);
assert!(result.is_some());
let query = Query::select()
.from(Alias::new("products"))
.column(ColumnRef::Asterisk)
.cond_where(Condition::all().add(result.unwrap()))
.to_string(PostgresQueryBuilder);
assert!(query.contains("\"price\""));
assert!(query.contains("\"discount_price\""));
}
#[test]
fn test_build_single_filter_expr_field_ref_gt() {
let filter = Filter::new(
"price".to_string(),
FilterOperator::Gt,
FilterValue::FieldRef(F::new("cost")),
);
let result = build_single_filter_expr(&filter);
assert!(result.is_some());
}
#[test]
fn test_build_single_filter_expr_field_ref_all_operators() {
let operators = [
FilterOperator::Eq,
FilterOperator::Ne,
FilterOperator::Gt,
FilterOperator::Gte,
FilterOperator::Lt,
FilterOperator::Lte,
];
for op in operators {
let filter = Filter::new(
"field_a".to_string(),
op.clone(),
FilterValue::FieldRef(F::new("field_b")),
);
let result = build_single_filter_expr(&filter);
assert!(
result.is_some(),
"FieldRef with {:?} should produce Some",
op
);
}
}
#[test]
fn test_build_single_filter_expr_outer_ref() {
let filter = Filter::new(
"author_id".to_string(),
FilterOperator::Eq,
FilterValue::OuterRef(OuterRef::new("authors.id")),
);
let result = build_single_filter_expr(&filter);
assert!(result.is_some());
let query = Query::select()
.from(Alias::new("books"))
.column(ColumnRef::Asterisk)
.cond_where(Condition::all().add(result.unwrap()))
.to_string(PostgresQueryBuilder);
assert!(query.contains("author_id"));
assert!(query.contains("authors.id"));
}
#[test]
fn test_build_single_filter_expr_outer_ref_all_operators() {
let operators = [
FilterOperator::Eq,
FilterOperator::Ne,
FilterOperator::Gt,
FilterOperator::Gte,
FilterOperator::Lt,
FilterOperator::Lte,
];
for op in operators {
let filter = Filter::new(
"child_id".to_string(),
op.clone(),
FilterValue::OuterRef(OuterRef::new("parent.id")),
);
let result = build_single_filter_expr(&filter);
assert!(
result.is_some(),
"OuterRef with {:?} should produce Some",
op
);
}
}
#[test]
fn test_build_single_filter_expr_expression() {
use reinhardt_db::orm::annotation::{AnnotationValue, Value};
let expr = Expression::Multiply(
Box::new(AnnotationValue::Field(F::new("cost"))),
Box::new(AnnotationValue::Value(Value::Int(2))),
);
let filter = Filter::new(
"price".to_string(),
FilterOperator::Gt,
FilterValue::Expression(expr),
);
let result = build_single_filter_expr(&filter);
assert!(result.is_some());
}
#[test]
fn test_build_single_filter_expr_expression_all_operators() {
use reinhardt_db::orm::annotation::{AnnotationValue, Value as OrmValue};
let operators = [
FilterOperator::Eq,
FilterOperator::Ne,
FilterOperator::Gt,
FilterOperator::Gte,
FilterOperator::Lt,
FilterOperator::Lte,
];
for op in operators {
let expr = Expression::Add(
Box::new(AnnotationValue::Field(F::new("base"))),
Box::new(AnnotationValue::Value(OrmValue::Int(10))),
);
let filter = Filter::new(
"total".to_string(),
op.clone(),
FilterValue::Expression(expr),
);
let result = build_single_filter_expr(&filter);
assert!(
result.is_some(),
"Expression with {:?} should produce Some",
op
);
}
}
#[test]
fn test_filter_value_to_sea_value_field_ref_fallback() {
let value = FilterValue::FieldRef(F::new("test_field"));
let sea_value = filter_value_to_sea_value(&value);
match sea_value {
Value::String(Some(s)) => assert_eq!(s.as_str(), "test_field"),
_ => panic!("Expected String value"),
}
}
#[test]
fn test_filter_value_to_sea_value_outer_ref_fallback() {
let value = FilterValue::OuterRef(OuterRef::new("outer.field"));
let sea_value = filter_value_to_sea_value(&value);
match sea_value {
Value::String(Some(s)) => assert_eq!(s.as_str(), "outer.field"),
_ => panic!("Expected String value"),
}
}
#[test]
fn test_filter_value_to_sea_value_expression_fallback() {
use reinhardt_db::orm::annotation::{AnnotationValue, Value as OrmValue};
let expr = Expression::Add(
Box::new(AnnotationValue::Field(F::new("a"))),
Box::new(AnnotationValue::Value(OrmValue::Int(1))),
);
let value = FilterValue::Expression(expr);
let sea_value = filter_value_to_sea_value(&value);
match sea_value {
Value::String(Some(s)) => {
assert!(s.contains("a"), "SQL should contain field name 'a'");
assert!(s.contains("1"), "SQL should contain value '1'");
}
_ => panic!("Expected String value"),
}
}
#[rstest]
fn test_insert_values_mismatch_returns_error_not_panic() {
let mut query = Query::insert()
.into_table(Alias::new("test_table"))
.to_owned();
let columns = vec![Alias::new("col1"), Alias::new("col2"), Alias::new("col3")];
let values = vec![Value::String(Some(Box::new("val1".to_string())))];
let result = query.columns(columns).values(values);
assert!(result.is_err());
}
#[rstest]
fn test_insert_values_matching_count_succeeds() {
let mut query = Query::insert()
.into_table(Alias::new("test_table"))
.to_owned();
let columns = vec![Alias::new("col1"), Alias::new("col2")];
let values = vec![
Value::String(Some(Box::new("val1".to_string()))),
Value::String(Some(Box::new("val2".to_string()))),
];
let result = query.columns(columns).values(values);
assert!(result.is_ok());
}
#[test]
fn test_outer_ref_filter_uses_safe_column_api() {
let filter = Filter::new(
"author_id".to_string(),
FilterOperator::Eq,
FilterValue::OuterRef(OuterRef::new("users.id")),
);
let result = build_single_filter_expr(&filter);
assert!(result.is_some());
let expr = result.unwrap();
let query = Query::select()
.from(Alias::new("books"))
.column(ColumnRef::Asterisk)
.cond_where(Condition::all().add(expr))
.to_string(PostgresQueryBuilder);
assert!(
query.contains("\"author_id\""),
"Column should be properly quoted: {}",
query
);
}
#[test]
fn test_outer_ref_injection_attempt_is_safely_quoted() {
let filter = Filter::new(
"id".to_string(),
FilterOperator::Eq,
FilterValue::OuterRef(OuterRef::new("id; DROP TABLE users; --")),
);
let result = build_single_filter_expr(&filter);
assert!(result.is_some());
let expr = result.unwrap();
let query = Query::select()
.from(Alias::new("items"))
.column(ColumnRef::Asterisk)
.cond_where(Condition::all().add(expr))
.to_string(PostgresQueryBuilder);
assert!(
query.contains("\"id; DROP TABLE users; --\""),
"Injection payload should be enclosed in double quotes as identifier: {}",
query
);
let unquoted_parts: Vec<&str> = query.split('"').enumerate()
.filter(|(i, _)| i % 2 == 0) .map(|(_, s)| s)
.collect();
let unquoted_sql = unquoted_parts.join("");
assert!(
!unquoted_sql.contains(';'),
"No semicolons should appear outside quoted identifiers: {}",
query
);
}
#[test]
fn test_expression_filter_uses_safe_api() {
use reinhardt_db::orm::annotation::AnnotationValue;
let expr = Expression::Multiply(
Box::new(AnnotationValue::Field(F::new("unit_price"))),
Box::new(AnnotationValue::Field(F::new("quantity"))),
);
let filter = Filter::new(
"total".to_string(),
FilterOperator::Eq,
FilterValue::Expression(expr),
);
let result = build_single_filter_expr(&filter);
assert!(result.is_some());
let sea_expr = result.unwrap();
let query = Query::select()
.from(Alias::new("orders"))
.column(ColumnRef::Asterisk)
.cond_where(Condition::all().add(sea_expr))
.to_string(PostgresQueryBuilder);
assert!(
query.contains("\"total\""),
"Left side should be quoted: {}",
query
);
}
#[test]
fn test_expression_filter_with_literal_value() {
use reinhardt_db::orm::annotation::{AnnotationValue, Value as OrmValue};
let expr = Expression::Add(
Box::new(AnnotationValue::Field(F::new("price"))),
Box::new(AnnotationValue::Value(OrmValue::Int(100))),
);
let filter = Filter::new(
"adjusted_price".to_string(),
FilterOperator::Gt,
FilterValue::Expression(expr),
);
let result = build_single_filter_expr(&filter);
assert!(result.is_some());
}
#[test]
fn test_outer_ref_all_operators_use_safe_api() {
let operators = vec![
FilterOperator::Eq,
FilterOperator::Ne,
FilterOperator::Gt,
FilterOperator::Gte,
FilterOperator::Lt,
FilterOperator::Lte,
];
for op in operators {
let filter = Filter::new(
"field_a".to_string(),
op.clone(),
FilterValue::OuterRef(OuterRef::new("field_b")),
);
let result = build_single_filter_expr(&filter);
assert!(
result.is_some(),
"OuterRef with {:?} should produce Some",
op
);
}
}
#[test]
fn test_coalesce_expression_uses_safe_parameterized_api() {
use reinhardt_db::orm::annotation::{AnnotationValue, Value as OrmValue};
let expr = Expression::Coalesce(vec![
AnnotationValue::Field(F::new("field_a")),
AnnotationValue::Value(OrmValue::Int(0)),
]);
let filter = Filter::new(
"result".to_string(),
FilterOperator::Gt,
FilterValue::Expression(expr),
);
let result = build_single_filter_expr(&filter);
assert!(result.is_some());
let sea_expr = result.unwrap();
let query = Query::select()
.from(Alias::new("items"))
.column(ColumnRef::Asterisk)
.cond_where(Condition::all().add(sea_expr))
.to_string(PostgresQueryBuilder);
assert!(
query.contains("COALESCE"),
"Should contain COALESCE function: {}",
query
);
assert!(
query.contains("\"result\""),
"Left side should be quoted: {}",
query
);
}
#[test]
fn test_case_expression_uses_safe_api() {
use reinhardt_db::orm::annotation::{
AnnotationValue, Value as OrmValue, When as AnnotWhen,
};
use reinhardt_db::orm::expressions::Q;
let expr = Expression::Case {
whens: vec![AnnotWhen::new(
Q::new("status", "=", "'active'"),
AnnotationValue::Value(OrmValue::Int(1)),
)],
default: Some(Box::new(AnnotationValue::Value(OrmValue::Int(0)))),
};
let filter = Filter::new(
"priority".to_string(),
FilterOperator::Eq,
FilterValue::Expression(expr),
);
let result = build_single_filter_expr(&filter);
assert!(result.is_some());
let sea_expr = result.unwrap();
let query = Query::select()
.from(Alias::new("tasks"))
.column(ColumnRef::Asterisk)
.cond_where(Condition::all().add(sea_expr))
.to_string(PostgresQueryBuilder);
assert!(
query.contains("CASE"),
"Should contain CASE keyword: {}",
query
);
assert!(
query.contains("WHEN"),
"Should contain WHEN keyword: {}",
query
);
assert!(
query.contains("ELSE"),
"Should contain ELSE keyword: {}",
query
);
}
#[test]
fn test_empty_coalesce_returns_null() {
let expr = Expression::Coalesce(vec![]);
let result = annotation_expr_to_safe_expr(&expr);
let query = Query::select()
.from(Alias::new("test"))
.column(ColumnRef::Asterisk)
.cond_where(Condition::all().add(result))
.to_string(PostgresQueryBuilder);
assert!(
query.contains("NULL"),
"Empty COALESCE should produce NULL: {}",
query
);
}
#[test]
fn test_aggregate_count_uses_safe_api() {
use reinhardt_db::orm::aggregation::{Aggregate, AggregateFunc};
let agg = Aggregate {
func: AggregateFunc::Count,
field: None,
alias: None,
distinct: false,
};
let result = aggregate_to_safe_expr(&agg);
let query = Query::select()
.from(Alias::new("items"))
.expr(result)
.to_string(PostgresQueryBuilder);
assert!(
query.contains("COUNT(*)"),
"Should contain COUNT(*): {}",
query
);
}
#[test]
fn test_aggregate_sum_field_uses_quoted_identifier() {
use reinhardt_db::orm::aggregation::{Aggregate, AggregateFunc};
let agg = Aggregate {
func: AggregateFunc::Sum,
field: Some("price".to_string()),
alias: None,
distinct: false,
};
let result = aggregate_to_safe_expr(&agg);
let query = Query::select()
.from(Alias::new("orders"))
.expr(result)
.to_string(PostgresQueryBuilder);
assert!(
query.contains("SUM("),
"Should contain SUM function: {}",
query
);
assert!(
query.contains("\"price\""),
"Field name should be quoted: {}",
query
);
}
#[test]
fn test_aggregate_count_distinct_uses_distinct_keyword() {
use reinhardt_db::orm::aggregation::{Aggregate, AggregateFunc};
let agg = Aggregate {
func: AggregateFunc::CountDistinct,
field: Some("category".to_string()),
alias: None,
distinct: false, };
let result = aggregate_to_safe_expr(&agg);
let query = Query::select()
.from(Alias::new("products"))
.expr(result)
.to_string(PostgresQueryBuilder);
assert!(
query.contains("COUNT(DISTINCT"),
"Should contain COUNT(DISTINCT: {}",
query
);
assert!(
query.contains("\"category\""),
"Field name should be quoted: {}",
query
);
}
#[test]
fn test_aggregate_injection_attempt_is_quoted() {
use reinhardt_db::orm::aggregation::{Aggregate, AggregateFunc};
let agg = Aggregate {
func: AggregateFunc::Sum,
field: Some("price); DROP TABLE users; --".to_string()),
alias: None,
distinct: false,
};
let result = aggregate_to_safe_expr(&agg);
let query = Query::select()
.from(Alias::new("orders"))
.expr(result)
.to_string(PostgresQueryBuilder);
assert!(
query.contains("\"price); DROP TABLE users; --\""),
"Injection payload should be enclosed in double quotes: {}",
query
);
}
#[rstest]
fn test_build_composite_and_all_unsupported_returns_none() {
let filter1 = Filter::new(
"field1".to_string(),
FilterOperator::Contains,
FilterValue::Boolean(true),
);
let filter2 = Filter::new(
"field2".to_string(),
FilterOperator::StartsWith,
FilterValue::Integer(5),
);
let condition = FilterCondition::And(vec![
FilterCondition::Single(filter1),
FilterCondition::Single(filter2),
]);
let result = build_composite_filter_condition(&condition);
assert!(result.is_ok());
assert!(
result.unwrap().is_none(),
"And with all unsupported filters should return None"
);
}
#[rstest]
fn test_build_composite_or_all_unsupported_returns_none() {
let filter1 = Filter::new(
"field1".to_string(),
FilterOperator::Contains,
FilterValue::Boolean(true),
);
let filter2 = Filter::new(
"field2".to_string(),
FilterOperator::StartsWith,
FilterValue::Integer(5),
);
let condition = FilterCondition::Or(vec![
FilterCondition::Single(filter1),
FilterCondition::Single(filter2),
]);
let result = build_composite_filter_condition(&condition);
assert!(result.is_ok());
assert!(
result.unwrap().is_none(),
"Or with all unsupported filters should return None"
);
}
#[rstest]
fn test_build_composite_and_mixed_valid_and_unsupported() {
let valid_filter = Filter::new(
"name".to_string(),
FilterOperator::Eq,
FilterValue::String("Alice".to_string()),
);
let unsupported_filter = Filter::new(
"field2".to_string(),
FilterOperator::Contains,
FilterValue::Boolean(true),
);
let condition = FilterCondition::And(vec![
FilterCondition::Single(valid_filter),
FilterCondition::Single(unsupported_filter),
]);
let result = build_composite_filter_condition(&condition);
assert!(result.is_ok());
let cond = result.unwrap();
assert!(
cond.is_some(),
"And with at least one valid filter should return Some"
);
let query = Query::select()
.from(Alias::new("t"))
.column(ColumnRef::Asterisk)
.cond_where(cond.unwrap())
.to_string(PostgresQueryBuilder);
assert!(
query.contains("\"name\""),
"SQL should contain the valid filter field, got: {}",
query
);
assert!(
query.contains("'Alice'"),
"SQL should contain the valid filter value, got: {}",
query
);
}
#[rstest]
fn test_build_composite_or_mixed_valid_and_unsupported() {
let valid_filter = Filter::new(
"email".to_string(),
FilterOperator::Eq,
FilterValue::String("test@example.com".to_string()),
);
let unsupported_filter = Filter::new(
"field2".to_string(),
FilterOperator::StartsWith,
FilterValue::Integer(5),
);
let condition = FilterCondition::Or(vec![
FilterCondition::Single(valid_filter),
FilterCondition::Single(unsupported_filter),
]);
let result = build_composite_filter_condition(&condition);
assert!(result.is_ok());
let cond = result.unwrap();
assert!(
cond.is_some(),
"Or with at least one valid filter should return Some"
);
let query = Query::select()
.from(Alias::new("t"))
.column(ColumnRef::Asterisk)
.cond_where(cond.unwrap())
.to_string(PostgresQueryBuilder);
assert!(
query.contains("\"email\""),
"SQL should contain the valid filter field, got: {}",
query
);
assert!(
query.contains("'test@example.com'"),
"SQL should contain the valid filter value, got: {}",
query
);
}
#[rstest]
fn test_build_filter_condition_all_unsupported_returns_none() {
let filters = vec![
Filter::new(
"field1".to_string(),
FilterOperator::Contains,
FilterValue::Boolean(true),
),
Filter::new(
"field2".to_string(),
FilterOperator::StartsWith,
FilterValue::Integer(5),
),
];
let result = build_filter_condition(&filters);
assert!(
result.is_none(),
"build_filter_condition with all unsupported filters should return None"
);
}
#[rstest]
fn test_extract_count_from_row_with_count_key() {
let data = serde_json::json!({"count": 42});
let result = extract_count_from_row(&data);
assert_eq!(result.unwrap(), 42);
}
#[rstest]
fn test_extract_count_from_row_without_count_key() {
let data = serde_json::json!({"total": 10});
let result = extract_count_from_row(&data);
let err = result.unwrap_err();
assert!(
err.to_string().contains("missing 'count' key"),
"Error should mention missing 'count' key, got: {}",
err
);
}
#[rstest]
fn test_extract_count_from_row_empty_object() {
let data = serde_json::json!({});
let result = extract_count_from_row(&data);
let err = result.unwrap_err();
assert!(
err.to_string().contains("missing 'count' key"),
"Error should mention missing 'count' key, got: {}",
err
);
}
#[rstest]
fn test_extract_count_from_row_non_integer() {
let data = serde_json::json!({"count": "abc"});
let result = extract_count_from_row(&data);
let err = result.unwrap_err();
assert!(
err.to_string().contains("non-integer"),
"Error should mention non-integer value, got: {}",
err
);
}
#[rstest]
fn test_extract_count_from_row_null_data() {
let data = serde_json::Value::Null;
let result = extract_count_from_row(&data);
let err = result.unwrap_err();
assert!(
err.to_string().contains("unexpected data format"),
"Error should mention unexpected data format, got: {}",
err
);
}
#[rstest]
fn test_extract_count_from_row_zero() {
let data = serde_json::json!({"count": 0});
let result = extract_count_from_row(&data);
assert_eq!(result.unwrap(), 0);
}
#[rstest]
#[tokio::test]
async fn test_admin_database_inject_error_hint_mentions_connection() {
let singleton = Arc::new(reinhardt_di::SingletonScope::new());
let ctx = reinhardt_di::InjectionContext::builder(singleton).build();
let result = AdminDatabase::inject(&ctx).await;
assert!(result.is_err());
let err = result.err().unwrap();
assert!(
err.to_string().contains("DatabaseConnection"),
"Error hint should mention DatabaseConnection, got: {}",
err
);
}
#[rstest]
#[tokio::test]
async fn test_admin_database_inject_returns_prebuilt_from_singleton() {
let singleton = Arc::new(reinhardt_di::SingletonScope::new());
let ctx = reinhardt_di::InjectionContext::builder(singleton).build();
let result = AdminDatabase::inject(&ctx).await;
assert!(result.is_err());
let err = result.err().unwrap();
assert!(
err.to_string().contains("DatabaseConnection"),
"Error should mention DatabaseConnection, got: {}",
err
);
}
#[rstest]
fn test_build_single_filter_expr_array_in() {
let filter = Filter::new(
"status".to_string(),
FilterOperator::In,
FilterValue::Array(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
);
let result = build_single_filter_expr(&filter);
assert!(
result.is_some(),
"Array In with non-empty values should return Some"
);
let query = Query::select()
.from(Alias::new("table"))
.column(ColumnRef::Asterisk)
.cond_where(Condition::all().add(result.unwrap()))
.to_string(PostgresQueryBuilder);
assert!(query.contains("IN"), "SQL should contain IN operator");
assert!(query.contains("'a'"), "SQL should contain value 'a'");
assert!(query.contains("'b'"), "SQL should contain value 'b'");
assert!(query.contains("'c'"), "SQL should contain value 'c'");
}
#[rstest]
fn test_build_single_filter_expr_array_not_in() {
let filter = Filter::new(
"status".to_string(),
FilterOperator::NotIn,
FilterValue::Array(vec!["x".to_string(), "y".to_string()]),
);
let result = build_single_filter_expr(&filter);
assert!(
result.is_some(),
"Array NotIn with non-empty values should return Some"
);
let query = Query::select()
.from(Alias::new("table"))
.column(ColumnRef::Asterisk)
.cond_where(Condition::all().add(result.unwrap()))
.to_string(PostgresQueryBuilder);
assert!(
query.contains("NOT IN"),
"SQL should contain NOT IN operator"
);
assert!(query.contains("'x'"), "SQL should contain value 'x'");
assert!(query.contains("'y'"), "SQL should contain value 'y'");
}
#[rstest]
fn test_build_single_filter_expr_array_in_empty() {
let filter = Filter::new(
"status".to_string(),
FilterOperator::In,
FilterValue::Array(vec![]),
);
let result = build_single_filter_expr(&filter);
assert!(
result.is_none(),
"Array In with empty values should return None"
);
}
#[rstest]
fn test_build_single_filter_expr_array_in_single_element() {
let filter = Filter::new(
"category".to_string(),
FilterOperator::In,
FilterValue::Array(vec!["solo".to_string()]),
);
let result = build_single_filter_expr(&filter);
assert!(
result.is_some(),
"Array In with single element should return Some"
);
let query = Query::select()
.from(Alias::new("table"))
.column(ColumnRef::Asterisk)
.cond_where(Condition::all().add(result.unwrap()))
.to_string(PostgresQueryBuilder);
assert!(query.contains("IN"), "SQL should contain IN operator");
assert!(query.contains("'solo'"), "SQL should contain value 'solo'");
}
#[rstest]
fn test_build_single_filter_expr_array_in_special_chars() {
let filter = Filter::new(
"name".to_string(),
FilterOperator::In,
FilterValue::Array(vec!["O'Brien".to_string(), "a;DROP TABLE".to_string()]),
);
let result = build_single_filter_expr(&filter);
assert!(
result.is_some(),
"Array In with special chars should return Some"
);
let query = Query::select()
.from(Alias::new("table"))
.column(ColumnRef::Asterisk)
.cond_where(Condition::all().add(result.unwrap()))
.to_string(PostgresQueryBuilder);
assert!(query.contains("IN"), "SQL should contain IN operator");
assert!(
query.contains("O''Brien"),
"Single quote in value should be escaped, got: {}",
query
);
assert!(
query.contains("'a;DROP TABLE'"),
"SQL injection attempt should be safely quoted as a string literal, got: {}",
query
);
}
#[rstest]
fn test_and_with_all_unsupported_returns_none() {
let unsupported1 = FilterCondition::Single(Filter::new(
"name",
FilterOperator::Contains,
FilterValue::Integer(42),
));
let unsupported2 = FilterCondition::Single(Filter::new(
"email",
FilterOperator::StartsWith,
FilterValue::Integer(99),
));
let condition = FilterCondition::And(vec![unsupported1, unsupported2]);
let result = build_composite_filter_condition(&condition);
assert!(result.is_ok());
let cond = result.unwrap();
assert!(
cond.is_none(),
"And with all unsupported sub-conditions should return None"
);
}
#[rstest]
fn test_or_with_all_unsupported_returns_none() {
let unsupported1 = FilterCondition::Single(Filter::new(
"name",
FilterOperator::Contains,
FilterValue::Integer(42),
));
let unsupported2 = FilterCondition::Single(Filter::new(
"email",
FilterOperator::StartsWith,
FilterValue::Integer(99),
));
let condition = FilterCondition::Or(vec![unsupported1, unsupported2]);
let result = build_composite_filter_condition(&condition);
assert!(result.is_ok());
let cond = result.unwrap();
assert!(
cond.is_none(),
"Or with all unsupported sub-conditions should return None"
);
}
#[rstest]
fn test_and_with_mix_supported_unsupported_keeps_supported() {
let supported = FilterCondition::Single(Filter::new(
"name",
FilterOperator::Eq,
FilterValue::String("Alice".to_string()),
));
let unsupported = FilterCondition::Single(Filter::new(
"email",
FilterOperator::Contains,
FilterValue::Integer(42),
));
let condition = FilterCondition::And(vec![supported, unsupported]);
let result = build_composite_filter_condition(&condition);
assert!(result.is_ok());
let cond = result.unwrap();
assert!(
cond.is_some(),
"And with mix of supported/unsupported should return Some with supported filters"
);
let query = Query::select()
.from(Alias::new("test"))
.column(ColumnRef::Asterisk)
.cond_where(cond.unwrap())
.to_string(PostgresQueryBuilder);
assert!(
query.contains("\"name\""),
"SQL should contain the supported filter field 'name': {}",
query
);
}
#[rstest]
fn test_or_with_one_supported_one_unsupported() {
let supported = FilterCondition::Single(Filter::new(
"status",
FilterOperator::Eq,
FilterValue::String("active".to_string()),
));
let unsupported = FilterCondition::Single(Filter::new(
"count",
FilterOperator::Contains,
FilterValue::Integer(42),
));
let condition = FilterCondition::Or(vec![supported, unsupported]);
let result = build_composite_filter_condition(&condition);
assert!(result.is_ok());
let cond = result.unwrap();
assert!(
cond.is_some(),
"Or with one supported condition should return Some"
);
let query = Query::select()
.from(Alias::new("test"))
.column(ColumnRef::Asterisk)
.cond_where(cond.unwrap())
.to_string(PostgresQueryBuilder);
assert!(
query.contains("\"status\""),
"SQL should contain the supported filter field 'status': {}",
query
);
}
#[rstest]
fn test_extract_count_with_count_key() {
let data = serde_json::json!({"count": 42});
let result = extract_count_from_row(&data);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[rstest]
fn test_extract_count_without_count_key_returns_error() {
let data = serde_json::json!({"total": 42});
let result = extract_count_from_row(&data);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("missing 'count' key"),
"Error should mention missing 'count' key, got: {}",
err
);
}
#[rstest]
fn test_extract_count_with_multiple_keys_no_count_returns_error() {
let data = serde_json::json!({"total": 42, "other": 99});
let result = extract_count_from_row(&data);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("available keys"),
"Error should list available keys, got: {}",
err
);
}
#[rstest]
fn test_extract_count_non_integer_returns_error() {
let data = serde_json::json!({"count": "not_a_number"});
let result = extract_count_from_row(&data);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, AdminError::DatabaseError(_)));
}
#[rstest]
fn test_extract_count_null_returns_error() {
let data = serde_json::json!({"count": null});
let result = extract_count_from_row(&data);
assert!(result.is_err());
}
#[rstest]
fn test_extract_count_empty_object_returns_error() {
let data = serde_json::json!({});
let result = extract_count_from_row(&data);
assert!(result.is_err());
}
#[rstest]
fn test_extract_count_non_object_returns_error() {
let data = serde_json::json!([1, 2, 3]);
let result = extract_count_from_row(&data);
assert!(result.is_err());
}
#[rstest]
fn test_parse_pk_value_integer_falls_back_to_bigint() {
let val = parse_pk_value("nonexistent_table", "id", "42");
assert_eq!(val, Value::BigInt(Some(42)));
}
#[rstest]
fn test_parse_pk_value_uuid_string_without_registry_falls_back_to_string() {
let val = parse_pk_value(
"nonexistent_table",
"id",
"c1a363b1-cc42-4dea-81f0-9dc1cedf0083",
);
assert!(matches!(val, Value::String(Some(_))));
}
#[rstest]
fn test_parse_pk_value_non_numeric_string_falls_back_to_string() {
let val = parse_pk_value("nonexistent_table", "id", "hello-world");
assert!(matches!(val, Value::String(Some(_))));
}
#[rstest]
fn test_parse_pk_value_negative_integer() {
let val = parse_pk_value("nonexistent_table", "id", "-1");
assert_eq!(val, Value::BigInt(Some(-1)));
}
#[rstest]
fn test_parse_pk_value_zero() {
let val = parse_pk_value("nonexistent_table", "id", "0");
assert_eq!(val, Value::BigInt(Some(0)));
}
}