#![allow(dead_code)]
use std::fmt::Debug;
use std::marker::PhantomData;
use crate::error::{QueryError, QueryResult};
use crate::filter::{Filter, FilterValue};
use crate::sql::quote_identifier;
use crate::traits::{Model, QueryEngine};
#[derive(Debug, Clone)]
pub enum NestedWrite<T: Model> {
Create(Vec<NestedCreateData<T>>),
CreateOrConnect(Vec<NestedCreateOrConnectData<T>>),
Connect(Vec<Filter>),
Disconnect(Vec<Filter>),
Set(Vec<Filter>),
Delete(Vec<Filter>),
Update(Vec<NestedUpdateData<T>>),
Upsert(Vec<NestedUpsertData<T>>),
UpdateMany(NestedUpdateManyData<T>),
DeleteMany(Filter),
}
impl<T: Model> NestedWrite<T> {
pub fn create(data: NestedCreateData<T>) -> Self {
Self::Create(vec![data])
}
pub fn create_many(data: Vec<NestedCreateData<T>>) -> Self {
Self::Create(data)
}
pub fn connect_one(filter: impl Into<Filter>) -> Self {
Self::Connect(vec![filter.into()])
}
pub fn connect(filters: Vec<impl Into<Filter>>) -> Self {
Self::Connect(filters.into_iter().map(Into::into).collect())
}
pub fn disconnect_one(filter: impl Into<Filter>) -> Self {
Self::Disconnect(vec![filter.into()])
}
pub fn disconnect(filters: Vec<impl Into<Filter>>) -> Self {
Self::Disconnect(filters.into_iter().map(Into::into).collect())
}
pub fn set(filters: Vec<impl Into<Filter>>) -> Self {
Self::Set(filters.into_iter().map(Into::into).collect())
}
pub fn delete(filters: Vec<impl Into<Filter>>) -> Self {
Self::Delete(filters.into_iter().map(Into::into).collect())
}
pub fn delete_many(filter: impl Into<Filter>) -> Self {
Self::DeleteMany(filter.into())
}
}
#[derive(Debug, Clone)]
pub struct NestedCreateData<T: Model> {
pub data: Vec<(String, FilterValue)>,
_model: PhantomData<T>,
}
impl<T: Model> NestedCreateData<T> {
pub fn new(data: Vec<(String, FilterValue)>) -> Self {
Self {
data,
_model: PhantomData,
}
}
pub fn from_pairs(
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
) -> Self {
Self::new(
pairs
.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.collect(),
)
}
}
impl<T: Model> Default for NestedCreateData<T> {
fn default() -> Self {
Self::new(Vec::new())
}
}
#[derive(Debug, Clone)]
pub struct NestedCreateOrConnectData<T: Model> {
pub filter: Filter,
pub create: NestedCreateData<T>,
}
impl<T: Model> NestedCreateOrConnectData<T> {
pub fn new(filter: impl Into<Filter>, create: NestedCreateData<T>) -> Self {
Self {
filter: filter.into(),
create,
}
}
}
#[derive(Debug, Clone)]
pub struct NestedUpdateData<T: Model> {
pub filter: Filter,
pub data: Vec<(String, FilterValue)>,
_model: PhantomData<T>,
}
impl<T: Model> NestedUpdateData<T> {
pub fn new(filter: impl Into<Filter>, data: Vec<(String, FilterValue)>) -> Self {
Self {
filter: filter.into(),
data,
_model: PhantomData,
}
}
pub fn from_pairs(
filter: impl Into<Filter>,
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
) -> Self {
Self::new(
filter,
pairs
.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.collect(),
)
}
}
#[derive(Debug, Clone)]
pub struct NestedUpsertData<T: Model> {
pub filter: Filter,
pub create: NestedCreateData<T>,
pub update: Vec<(String, FilterValue)>,
_model: PhantomData<T>,
}
impl<T: Model> NestedUpsertData<T> {
pub fn new(
filter: impl Into<Filter>,
create: NestedCreateData<T>,
update: Vec<(String, FilterValue)>,
) -> Self {
Self {
filter: filter.into(),
create,
update,
_model: PhantomData,
}
}
}
#[derive(Debug, Clone)]
pub struct NestedUpdateManyData<T: Model> {
pub filter: Filter,
pub data: Vec<(String, FilterValue)>,
_model: PhantomData<T>,
}
impl<T: Model> NestedUpdateManyData<T> {
pub fn new(filter: impl Into<Filter>, data: Vec<(String, FilterValue)>) -> Self {
Self {
filter: filter.into(),
data,
_model: PhantomData,
}
}
}
#[derive(Debug)]
pub struct NestedWriteBuilder {
parent_table: String,
parent_pk: Vec<String>,
related_table: String,
foreign_key: String,
is_one_to_many: bool,
join_table: Option<JoinTableInfo>,
}
#[derive(Debug, Clone)]
pub struct JoinTableInfo {
pub table_name: String,
pub parent_column: String,
pub related_column: String,
}
impl NestedWriteBuilder {
pub fn one_to_many(
parent_table: impl Into<String>,
parent_pk: Vec<String>,
related_table: impl Into<String>,
foreign_key: impl Into<String>,
) -> Self {
Self {
parent_table: parent_table.into(),
parent_pk,
related_table: related_table.into(),
foreign_key: foreign_key.into(),
is_one_to_many: true,
join_table: None,
}
}
pub fn many_to_many(
parent_table: impl Into<String>,
parent_pk: Vec<String>,
related_table: impl Into<String>,
join_table: JoinTableInfo,
) -> Self {
Self {
parent_table: parent_table.into(),
parent_pk,
related_table: related_table.into(),
foreign_key: String::new(), is_one_to_many: false,
join_table: Some(join_table),
}
}
pub fn build_connect_sql<T: Model>(
&self,
parent_id: &FilterValue,
filters: &[Filter],
) -> Vec<(String, Vec<FilterValue>)> {
let mut statements = Vec::new();
if self.is_one_to_many {
for filter in filters {
let (where_sql, mut params) = filter.to_sql(1, &crate::dialect::Postgres);
let sql = format!(
"UPDATE {} SET {} = ${} WHERE {}",
quote_identifier(&self.related_table),
quote_identifier(&self.foreign_key),
params.len() + 1,
where_sql
);
params.push(parent_id.clone());
statements.push((sql, params));
}
} else if let Some(join) = &self.join_table {
for filter in filters {
let (where_sql, mut params) = filter.to_sql(1, &crate::dialect::Postgres);
let select_sql = format!(
"SELECT {} FROM {} WHERE {}",
quote_identifier(T::PRIMARY_KEY.first().unwrap_or(&"id")),
quote_identifier(&self.related_table),
where_sql
);
let insert_sql = format!(
"INSERT INTO {} ({}, {}) SELECT ${}, {} FROM {} WHERE {} ON CONFLICT DO NOTHING",
quote_identifier(&join.table_name),
quote_identifier(&join.parent_column),
quote_identifier(&join.related_column),
params.len() + 1,
quote_identifier(T::PRIMARY_KEY.first().unwrap_or(&"id")),
quote_identifier(&self.related_table),
where_sql
);
params.push(parent_id.clone());
statements.push((insert_sql, params));
let _ = select_sql;
}
}
statements
}
pub fn build_disconnect_sql(
&self,
parent_id: &FilterValue,
filters: &[Filter],
) -> Vec<(String, Vec<FilterValue>)> {
let mut statements = Vec::new();
if self.is_one_to_many {
for filter in filters {
let (where_sql, mut params) = filter.to_sql(1, &crate::dialect::Postgres);
let sql = format!(
"UPDATE {} SET {} = NULL WHERE {} AND {} = ${}",
quote_identifier(&self.related_table),
quote_identifier(&self.foreign_key),
where_sql,
quote_identifier(&self.foreign_key),
params.len() + 1
);
params.push(parent_id.clone());
statements.push((sql, params));
}
} else if let Some(join) = &self.join_table {
for filter in filters {
let (where_sql, mut params) = filter.to_sql(2, &crate::dialect::Postgres);
let sql = format!(
"DELETE FROM {} WHERE {} = $1 AND {} IN (SELECT id FROM {} WHERE {})",
quote_identifier(&join.table_name),
quote_identifier(&join.parent_column),
quote_identifier(&join.related_column),
quote_identifier(&self.related_table),
where_sql
);
let mut final_params = vec![parent_id.clone()];
final_params.extend(params);
params = final_params;
statements.push((sql, params));
}
}
statements
}
pub fn build_set_sql<T: Model>(
&self,
parent_id: &FilterValue,
filters: &[Filter],
) -> Vec<(String, Vec<FilterValue>)> {
let mut statements = Vec::new();
if self.is_one_to_many {
let sql = format!(
"UPDATE {} SET {} = NULL WHERE {} = $1",
quote_identifier(&self.related_table),
quote_identifier(&self.foreign_key),
quote_identifier(&self.foreign_key)
);
statements.push((sql, vec![parent_id.clone()]));
} else if let Some(join) = &self.join_table {
let sql = format!(
"DELETE FROM {} WHERE {} = $1",
quote_identifier(&join.table_name),
quote_identifier(&join.parent_column)
);
statements.push((sql, vec![parent_id.clone()]));
}
statements.extend(self.build_connect_sql::<T>(parent_id, filters));
statements
}
pub fn build_create_sql<T: Model>(
&self,
parent_id: &FilterValue,
creates: &[NestedCreateData<T>],
) -> Vec<(String, Vec<FilterValue>)> {
let mut statements = Vec::new();
for create in creates {
let mut columns: Vec<String> = create.data.iter().map(|(k, _)| k.clone()).collect();
let mut values: Vec<FilterValue> = create.data.iter().map(|(_, v)| v.clone()).collect();
columns.push(self.foreign_key.clone());
values.push(parent_id.clone());
let placeholders: Vec<String> = (1..=values.len()).map(|i| format!("${}", i)).collect();
let sql = format!(
"INSERT INTO {} ({}) VALUES ({}) RETURNING *",
quote_identifier(&self.related_table),
columns
.iter()
.map(|c| quote_identifier(c))
.collect::<Vec<_>>()
.join(", "),
placeholders.join(", ")
);
statements.push((sql, values));
}
statements
}
pub fn build_delete_sql(
&self,
parent_id: &FilterValue,
filters: &[Filter],
) -> Vec<(String, Vec<FilterValue>)> {
let mut statements = Vec::new();
for filter in filters {
let (where_sql, mut params) = filter.to_sql(1, &crate::dialect::Postgres);
let sql = format!(
"DELETE FROM {} WHERE {} AND {} = ${}",
quote_identifier(&self.related_table),
where_sql,
quote_identifier(&self.foreign_key),
params.len() + 1
);
params.push(parent_id.clone());
statements.push((sql, params));
}
statements
}
}
#[derive(Debug, Default)]
pub struct NestedWriteOperations {
pub pre_statements: Vec<(String, Vec<FilterValue>)>,
pub post_statements: Vec<(String, Vec<FilterValue>)>,
}
impl NestedWriteOperations {
pub fn new() -> Self {
Self::default()
}
pub fn add_pre(&mut self, sql: String, params: Vec<FilterValue>) {
self.pre_statements.push((sql, params));
}
pub fn add_post(&mut self, sql: String, params: Vec<FilterValue>) {
self.post_statements.push((sql, params));
}
pub fn extend(&mut self, other: Self) {
self.pre_statements.extend(other.pre_statements);
self.post_statements.extend(other.post_statements);
}
pub fn is_empty(&self) -> bool {
self.pre_statements.is_empty() && self.post_statements.is_empty()
}
pub fn len(&self) -> usize {
self.pre_statements.len() + self.post_statements.len()
}
}
#[derive(Debug, Clone)]
pub enum NestedWriteOp {
Create {
relation: String,
target_table: String,
foreign_key: String,
payload: Vec<Vec<(String, FilterValue)>>,
},
Connect {
relation: String,
pk: FilterValue,
},
}
impl NestedWriteOp {
pub async fn execute<E>(self, engine: &E, parent_pk: &FilterValue) -> QueryResult<()>
where
E: QueryEngine,
{
match self {
NestedWriteOp::Connect { relation, pk: _ } => {
let _ = relation;
Err(QueryError::internal(
"nested Connect is not yet implemented (needs child-PK column metadata)",
))
}
NestedWriteOp::Create {
relation: _,
target_table,
foreign_key,
payload,
} => {
let dialect = engine.dialect();
for child in payload {
let mut columns: Vec<String> = child.iter().map(|(c, _)| c.clone()).collect();
let mut values: Vec<FilterValue> = child.into_iter().map(|(_, v)| v).collect();
columns.push(foreign_key.clone());
values.push(parent_pk.clone());
let placeholders: Vec<String> =
(1..=values.len()).map(|i| dialect.placeholder(i)).collect();
let quoted_cols: Vec<String> =
columns.iter().map(|c| dialect.quote_ident(c)).collect();
let sql = format!(
"INSERT INTO {} ({}) VALUES ({})",
dialect.quote_ident(&target_table),
quoted_cols.join(", "),
placeholders.join(", "),
);
engine.execute_raw(&sql, values).await?;
}
Ok(())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestModel;
impl Model for TestModel {
const MODEL_NAME: &'static str = "Post";
const TABLE_NAME: &'static str = "posts";
const PRIMARY_KEY: &'static [&'static str] = &["id"];
const COLUMNS: &'static [&'static str] = &["id", "title", "user_id"];
}
struct TagModel;
impl Model for TagModel {
const MODEL_NAME: &'static str = "Tag";
const TABLE_NAME: &'static str = "tags";
const PRIMARY_KEY: &'static [&'static str] = &["id"];
const COLUMNS: &'static [&'static str] = &["id", "name"];
}
#[test]
fn test_nested_create_data() {
let data: NestedCreateData<TestModel> =
NestedCreateData::from_pairs([("title", FilterValue::String("Test Post".to_string()))]);
assert_eq!(data.data.len(), 1);
assert_eq!(data.data[0].0, "title");
}
#[test]
fn test_nested_write_create() {
let data: NestedCreateData<TestModel> =
NestedCreateData::from_pairs([("title", FilterValue::String("Test Post".to_string()))]);
let write: NestedWrite<TestModel> = NestedWrite::create(data);
match write {
NestedWrite::Create(creates) => assert_eq!(creates.len(), 1),
_ => panic!("Expected Create variant"),
}
}
#[test]
fn test_nested_write_connect() {
let write: NestedWrite<TestModel> = NestedWrite::connect(vec![
Filter::Equals("id".into(), FilterValue::Int(1)),
Filter::Equals("id".into(), FilterValue::Int(2)),
]);
match write {
NestedWrite::Connect(filters) => assert_eq!(filters.len(), 2),
_ => panic!("Expected Connect variant"),
}
}
#[test]
fn test_nested_write_disconnect() {
let write: NestedWrite<TestModel> =
NestedWrite::disconnect_one(Filter::Equals("id".into(), FilterValue::Int(1)));
match write {
NestedWrite::Disconnect(filters) => assert_eq!(filters.len(), 1),
_ => panic!("Expected Disconnect variant"),
}
}
#[test]
fn test_nested_write_set() {
let write: NestedWrite<TestModel> =
NestedWrite::set(vec![Filter::Equals("id".into(), FilterValue::Int(1))]);
match write {
NestedWrite::Set(filters) => assert_eq!(filters.len(), 1),
_ => panic!("Expected Set variant"),
}
}
#[test]
fn test_builder_one_to_many_connect() {
let builder =
NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
let parent_id = FilterValue::Int(1);
let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
let statements = builder.build_connect_sql::<TestModel>(&parent_id, &filters);
assert_eq!(statements.len(), 1);
let (sql, params) = &statements[0];
assert!(sql.contains("UPDATE"));
assert!(sql.contains("posts"));
assert!(sql.contains("user_id"));
assert_eq!(params.len(), 2);
}
#[test]
fn test_builder_one_to_many_disconnect() {
let builder =
NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
let parent_id = FilterValue::Int(1);
let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
let statements = builder.build_disconnect_sql(&parent_id, &filters);
assert_eq!(statements.len(), 1);
let (sql, params) = &statements[0];
assert!(sql.contains("UPDATE"));
assert!(sql.contains("SET"));
assert!(sql.contains("NULL"));
assert_eq!(params.len(), 2);
}
#[test]
fn test_builder_many_to_many_connect() {
let builder = NestedWriteBuilder::many_to_many(
"posts",
vec!["id".to_string()],
"tags",
JoinTableInfo {
table_name: "post_tags".to_string(),
parent_column: "post_id".to_string(),
related_column: "tag_id".to_string(),
},
);
let parent_id = FilterValue::Int(1);
let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
let statements = builder.build_connect_sql::<TagModel>(&parent_id, &filters);
assert_eq!(statements.len(), 1);
let (sql, _params) = &statements[0];
assert!(sql.contains("INSERT INTO"));
assert!(sql.contains("post_tags"));
assert!(sql.contains("ON CONFLICT DO NOTHING"));
}
#[test]
fn test_builder_create() {
let builder =
NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
let parent_id = FilterValue::Int(1);
let creates = vec![NestedCreateData::<TestModel>::from_pairs([(
"title",
FilterValue::String("New Post".to_string()),
)])];
let statements = builder.build_create_sql::<TestModel>(&parent_id, &creates);
assert_eq!(statements.len(), 1);
let (sql, params) = &statements[0];
assert!(sql.contains("INSERT INTO"));
assert!(sql.contains("posts"));
assert!(sql.contains("RETURNING"));
assert_eq!(params.len(), 2); }
#[test]
fn test_builder_set() {
let builder =
NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
let parent_id = FilterValue::Int(1);
let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
let statements = builder.build_set_sql::<TestModel>(&parent_id, &filters);
assert!(statements.len() >= 2);
let (first_sql, _) = &statements[0];
assert!(first_sql.contains("UPDATE"));
assert!(first_sql.contains("NULL"));
}
#[test]
fn test_nested_write_operations() {
let mut ops = NestedWriteOperations::new();
assert!(ops.is_empty());
assert_eq!(ops.len(), 0);
ops.add_pre("SELECT 1".to_string(), vec![]);
ops.add_post("SELECT 2".to_string(), vec![]);
assert!(!ops.is_empty());
assert_eq!(ops.len(), 2);
}
#[test]
fn test_nested_create_or_connect() {
let create_data: NestedCreateData<TestModel> =
NestedCreateData::from_pairs([("title", FilterValue::String("New Post".to_string()))]);
let create_or_connect = NestedCreateOrConnectData::new(
Filter::Equals("title".into(), FilterValue::String("Existing".to_string())),
create_data,
);
assert!(matches!(create_or_connect.filter, Filter::Equals(..)));
assert_eq!(create_or_connect.create.data.len(), 1);
}
#[test]
fn test_nested_update_data() {
let update: NestedUpdateData<TestModel> = NestedUpdateData::from_pairs(
Filter::Equals("id".into(), FilterValue::Int(1)),
[("title", FilterValue::String("Updated".to_string()))],
);
assert!(matches!(update.filter, Filter::Equals(..)));
assert_eq!(update.data.len(), 1);
assert_eq!(update.data[0].0, "title");
}
#[test]
fn test_nested_upsert_data() {
let create: NestedCreateData<TestModel> =
NestedCreateData::from_pairs([("title", FilterValue::String("New".to_string()))]);
let upsert: NestedUpsertData<TestModel> = NestedUpsertData::new(
Filter::Equals("id".into(), FilterValue::Int(1)),
create,
vec![(
"title".to_string(),
FilterValue::String("Updated".to_string()),
)],
);
assert!(matches!(upsert.filter, Filter::Equals(..)));
assert_eq!(upsert.create.data.len(), 1);
assert_eq!(upsert.update.len(), 1);
}
}