use std::marker::PhantomData;
use crate::error::QueryResult;
use crate::filter::FilterValue;
use crate::traits::{Model, QueryEngine};
use crate::types::Select;
pub struct CreateOperation<E: QueryEngine, M: Model> {
engine: E,
columns: Vec<String>,
values: Vec<FilterValue>,
select: Select,
_model: PhantomData<M>,
}
impl<E: QueryEngine, M: Model> CreateOperation<E, M> {
pub fn new(engine: E) -> Self {
Self {
engine,
columns: Vec::new(),
values: Vec::new(),
select: Select::All,
_model: PhantomData,
}
}
pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
self.columns.push(column.into());
self.values.push(value.into());
self
}
pub fn set_many(
mut self,
values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
) -> Self {
for (col, val) in values {
self.columns.push(col.into());
self.values.push(val.into());
}
self
}
pub fn select(mut self, select: impl Into<Select>) -> Self {
self.select = select.into();
self
}
pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
let mut sql = String::new();
sql.push_str("INSERT INTO ");
sql.push_str(M::TABLE_NAME);
sql.push_str(" (");
sql.push_str(&self.columns.join(", "));
sql.push(')');
sql.push_str(" VALUES (");
let placeholders: Vec<_> = (1..=self.values.len()).map(|i| format!("${}", i)).collect();
sql.push_str(&placeholders.join(", "));
sql.push(')');
sql.push_str(" RETURNING ");
sql.push_str(&self.select.to_sql());
(sql, self.values.clone())
}
pub async fn exec(self) -> QueryResult<M>
where
M: Send + 'static,
{
let (sql, params) = self.build_sql();
self.engine.execute_insert::<M>(&sql, params).await
}
}
pub struct CreateManyOperation<E: QueryEngine, M: Model> {
engine: E,
columns: Vec<String>,
rows: Vec<Vec<FilterValue>>,
skip_duplicates: bool,
_model: PhantomData<M>,
}
impl<E: QueryEngine, M: Model> CreateManyOperation<E, M> {
pub fn new(engine: E) -> Self {
Self {
engine,
columns: Vec::new(),
rows: Vec::new(),
skip_duplicates: false,
_model: PhantomData,
}
}
pub fn columns(mut self, columns: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.columns = columns.into_iter().map(Into::into).collect();
self
}
pub fn row(mut self, values: impl IntoIterator<Item = impl Into<FilterValue>>) -> Self {
self.rows.push(values.into_iter().map(Into::into).collect());
self
}
pub fn rows(
mut self,
rows: impl IntoIterator<Item = impl IntoIterator<Item = impl Into<FilterValue>>>,
) -> Self {
for row in rows {
self.rows.push(row.into_iter().map(Into::into).collect());
}
self
}
pub fn skip_duplicates(mut self) -> Self {
self.skip_duplicates = true;
self
}
pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
let mut sql = String::new();
let mut all_params = Vec::new();
sql.push_str("INSERT INTO ");
sql.push_str(M::TABLE_NAME);
sql.push_str(" (");
sql.push_str(&self.columns.join(", "));
sql.push(')');
sql.push_str(" VALUES ");
let mut value_groups = Vec::new();
let mut param_idx = 1;
for row in &self.rows {
let placeholders: Vec<_> = row
.iter()
.map(|v| {
all_params.push(v.clone());
let placeholder = format!("${}", param_idx);
param_idx += 1;
placeholder
})
.collect();
value_groups.push(format!("({})", placeholders.join(", ")));
}
sql.push_str(&value_groups.join(", "));
if self.skip_duplicates {
sql.push_str(" ON CONFLICT DO NOTHING");
}
(sql, all_params)
}
pub async fn exec(self) -> QueryResult<u64> {
let (sql, params) = self.build_sql();
self.engine.execute_raw(&sql, params).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::QueryError;
struct TestModel;
impl Model for TestModel {
const MODEL_NAME: &'static str = "TestModel";
const TABLE_NAME: &'static str = "test_models";
const PRIMARY_KEY: &'static [&'static str] = &["id"];
const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
}
#[derive(Clone)]
struct MockEngine {
insert_count: u64,
}
impl MockEngine {
fn new() -> Self {
Self { insert_count: 0 }
}
fn with_count(count: u64) -> Self {
Self {
insert_count: count,
}
}
}
impl QueryEngine for MockEngine {
fn query_many<T: Model + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
Box::pin(async { Ok(Vec::new()) })
}
fn query_one<T: Model + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
Box::pin(async { Err(QueryError::not_found("test")) })
}
fn query_optional<T: Model + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
Box::pin(async { Ok(None) })
}
fn execute_insert<T: Model + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
Box::pin(async { Err(QueryError::not_found("test")) })
}
fn execute_update<T: Model + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
Box::pin(async { Ok(Vec::new()) })
}
fn execute_delete(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
Box::pin(async { Ok(0) })
}
fn execute_raw(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
let count = self.insert_count;
Box::pin(async move { Ok(count) })
}
fn count(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
Box::pin(async { Ok(0) })
}
}
#[test]
fn test_create_new() {
let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new());
let (sql, params) = op.build_sql();
assert!(sql.contains("INSERT INTO test_models"));
assert!(sql.contains("RETURNING *"));
assert!(params.is_empty());
}
#[test]
fn test_create_basic() {
let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("name", "Alice")
.set("email", "alice@example.com");
let (sql, params) = op.build_sql();
assert!(sql.contains("INSERT INTO test_models"));
assert!(sql.contains("(name, email)"));
assert!(sql.contains("VALUES ($1, $2)"));
assert!(sql.contains("RETURNING *"));
assert_eq!(params.len(), 2);
}
#[test]
fn test_create_single_field() {
let op =
CreateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set("name", "Alice");
let (sql, params) = op.build_sql();
assert!(sql.contains("(name)"));
assert!(sql.contains("VALUES ($1)"));
assert_eq!(params.len(), 1);
}
#[test]
fn test_create_with_set_many() {
let values = vec![
("name", FilterValue::String("Bob".to_string())),
("email", FilterValue::String("bob@test.com".to_string())),
("age", FilterValue::Int(25)),
];
let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set_many(values);
let (sql, params) = op.build_sql();
assert!(sql.contains("(name, email, age)"));
assert!(sql.contains("VALUES ($1, $2, $3)"));
assert_eq!(params.len(), 3);
}
#[test]
fn test_create_with_select() {
let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("name", "Alice")
.select(Select::fields(["id", "name"]));
let (sql, _) = op.build_sql();
assert!(sql.contains("RETURNING id, name"));
assert!(!sql.contains("RETURNING *"));
}
#[test]
fn test_create_with_null_value() {
let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("name", "Alice")
.set("nickname", FilterValue::Null);
let (_sql, params) = op.build_sql();
assert_eq!(params.len(), 2);
assert_eq!(params[1], FilterValue::Null);
}
#[test]
fn test_create_with_boolean_value() {
let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("active", FilterValue::Bool(true));
let (_, params) = op.build_sql();
assert_eq!(params[0], FilterValue::Bool(true));
}
#[test]
fn test_create_with_numeric_values() {
let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("count", FilterValue::Int(42))
.set("price", FilterValue::Float(99.99));
let (_, params) = op.build_sql();
assert_eq!(params[0], FilterValue::Int(42));
assert_eq!(params[1], FilterValue::Float(99.99));
}
#[test]
fn test_create_with_json_value() {
let json = serde_json::json!({"key": "value", "nested": {"a": 1}});
let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("metadata", FilterValue::Json(json.clone()));
let (_, params) = op.build_sql();
assert_eq!(params[0], FilterValue::Json(json));
}
#[tokio::test]
async fn test_create_exec() {
let op =
CreateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set("name", "Alice");
let result = op.exec().await;
assert!(result.is_err());
}
#[test]
fn test_create_many_new() {
let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new());
let (sql, params) = op.build_sql();
assert!(sql.contains("INSERT INTO test_models"));
assert!(!sql.contains("RETURNING")); assert!(params.is_empty());
}
#[test]
fn test_create_many() {
let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
.columns(["name", "email"])
.row(["Alice", "alice@example.com"])
.row(["Bob", "bob@example.com"]);
let (sql, params) = op.build_sql();
assert!(sql.contains("INSERT INTO test_models"));
assert!(sql.contains("(name, email)"));
assert!(sql.contains("VALUES ($1, $2), ($3, $4)"));
assert_eq!(params.len(), 4);
}
#[test]
fn test_create_many_single_row() {
let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
.columns(["name"])
.row(["Alice"]);
let (sql, params) = op.build_sql();
assert!(sql.contains("VALUES ($1)"));
assert_eq!(params.len(), 1);
}
#[test]
fn test_create_many_skip_duplicates() {
let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
.columns(["name", "email"])
.row(["Alice", "alice@example.com"])
.skip_duplicates();
let (sql, _) = op.build_sql();
assert!(sql.contains("ON CONFLICT DO NOTHING"));
}
#[test]
fn test_create_many_without_skip_duplicates() {
let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
.columns(["name"])
.row(["Alice"]);
let (sql, _) = op.build_sql();
assert!(!sql.contains("ON CONFLICT"));
}
#[test]
fn test_create_many_with_rows() {
let rows = vec![
vec!["Alice", "alice@test.com"],
vec!["Bob", "bob@test.com"],
vec!["Charlie", "charlie@test.com"],
];
let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
.columns(["name", "email"])
.rows(rows);
let (sql, params) = op.build_sql();
assert!(sql.contains("VALUES ($1, $2), ($3, $4), ($5, $6)"));
assert_eq!(params.len(), 6);
}
#[test]
fn test_create_many_param_ordering() {
let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
.columns(["a", "b"])
.row(["1", "2"])
.row(["3", "4"]);
let (_, params) = op.build_sql();
assert_eq!(params[0], FilterValue::String("1".to_string()));
assert_eq!(params[1], FilterValue::String("2".to_string()));
assert_eq!(params[2], FilterValue::String("3".to_string()));
assert_eq!(params[3], FilterValue::String("4".to_string()));
}
#[tokio::test]
async fn test_create_many_exec() {
let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::with_count(3))
.columns(["name"])
.row(["Alice"])
.row(["Bob"])
.row(["Charlie"]);
let result = op.exec().await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 3);
}
#[test]
fn test_create_sql_structure() {
let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("name", "Alice")
.select(Select::fields(["id"]));
let (sql, _) = op.build_sql();
let insert_pos = sql.find("INSERT INTO").unwrap();
let columns_pos = sql.find("(name)").unwrap();
let values_pos = sql.find("VALUES").unwrap();
let returning_pos = sql.find("RETURNING").unwrap();
assert!(insert_pos < columns_pos);
assert!(columns_pos < values_pos);
assert!(values_pos < returning_pos);
}
#[test]
fn test_create_many_sql_structure() {
let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
.columns(["name", "email"])
.row(["Alice", "alice@test.com"])
.skip_duplicates();
let (sql, _) = op.build_sql();
let insert_pos = sql.find("INSERT INTO").unwrap();
let columns_pos = sql.find("(name, email)").unwrap();
let values_pos = sql.find("VALUES").unwrap();
let conflict_pos = sql.find("ON CONFLICT").unwrap();
assert!(insert_pos < columns_pos);
assert!(columns_pos < values_pos);
assert!(values_pos < conflict_pos);
}
#[test]
fn test_create_table_name() {
let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new());
let (sql, _) = op.build_sql();
assert!(sql.contains("test_models"));
}
#[test]
fn test_create_method_chaining() {
let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("name", "Alice")
.set("email", "alice@test.com")
.select(Select::fields(["id", "name"]));
let (sql, params) = op.build_sql();
assert!(sql.contains("(name, email)"));
assert!(sql.contains("VALUES ($1, $2)"));
assert!(sql.contains("RETURNING id, name"));
assert_eq!(params.len(), 2);
}
#[test]
fn test_create_many_method_chaining() {
let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
.columns(["a", "b"])
.row(["1", "2"])
.row(["3", "4"])
.skip_duplicates();
let (sql, params) = op.build_sql();
assert!(sql.contains("ON CONFLICT DO NOTHING"));
assert_eq!(params.len(), 4);
}
}