use std::{borrow::Cow, fmt::Write, marker::PhantomData};
use crate::{
Compatible, Dialect, FromRow, HasDialect, Qrafting, Query, QueryOf, RpnInstr, TypeMeta,
emitter::{Directive, Emitter},
expression::{Binary, Column, op, prepare_sqlite_glob},
impl_for_all_tuples,
lower::{Data, LowerCtx},
param::{Param, encode_param},
query::{
LockedQuery, LockedQueryOf, LowerProject, Select, Table, TypedCompiled, WithSelect,
rewrite_params,
},
relation::{ModelField, PersistedField},
span::TextSource,
};
pub trait Insertable<M> {
fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>);
fn insert_into(self, table: Table<M>) -> Insert<M>
where
M: Qrafting,
Self: Sized,
{
crate::insert_into(table).values(self)
}
}
pub struct InsertInto<M> {
table: Table<M>,
}
pub struct InsertColumnsInto<M> {
table: Table<M>,
columns: Vec<&'static str>,
}
pub struct Insert<M> {
pub table: Table<M>,
pub columns: Vec<&'static str>,
pub select: Option<Query>,
pub returning: InsertReturning,
pub conflict: InsertConflict,
pub query: Query,
}
pub enum InsertReturning {
None,
All,
Projection(Vec<RpnInstr>),
}
pub struct ReturningInsert<M, T> {
pub insert: Insert<M>,
pub marker: PhantomData<T>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InsertConflict {
None,
Ignore,
Upsert {
unique_by: Vec<&'static str>,
update: Vec<&'static str>,
},
}
pub struct FieldCollector {
fields: Vec<&'static str>,
}
impl FieldCollector {
pub fn new() -> Self {
Self { fields: Vec::new() }
}
pub fn into_inner(self) -> Vec<&'static str> {
self.fields
}
}
impl Default for FieldCollector {
fn default() -> Self {
Self::new()
}
}
impl<'a> VisitParam<'a> for FieldCollector {
fn param(&mut self, field: &'static str, _param: impl quex::Encode) {
if !self.fields.contains(&field) {
self.fields.push(field);
}
}
fn param_typed<T>(&mut self, field: &'static str, _value: &'a (impl Compatible<T> + ?Sized))
where
T: TypeMeta,
{
if !self.fields.contains(&field) {
self.fields.push(field);
}
}
}
impl<M> Insert<M>
where
M: Qrafting,
{
pub fn to_sql<D: HasDialect>(&self) -> String {
match self.select.clone() {
Some(mut query) => self.render_select_sql::<D>(&mut query),
None => {
let mut writer = String::new();
let mut ctx = FormatContext {
writer: &mut writer,
index: 0,
dialect: D::DIALECT,
};
self.format_writer(&mut ctx)
.expect("cannot fail on a string writer");
writer
}
}
}
pub fn ignore(mut self) -> Self {
self.conflict = InsertConflict::Ignore;
self
}
pub fn returning<P>(mut self, project: P) -> ReturningInsert<M, M>
where
P: LowerProject,
M: FromRow,
{
let mut instrs = Vec::new();
match self.select.as_mut() {
Some(query) => {
let mut ctx = LowerCtx {
instrs: &mut instrs,
params: &mut query.params,
data: &mut query.data,
};
project.lower_project(&mut ctx);
}
None => {
let mut ctx = LowerCtx {
instrs: &mut instrs,
params: &mut self.query.params,
data: &mut self.query.data,
};
project.lower_project(&mut ctx);
}
}
self.returning = InsertReturning::Projection(instrs);
ReturningInsert {
insert: self,
marker: PhantomData,
}
}
pub fn returning_all(mut self) -> ReturningInsert<M, M>
where
M: FromRow,
{
self.returning = InsertReturning::All;
ReturningInsert {
insert: self,
marker: PhantomData,
}
}
pub fn no_returning(mut self) -> Self {
self.returning = InsertReturning::None;
self
}
pub fn upsert<U, C>(mut self, unique_by: C, update: U) -> Self
where
C: InsertColumns,
U: InsertColumns,
{
self.conflict = InsertConflict::Upsert {
unique_by: unique_by.into_columns(),
update: update.into_columns(),
};
self
}
#[doc(hidden)]
pub fn into_select_compiled<D: HasDialect>(self) -> Option<TypedCompiled<M>> {
let Insert {
table,
columns,
select,
returning,
conflict,
query: insert_query,
} = self;
let mut query = select?;
let insert = Insert {
table,
columns,
select: None,
returning,
conflict,
query: insert_query,
};
let sql = insert.render_select_sql::<D>(&mut query);
Some(TypedCompiled {
sql,
params: query.params,
data: query.data,
marker: PhantomData,
})
}
fn render_select_sql<D: HasDialect>(&self, query: &mut Query) -> String {
let columns = self.columns.as_slice();
assert!(
!columns.is_empty(),
"insert-select requires at least one target column"
);
let mut writer = String::new();
let mut directives = Vec::new();
let mut indexes = Vec::new();
{
let mut emitter = Emitter::new(
&mut writer,
&query.data,
D::DIALECT,
&mut directives,
&mut indexes,
);
emitter.emit_ctes_for_query(query).unwrap();
}
let mut ctx = FormatContext {
writer: &mut writer,
index: 0,
dialect: D::DIALECT,
};
self.write_insert_prefix(&mut ctx).unwrap();
self.write_columns(&mut ctx, columns).unwrap();
ctx.writer.write_char(' ').unwrap();
{
let writer = &mut *ctx.writer;
let mut emitter = Emitter::new(
writer,
&query.data,
D::DIALECT,
&mut directives,
&mut indexes,
);
emitter.emit_query_body(query).unwrap();
}
self.write_conflict(&mut ctx).unwrap();
self.write_returning(&mut ctx, query, &mut directives, &mut indexes)
.unwrap();
finalize_query_params(query, directives, &indexes);
writer
}
fn write_insert_prefix<'w, W: Write>(
&self,
context: &mut FormatContext<'w, W>,
) -> std::fmt::Result {
match (&self.conflict, context.dialect) {
(InsertConflict::Ignore, Dialect::Sqlite) => {
context.writer.write_str("insert or ignore into ")
}
(InsertConflict::Ignore, Dialect::MariaDb) => {
context.writer.write_str("insert ignore into ")
}
_ => context.writer.write_str("insert into "),
}?;
self.table.format_writer(context)
}
fn write_columns<'w, W: Write>(
&self,
context: &mut FormatContext<'w, W>,
fields: &[&'static str],
) -> std::fmt::Result {
context.writer.write_str(" (")?;
for (i, field) in fields.iter().enumerate() {
if i > 0 {
context.writer.write_str(", ")?;
}
context.write_ident(field)?;
}
context.writer.write_char(')')
}
fn write_conflict<'w, W: Write>(&self, context: &mut FormatContext<'w, W>) -> std::fmt::Result {
match (&self.conflict, context.dialect) {
(InsertConflict::None, _)
| (InsertConflict::Ignore, Dialect::Sqlite | Dialect::MariaDb) => {}
(InsertConflict::Ignore, Dialect::Postgres) => {
context.writer.write_str(" on conflict do nothing")?;
}
(InsertConflict::Upsert { unique_by, update }, Dialect::Postgres | Dialect::Sqlite) => {
assert!(
!unique_by.is_empty(),
"upsert requires at least one conflict column"
);
assert!(
!update.is_empty(),
"upsert requires at least one update column"
);
context.writer.write_str(" on conflict (")?;
for (i, field) in unique_by.iter().enumerate() {
if i > 0 {
context.writer.write_str(", ")?;
}
context.write_ident(field)?;
}
context.writer.write_str(") do update set ")?;
for (i, field) in update.iter().enumerate() {
if i > 0 {
context.writer.write_str(", ")?;
}
context.write_ident(field)?;
context.writer.write_str(" = ")?;
context.write_ident("excluded")?;
context.writer.write_char('.')?;
context.write_ident(field)?;
}
}
(InsertConflict::Upsert { update, .. }, Dialect::MariaDb) => {
assert!(
!update.is_empty(),
"upsert requires at least one update column"
);
context.writer.write_str(" on duplicate key update ")?;
for (i, field) in update.iter().enumerate() {
if i > 0 {
context.writer.write_str(", ")?;
}
context.write_ident(field)?;
context.writer.write_str(" = values(")?;
context.write_ident(field)?;
context.writer.write_char(')')?;
}
}
}
Ok(())
}
fn write_returning<'w, W: Write>(
&self,
context: &mut FormatContext<'w, W>,
query: &Query,
directives: &mut Vec<Directive>,
indexes: &mut Vec<usize>,
) -> std::fmt::Result {
match &self.returning {
InsertReturning::None => Ok(()),
InsertReturning::All => context.writer.write_str(" returning *"),
InsertReturning::Projection(instrs) => {
context.writer.write_str(" returning ")?;
let mut emitter = Emitter::new(
context.writer,
&query.data,
context.dialect,
directives,
indexes,
);
emitter.emit_instrs(instrs)
}
}
}
}
impl<M> InsertInto<M> {
pub fn values<V>(self, values: V) -> Insert<M>
where
M: Qrafting,
V: Insertable<M>,
{
let mut query = Query::default();
let mut field_collector = FieldCollector::new();
values.values(&mut field_collector);
let columns = field_collector.into_inner();
let mut param_collector = QueryParamCollector {
params: &mut query.params,
data: &mut query.data,
};
values.values(&mut param_collector);
Insert {
table: self.table,
columns,
select: None,
returning: InsertReturning::All,
conflict: InsertConflict::None,
query,
}
}
pub fn columns<C>(self, columns: C) -> InsertColumnsInto<M>
where
C: InsertColumns,
{
let columns = columns.into_columns();
assert!(
!columns.is_empty(),
"insert-select requires at least one target column"
);
InsertColumnsInto {
table: self.table,
columns,
}
}
}
impl<M> InsertColumnsInto<M> {
pub fn query<Q>(self, query: Q) -> Insert<M>
where
Q: IntoInsertSelectQuery,
{
Insert {
table: self.table,
columns: self.columns,
select: Some(query.into_insert_select_query()),
returning: InsertReturning::All,
conflict: InsertConflict::None,
query: Query::default(),
}
}
}
pub trait InsertColumns {
fn into_columns(self) -> Vec<&'static str>;
}
pub trait IntoInsertColumn {
fn into_insert_column(self) -> &'static str;
}
pub trait IntoInsertSelectQuery {
fn into_insert_select_query(self) -> Query;
}
impl IntoInsertSelectQuery for Query {
fn into_insert_select_query(self) -> Query {
self
}
}
impl<M> IntoInsertSelectQuery for QueryOf<M> {
fn into_insert_select_query(self) -> Query {
self.into()
}
}
impl IntoInsertSelectQuery for LockedQuery {
fn into_insert_select_query(self) -> Query {
self.into()
}
}
impl<M> IntoInsertSelectQuery for LockedQueryOf<M> {
fn into_insert_select_query(self) -> Query {
self.into()
}
}
impl<P> IntoInsertSelectQuery for Select<P>
where
P: LowerProject,
{
fn into_insert_select_query(self) -> Query {
self.into_query()
}
}
impl<P> IntoInsertSelectQuery for WithSelect<P>
where
P: LowerProject,
{
fn into_insert_select_query(self) -> Query {
self.into_query()
}
}
impl IntoInsertColumn for &'static str {
fn into_insert_column(self) -> &'static str {
self
}
}
impl<M, T> IntoInsertColumn for Column<M, T>
where
T: TypeMeta,
{
fn into_insert_column(self) -> &'static str {
self.name
}
}
impl<M, V, T> IntoInsertColumn for ModelField<M, V, T>
where
T: TypeMeta,
{
fn into_insert_column(self) -> &'static str {
self.name()
}
}
impl<M, V, T, K> IntoInsertColumn for PersistedField<M, V, T, K>
where
T: TypeMeta,
{
fn into_insert_column(self) -> &'static str {
self.name()
}
}
impl InsertColumns for &'static str {
fn into_columns(self) -> Vec<&'static str> {
vec![self]
}
}
impl<M, T> InsertColumns for Column<M, T>
where
T: TypeMeta,
{
fn into_columns(self) -> Vec<&'static str> {
vec![self.name]
}
}
impl<const N: usize> InsertColumns for [&'static str; N] {
fn into_columns(self) -> Vec<&'static str> {
self.into_iter().collect()
}
}
impl InsertColumns for &[&'static str] {
fn into_columns(self) -> Vec<&'static str> {
self.to_vec()
}
}
impl InsertColumns for Vec<&'static str> {
fn into_columns(self) -> Vec<&'static str> {
self
}
}
macro_rules! impl_insert_columns_tuple {
($($T:ident),+) => {
impl<$($T,)+> InsertColumns for ($($T,)+)
where
$($T: IntoInsertColumn,)+
{
fn into_columns(self) -> Vec<&'static str> {
#[allow(non_snake_case)]
let ($($T,)+) = self;
vec![$($T.into_insert_column(),)+]
}
}
};
}
impl_for_all_tuples!(impl_insert_columns_tuple);
pub trait VisitParam<'v> {
fn param(&mut self, field: &'static str, _param: impl quex::Encode);
fn param_typed<T>(&mut self, field: &'static str, value: &'v (impl Compatible<T> + ?Sized))
where
T: TypeMeta,
{
self.param(field, value)
}
}
impl<M, T, R> Insertable<M> for Binary<op::Eq, Column<M, T>, R>
where
T: TypeMeta,
M: Qrafting,
R: Compatible<T>,
{
fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
visitor.param_typed::<T>(self.left.name, &self.right);
}
}
impl<M, V, T, R> Insertable<M> for Binary<op::Eq, ModelField<M, V, T>, R>
where
T: TypeMeta,
M: Qrafting,
R: Compatible<T>,
{
fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
visitor.param_typed::<T>(self.left.name(), &self.right);
}
}
impl<M, V, T, K, R> Insertable<M> for Binary<op::Eq, PersistedField<M, V, T, K>, R>
where
T: TypeMeta,
M: Qrafting,
R: Compatible<T>,
{
fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
visitor.param_typed::<T>(self.left.name(), &self.right);
}
}
impl<M> Insertable<M> for () {
fn values<'v>(&'v self, _visitor: &mut impl VisitParam<'v>) {}
}
impl<M, T> Insertable<M> for &[T]
where
T: Insertable<M>,
{
fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
for item in self.iter() {
item.values(visitor);
}
}
}
impl<M, T> Insertable<M> for Vec<T>
where
T: Insertable<M>,
{
fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
for item in self.iter() {
item.values(visitor);
}
}
}
impl<M, T, const N: usize> Insertable<M> for [T; N]
where
T: Insertable<M>,
{
fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
for item in self.iter() {
item.values(visitor);
}
}
}
impl<M, T> Insertable<M> for &T
where
T: Insertable<M>,
{
fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
(**self).values(visitor)
}
}
impl<M, T> Insertable<M> for [T]
where
T: Insertable<M>,
{
fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
for item in self.iter() {
item.values(visitor);
}
}
}
pub struct FormatValue<'w, W: Write> {
writer: &'w mut W,
dialect: Dialect,
count: usize,
field_count: usize,
result: std::fmt::Result,
}
struct QueryParamCollector<'q> {
params: &'q mut Vec<Param>,
data: &'q mut Vec<u8>,
}
impl<'q, 'v> VisitParam<'v> for QueryParamCollector<'q> {
fn param(&mut self, _field: &'static str, param: impl quex::Encode) {
self.params.push(encode_param(¶m, self.data));
}
fn param_typed<T>(&mut self, _field: &'static str, value: &'v (impl Compatible<T> + ?Sized))
where
T: TypeMeta,
{
self.params.push(encode_param(value, self.data));
}
}
impl<'w, 'v, W: Write> VisitParam<'v> for FormatValue<'w, W> {
fn param(&mut self, _field: &'static str, _param: impl quex::Encode) {
if self.result.is_err() {
return;
}
if self.count == 0 {
} else if self.count.is_multiple_of(self.field_count) {
self.result = self.writer.write_str("), (");
} else {
self.result = self.writer.write_str(", ");
}
if self.result.is_err() {
return;
}
self.count += 1;
self.result = match self.dialect {
Dialect::Postgres => write!(self.writer, "${}", self.count),
Dialect::MariaDb | Dialect::Sqlite => self.writer.write_char('?'),
};
}
fn param_typed<T>(&mut self, _field: &'static str, _value: &'v (impl Compatible<T> + ?Sized))
where
T: TypeMeta,
{
self.param("", "")
}
}
pub struct FormatContext<'w, W: Write> {
pub writer: &'w mut W,
pub index: usize,
pub dialect: Dialect,
}
impl<'w, W: Write> FormatContext<'w, W> {
pub(crate) fn write_ident(&mut self, part: &str) -> std::fmt::Result {
if part == "*" {
return self.writer.write_char('*');
}
let quote = match self.dialect {
Dialect::Postgres | Dialect::Sqlite => '"',
Dialect::MariaDb => '`',
};
self.writer.write_char(quote)?;
let dbl = if quote == '"' { "\"\"" } else { "``" };
let mut last = 0;
for (index, char) in part.char_indices() {
if char == quote {
if index != last {
self.writer.write_str(&part[last..index])?;
}
self.writer.write_str(dbl)?;
last = index + char.len_utf8();
}
}
if last < part.len() {
self.writer.write_str(&part[last..])?;
}
self.writer.write_char(quote)?;
Ok(())
}
pub(crate) fn write_table(&mut self, ident: &str) -> std::fmt::Result {
for (i, part) in ident.split('.').enumerate() {
if i > 0 {
self.writer.write_char('.')?;
}
self.write_ident(part)?;
}
Ok(())
}
}
fn finalize_query_params(query: &mut Query, directives: Vec<Directive>, indexes: &[usize]) {
for directive in directives {
match directive {
Directive::RewriteGlob { id } => {
let maybe_param = query.params.get_mut(id);
if let Some(crate::param::Param::Text(Some(text_span))) = maybe_param {
let text = query.data.text(*text_span);
let value = prepare_sqlite_glob(text);
if let Cow::Owned(value) = value {
if let TextSource::Text(span) = text_span.0
&& value.len() == text.len()
{
let bytes = &mut query.data
[span.start as usize..span.start as usize + span.len as usize];
bytes.copy_from_slice(value.as_bytes());
} else {
let span = query.data.intern_text(&value);
*text_span = span;
}
}
}
}
}
}
rewrite_params(indexes, &mut query.params);
}
pub trait FormatWriter {
fn format_writer<'w, W: Write>(&self, context: &mut FormatContext<'w, W>) -> std::fmt::Result;
}
impl<M> FormatWriter for Insert<M>
where
M: Qrafting,
{
fn format_writer<'w, W: Write>(&self, context: &mut FormatContext<'w, W>) -> std::fmt::Result {
self.write_insert_prefix(context)?;
let fields = self.columns.as_slice();
if fields.is_empty() {
context.writer.write_str(" default values")?;
} else {
self.write_columns(context, fields)?;
context.writer.write_str(" values (")?;
let mut value_formatter = FormatValue {
writer: context.writer,
count: 0,
field_count: fields.len(),
dialect: context.dialect,
result: Ok(()),
};
for _ in &self.query.params {
value_formatter.param("", "");
}
value_formatter.result?;
context.writer.write_char(')')?;
}
self.write_conflict(context)?;
let mut directives = Vec::new();
let mut indexes = Vec::new();
self.write_returning(context, &self.query, &mut directives, &mut indexes)?;
Ok(())
}
}
impl<M, T> ReturningInsert<M, T>
where
M: Qrafting,
T: FromRow,
{
pub fn typed<R>(self) -> ReturningInsert<M, R>
where
R: FromRow,
{
ReturningInsert {
insert: self.insert,
marker: PhantomData,
}
}
pub fn to_sql<D: HasDialect>(&self) -> String {
self.insert.to_sql::<D>()
}
pub fn to_debug_sql<D: HasDialect>(&self) -> String {
match self.insert.select.clone() {
Some(mut query) => {
let mut sql = self.insert.render_select_sql::<D>(&mut query);
query.debug_params(&mut sql).unwrap();
sql
}
None => {
let mut sql = self.insert.to_sql::<D>();
self.insert.query.debug_params(&mut sql).unwrap();
sql
}
}
}
pub fn into_compiled<D: HasDialect>(self) -> TypedCompiled<T> {
let insert = self.insert;
match insert.select {
Some(mut query) => {
let insert = Insert {
table: insert.table,
columns: insert.columns,
select: None,
returning: insert.returning,
conflict: insert.conflict,
query: insert.query,
};
let sql = insert.render_select_sql::<D>(&mut query);
TypedCompiled {
sql,
params: query.params,
data: query.data,
marker: PhantomData,
}
}
None => TypedCompiled {
sql: insert.to_sql::<D>(),
params: insert.query.params,
data: insert.query.data,
marker: PhantomData,
},
}
}
}
pub fn insert_into<M>(table: Table<M>) -> InsertInto<M> {
InsertInto { table }
}
macro_rules! impl_insertable_macro {
($($T:ident),+) => {
impl<M, $($T,)+> Insertable<M> for ($($T,)+)
where
$($T: Insertable<M>,)+
{
fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
#[allow(non_snake_case)]
let ($($T,)+) = self;
$(
$T.values(visitor);
)+
}
}
};
}
impl_for_all_tuples!(impl_insertable_macro);
#[cfg(test)]
mod tests {
use super::insert_into;
use crate::{
MySql, Postgres, Sqlite,
query::select,
tests::{id, name, table, username},
};
#[test]
fn test_insert_default_values() {
let stmt = insert_into(table).values(()).to_sql::<Sqlite>();
assert_eq!(stmt, r#"insert into "users" default values returning *"#);
}
#[test]
fn test_insert_single_row_tuple_values() {
let stmt = insert_into(table)
.values((id.eq(1), name.eq("alice"), username.eq("alice1")))
.to_sql::<Sqlite>();
assert_eq!(
stmt,
r#"insert into "users" ("id", "name", "username") values (?, ?, ?) returning *"#
);
}
#[test]
fn test_insert_multiple_rows_from_array() {
let stmt = insert_into(table)
.values([
(id.eq(1), username.eq("alpha")),
(id.eq(2), username.eq("beta")),
])
.to_sql::<Sqlite>();
assert_eq!(
stmt,
r#"insert into "users" ("id", "username") values (?, ?), (?, ?) returning *"#
);
}
#[test]
fn test_insert_postgres_uses_dollar_placeholders() {
let stmt = insert_into(table)
.values((id.eq(10), username.eq("hello")))
.to_sql::<Postgres>();
assert_eq!(
stmt,
r#"insert into "users" ("id", "username") values ($1, $2) returning *"#
);
}
#[test]
fn test_insert_ignore_sqlite() {
let stmt = insert_into(table)
.values((id.eq(10), username.eq("hello")))
.ignore()
.to_sql::<Sqlite>();
assert_eq!(
stmt,
r#"insert or ignore into "users" ("id", "username") values (?, ?) returning *"#
);
}
#[test]
fn test_insert_ignore_postgres() {
let stmt = insert_into(table)
.values((id.eq(10), username.eq("hello")))
.ignore()
.to_sql::<Postgres>();
assert_eq!(
stmt,
r#"insert into "users" ("id", "username") values ($1, $2) on conflict do nothing returning *"#
);
}
#[test]
fn test_insert_ignore_mariadb() {
let stmt = insert_into(table)
.values((id.eq(10), username.eq("hello")))
.ignore()
.to_sql::<MySql>();
assert_eq!(
stmt,
"insert ignore into `users` (`id`, `username`) values (?, ?) returning *"
);
}
#[test]
fn test_insert_upsert_postgres() {
let stmt = insert_into(table)
.values((id.eq(10), username.eq("hello"), name.eq("lea")))
.upsert(["id"], ["username", "name"])
.to_sql::<Postgres>();
assert_eq!(
stmt,
r#"insert into "users" ("id", "username", "name") values ($1, $2, $3) on conflict ("id") do update set "username" = "excluded"."username", "name" = "excluded"."name" returning *"#
);
}
#[test]
fn test_insert_upsert_sqlite() {
let stmt = insert_into(table)
.values((id.eq(10), username.eq("hello")))
.upsert(["id"], ["username"])
.to_sql::<Sqlite>();
assert_eq!(
stmt,
r#"insert into "users" ("id", "username") values (?, ?) on conflict ("id") do update set "username" = "excluded"."username" returning *"#
);
}
#[test]
fn test_insert_upsert_mariadb() {
let stmt = insert_into(table)
.values((id.eq(10), username.eq("hello")))
.upsert(["id"], ["username"])
.to_sql::<MySql>();
assert_eq!(
stmt,
"insert into `users` (`id`, `username`) values (?, ?) on duplicate key update `username` = values(`username`) returning *"
);
}
#[test]
fn test_insert_upsert_accepts_tuple_and_vec_columns() {
let unique_by = ("id", "username");
let update = vec!["name", "username"];
let stmt = insert_into(table)
.values((id.eq(10), username.eq("hello"), name.eq("lea")))
.upsert(unique_by, update)
.to_sql::<Postgres>();
assert_eq!(
stmt,
r#"insert into "users" ("id", "username", "name") values ($1, $2, $3) on conflict ("id", "username") do update set "name" = "excluded"."name", "username" = "excluded"."username" returning *"#
);
}
#[test]
fn test_insert_returning_all_is_explicit() {
let stmt = insert_into(table)
.values((id.eq(10), username.eq("hello")))
.returning_all()
.to_sql::<Postgres>();
assert_eq!(
stmt,
r#"insert into "users" ("id", "username") values ($1, $2) returning *"#
);
}
#[test]
fn test_insert_returning_projection_matches_update_style() {
let stmt = insert_into(table)
.values((id.eq(10), username.eq("hello")))
.returning((id, username))
.to_sql::<Postgres>();
assert_eq!(
stmt,
r#"insert into "users" ("id", "username") values ($1, $2) returning "users"."id", "users"."username""#
);
}
#[test]
fn test_insert_select_sqlite() {
let stmt = insert_into(table)
.columns((id, username))
.query(
select((id, username))
.from(table)
.filter(username.eq("alice")),
)
.to_sql::<Sqlite>();
assert_eq!(
stmt,
r#"insert into "users" ("id", "username") select "users"."id", "users"."username" from "users" where "users"."username" = ? returning *"#
);
}
#[test]
fn test_insert_select_postgres_with_filter_and_upsert() {
let stmt = insert_into(table)
.columns((id, username))
.query(select((id, username)).from(table).filter(id.eq(10)))
.upsert(["id"], ["username"])
.to_sql::<Postgres>();
assert_eq!(
stmt,
r#"insert into "users" ("id", "username") select "users"."id", "users"."username" from "users" where "users"."id" = $1 on conflict ("id") do update set "username" = "excluded"."username" returning *"#
);
}
#[test]
fn test_insert_select_mariadb_ignore_no_returning() {
let stmt = insert_into(table)
.columns((id, username))
.query(select((id, username)).from(table))
.ignore()
.no_returning()
.to_sql::<MySql>();
assert_eq!(
stmt,
"insert ignore into `users` (`id`, `username`) select `users`.`id`, `users`.`username` from `users`"
);
}
}