1use std::{borrow::Cow, fmt::Write, marker::PhantomData};
4
5use crate::{
6 Compatible, Dialect, FromRow, HasDialect, Qrafting, Query, QueryOf, RpnInstr, TypeMeta,
7 emitter::{Directive, Emitter},
8 expression::{Binary, Column, op, prepare_sqlite_glob},
9 impl_for_all_tuples,
10 lower::{Data, LowerCtx},
11 param::{Param, encode_param},
12 query::{
13 LockedQuery, LockedQueryOf, LowerProject, Select, Table, TypedCompiled, WithSelect,
14 rewrite_params,
15 },
16 relation::{ModelField, PersistedField},
17 span::TextSource,
18};
19
20pub trait Insertable<M> {
22 fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>);
24
25 fn insert_into(self, table: Table<M>) -> Insert<M>
27 where
28 M: Qrafting,
29 Self: Sized,
30 {
31 crate::insert_into(table).values(self)
32 }
33}
34
35pub struct InsertInto<M> {
37 table: Table<M>,
38}
39
40pub struct InsertColumnsInto<M> {
42 table: Table<M>,
43 columns: Vec<&'static str>,
44}
45
46pub struct Insert<M> {
48 pub table: Table<M>,
50 pub columns: Vec<&'static str>,
52 pub select: Option<Query>,
54 pub returning: InsertReturning,
56 pub conflict: InsertConflict,
58 pub query: Query,
60}
61
62pub enum InsertReturning {
64 None,
66 All,
68 Projection(Vec<RpnInstr>),
70}
71
72pub struct ReturningInsert<M, T> {
74 pub insert: Insert<M>,
76 pub marker: PhantomData<T>,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq)]
82pub enum InsertConflict {
83 None,
84 Ignore,
85 Upsert {
86 unique_by: Vec<&'static str>,
87 update: Vec<&'static str>,
88 },
89}
90
91pub struct FieldCollector {
93 fields: Vec<&'static str>,
94}
95
96impl FieldCollector {
97 pub fn new() -> Self {
98 Self { fields: Vec::new() }
99 }
100
101 pub fn into_inner(self) -> Vec<&'static str> {
103 self.fields
104 }
105}
106
107impl Default for FieldCollector {
108 fn default() -> Self {
109 Self::new()
110 }
111}
112
113impl<'a> VisitParam<'a> for FieldCollector {
114 fn param(&mut self, field: &'static str, _param: impl quex::Encode) {
115 if !self.fields.contains(&field) {
116 self.fields.push(field);
117 }
118 }
119
120 fn param_typed<T>(&mut self, field: &'static str, _value: &'a (impl Compatible<T> + ?Sized))
121 where
122 T: TypeMeta,
123 {
124 if !self.fields.contains(&field) {
125 self.fields.push(field);
126 }
127 }
128}
129
130impl<M> Insert<M>
131where
132 M: Qrafting,
133{
134 pub fn to_sql<D: HasDialect>(&self) -> String {
136 match self.select.clone() {
137 Some(mut query) => self.render_select_sql::<D>(&mut query),
138 None => {
139 let mut writer = String::new();
140 let mut ctx = FormatContext {
141 writer: &mut writer,
142 index: 0,
143 dialect: D::DIALECT,
144 };
145 self.format_writer(&mut ctx)
146 .expect("cannot fail on a string writer");
147 writer
148 }
149 }
150 }
151
152 pub fn ignore(mut self) -> Self {
154 self.conflict = InsertConflict::Ignore;
155 self
156 }
157
158 pub fn returning<P>(mut self, project: P) -> ReturningInsert<M, M>
160 where
161 P: LowerProject,
162 M: FromRow,
163 {
164 let mut instrs = Vec::new();
165 match self.select.as_mut() {
166 Some(query) => {
167 let mut ctx = LowerCtx {
168 instrs: &mut instrs,
169 params: &mut query.params,
170 data: &mut query.data,
171 };
172 project.lower_project(&mut ctx);
173 }
174 None => {
175 let mut ctx = LowerCtx {
176 instrs: &mut instrs,
177 params: &mut self.query.params,
178 data: &mut self.query.data,
179 };
180 project.lower_project(&mut ctx);
181 }
182 }
183 self.returning = InsertReturning::Projection(instrs);
184 ReturningInsert {
185 insert: self,
186 marker: PhantomData,
187 }
188 }
189
190 pub fn returning_all(mut self) -> ReturningInsert<M, M>
192 where
193 M: FromRow,
194 {
195 self.returning = InsertReturning::All;
196 ReturningInsert {
197 insert: self,
198 marker: PhantomData,
199 }
200 }
201
202 pub fn no_returning(mut self) -> Self {
204 self.returning = InsertReturning::None;
205 self
206 }
207
208 pub fn upsert<U, C>(mut self, unique_by: C, update: U) -> Self
210 where
211 C: InsertColumns,
212 U: InsertColumns,
213 {
214 self.conflict = InsertConflict::Upsert {
215 unique_by: unique_by.into_columns(),
216 update: update.into_columns(),
217 };
218 self
219 }
220
221 #[doc(hidden)]
222 pub fn into_select_compiled<D: HasDialect>(self) -> Option<TypedCompiled<M>> {
223 let Insert {
224 table,
225 columns,
226 select,
227 returning,
228 conflict,
229 query: insert_query,
230 } = self;
231 let mut query = select?;
232 let insert = Insert {
233 table,
234 columns,
235 select: None,
236 returning,
237 conflict,
238 query: insert_query,
239 };
240 let sql = insert.render_select_sql::<D>(&mut query);
241 Some(TypedCompiled {
242 sql,
243 params: query.params,
244 data: query.data,
245 marker: PhantomData,
246 })
247 }
248
249 fn render_select_sql<D: HasDialect>(&self, query: &mut Query) -> String {
250 let columns = self.columns.as_slice();
251 assert!(
252 !columns.is_empty(),
253 "insert-select requires at least one target column"
254 );
255
256 let mut writer = String::new();
257 let mut directives = Vec::new();
258 let mut indexes = Vec::new();
259
260 {
261 let mut emitter = Emitter::new(
262 &mut writer,
263 &query.data,
264 D::DIALECT,
265 &mut directives,
266 &mut indexes,
267 );
268 emitter.emit_ctes_for_query(query).unwrap();
269 }
270
271 let mut ctx = FormatContext {
272 writer: &mut writer,
273 index: 0,
274 dialect: D::DIALECT,
275 };
276
277 self.write_insert_prefix(&mut ctx).unwrap();
278 self.write_columns(&mut ctx, columns).unwrap();
279 ctx.writer.write_char(' ').unwrap();
280
281 {
282 let writer = &mut *ctx.writer;
283 let mut emitter = Emitter::new(
284 writer,
285 &query.data,
286 D::DIALECT,
287 &mut directives,
288 &mut indexes,
289 );
290 emitter.emit_query_body(query).unwrap();
291 }
292
293 self.write_conflict(&mut ctx).unwrap();
294 self.write_returning(&mut ctx, query, &mut directives, &mut indexes)
295 .unwrap();
296
297 finalize_query_params(query, directives, &indexes);
298
299 writer
300 }
301
302 fn write_insert_prefix<'w, W: Write>(
303 &self,
304 context: &mut FormatContext<'w, W>,
305 ) -> std::fmt::Result {
306 match (&self.conflict, context.dialect) {
307 (InsertConflict::Ignore, Dialect::Sqlite) => {
308 context.writer.write_str("insert or ignore into ")
309 }
310 (InsertConflict::Ignore, Dialect::MariaDb) => {
311 context.writer.write_str("insert ignore into ")
312 }
313 _ => context.writer.write_str("insert into "),
314 }?;
315 self.table.format_writer(context)
316 }
317
318 fn write_columns<'w, W: Write>(
319 &self,
320 context: &mut FormatContext<'w, W>,
321 fields: &[&'static str],
322 ) -> std::fmt::Result {
323 context.writer.write_str(" (")?;
324 for (i, field) in fields.iter().enumerate() {
325 if i > 0 {
326 context.writer.write_str(", ")?;
327 }
328 context.write_ident(field)?;
329 }
330 context.writer.write_char(')')
331 }
332
333 fn write_conflict<'w, W: Write>(&self, context: &mut FormatContext<'w, W>) -> std::fmt::Result {
334 match (&self.conflict, context.dialect) {
335 (InsertConflict::None, _)
336 | (InsertConflict::Ignore, Dialect::Sqlite | Dialect::MariaDb) => {}
337 (InsertConflict::Ignore, Dialect::Postgres) => {
338 context.writer.write_str(" on conflict do nothing")?;
339 }
340 (InsertConflict::Upsert { unique_by, update }, Dialect::Postgres | Dialect::Sqlite) => {
341 assert!(
342 !unique_by.is_empty(),
343 "upsert requires at least one conflict column"
344 );
345 assert!(
346 !update.is_empty(),
347 "upsert requires at least one update column"
348 );
349 context.writer.write_str(" on conflict (")?;
350 for (i, field) in unique_by.iter().enumerate() {
351 if i > 0 {
352 context.writer.write_str(", ")?;
353 }
354 context.write_ident(field)?;
355 }
356 context.writer.write_str(") do update set ")?;
357 for (i, field) in update.iter().enumerate() {
358 if i > 0 {
359 context.writer.write_str(", ")?;
360 }
361 context.write_ident(field)?;
362 context.writer.write_str(" = ")?;
363 context.write_ident("excluded")?;
364 context.writer.write_char('.')?;
365 context.write_ident(field)?;
366 }
367 }
368 (InsertConflict::Upsert { update, .. }, Dialect::MariaDb) => {
369 assert!(
370 !update.is_empty(),
371 "upsert requires at least one update column"
372 );
373 context.writer.write_str(" on duplicate key update ")?;
374 for (i, field) in update.iter().enumerate() {
375 if i > 0 {
376 context.writer.write_str(", ")?;
377 }
378 context.write_ident(field)?;
379 context.writer.write_str(" = values(")?;
380 context.write_ident(field)?;
381 context.writer.write_char(')')?;
382 }
383 }
384 }
385
386 Ok(())
387 }
388
389 fn write_returning<'w, W: Write>(
390 &self,
391 context: &mut FormatContext<'w, W>,
392 query: &Query,
393 directives: &mut Vec<Directive>,
394 indexes: &mut Vec<usize>,
395 ) -> std::fmt::Result {
396 match &self.returning {
397 InsertReturning::None => Ok(()),
398 InsertReturning::All => context.writer.write_str(" returning *"),
399 InsertReturning::Projection(instrs) => {
400 context.writer.write_str(" returning ")?;
401 let mut emitter = Emitter::new(
402 context.writer,
403 &query.data,
404 context.dialect,
405 directives,
406 indexes,
407 );
408 emitter.emit_instrs(instrs)
409 }
410 }
411 }
412}
413
414impl<M> InsertInto<M> {
415 pub fn values<V>(self, values: V) -> Insert<M>
417 where
418 M: Qrafting,
419 V: Insertable<M>,
420 {
421 let mut query = Query::default();
422 let mut field_collector = FieldCollector::new();
423 values.values(&mut field_collector);
424 let columns = field_collector.into_inner();
425
426 let mut param_collector = QueryParamCollector {
427 params: &mut query.params,
428 data: &mut query.data,
429 };
430 values.values(&mut param_collector);
431
432 Insert {
433 table: self.table,
434 columns,
435 select: None,
436 returning: InsertReturning::All,
437 conflict: InsertConflict::None,
438 query,
439 }
440 }
441
442 pub fn columns<C>(self, columns: C) -> InsertColumnsInto<M>
444 where
445 C: InsertColumns,
446 {
447 let columns = columns.into_columns();
448 assert!(
449 !columns.is_empty(),
450 "insert-select requires at least one target column"
451 );
452 InsertColumnsInto {
453 table: self.table,
454 columns,
455 }
456 }
457}
458
459impl<M> InsertColumnsInto<M> {
460 pub fn query<Q>(self, query: Q) -> Insert<M>
462 where
463 Q: IntoInsertSelectQuery,
464 {
465 Insert {
466 table: self.table,
467 columns: self.columns,
468 select: Some(query.into_insert_select_query()),
469 returning: InsertReturning::All,
470 conflict: InsertConflict::None,
471 query: Query::default(),
472 }
473 }
474}
475
476pub trait InsertColumns {
478 fn into_columns(self) -> Vec<&'static str>;
480}
481
482pub trait IntoInsertColumn {
484 fn into_insert_column(self) -> &'static str;
485}
486
487pub trait IntoInsertSelectQuery {
489 fn into_insert_select_query(self) -> Query;
490}
491
492impl IntoInsertSelectQuery for Query {
493 fn into_insert_select_query(self) -> Query {
494 self
495 }
496}
497
498impl<M> IntoInsertSelectQuery for QueryOf<M> {
499 fn into_insert_select_query(self) -> Query {
500 self.into()
501 }
502}
503
504impl IntoInsertSelectQuery for LockedQuery {
505 fn into_insert_select_query(self) -> Query {
506 self.into()
507 }
508}
509
510impl<M> IntoInsertSelectQuery for LockedQueryOf<M> {
511 fn into_insert_select_query(self) -> Query {
512 self.into()
513 }
514}
515
516impl<P> IntoInsertSelectQuery for Select<P>
517where
518 P: LowerProject,
519{
520 fn into_insert_select_query(self) -> Query {
521 self.into_query()
522 }
523}
524
525impl<P> IntoInsertSelectQuery for WithSelect<P>
526where
527 P: LowerProject,
528{
529 fn into_insert_select_query(self) -> Query {
530 self.into_query()
531 }
532}
533
534impl IntoInsertColumn for &'static str {
535 fn into_insert_column(self) -> &'static str {
536 self
537 }
538}
539
540impl<M, T> IntoInsertColumn for Column<M, T>
541where
542 T: TypeMeta,
543{
544 fn into_insert_column(self) -> &'static str {
545 self.name
546 }
547}
548
549impl<M, V, T> IntoInsertColumn for ModelField<M, V, T>
550where
551 T: TypeMeta,
552{
553 fn into_insert_column(self) -> &'static str {
554 self.name()
555 }
556}
557
558impl<M, V, T, K> IntoInsertColumn for PersistedField<M, V, T, K>
559where
560 T: TypeMeta,
561{
562 fn into_insert_column(self) -> &'static str {
563 self.name()
564 }
565}
566
567impl InsertColumns for &'static str {
568 fn into_columns(self) -> Vec<&'static str> {
569 vec![self]
570 }
571}
572
573impl<M, T> InsertColumns for Column<M, T>
574where
575 T: TypeMeta,
576{
577 fn into_columns(self) -> Vec<&'static str> {
578 vec![self.name]
579 }
580}
581
582impl<const N: usize> InsertColumns for [&'static str; N] {
583 fn into_columns(self) -> Vec<&'static str> {
584 self.into_iter().collect()
585 }
586}
587
588impl InsertColumns for &[&'static str] {
589 fn into_columns(self) -> Vec<&'static str> {
590 self.to_vec()
591 }
592}
593
594impl InsertColumns for Vec<&'static str> {
595 fn into_columns(self) -> Vec<&'static str> {
596 self
597 }
598}
599
600macro_rules! impl_insert_columns_tuple {
601 ($($T:ident),+) => {
602 impl<$($T,)+> InsertColumns for ($($T,)+)
603 where
604 $($T: IntoInsertColumn,)+
605 {
606 fn into_columns(self) -> Vec<&'static str> {
607 #[allow(non_snake_case)]
608 let ($($T,)+) = self;
609 vec![$($T.into_insert_column(),)+]
610 }
611 }
612 };
613}
614
615impl_for_all_tuples!(impl_insert_columns_tuple);
616
617pub trait VisitParam<'v> {
619 fn param(&mut self, field: &'static str, _param: impl quex::Encode);
621
622 fn param_typed<T>(&mut self, field: &'static str, value: &'v (impl Compatible<T> + ?Sized))
624 where
625 T: TypeMeta,
626 {
627 self.param(field, value)
628 }
629}
630
631impl<M, T, R> Insertable<M> for Binary<op::Eq, Column<M, T>, R>
633where
634 T: TypeMeta,
635 M: Qrafting,
636 R: Compatible<T>,
637{
638 fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
639 visitor.param_typed::<T>(self.left.name, &self.right);
640 }
641}
642
643impl<M, V, T, R> Insertable<M> for Binary<op::Eq, ModelField<M, V, T>, R>
644where
645 T: TypeMeta,
646 M: Qrafting,
647 R: Compatible<T>,
648{
649 fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
650 visitor.param_typed::<T>(self.left.name(), &self.right);
651 }
652}
653
654impl<M, V, T, K, R> Insertable<M> for Binary<op::Eq, PersistedField<M, V, T, K>, R>
655where
656 T: TypeMeta,
657 M: Qrafting,
658 R: Compatible<T>,
659{
660 fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
661 visitor.param_typed::<T>(self.left.name(), &self.right);
662 }
663}
664
665impl<M> Insertable<M> for () {
666 fn values<'v>(&'v self, _visitor: &mut impl VisitParam<'v>) {}
667}
668
669impl<M, T> Insertable<M> for &[T]
670where
671 T: Insertable<M>,
672{
673 fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
674 for item in self.iter() {
675 item.values(visitor);
676 }
677 }
678}
679
680impl<M, T> Insertable<M> for Vec<T>
681where
682 T: Insertable<M>,
683{
684 fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
685 for item in self.iter() {
686 item.values(visitor);
687 }
688 }
689}
690
691impl<M, T, const N: usize> Insertable<M> for [T; N]
692where
693 T: Insertable<M>,
694{
695 fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
696 for item in self.iter() {
697 item.values(visitor);
698 }
699 }
700}
701
702impl<M, T> Insertable<M> for &T
703where
704 T: Insertable<M>,
705{
706 fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
707 (**self).values(visitor)
708 }
709}
710
711impl<M, T> Insertable<M> for [T]
712where
713 T: Insertable<M>,
714{
715 fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
716 for item in self.iter() {
717 item.values(visitor);
718 }
719 }
720}
721
722pub struct FormatValue<'w, W: Write> {
724 writer: &'w mut W,
725 dialect: Dialect,
726 count: usize,
727 field_count: usize,
728 result: std::fmt::Result,
729}
730
731struct QueryParamCollector<'q> {
732 params: &'q mut Vec<Param>,
733 data: &'q mut Vec<u8>,
734}
735
736impl<'q, 'v> VisitParam<'v> for QueryParamCollector<'q> {
737 fn param(&mut self, _field: &'static str, param: impl quex::Encode) {
738 self.params.push(encode_param(¶m, self.data));
739 }
740
741 fn param_typed<T>(&mut self, _field: &'static str, value: &'v (impl Compatible<T> + ?Sized))
742 where
743 T: TypeMeta,
744 {
745 self.params.push(encode_param(value, self.data));
746 }
747}
748
749impl<'w, 'v, W: Write> VisitParam<'v> for FormatValue<'w, W> {
750 fn param(&mut self, _field: &'static str, _param: impl quex::Encode) {
751 if self.result.is_err() {
752 return;
753 }
754 if self.count == 0 {
755 } else if self.count.is_multiple_of(self.field_count) {
757 self.result = self.writer.write_str("), (");
758 } else {
759 self.result = self.writer.write_str(", ");
760 }
761 if self.result.is_err() {
762 return;
763 }
764 self.count += 1;
765 self.result = match self.dialect {
766 Dialect::Postgres => write!(self.writer, "${}", self.count),
767 Dialect::MariaDb | Dialect::Sqlite => self.writer.write_char('?'),
768 };
769 }
770
771 fn param_typed<T>(&mut self, _field: &'static str, _value: &'v (impl Compatible<T> + ?Sized))
772 where
773 T: TypeMeta,
774 {
775 self.param("", "")
776 }
777}
778
779pub struct FormatContext<'w, W: Write> {
781 pub writer: &'w mut W,
782 pub index: usize,
783 pub dialect: Dialect,
784}
785
786impl<'w, W: Write> FormatContext<'w, W> {
787 pub(crate) fn write_ident(&mut self, part: &str) -> std::fmt::Result {
788 if part == "*" {
789 return self.writer.write_char('*');
790 }
791
792 let quote = match self.dialect {
793 Dialect::Postgres | Dialect::Sqlite => '"',
794 Dialect::MariaDb => '`',
795 };
796 self.writer.write_char(quote)?;
797 let dbl = if quote == '"' { "\"\"" } else { "``" };
799
800 let mut last = 0;
801 for (index, char) in part.char_indices() {
802 if char == quote {
803 if index != last {
804 self.writer.write_str(&part[last..index])?;
805 }
806 self.writer.write_str(dbl)?;
807 last = index + char.len_utf8();
808 }
809 }
810
811 if last < part.len() {
813 self.writer.write_str(&part[last..])?;
814 }
815
816 self.writer.write_char(quote)?;
817 Ok(())
818 }
819
820 pub(crate) fn write_table(&mut self, ident: &str) -> std::fmt::Result {
821 for (i, part) in ident.split('.').enumerate() {
822 if i > 0 {
823 self.writer.write_char('.')?;
824 }
825 self.write_ident(part)?;
826 }
827 Ok(())
828 }
829}
830
831fn finalize_query_params(query: &mut Query, directives: Vec<Directive>, indexes: &[usize]) {
832 for directive in directives {
833 match directive {
834 Directive::RewriteGlob { id } => {
835 let maybe_param = query.params.get_mut(id);
836 if let Some(crate::param::Param::Text(Some(text_span))) = maybe_param {
837 let text = query.data.text(*text_span);
838 let value = prepare_sqlite_glob(text);
839 if let Cow::Owned(value) = value {
840 if let TextSource::Text(span) = text_span.0
841 && value.len() == text.len()
842 {
843 let bytes = &mut query.data
844 [span.start as usize..span.start as usize + span.len as usize];
845 bytes.copy_from_slice(value.as_bytes());
846 } else {
847 let span = query.data.intern_text(&value);
848 *text_span = span;
849 }
850 }
851 }
852 }
853 }
854 }
855
856 rewrite_params(indexes, &mut query.params);
857}
858
859pub trait FormatWriter {
861 fn format_writer<'w, W: Write>(&self, context: &mut FormatContext<'w, W>) -> std::fmt::Result;
863}
864
865impl<M> FormatWriter for Insert<M>
866where
867 M: Qrafting,
868{
869 fn format_writer<'w, W: Write>(&self, context: &mut FormatContext<'w, W>) -> std::fmt::Result {
870 self.write_insert_prefix(context)?;
871 let fields = self.columns.as_slice();
872 if fields.is_empty() {
873 context.writer.write_str(" default values")?;
874 } else {
875 self.write_columns(context, fields)?;
876 context.writer.write_str(" values (")?;
877 let mut value_formatter = FormatValue {
878 writer: context.writer,
879 count: 0,
880 field_count: fields.len(),
881 dialect: context.dialect,
882 result: Ok(()),
883 };
884 for _ in &self.query.params {
885 value_formatter.param("", "");
886 }
887 value_formatter.result?;
888 context.writer.write_char(')')?;
889 }
890 self.write_conflict(context)?;
891 let mut directives = Vec::new();
892 let mut indexes = Vec::new();
893 self.write_returning(context, &self.query, &mut directives, &mut indexes)?;
894 Ok(())
895 }
896}
897
898impl<M, T> ReturningInsert<M, T>
899where
900 M: Qrafting,
901 T: FromRow,
902{
903 pub fn typed<R>(self) -> ReturningInsert<M, R>
905 where
906 R: FromRow,
907 {
908 ReturningInsert {
909 insert: self.insert,
910 marker: PhantomData,
911 }
912 }
913
914 pub fn to_sql<D: HasDialect>(&self) -> String {
916 self.insert.to_sql::<D>()
917 }
918
919 pub fn to_debug_sql<D: HasDialect>(&self) -> String {
921 match self.insert.select.clone() {
922 Some(mut query) => {
923 let mut sql = self.insert.render_select_sql::<D>(&mut query);
924 query.debug_params(&mut sql).unwrap();
925 sql
926 }
927 None => {
928 let mut sql = self.insert.to_sql::<D>();
929 self.insert.query.debug_params(&mut sql).unwrap();
930 sql
931 }
932 }
933 }
934
935 pub fn into_compiled<D: HasDialect>(self) -> TypedCompiled<T> {
937 let insert = self.insert;
938 match insert.select {
939 Some(mut query) => {
940 let insert = Insert {
941 table: insert.table,
942 columns: insert.columns,
943 select: None,
944 returning: insert.returning,
945 conflict: insert.conflict,
946 query: insert.query,
947 };
948 let sql = insert.render_select_sql::<D>(&mut query);
949 TypedCompiled {
950 sql,
951 params: query.params,
952 data: query.data,
953 marker: PhantomData,
954 }
955 }
956 None => TypedCompiled {
957 sql: insert.to_sql::<D>(),
958 params: insert.query.params,
959 data: insert.query.data,
960 marker: PhantomData,
961 },
962 }
963 }
964}
965
966pub fn insert_into<M>(table: Table<M>) -> InsertInto<M> {
998 InsertInto { table }
999}
1000
1001macro_rules! impl_insertable_macro {
1002 ($($T:ident),+) => {
1003 impl<M, $($T,)+> Insertable<M> for ($($T,)+)
1004 where
1005 $($T: Insertable<M>,)+
1006 {
1007 fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
1008 #[allow(non_snake_case)]
1009 let ($($T,)+) = self;
1010 $(
1011 $T.values(visitor);
1012 )+
1013 }
1014 }
1015 };
1016}
1017
1018impl_for_all_tuples!(impl_insertable_macro);
1019
1020#[cfg(test)]
1021mod tests {
1022 use super::insert_into;
1023 use crate::{
1024 MySql, Postgres, Sqlite,
1025 query::select,
1026 tests::{id, name, table, username},
1027 };
1028
1029 #[test]
1030 fn test_insert_default_values() {
1031 let stmt = insert_into(table).values(()).to_sql::<Sqlite>();
1032 assert_eq!(stmt, r#"insert into "users" default values returning *"#);
1033 }
1034
1035 #[test]
1036 fn test_insert_single_row_tuple_values() {
1037 let stmt = insert_into(table)
1038 .values((id.eq(1), name.eq("alice"), username.eq("alice1")))
1039 .to_sql::<Sqlite>();
1040
1041 assert_eq!(
1042 stmt,
1043 r#"insert into "users" ("id", "name", "username") values (?, ?, ?) returning *"#
1044 );
1045 }
1046
1047 #[test]
1048 fn test_insert_multiple_rows_from_array() {
1049 let stmt = insert_into(table)
1050 .values([
1051 (id.eq(1), username.eq("alpha")),
1052 (id.eq(2), username.eq("beta")),
1053 ])
1054 .to_sql::<Sqlite>();
1055
1056 assert_eq!(
1057 stmt,
1058 r#"insert into "users" ("id", "username") values (?, ?), (?, ?) returning *"#
1059 );
1060 }
1061
1062 #[test]
1063 fn test_insert_postgres_uses_dollar_placeholders() {
1064 let stmt = insert_into(table)
1065 .values((id.eq(10), username.eq("hello")))
1066 .to_sql::<Postgres>();
1067
1068 assert_eq!(
1069 stmt,
1070 r#"insert into "users" ("id", "username") values ($1, $2) returning *"#
1071 );
1072 }
1073
1074 #[test]
1075 fn test_insert_ignore_sqlite() {
1076 let stmt = insert_into(table)
1077 .values((id.eq(10), username.eq("hello")))
1078 .ignore()
1079 .to_sql::<Sqlite>();
1080
1081 assert_eq!(
1082 stmt,
1083 r#"insert or ignore into "users" ("id", "username") values (?, ?) returning *"#
1084 );
1085 }
1086
1087 #[test]
1088 fn test_insert_ignore_postgres() {
1089 let stmt = insert_into(table)
1090 .values((id.eq(10), username.eq("hello")))
1091 .ignore()
1092 .to_sql::<Postgres>();
1093
1094 assert_eq!(
1095 stmt,
1096 r#"insert into "users" ("id", "username") values ($1, $2) on conflict do nothing returning *"#
1097 );
1098 }
1099
1100 #[test]
1101 fn test_insert_ignore_mariadb() {
1102 let stmt = insert_into(table)
1103 .values((id.eq(10), username.eq("hello")))
1104 .ignore()
1105 .to_sql::<MySql>();
1106
1107 assert_eq!(
1108 stmt,
1109 "insert ignore into `users` (`id`, `username`) values (?, ?) returning *"
1110 );
1111 }
1112
1113 #[test]
1114 fn test_insert_upsert_postgres() {
1115 let stmt = insert_into(table)
1116 .values((id.eq(10), username.eq("hello"), name.eq("lea")))
1117 .upsert(["id"], ["username", "name"])
1118 .to_sql::<Postgres>();
1119
1120 assert_eq!(
1121 stmt,
1122 r#"insert into "users" ("id", "username", "name") values ($1, $2, $3) on conflict ("id") do update set "username" = "excluded"."username", "name" = "excluded"."name" returning *"#
1123 );
1124 }
1125
1126 #[test]
1127 fn test_insert_upsert_sqlite() {
1128 let stmt = insert_into(table)
1129 .values((id.eq(10), username.eq("hello")))
1130 .upsert(["id"], ["username"])
1131 .to_sql::<Sqlite>();
1132
1133 assert_eq!(
1134 stmt,
1135 r#"insert into "users" ("id", "username") values (?, ?) on conflict ("id") do update set "username" = "excluded"."username" returning *"#
1136 );
1137 }
1138
1139 #[test]
1140 fn test_insert_upsert_mariadb() {
1141 let stmt = insert_into(table)
1142 .values((id.eq(10), username.eq("hello")))
1143 .upsert(["id"], ["username"])
1144 .to_sql::<MySql>();
1145
1146 assert_eq!(
1147 stmt,
1148 "insert into `users` (`id`, `username`) values (?, ?) on duplicate key update `username` = values(`username`) returning *"
1149 );
1150 }
1151
1152 #[test]
1153 fn test_insert_upsert_accepts_tuple_and_vec_columns() {
1154 let unique_by = ("id", "username");
1155 let update = vec!["name", "username"];
1156
1157 let stmt = insert_into(table)
1158 .values((id.eq(10), username.eq("hello"), name.eq("lea")))
1159 .upsert(unique_by, update)
1160 .to_sql::<Postgres>();
1161
1162 assert_eq!(
1163 stmt,
1164 r#"insert into "users" ("id", "username", "name") values ($1, $2, $3) on conflict ("id", "username") do update set "name" = "excluded"."name", "username" = "excluded"."username" returning *"#
1165 );
1166 }
1167
1168 #[test]
1169 fn test_insert_returning_all_is_explicit() {
1170 let stmt = insert_into(table)
1171 .values((id.eq(10), username.eq("hello")))
1172 .returning_all()
1173 .to_sql::<Postgres>();
1174
1175 assert_eq!(
1176 stmt,
1177 r#"insert into "users" ("id", "username") values ($1, $2) returning *"#
1178 );
1179 }
1180
1181 #[test]
1182 fn test_insert_returning_projection_matches_update_style() {
1183 let stmt = insert_into(table)
1184 .values((id.eq(10), username.eq("hello")))
1185 .returning((id, username))
1186 .to_sql::<Postgres>();
1187
1188 assert_eq!(
1189 stmt,
1190 r#"insert into "users" ("id", "username") values ($1, $2) returning "users"."id", "users"."username""#
1191 );
1192 }
1193
1194 #[test]
1195 fn test_insert_select_sqlite() {
1196 let stmt = insert_into(table)
1197 .columns((id, username))
1198 .query(
1199 select((id, username))
1200 .from(table)
1201 .filter(username.eq("alice")),
1202 )
1203 .to_sql::<Sqlite>();
1204
1205 assert_eq!(
1206 stmt,
1207 r#"insert into "users" ("id", "username") select "users"."id", "users"."username" from "users" where "users"."username" = ? returning *"#
1208 );
1209 }
1210
1211 #[test]
1212 fn test_insert_select_postgres_with_filter_and_upsert() {
1213 let stmt = insert_into(table)
1214 .columns((id, username))
1215 .query(select((id, username)).from(table).filter(id.eq(10)))
1216 .upsert(["id"], ["username"])
1217 .to_sql::<Postgres>();
1218
1219 assert_eq!(
1220 stmt,
1221 r#"insert into "users" ("id", "username") select "users"."id", "users"."username" from "users" where "users"."id" = $1 on conflict ("id") do update set "username" = "excluded"."username" returning *"#
1222 );
1223 }
1224
1225 #[test]
1226 fn test_insert_select_mariadb_ignore_no_returning() {
1227 let stmt = insert_into(table)
1228 .columns((id, username))
1229 .query(select((id, username)).from(table))
1230 .ignore()
1231 .no_returning()
1232 .to_sql::<MySql>();
1233
1234 assert_eq!(
1235 stmt,
1236 "insert ignore into `users` (`id`, `username`) select `users`.`id`, `users`.`username` from `users`"
1237 );
1238 }
1239}