use crate::{
backend::QueryBuilder, error::*, prepare::*, types::*, value::*, Expr, Query,
QueryStatementBuilder, QueryStatementWriter, SelectExpr, SelectStatement, SimpleExpr,
SubQueryStatement, WithClause, WithQuery,
};
#[derive(Debug, Clone)]
pub(crate) enum InsertValueSource {
Values(Vec<Vec<SimpleExpr>>),
Select(Box<SelectStatement>),
}
#[derive(Debug, Default, Clone)]
pub struct InsertStatement {
pub(crate) table: Option<Box<TableRef>>,
pub(crate) columns: Vec<DynIden>,
pub(crate) source: Option<InsertValueSource>,
pub(crate) returning: Vec<SelectExpr>,
}
impl InsertStatement {
pub fn new() -> Self {
Self::default()
}
#[allow(clippy::wrong_self_convention)]
pub fn into_table<T>(&mut self, tbl_ref: T) -> &mut Self
where
T: IntoTableRef,
{
self.table = Some(Box::new(tbl_ref.into_table_ref()));
self
}
pub fn columns<C, I>(&mut self, columns: I) -> &mut Self
where
C: IntoIden,
I: IntoIterator<Item = C>,
{
self.columns = columns.into_iter().map(|c| c.into_iden()).collect();
self
}
pub fn values<I>(&mut self, values: I) -> Result<&mut Self>
where
I: IntoIterator<Item = Value>,
{
let values = values
.into_iter()
.map(|v| Expr::val(v).into())
.collect::<Vec<SimpleExpr>>();
if self.columns.len() != values.len() {
return Err(Error::ColValNumMismatch {
col_len: self.columns.len(),
val_len: values.len(),
});
}
let values_source = if let Some(InsertValueSource::Values(values)) = &mut self.source {
values
} else {
self.source = Some(InsertValueSource::Values(Default::default()));
if let Some(InsertValueSource::Values(values)) = &mut self.source {
values
} else {
unreachable!();
}
};
values_source.push(values);
Ok(self)
}
pub fn select_from<S>(&mut self, select: S) -> Result<&mut Self>
where
S: Into<SelectStatement>,
{
let statement = select.into();
if self.columns.len() != statement.selects.len() {
return Err(Error::ColValNumMismatch {
col_len: self.columns.len(),
val_len: statement.selects.len(),
});
}
self.source = Some(InsertValueSource::Select(Box::new(statement)));
Ok(self)
}
pub fn exprs<I>(&mut self, values: I) -> Result<&mut Self>
where
I: IntoIterator<Item = SimpleExpr>,
{
let values = values.into_iter().collect::<Vec<SimpleExpr>>();
if self.columns.len() != values.len() {
return Err(Error::ColValNumMismatch {
col_len: self.columns.len(),
val_len: values.len(),
});
}
let values_source = if let Some(InsertValueSource::Values(values)) = &mut self.source {
values
} else {
self.source = Some(InsertValueSource::Values(Default::default()));
if let Some(InsertValueSource::Values(values)) = &mut self.source {
values
} else {
unreachable!();
}
};
values_source.push(values);
Ok(self)
}
pub fn values_panic<I>(&mut self, values: I) -> &mut Self
where
I: IntoIterator<Item = Value>,
{
self.values(values).unwrap()
}
pub fn exprs_panic<I>(&mut self, values: I) -> &mut Self
where
I: IntoIterator<Item = SimpleExpr>,
{
self.exprs(values).unwrap()
}
pub fn returning(&mut self, select: SelectStatement) -> &mut Self {
self.returning = select.selects;
self
}
pub fn returning_col<C>(&mut self, col: C) -> &mut Self
where
C: IntoIden,
{
self.returning(Query::select().column(col.into_iden()).take())
}
pub fn with(self, clause: WithClause) -> WithQuery {
clause.query(self)
}
}
impl QueryStatementBuilder for InsertStatement {
fn build_collect_any_into(
&self,
query_builder: &dyn QueryBuilder,
sql: &mut SqlWriter,
collector: &mut dyn FnMut(Value),
) {
query_builder.prepare_insert_statement(self, sql, collector);
}
fn into_sub_query_statement(self) -> SubQueryStatement {
SubQueryStatement::InsertStatement(self)
}
}
impl QueryStatementWriter for InsertStatement {
fn build_collect<T: QueryBuilder>(
&self,
query_builder: T,
collector: &mut dyn FnMut(Value),
) -> String {
let mut sql = SqlWriter::new();
query_builder.prepare_insert_statement(self, &mut sql, collector);
sql.result()
}
}