use crate::{
builder::SqlFragment,
identifier::{escape_ident, from_qi, QualifiedIdentifier},
param::SqlParam,
};
#[derive(Clone, Debug, Default)]
pub struct InsertBuilder {
table: Option<SqlFragment>,
columns: Vec<String>,
values: Vec<Vec<SqlFragment>>,
on_conflict: Option<OnConflict>,
returning: Vec<SqlFragment>,
}
#[derive(Clone, Debug)]
pub enum OnConflict {
DoNothing,
DoUpdate {
columns: Vec<String>,
set: Vec<(String, SqlFragment)>,
where_clause: Option<SqlFragment>,
},
}
impl InsertBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn into_table(mut self, qi: &QualifiedIdentifier) -> Self {
self.table = Some(SqlFragment::raw(from_qi(qi)));
self
}
pub fn columns(mut self, cols: Vec<String>) -> Self {
self.columns = cols;
self
}
pub fn values(mut self, vals: Vec<SqlParam>) -> Self {
let row: Vec<SqlFragment> = vals
.into_iter()
.map(|v| {
let mut frag = SqlFragment::new();
frag.push_param(v);
frag
})
.collect();
self.values.push(row);
self
}
pub fn values_raw(mut self, vals: Vec<SqlFragment>) -> Self {
self.values.push(vals);
self
}
pub fn on_conflict_do_nothing(mut self) -> Self {
self.on_conflict = Some(OnConflict::DoNothing);
self
}
pub fn on_conflict_do_update(
mut self,
conflict_columns: Vec<String>,
set: Vec<(String, SqlFragment)>,
) -> Self {
self.on_conflict = Some(OnConflict::DoUpdate {
columns: conflict_columns,
set,
where_clause: None,
});
self
}
pub fn returning(mut self, column: &str) -> Self {
self.returning
.push(SqlFragment::raw(escape_ident(column)));
self
}
pub fn returning_all(mut self) -> Self {
self.returning.push(SqlFragment::raw("*"));
self
}
pub fn build(self) -> SqlFragment {
let mut result = SqlFragment::new();
result.push("INSERT INTO ");
if let Some(table) = self.table {
result.append(table);
}
if !self.columns.is_empty() {
result.push(" (");
for (i, col) in self.columns.iter().enumerate() {
if i > 0 {
result.push(", ");
}
result.push(&escape_ident(col));
}
result.push(")");
}
if !self.values.is_empty() {
result.push(" VALUES ");
for (i, row) in self.values.into_iter().enumerate() {
if i > 0 {
result.push(", ");
}
result.push("(");
for (j, val) in row.into_iter().enumerate() {
if j > 0 {
result.push(", ");
}
result.append(val);
}
result.push(")");
}
} else {
result.push(" DEFAULT VALUES");
}
if let Some(conflict) = self.on_conflict {
match conflict {
OnConflict::DoNothing => {
result.push(" ON CONFLICT DO NOTHING");
}
OnConflict::DoUpdate {
columns,
set,
where_clause,
} => {
result.push(" ON CONFLICT (");
for (i, col) in columns.iter().enumerate() {
if i > 0 {
result.push(", ");
}
result.push(&escape_ident(col));
}
result.push(") DO UPDATE SET ");
for (i, (col, val)) in set.into_iter().enumerate() {
if i > 0 {
result.push(", ");
}
result.push(&escape_ident(&col));
result.push(" = ");
result.append(val);
}
if let Some(where_sql) = where_clause {
result.push(" WHERE ");
result.append(where_sql);
}
}
}
}
if !self.returning.is_empty() {
result.push(" RETURNING ");
for (i, ret) in self.returning.into_iter().enumerate() {
if i > 0 {
result.push(", ");
}
result.append(ret);
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_insert() {
let qi = QualifiedIdentifier::new("public", "users");
let sql = InsertBuilder::new()
.into_table(&qi)
.columns(vec!["name".into(), "email".into()])
.values(vec![SqlParam::text("John"), SqlParam::text("john@example.com")])
.build();
assert!(sql.sql().contains("INSERT INTO"));
assert!(sql.sql().contains("VALUES"));
assert_eq!(sql.params().len(), 2);
}
#[test]
fn test_insert_returning() {
let qi = QualifiedIdentifier::unqualified("users");
let sql = InsertBuilder::new()
.into_table(&qi)
.columns(vec!["name".into()])
.values(vec![SqlParam::text("John")])
.returning("id")
.build();
assert!(sql.sql().contains("RETURNING"));
}
#[test]
fn test_insert_on_conflict_nothing() {
let qi = QualifiedIdentifier::unqualified("users");
let sql = InsertBuilder::new()
.into_table(&qi)
.columns(vec!["email".into()])
.values(vec![SqlParam::text("john@example.com")])
.on_conflict_do_nothing()
.build();
assert!(sql.sql().contains("ON CONFLICT DO NOTHING"));
}
#[test]
fn test_insert_upsert() {
let qi = QualifiedIdentifier::unqualified("users");
let mut name_val = SqlFragment::new();
name_val.push("EXCLUDED.\"name\"");
let sql = InsertBuilder::new()
.into_table(&qi)
.columns(vec!["id".into(), "name".into()])
.values(vec![SqlParam::Int(1), SqlParam::text("John")])
.on_conflict_do_update(vec!["id".into()], vec![("name".into(), name_val)])
.build();
assert!(sql.sql().contains("ON CONFLICT"));
assert!(sql.sql().contains("DO UPDATE SET"));
}
}