use crate::orm::{DatabaseConnection, QueryRow};
use reinhardt_core::exception::Result;
use reinhardt_query::prelude::{
Alias, Expr, ExprTrait, Func, OnConflict, PostgresQueryBuilder, Query, QueryBuilder,
};
use std::fmt::Display;
use std::marker::PhantomData;
pub struct ManyToManyManager<S, T, PK> {
source_pk: PK,
through_table: String,
source_field: String,
target_field: String,
_phantom_s: PhantomData<S>,
_phantom_t: PhantomData<T>,
}
impl<S, T, PK> ManyToManyManager<S, T, PK>
where
PK: Display + Clone,
{
pub fn new(
source_pk: PK,
through_table: String,
source_field: String,
target_field: String,
) -> Self {
Self {
source_pk,
through_table,
source_field,
target_field,
_phantom_s: PhantomData,
_phantom_t: PhantomData,
}
}
pub async fn add_with_db<TPK>(&self, conn: &DatabaseConnection, target_pk: TPK) -> Result<()>
where
TPK: Display,
{
let mut stmt = Query::insert();
stmt.into_table(Alias::new(&self.through_table))
.columns([
Alias::new(&self.source_field),
Alias::new(&self.target_field),
])
.values_panic([self.source_pk.to_string(), target_pk.to_string()])
.on_conflict(
OnConflict::columns([
Alias::new(&self.source_field),
Alias::new(&self.target_field),
])
.do_nothing()
.to_owned(),
);
let pg = PostgresQueryBuilder::new();
let (sql, values) = pg.build_insert(&stmt);
let params = crate::orm::execution::convert_values(values);
conn.execute(&sql, params).await?;
Ok(())
}
pub async fn remove_with_db<TPK>(&self, conn: &DatabaseConnection, target_pk: TPK) -> Result<()>
where
TPK: Display,
{
let mut stmt = Query::delete();
stmt.from_table(Alias::new(&self.through_table))
.and_where(Expr::col(Alias::new(&self.source_field)).eq(self.source_pk.to_string()))
.and_where(Expr::col(Alias::new(&self.target_field)).eq(target_pk.to_string()));
let pg = PostgresQueryBuilder::new();
let (sql, _) = pg.build_delete(&stmt);
conn.execute(&sql, vec![]).await?;
Ok(())
}
pub async fn contains_with_db<TPK>(
&self,
conn: &DatabaseConnection,
target_pk: TPK,
) -> Result<bool>
where
TPK: Display,
{
let mut stmt = Query::select();
stmt.from(Alias::new(&self.through_table))
.expr(Expr::asterisk())
.and_where(Expr::col(Alias::new(&self.source_field)).eq(self.source_pk.to_string()))
.and_where(Expr::col(Alias::new(&self.target_field)).eq(target_pk.to_string()));
let pg = PostgresQueryBuilder::new();
let (sql, _) = pg.build_select(&stmt);
let rows = conn.query(&sql, vec![]).await?;
Ok(!rows.is_empty())
}
pub async fn all_with_db(
&self,
conn: &DatabaseConnection,
target_table: &str,
target_pk_field: &str,
) -> Result<Vec<QueryRow>> {
let mut stmt = Query::select();
stmt.from(Alias::new(&self.through_table))
.inner_join(
Alias::new(target_table),
Expr::col((Alias::new(&self.through_table), Alias::new(&self.target_field)))
.equals((Alias::new(target_table), Alias::new(target_pk_field))),
)
.expr(Expr::asterisk())
.and_where(
Expr::col((Alias::new(&self.through_table), Alias::new(&self.source_field)))
.eq(self.source_pk.to_string()),
);
let pg = PostgresQueryBuilder::new();
let (sql, _) = pg.build_select(&stmt);
conn.query(&sql, vec![])
.await
.map_err(|e| reinhardt_core::exception::Error::Database(e.to_string()))
}
pub async fn clear_with_db(&self, conn: &DatabaseConnection) -> Result<()> {
let mut stmt = Query::delete();
stmt.from_table(Alias::new(&self.through_table))
.and_where(Expr::col(Alias::new(&self.source_field)).eq(self.source_pk.to_string()));
let pg = PostgresQueryBuilder::new();
let (sql, _) = pg.build_delete(&stmt);
conn.execute(&sql, vec![]).await?;
Ok(())
}
pub async fn count_with_db(&self, conn: &DatabaseConnection) -> Result<usize> {
let mut stmt = Query::select();
stmt.from(Alias::new(&self.through_table))
.expr_as(
Func::count(Expr::asterisk().into_simple_expr()),
Alias::new("count"),
)
.and_where(Expr::col(Alias::new(&self.source_field)).eq(self.source_pk.to_string()));
let pg = PostgresQueryBuilder::new();
let (sql, _) = pg.build_select(&stmt);
let row = conn.query_one(&sql, vec![]).await?;
let count_value = row
.get::<i64>("count")
.or_else(|| row.get::<i64>("COUNT"))
.ok_or_else(|| {
reinhardt_core::exception::Error::Database(
"Failed to extract count value from query result".to_string(),
)
})?;
Ok(count_value as usize)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_manager_creation() {
let manager: ManyToManyManager<(), (), i64> = ManyToManyManager::new(
42,
"user_groups".to_string(),
"user_id".to_string(),
"group_id".to_string(),
);
assert_eq!(manager.through_table, "user_groups");
assert_eq!(manager.source_field, "user_id");
assert_eq!(manager.target_field, "group_id");
}
}