use super::{SerializerError, ValidatorError};
use async_trait::async_trait;
use reinhardt_db::orm::{
Model,
transaction::{Transaction, TransactionScope, transaction},
};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value;
use std::collections::HashMap;
use std::marker::PhantomData;
#[derive(Debug)]
pub struct NestedSaveContext {
pub transaction: Option<Transaction>,
pub parent_data: HashMap<String, Value>,
pub depth: usize,
pub max_depth: usize,
}
impl NestedSaveContext {
pub fn new() -> Self {
Self {
transaction: None,
parent_data: HashMap::new(),
depth: 0,
max_depth: 10,
}
}
pub fn with_transaction(mut self, transaction: Transaction) -> Self {
self.transaction = Some(transaction);
self
}
pub fn with_parent_data(mut self, key: String, value: Value) -> Self {
self.parent_data.insert(key, value);
self
}
pub fn with_max_depth(mut self, max_depth: usize) -> Self {
self.max_depth = max_depth;
self
}
pub fn increment_depth(&mut self) -> Result<(), SerializerError> {
self.depth += 1;
if self.depth > self.max_depth {
return Err(SerializerError::Validation(ValidatorError::Custom {
message: format!("Maximum nesting depth {} exceeded", self.max_depth),
}));
}
Ok(())
}
pub fn child_context(&self) -> Result<Self, SerializerError> {
let child = Self {
transaction: None, parent_data: self.parent_data.clone(),
depth: self.depth + 1,
max_depth: self.max_depth,
};
if child.depth > child.max_depth {
return Err(SerializerError::Validation(ValidatorError::Custom {
message: format!("Maximum nesting depth {} exceeded", self.max_depth),
}));
}
Ok(child)
}
pub fn get_parent_value(&self, key: &str) -> Option<&Value> {
self.parent_data.get(key)
}
pub async fn with_scope<F, Fut, T>(&self, f: F) -> Result<T, SerializerError>
where
F: FnOnce(&mut TransactionScope) -> Fut,
Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>>,
{
if self.depth == 0 {
TransactionHelper::with_transaction(f).await
} else {
TransactionHelper::savepoint(self.depth, f).await
}
}
}
impl Default for NestedSaveContext {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
pub trait NestedSerializerSave
where
Self: Sized,
{
type Model: Model + Serialize + DeserializeOwned + Clone + Send + Sync;
type ParentModel: Model + Send + Sync;
async fn create_nested(
data: Value,
context: &mut NestedSaveContext,
) -> Result<Self::Model, SerializerError>;
async fn update_nested(
instance: Self::Model,
data: Value,
context: &mut NestedSaveContext,
) -> Result<Self::Model, SerializerError>;
fn resolve_parent_fk(
data: &mut Value,
context: &NestedSaveContext,
fk_field: &str,
parent_pk_field: &str,
) -> Result<(), SerializerError> {
if let Some(parent_pk) = context.get_parent_value(parent_pk_field) {
if let Value::Object(map) = data {
map.insert(fk_field.to_string(), parent_pk.clone());
}
} else {
return Err(SerializerError::Validation(ValidatorError::RequiredField {
field_name: parent_pk_field.to_string(),
message: format!(
"Parent primary key '{}' not found in context",
parent_pk_field
),
}));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ManyToManyManager<T, R>
where
T: Model,
R: Model,
{
_source: PhantomData<T>,
_target: PhantomData<R>,
pub junction_table: String,
pub source_fk: String,
pub target_fk: String,
}
impl<T, R> ManyToManyManager<T, R>
where
T: Model,
R: Model,
{
pub fn new(
junction_table: impl Into<String>,
source_fk: impl Into<String>,
target_fk: impl Into<String>,
) -> Self {
Self {
_source: PhantomData,
_target: PhantomData,
junction_table: junction_table.into(),
source_fk: source_fk.into(),
target_fk: target_fk.into(),
}
}
pub async fn add_bulk(
&self,
source_id: &T::PrimaryKey,
target_ids: Vec<R::PrimaryKey>,
) -> Result<(), SerializerError>
where
T::PrimaryKey: std::fmt::Display,
R::PrimaryKey: std::fmt::Display,
{
use reinhardt_db::orm::manager::get_connection;
if target_ids.is_empty() {
return Ok(());
}
let values: Vec<String> = target_ids
.iter()
.map(|target_id| format!("({}, {})", source_id, target_id))
.collect();
let query = format!(
"INSERT INTO {} ({}, {}) VALUES {}",
self.junction_table,
self.source_fk,
self.target_fk,
values.join(", ")
);
let conn = get_connection().await.map_err(|e| SerializerError::Other {
message: format!("Failed to get connection: {}", e),
})?;
conn.execute(&query, vec![])
.await
.map_err(|e| SerializerError::Other {
message: format!("Failed to add M2M relationships: {}", e),
})?;
Ok(())
}
pub async fn remove_bulk(
&self,
source_id: &T::PrimaryKey,
target_ids: Vec<R::PrimaryKey>,
) -> Result<(), SerializerError>
where
T::PrimaryKey: std::fmt::Display,
R::PrimaryKey: std::fmt::Display,
{
use reinhardt_db::orm::manager::get_connection;
if target_ids.is_empty() {
return Ok(());
}
let target_ids_str: Vec<String> = target_ids.iter().map(|id| id.to_string()).collect();
let query = format!(
"DELETE FROM {} WHERE {} = {} AND {} IN ({})",
self.junction_table,
self.source_fk,
source_id,
self.target_fk,
target_ids_str.join(", ")
);
let conn = get_connection().await.map_err(|e| SerializerError::Other {
message: format!("Failed to get connection: {}", e),
})?;
conn.execute(&query, vec![])
.await
.map_err(|e| SerializerError::Other {
message: format!("Failed to remove M2M relationships: {}", e),
})?;
Ok(())
}
pub async fn set(
&self,
source_id: &T::PrimaryKey,
target_ids: Vec<R::PrimaryKey>,
) -> Result<(), SerializerError>
where
T::PrimaryKey: std::fmt::Display,
R::PrimaryKey: std::fmt::Display,
{
self.clear(source_id).await?;
self.add_bulk(source_id, target_ids).await?;
Ok(())
}
pub async fn clear(&self, source_id: &T::PrimaryKey) -> Result<(), SerializerError>
where
T::PrimaryKey: std::fmt::Display,
{
use reinhardt_db::orm::manager::get_connection;
let query = format!(
"DELETE FROM {} WHERE {} = {}",
self.junction_table, self.source_fk, source_id
);
let conn = get_connection().await.map_err(|e| SerializerError::Other {
message: format!("Failed to get connection: {}", e),
})?;
conn.execute(&query, vec![])
.await
.map_err(|e| SerializerError::Other {
message: format!("Failed to clear M2M relationships: {}", e),
})?;
Ok(())
}
}
pub struct TransactionHelper;
impl TransactionHelper {
pub async fn with_transaction<F, Fut, T>(f: F) -> Result<T, SerializerError>
where
F: FnOnce(&mut TransactionScope) -> Fut,
Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>>,
{
use reinhardt_db::orm::manager::get_connection;
let conn = get_connection().await.map_err(|e| SerializerError::Other {
message: format!("Failed to get connection: {}", e),
})?;
let wrapped_f = |tx: &mut TransactionScope| {
let fut = f(tx);
async move {
match fut.await {
Ok(value) => Ok(value),
Err(e) => Err(anyhow::anyhow!("{}", e)),
}
}
};
transaction(&conn, wrapped_f)
.await
.map_err(|e| SerializerError::Other {
message: format!("Transaction failed: {}", e),
})
}
pub async fn savepoint<F, Fut, T>(depth: usize, f: F) -> Result<T, SerializerError>
where
F: FnOnce(&mut TransactionScope) -> Fut,
Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>>,
{
use reinhardt_db::orm::manager::get_connection;
let conn = get_connection().await.map_err(|e| SerializerError::Other {
message: format!("Failed to get connection: {}", e),
})?;
let mut tx = TransactionScope::begin(&conn)
.await
.map_err(|e| SerializerError::Other {
message: format!("Failed to begin transaction: {}", e),
})?;
let savepoint_name = format!("nested_save_sp_{}", depth);
tx.savepoint(&savepoint_name)
.await
.map_err(|e| SerializerError::Other {
message: format!("Failed to create savepoint: {}", e),
})?;
match f(&mut tx).await {
Ok(result) => {
tx.release_savepoint(&savepoint_name).await.map_err(|e| {
SerializerError::Other {
message: format!("Failed to release savepoint: {}", e),
}
})?;
tx.commit().await.map_err(|e| SerializerError::Other {
message: format!("Failed to commit transaction: {}", e),
})?;
Ok(result)
}
Err(e) => {
let _ = tx.rollback_to_savepoint(&savepoint_name).await;
let _ = tx.rollback().await;
Err(SerializerError::Other {
message: format!("Savepoint operation failed: {}", e),
})
}
}
}
}