use crate::column::ColumnMarker;
use crate::error::{Error, Result};
use crate::value::Value;
#[derive(Debug, Clone, PartialEq)]
pub struct Insert {
pub table: String,
pub columns: Vec<ColumnMarker>,
pub values: Vec<Vec<Value>>,
pub returning: Vec<ColumnMarker>,
}
impl Insert {
pub fn into_table(table: impl Into<String>) -> InsertBuilder {
InsertBuilder {
table: table.into(),
columns: Vec::new(),
values: Vec::new(),
returning: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct InsertBuilder {
table: String,
columns: Vec<ColumnMarker>,
values: Vec<Vec<Value>>,
returning: Vec<ColumnMarker>,
}
impl InsertBuilder {
#[must_use]
pub fn columns(mut self, columns: Vec<ColumnMarker>) -> Self {
self.columns = columns;
self
}
#[must_use]
pub fn column(mut self, column: ColumnMarker) -> Self {
self.columns.push(column);
self
}
#[must_use]
pub fn values(mut self, row: Vec<Value>) -> Self {
self.values.push(row);
self
}
#[must_use]
pub fn returning(mut self, columns: Vec<ColumnMarker>) -> Self {
self.returning = columns;
self
}
pub fn build(self) -> Result<Insert> {
if self.table.is_empty() {
return Err(Error::MissingField("table".to_string()));
}
if self.columns.is_empty() {
return Err(Error::MissingField("columns".to_string()));
}
if self.values.is_empty() {
return Err(Error::MissingField("values".to_string()));
}
let col_count = self.columns.len();
for (i, row) in self.values.iter().enumerate() {
if row.len() != col_count {
return Err(Error::InvalidQuery(format!(
"row {} has {} values but {} columns were specified",
i,
row.len(),
col_count
)));
}
}
Ok(Insert {
table: self.table,
columns: self.columns,
values: self.values,
returning: self.returning,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_insert() {
let insert = Insert::into_table("users")
.column(ColumnMarker::new("users", "email"))
.values(vec![Value::String("alice@example.com".to_string())])
.build()
.unwrap();
assert_eq!(insert.table, "users");
assert_eq!(insert.columns.len(), 1);
assert_eq!(insert.columns[0].name, "email");
assert_eq!(insert.values.len(), 1);
assert!(insert.returning.is_empty());
}
#[test]
fn test_multi_column_insert() {
let insert = Insert::into_table("users")
.columns(vec![
ColumnMarker::new("users", "email"),
ColumnMarker::new("users", "name"),
])
.values(vec![
Value::String("alice@example.com".to_string()),
Value::String("Alice".to_string()),
])
.build()
.unwrap();
assert_eq!(insert.columns.len(), 2);
assert_eq!(insert.values.len(), 1);
assert_eq!(insert.values[0].len(), 2);
}
#[test]
fn test_batch_insert() {
let insert = Insert::into_table("users")
.column(ColumnMarker::new("users", "email"))
.values(vec![Value::String("alice@example.com".to_string())])
.values(vec![Value::String("bob@example.com".to_string())])
.values(vec![Value::String("charlie@example.com".to_string())])
.build()
.unwrap();
assert_eq!(insert.values.len(), 3);
}
#[test]
fn test_insert_with_returning() {
let insert = Insert::into_table("users")
.column(ColumnMarker::new("users", "email"))
.values(vec![Value::String("alice@example.com".to_string())])
.returning(vec![
ColumnMarker::new("users", "id"),
ColumnMarker::new("users", "email"),
])
.build()
.unwrap();
assert_eq!(insert.returning.len(), 2);
assert_eq!(insert.returning[0].name, "id");
assert_eq!(insert.returning[1].name, "email");
}
#[test]
fn test_missing_table() {
let result = Insert::into_table("")
.column(ColumnMarker::new("users", "email"))
.values(vec![Value::String("test".to_string())])
.build();
assert!(result.is_err());
}
#[test]
fn test_missing_columns() {
let result = Insert::into_table("users")
.values(vec![Value::String("test".to_string())])
.build();
assert!(result.is_err());
}
#[test]
fn test_missing_values() {
let result = Insert::into_table("users")
.column(ColumnMarker::new("users", "email"))
.build();
assert!(result.is_err());
}
#[test]
fn test_mismatched_values_count() {
let result = Insert::into_table("users")
.columns(vec![
ColumnMarker::new("users", "email"),
ColumnMarker::new("users", "name"),
])
.values(vec![Value::String("alice@example.com".to_string())])
.build();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, Error::InvalidQuery(_)));
}
}