Skip to main content

qraft_core/builder/
insert.rs

1//! Insert statement builders and conflict helpers.
2
3use std::{borrow::Cow, fmt::Write, marker::PhantomData};
4
5use crate::{
6    Compatible, Dialect, FromRow, HasDialect, Qrafting, Query, QueryOf, RpnInstr, TypeMeta,
7    emitter::{Directive, Emitter},
8    expression::{Binary, Column, op, prepare_sqlite_glob},
9    impl_for_all_tuples,
10    lower::{Data, LowerCtx},
11    param::{Param, encode_param},
12    query::{
13        LockedQuery, LockedQueryOf, LowerProject, Select, Table, TypedCompiled, WithSelect,
14        rewrite_params,
15    },
16    relation::{ModelField, PersistedField},
17    span::TextSource,
18};
19
20/// Provides values for an insert builder.
21pub trait Insertable<M> {
22    /// Visits the field/value pairs that should be inserted.
23    fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>);
24
25    /// Starts an insert using this value source.
26    fn insert_into(self, table: Table<M>) -> Insert<M>
27    where
28        M: Qrafting,
29        Self: Sized,
30    {
31        crate::insert_into(table).values(self)
32    }
33}
34
35/// The first stage of an insert builder, before values are attached.
36pub struct InsertInto<M> {
37    table: Table<M>,
38}
39
40/// The insert stage used for `insert into .. (columns) select ..`.
41pub struct InsertColumnsInto<M> {
42    table: Table<M>,
43    columns: Vec<&'static str>,
44}
45
46/// A typed insert statement with its value source.
47pub struct Insert<M> {
48    /// Insert target table.
49    pub table: Table<M>,
50    /// Explicit target columns used by value or select-based inserts.
51    pub columns: Vec<&'static str>,
52    /// Optional select body used instead of `values (...)`.
53    pub select: Option<Query>,
54    /// Returning behavior for the insert.
55    pub returning: InsertReturning,
56    /// Conflict handling mode for the insert.
57    pub conflict: InsertConflict,
58    /// Shared query state used for value inserts with returning projections.
59    pub query: Query,
60}
61
62/// Insert returning behavior.
63pub enum InsertReturning {
64    /// Omit the `returning` clause entirely.
65    None,
66    /// Emit `returning *`.
67    All,
68    /// Emit an explicit returning projection.
69    Projection(Vec<RpnInstr>),
70}
71
72/// An insert statement with a typed `returning` projection.
73pub struct ReturningInsert<M, T> {
74    /// Underlying insert statement.
75    pub insert: Insert<M>,
76    /// Returned row type marker.
77    pub marker: PhantomData<T>,
78}
79
80/// Conflict behavior emitted after the insert body.
81#[derive(Debug, Clone, PartialEq, Eq)]
82pub enum InsertConflict {
83    None,
84    Ignore,
85    Upsert {
86        unique_by: Vec<&'static str>,
87        update: Vec<&'static str>,
88    },
89}
90
91/// Collects distinct field names from insert values.
92pub struct FieldCollector {
93    fields: Vec<&'static str>,
94}
95
96impl FieldCollector {
97    pub fn new() -> Self {
98        Self { fields: Vec::new() }
99    }
100
101    /// Returns the collected field names in first-seen order.
102    pub fn into_inner(self) -> Vec<&'static str> {
103        self.fields
104    }
105}
106
107impl Default for FieldCollector {
108    fn default() -> Self {
109        Self::new()
110    }
111}
112
113impl<'a> VisitParam<'a> for FieldCollector {
114    fn param(&mut self, field: &'static str, _param: impl quex::Encode) {
115        if !self.fields.contains(&field) {
116            self.fields.push(field);
117        }
118    }
119
120    fn param_typed<T>(&mut self, field: &'static str, _value: &'a (impl Compatible<T> + ?Sized))
121    where
122        T: TypeMeta,
123    {
124        if !self.fields.contains(&field) {
125            self.fields.push(field);
126        }
127    }
128}
129
130impl<M> Insert<M>
131where
132    M: Qrafting,
133{
134    /// Emits SQL for the requested dialect.
135    pub fn to_sql<D: HasDialect>(&self) -> String {
136        match self.select.clone() {
137            Some(mut query) => self.render_select_sql::<D>(&mut query),
138            None => {
139                let mut writer = String::new();
140                let mut ctx = FormatContext {
141                    writer: &mut writer,
142                    index: 0,
143                    dialect: D::DIALECT,
144                };
145                self.format_writer(&mut ctx)
146                    .expect("cannot fail on a string writer");
147                writer
148            }
149        }
150    }
151
152    /// Switches conflict handling to the dialect's "ignore duplicates" form.
153    pub fn ignore(mut self) -> Self {
154        self.conflict = InsertConflict::Ignore;
155        self
156    }
157
158    /// Adds a `returning` projection and switches into the returning builder.
159    pub fn returning<P>(mut self, project: P) -> ReturningInsert<M, M>
160    where
161        P: LowerProject,
162        M: FromRow,
163    {
164        let mut instrs = Vec::new();
165        match self.select.as_mut() {
166            Some(query) => {
167                let mut ctx = LowerCtx {
168                    instrs: &mut instrs,
169                    params: &mut query.params,
170                    data: &mut query.data,
171                };
172                project.lower_project(&mut ctx);
173            }
174            None => {
175                let mut ctx = LowerCtx {
176                    instrs: &mut instrs,
177                    params: &mut self.query.params,
178                    data: &mut self.query.data,
179                };
180                project.lower_project(&mut ctx);
181            }
182        }
183        self.returning = InsertReturning::Projection(instrs);
184        ReturningInsert {
185            insert: self,
186            marker: PhantomData,
187        }
188    }
189
190    /// Forces `returning *` on dialects that support it.
191    pub fn returning_all(mut self) -> ReturningInsert<M, M>
192    where
193        M: FromRow,
194    {
195        self.returning = InsertReturning::All;
196        ReturningInsert {
197            insert: self,
198            marker: PhantomData,
199        }
200    }
201
202    /// Disables automatic `returning *`.
203    pub fn no_returning(mut self) -> Self {
204        self.returning = InsertReturning::None;
205        self
206    }
207
208    /// Configures an upsert using the given conflict keys and updated columns.
209    pub fn upsert<U, C>(mut self, unique_by: C, update: U) -> Self
210    where
211        C: InsertColumns,
212        U: InsertColumns,
213    {
214        self.conflict = InsertConflict::Upsert {
215            unique_by: unique_by.into_columns(),
216            update: update.into_columns(),
217        };
218        self
219    }
220
221    #[doc(hidden)]
222    pub fn into_select_compiled<D: HasDialect>(self) -> Option<TypedCompiled<M>> {
223        let Insert {
224            table,
225            columns,
226            select,
227            returning,
228            conflict,
229            query: insert_query,
230        } = self;
231        let mut query = select?;
232        let insert = Insert {
233            table,
234            columns,
235            select: None,
236            returning,
237            conflict,
238            query: insert_query,
239        };
240        let sql = insert.render_select_sql::<D>(&mut query);
241        Some(TypedCompiled {
242            sql,
243            params: query.params,
244            data: query.data,
245            marker: PhantomData,
246        })
247    }
248
249    fn render_select_sql<D: HasDialect>(&self, query: &mut Query) -> String {
250        let columns = self.columns.as_slice();
251        assert!(
252            !columns.is_empty(),
253            "insert-select requires at least one target column"
254        );
255
256        let mut writer = String::new();
257        let mut directives = Vec::new();
258        let mut indexes = Vec::new();
259
260        {
261            let mut emitter = Emitter::new(
262                &mut writer,
263                &query.data,
264                D::DIALECT,
265                &mut directives,
266                &mut indexes,
267            );
268            emitter.emit_ctes_for_query(query).unwrap();
269        }
270
271        let mut ctx = FormatContext {
272            writer: &mut writer,
273            index: 0,
274            dialect: D::DIALECT,
275        };
276
277        self.write_insert_prefix(&mut ctx).unwrap();
278        self.write_columns(&mut ctx, columns).unwrap();
279        ctx.writer.write_char(' ').unwrap();
280
281        {
282            let writer = &mut *ctx.writer;
283            let mut emitter = Emitter::new(
284                writer,
285                &query.data,
286                D::DIALECT,
287                &mut directives,
288                &mut indexes,
289            );
290            emitter.emit_query_body(query).unwrap();
291        }
292
293        self.write_conflict(&mut ctx).unwrap();
294        self.write_returning(&mut ctx, query, &mut directives, &mut indexes)
295            .unwrap();
296
297        finalize_query_params(query, directives, &indexes);
298
299        writer
300    }
301
302    fn write_insert_prefix<'w, W: Write>(
303        &self,
304        context: &mut FormatContext<'w, W>,
305    ) -> std::fmt::Result {
306        match (&self.conflict, context.dialect) {
307            (InsertConflict::Ignore, Dialect::Sqlite) => {
308                context.writer.write_str("insert or ignore into ")
309            }
310            (InsertConflict::Ignore, Dialect::MariaDb) => {
311                context.writer.write_str("insert ignore into ")
312            }
313            _ => context.writer.write_str("insert into "),
314        }?;
315        self.table.format_writer(context)
316    }
317
318    fn write_columns<'w, W: Write>(
319        &self,
320        context: &mut FormatContext<'w, W>,
321        fields: &[&'static str],
322    ) -> std::fmt::Result {
323        context.writer.write_str(" (")?;
324        for (i, field) in fields.iter().enumerate() {
325            if i > 0 {
326                context.writer.write_str(", ")?;
327            }
328            context.write_ident(field)?;
329        }
330        context.writer.write_char(')')
331    }
332
333    fn write_conflict<'w, W: Write>(&self, context: &mut FormatContext<'w, W>) -> std::fmt::Result {
334        match (&self.conflict, context.dialect) {
335            (InsertConflict::None, _)
336            | (InsertConflict::Ignore, Dialect::Sqlite | Dialect::MariaDb) => {}
337            (InsertConflict::Ignore, Dialect::Postgres) => {
338                context.writer.write_str(" on conflict do nothing")?;
339            }
340            (InsertConflict::Upsert { unique_by, update }, Dialect::Postgres | Dialect::Sqlite) => {
341                assert!(
342                    !unique_by.is_empty(),
343                    "upsert requires at least one conflict column"
344                );
345                assert!(
346                    !update.is_empty(),
347                    "upsert requires at least one update column"
348                );
349                context.writer.write_str(" on conflict (")?;
350                for (i, field) in unique_by.iter().enumerate() {
351                    if i > 0 {
352                        context.writer.write_str(", ")?;
353                    }
354                    context.write_ident(field)?;
355                }
356                context.writer.write_str(") do update set ")?;
357                for (i, field) in update.iter().enumerate() {
358                    if i > 0 {
359                        context.writer.write_str(", ")?;
360                    }
361                    context.write_ident(field)?;
362                    context.writer.write_str(" = ")?;
363                    context.write_ident("excluded")?;
364                    context.writer.write_char('.')?;
365                    context.write_ident(field)?;
366                }
367            }
368            (InsertConflict::Upsert { update, .. }, Dialect::MariaDb) => {
369                assert!(
370                    !update.is_empty(),
371                    "upsert requires at least one update column"
372                );
373                context.writer.write_str(" on duplicate key update ")?;
374                for (i, field) in update.iter().enumerate() {
375                    if i > 0 {
376                        context.writer.write_str(", ")?;
377                    }
378                    context.write_ident(field)?;
379                    context.writer.write_str(" = values(")?;
380                    context.write_ident(field)?;
381                    context.writer.write_char(')')?;
382                }
383            }
384        }
385
386        Ok(())
387    }
388
389    fn write_returning<'w, W: Write>(
390        &self,
391        context: &mut FormatContext<'w, W>,
392        query: &Query,
393        directives: &mut Vec<Directive>,
394        indexes: &mut Vec<usize>,
395    ) -> std::fmt::Result {
396        match &self.returning {
397            InsertReturning::None => Ok(()),
398            InsertReturning::All => context.writer.write_str(" returning *"),
399            InsertReturning::Projection(instrs) => {
400                context.writer.write_str(" returning ")?;
401                let mut emitter = Emitter::new(
402                    context.writer,
403                    &query.data,
404                    context.dialect,
405                    directives,
406                    indexes,
407                );
408                emitter.emit_instrs(instrs)
409            }
410        }
411    }
412}
413
414impl<M> InsertInto<M> {
415    /// Attaches the values that will be inserted.
416    pub fn values<V>(self, values: V) -> Insert<M>
417    where
418        M: Qrafting,
419        V: Insertable<M>,
420    {
421        let mut query = Query::default();
422        let mut field_collector = FieldCollector::new();
423        values.values(&mut field_collector);
424        let columns = field_collector.into_inner();
425
426        let mut param_collector = QueryParamCollector {
427            params: &mut query.params,
428            data: &mut query.data,
429        };
430        values.values(&mut param_collector);
431
432        Insert {
433            table: self.table,
434            columns,
435            select: None,
436            returning: InsertReturning::All,
437            conflict: InsertConflict::None,
438            query,
439        }
440    }
441
442    /// Declares the target columns used by a later query body.
443    pub fn columns<C>(self, columns: C) -> InsertColumnsInto<M>
444    where
445        C: InsertColumns,
446    {
447        let columns = columns.into_columns();
448        assert!(
449            !columns.is_empty(),
450            "insert-select requires at least one target column"
451        );
452        InsertColumnsInto {
453            table: self.table,
454            columns,
455        }
456    }
457}
458
459impl<M> InsertColumnsInto<M> {
460    /// Attaches the query body used to populate the insert.
461    pub fn query<Q>(self, query: Q) -> Insert<M>
462    where
463        Q: IntoInsertSelectQuery,
464    {
465        Insert {
466            table: self.table,
467            columns: self.columns,
468            select: Some(query.into_insert_select_query()),
469            returning: InsertReturning::All,
470            conflict: InsertConflict::None,
471            query: Query::default(),
472        }
473    }
474}
475
476/// Normalizes different column-list inputs for upserts.
477pub trait InsertColumns {
478    /// Returns the column names as owned static slices.
479    fn into_columns(self) -> Vec<&'static str>;
480}
481
482/// Converts a single typed or untyped column identifier into its SQL field name.
483pub trait IntoInsertColumn {
484    fn into_insert_column(self) -> &'static str;
485}
486
487/// Converts the accepted `INSERT .. SELECT` inputs into an untyped query.
488pub trait IntoInsertSelectQuery {
489    fn into_insert_select_query(self) -> Query;
490}
491
492impl IntoInsertSelectQuery for Query {
493    fn into_insert_select_query(self) -> Query {
494        self
495    }
496}
497
498impl<M> IntoInsertSelectQuery for QueryOf<M> {
499    fn into_insert_select_query(self) -> Query {
500        self.into()
501    }
502}
503
504impl IntoInsertSelectQuery for LockedQuery {
505    fn into_insert_select_query(self) -> Query {
506        self.into()
507    }
508}
509
510impl<M> IntoInsertSelectQuery for LockedQueryOf<M> {
511    fn into_insert_select_query(self) -> Query {
512        self.into()
513    }
514}
515
516impl<P> IntoInsertSelectQuery for Select<P>
517where
518    P: LowerProject,
519{
520    fn into_insert_select_query(self) -> Query {
521        self.into_query()
522    }
523}
524
525impl<P> IntoInsertSelectQuery for WithSelect<P>
526where
527    P: LowerProject,
528{
529    fn into_insert_select_query(self) -> Query {
530        self.into_query()
531    }
532}
533
534impl IntoInsertColumn for &'static str {
535    fn into_insert_column(self) -> &'static str {
536        self
537    }
538}
539
540impl<M, T> IntoInsertColumn for Column<M, T>
541where
542    T: TypeMeta,
543{
544    fn into_insert_column(self) -> &'static str {
545        self.name
546    }
547}
548
549impl<M, V, T> IntoInsertColumn for ModelField<M, V, T>
550where
551    T: TypeMeta,
552{
553    fn into_insert_column(self) -> &'static str {
554        self.name()
555    }
556}
557
558impl<M, V, T, K> IntoInsertColumn for PersistedField<M, V, T, K>
559where
560    T: TypeMeta,
561{
562    fn into_insert_column(self) -> &'static str {
563        self.name()
564    }
565}
566
567impl InsertColumns for &'static str {
568    fn into_columns(self) -> Vec<&'static str> {
569        vec![self]
570    }
571}
572
573impl<M, T> InsertColumns for Column<M, T>
574where
575    T: TypeMeta,
576{
577    fn into_columns(self) -> Vec<&'static str> {
578        vec![self.name]
579    }
580}
581
582impl<const N: usize> InsertColumns for [&'static str; N] {
583    fn into_columns(self) -> Vec<&'static str> {
584        self.into_iter().collect()
585    }
586}
587
588impl InsertColumns for &[&'static str] {
589    fn into_columns(self) -> Vec<&'static str> {
590        self.to_vec()
591    }
592}
593
594impl InsertColumns for Vec<&'static str> {
595    fn into_columns(self) -> Vec<&'static str> {
596        self
597    }
598}
599
600macro_rules! impl_insert_columns_tuple {
601    ($($T:ident),+) => {
602        impl<$($T,)+> InsertColumns for ($($T,)+)
603        where
604            $($T: IntoInsertColumn,)+
605        {
606            fn into_columns(self) -> Vec<&'static str> {
607                #[allow(non_snake_case)]
608                let ($($T,)+) = self;
609                vec![$($T.into_insert_column(),)+]
610            }
611        }
612    };
613}
614
615impl_for_all_tuples!(impl_insert_columns_tuple);
616
617/// Visits inserted field/value pairs.
618pub trait VisitParam<'v> {
619    /// Records a field and its bound value.
620    fn param(&mut self, field: &'static str, _param: impl quex::Encode);
621
622    /// Records a field using typed compatibility for the SQL column type.
623    fn param_typed<T>(&mut self, field: &'static str, value: &'v (impl Compatible<T> + ?Sized))
624    where
625        T: TypeMeta,
626    {
627        self.param(field, value)
628    }
629}
630
631// all the tuples for this
632impl<M, T, R> Insertable<M> for Binary<op::Eq, Column<M, T>, R>
633where
634    T: TypeMeta,
635    M: Qrafting,
636    R: Compatible<T>,
637{
638    fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
639        visitor.param_typed::<T>(self.left.name, &self.right);
640    }
641}
642
643impl<M, V, T, R> Insertable<M> for Binary<op::Eq, ModelField<M, V, T>, R>
644where
645    T: TypeMeta,
646    M: Qrafting,
647    R: Compatible<T>,
648{
649    fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
650        visitor.param_typed::<T>(self.left.name(), &self.right);
651    }
652}
653
654impl<M, V, T, K, R> Insertable<M> for Binary<op::Eq, PersistedField<M, V, T, K>, R>
655where
656    T: TypeMeta,
657    M: Qrafting,
658    R: Compatible<T>,
659{
660    fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
661        visitor.param_typed::<T>(self.left.name(), &self.right);
662    }
663}
664
665impl<M> Insertable<M> for () {
666    fn values<'v>(&'v self, _visitor: &mut impl VisitParam<'v>) {}
667}
668
669impl<M, T> Insertable<M> for &[T]
670where
671    T: Insertable<M>,
672{
673    fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
674        for item in self.iter() {
675            item.values(visitor);
676        }
677    }
678}
679
680impl<M, T> Insertable<M> for Vec<T>
681where
682    T: Insertable<M>,
683{
684    fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
685        for item in self.iter() {
686            item.values(visitor);
687        }
688    }
689}
690
691impl<M, T, const N: usize> Insertable<M> for [T; N]
692where
693    T: Insertable<M>,
694{
695    fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
696        for item in self.iter() {
697            item.values(visitor);
698        }
699    }
700}
701
702impl<M, T> Insertable<M> for &T
703where
704    T: Insertable<M>,
705{
706    fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
707        (**self).values(visitor)
708    }
709}
710
711impl<M, T> Insertable<M> for [T]
712where
713    T: Insertable<M>,
714{
715    fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
716        for item in self.iter() {
717            item.values(visitor);
718        }
719    }
720}
721
722/// Writes placeholder lists for insert rows.
723pub struct FormatValue<'w, W: Write> {
724    writer: &'w mut W,
725    dialect: Dialect,
726    count: usize,
727    field_count: usize,
728    result: std::fmt::Result,
729}
730
731struct QueryParamCollector<'q> {
732    params: &'q mut Vec<Param>,
733    data: &'q mut Vec<u8>,
734}
735
736impl<'q, 'v> VisitParam<'v> for QueryParamCollector<'q> {
737    fn param(&mut self, _field: &'static str, param: impl quex::Encode) {
738        self.params.push(encode_param(&param, self.data));
739    }
740
741    fn param_typed<T>(&mut self, _field: &'static str, value: &'v (impl Compatible<T> + ?Sized))
742    where
743        T: TypeMeta,
744    {
745        self.params.push(encode_param(value, self.data));
746    }
747}
748
749impl<'w, 'v, W: Write> VisitParam<'v> for FormatValue<'w, W> {
750    fn param(&mut self, _field: &'static str, _param: impl quex::Encode) {
751        if self.result.is_err() {
752            return;
753        }
754        if self.count == 0 {
755            // nop
756        } else if self.count.is_multiple_of(self.field_count) {
757            self.result = self.writer.write_str("), (");
758        } else {
759            self.result = self.writer.write_str(", ");
760        }
761        if self.result.is_err() {
762            return;
763        }
764        self.count += 1;
765        self.result = match self.dialect {
766            Dialect::Postgres => write!(self.writer, "${}", self.count),
767            Dialect::MariaDb | Dialect::Sqlite => self.writer.write_char('?'),
768        };
769    }
770
771    fn param_typed<T>(&mut self, _field: &'static str, _value: &'v (impl Compatible<T> + ?Sized))
772    where
773        T: TypeMeta,
774    {
775        self.param("", "")
776    }
777}
778
779/// Shared state used while formatting an insert statement.
780pub struct FormatContext<'w, W: Write> {
781    pub writer: &'w mut W,
782    pub index: usize,
783    pub dialect: Dialect,
784}
785
786impl<'w, W: Write> FormatContext<'w, W> {
787    pub(crate) fn write_ident(&mut self, part: &str) -> std::fmt::Result {
788        if part == "*" {
789            return self.writer.write_char('*');
790        }
791
792        let quote = match self.dialect {
793            Dialect::Postgres | Dialect::Sqlite => '"',
794            Dialect::MariaDb => '`',
795        };
796        self.writer.write_char(quote)?;
797        // duplicate the quote if present
798        let dbl = if quote == '"' { "\"\"" } else { "``" };
799
800        let mut last = 0;
801        for (index, char) in part.char_indices() {
802            if char == quote {
803                if index != last {
804                    self.writer.write_str(&part[last..index])?;
805                }
806                self.writer.write_str(dbl)?;
807                last = index + char.len_utf8();
808            }
809        }
810
811        // write trailing slice
812        if last < part.len() {
813            self.writer.write_str(&part[last..])?;
814        }
815
816        self.writer.write_char(quote)?;
817        Ok(())
818    }
819
820    pub(crate) fn write_table(&mut self, ident: &str) -> std::fmt::Result {
821        for (i, part) in ident.split('.').enumerate() {
822            if i > 0 {
823                self.writer.write_char('.')?;
824            }
825            self.write_ident(part)?;
826        }
827        Ok(())
828    }
829}
830
831fn finalize_query_params(query: &mut Query, directives: Vec<Directive>, indexes: &[usize]) {
832    for directive in directives {
833        match directive {
834            Directive::RewriteGlob { id } => {
835                let maybe_param = query.params.get_mut(id);
836                if let Some(crate::param::Param::Text(Some(text_span))) = maybe_param {
837                    let text = query.data.text(*text_span);
838                    let value = prepare_sqlite_glob(text);
839                    if let Cow::Owned(value) = value {
840                        if let TextSource::Text(span) = text_span.0
841                            && value.len() == text.len()
842                        {
843                            let bytes = &mut query.data
844                                [span.start as usize..span.start as usize + span.len as usize];
845                            bytes.copy_from_slice(value.as_bytes());
846                        } else {
847                            let span = query.data.intern_text(&value);
848                            *text_span = span;
849                        }
850                    }
851                }
852            }
853        }
854    }
855
856    rewrite_params(indexes, &mut query.params);
857}
858
859/// Emits SQL into a `FormatContext`.
860pub trait FormatWriter {
861    /// Writes `self` into the supplied formatting context.
862    fn format_writer<'w, W: Write>(&self, context: &mut FormatContext<'w, W>) -> std::fmt::Result;
863}
864
865impl<M> FormatWriter for Insert<M>
866where
867    M: Qrafting,
868{
869    fn format_writer<'w, W: Write>(&self, context: &mut FormatContext<'w, W>) -> std::fmt::Result {
870        self.write_insert_prefix(context)?;
871        let fields = self.columns.as_slice();
872        if fields.is_empty() {
873            context.writer.write_str(" default values")?;
874        } else {
875            self.write_columns(context, fields)?;
876            context.writer.write_str(" values (")?;
877            let mut value_formatter = FormatValue {
878                writer: context.writer,
879                count: 0,
880                field_count: fields.len(),
881                dialect: context.dialect,
882                result: Ok(()),
883            };
884            for _ in &self.query.params {
885                value_formatter.param("", "");
886            }
887            value_formatter.result?;
888            context.writer.write_char(')')?;
889        }
890        self.write_conflict(context)?;
891        let mut directives = Vec::new();
892        let mut indexes = Vec::new();
893        self.write_returning(context, &self.query, &mut directives, &mut indexes)?;
894        Ok(())
895    }
896}
897
898impl<M, T> ReturningInsert<M, T>
899where
900    M: Qrafting,
901    T: FromRow,
902{
903    /// Changes the returned row type without changing the SQL projection.
904    pub fn typed<R>(self) -> ReturningInsert<M, R>
905    where
906        R: FromRow,
907    {
908        ReturningInsert {
909            insert: self.insert,
910            marker: PhantomData,
911        }
912    }
913
914    /// Emits SQL for the requested dialect.
915    pub fn to_sql<D: HasDialect>(&self) -> String {
916        self.insert.to_sql::<D>()
917    }
918
919    /// Emits SQL with appended debug parameter output.
920    pub fn to_debug_sql<D: HasDialect>(&self) -> String {
921        match self.insert.select.clone() {
922            Some(mut query) => {
923                let mut sql = self.insert.render_select_sql::<D>(&mut query);
924                query.debug_params(&mut sql).unwrap();
925                sql
926            }
927            None => {
928                let mut sql = self.insert.to_sql::<D>();
929                self.insert.query.debug_params(&mut sql).unwrap();
930                sql
931            }
932        }
933    }
934
935    /// Compiles the returning insert into the typed representation used by executors.
936    pub fn into_compiled<D: HasDialect>(self) -> TypedCompiled<T> {
937        let insert = self.insert;
938        match insert.select {
939            Some(mut query) => {
940                let insert = Insert {
941                    table: insert.table,
942                    columns: insert.columns,
943                    select: None,
944                    returning: insert.returning,
945                    conflict: insert.conflict,
946                    query: insert.query,
947                };
948                let sql = insert.render_select_sql::<D>(&mut query);
949                TypedCompiled {
950                    sql,
951                    params: query.params,
952                    data: query.data,
953                    marker: PhantomData,
954                }
955            }
956            None => TypedCompiled {
957                sql: insert.to_sql::<D>(),
958                params: insert.query.params,
959                data: insert.query.data,
960                marker: PhantomData,
961            },
962        }
963    }
964}
965
966/// Starts an insert statement for the given table.
967///
968/// # Examples
969///
970/// ```rust
971/// use qraft_core::{
972///     BigInt, DefaultQueryPolicy, Qrafting, Sqlite, Text, expression::Column, insert_into,
973///     query::Table,
974/// };
975///
976/// struct User;
977///
978/// impl Qrafting for User {
979///     type Schema = ();
980///     type QueryPolicy = DefaultQueryPolicy<Self>;
981///     const FIELD_COUNT: usize = 2;
982///     const TABLE: &'static str = "users";
983/// }
984///
985/// let id = Column::<User, BigInt>::new("id");
986/// let name = Column::<User, Text>::new("name");
987/// let users = Table::<User>::new("users");
988/// let sql = insert_into(users)
989///     .values((id.eq(1_i64), name.eq("lea")))
990///     .to_sql::<Sqlite>();
991///
992/// assert_eq!(
993///     sql,
994///     r#"insert into "users" ("id", "name") values (?, ?) returning *"#
995/// );
996/// ```
997pub fn insert_into<M>(table: Table<M>) -> InsertInto<M> {
998    InsertInto { table }
999}
1000
1001macro_rules! impl_insertable_macro {
1002    ($($T:ident),+) => {
1003        impl<M, $($T,)+> Insertable<M> for ($($T,)+)
1004        where
1005            $($T: Insertable<M>,)+
1006        {
1007            fn values<'v>(&'v self, visitor: &mut impl VisitParam<'v>) {
1008                #[allow(non_snake_case)]
1009                let ($($T,)+) = self;
1010                $(
1011                    $T.values(visitor);
1012                )+
1013            }
1014        }
1015    };
1016}
1017
1018impl_for_all_tuples!(impl_insertable_macro);
1019
1020#[cfg(test)]
1021mod tests {
1022    use super::insert_into;
1023    use crate::{
1024        MySql, Postgres, Sqlite,
1025        query::select,
1026        tests::{id, name, table, username},
1027    };
1028
1029    #[test]
1030    fn test_insert_default_values() {
1031        let stmt = insert_into(table).values(()).to_sql::<Sqlite>();
1032        assert_eq!(stmt, r#"insert into "users" default values returning *"#);
1033    }
1034
1035    #[test]
1036    fn test_insert_single_row_tuple_values() {
1037        let stmt = insert_into(table)
1038            .values((id.eq(1), name.eq("alice"), username.eq("alice1")))
1039            .to_sql::<Sqlite>();
1040
1041        assert_eq!(
1042            stmt,
1043            r#"insert into "users" ("id", "name", "username") values (?, ?, ?) returning *"#
1044        );
1045    }
1046
1047    #[test]
1048    fn test_insert_multiple_rows_from_array() {
1049        let stmt = insert_into(table)
1050            .values([
1051                (id.eq(1), username.eq("alpha")),
1052                (id.eq(2), username.eq("beta")),
1053            ])
1054            .to_sql::<Sqlite>();
1055
1056        assert_eq!(
1057            stmt,
1058            r#"insert into "users" ("id", "username") values (?, ?), (?, ?) returning *"#
1059        );
1060    }
1061
1062    #[test]
1063    fn test_insert_postgres_uses_dollar_placeholders() {
1064        let stmt = insert_into(table)
1065            .values((id.eq(10), username.eq("hello")))
1066            .to_sql::<Postgres>();
1067
1068        assert_eq!(
1069            stmt,
1070            r#"insert into "users" ("id", "username") values ($1, $2) returning *"#
1071        );
1072    }
1073
1074    #[test]
1075    fn test_insert_ignore_sqlite() {
1076        let stmt = insert_into(table)
1077            .values((id.eq(10), username.eq("hello")))
1078            .ignore()
1079            .to_sql::<Sqlite>();
1080
1081        assert_eq!(
1082            stmt,
1083            r#"insert or ignore into "users" ("id", "username") values (?, ?) returning *"#
1084        );
1085    }
1086
1087    #[test]
1088    fn test_insert_ignore_postgres() {
1089        let stmt = insert_into(table)
1090            .values((id.eq(10), username.eq("hello")))
1091            .ignore()
1092            .to_sql::<Postgres>();
1093
1094        assert_eq!(
1095            stmt,
1096            r#"insert into "users" ("id", "username") values ($1, $2) on conflict do nothing returning *"#
1097        );
1098    }
1099
1100    #[test]
1101    fn test_insert_ignore_mariadb() {
1102        let stmt = insert_into(table)
1103            .values((id.eq(10), username.eq("hello")))
1104            .ignore()
1105            .to_sql::<MySql>();
1106
1107        assert_eq!(
1108            stmt,
1109            "insert ignore into `users` (`id`, `username`) values (?, ?) returning *"
1110        );
1111    }
1112
1113    #[test]
1114    fn test_insert_upsert_postgres() {
1115        let stmt = insert_into(table)
1116            .values((id.eq(10), username.eq("hello"), name.eq("lea")))
1117            .upsert(["id"], ["username", "name"])
1118            .to_sql::<Postgres>();
1119
1120        assert_eq!(
1121            stmt,
1122            r#"insert into "users" ("id", "username", "name") values ($1, $2, $3) on conflict ("id") do update set "username" = "excluded"."username", "name" = "excluded"."name" returning *"#
1123        );
1124    }
1125
1126    #[test]
1127    fn test_insert_upsert_sqlite() {
1128        let stmt = insert_into(table)
1129            .values((id.eq(10), username.eq("hello")))
1130            .upsert(["id"], ["username"])
1131            .to_sql::<Sqlite>();
1132
1133        assert_eq!(
1134            stmt,
1135            r#"insert into "users" ("id", "username") values (?, ?) on conflict ("id") do update set "username" = "excluded"."username" returning *"#
1136        );
1137    }
1138
1139    #[test]
1140    fn test_insert_upsert_mariadb() {
1141        let stmt = insert_into(table)
1142            .values((id.eq(10), username.eq("hello")))
1143            .upsert(["id"], ["username"])
1144            .to_sql::<MySql>();
1145
1146        assert_eq!(
1147            stmt,
1148            "insert into `users` (`id`, `username`) values (?, ?) on duplicate key update `username` = values(`username`) returning *"
1149        );
1150    }
1151
1152    #[test]
1153    fn test_insert_upsert_accepts_tuple_and_vec_columns() {
1154        let unique_by = ("id", "username");
1155        let update = vec!["name", "username"];
1156
1157        let stmt = insert_into(table)
1158            .values((id.eq(10), username.eq("hello"), name.eq("lea")))
1159            .upsert(unique_by, update)
1160            .to_sql::<Postgres>();
1161
1162        assert_eq!(
1163            stmt,
1164            r#"insert into "users" ("id", "username", "name") values ($1, $2, $3) on conflict ("id", "username") do update set "name" = "excluded"."name", "username" = "excluded"."username" returning *"#
1165        );
1166    }
1167
1168    #[test]
1169    fn test_insert_returning_all_is_explicit() {
1170        let stmt = insert_into(table)
1171            .values((id.eq(10), username.eq("hello")))
1172            .returning_all()
1173            .to_sql::<Postgres>();
1174
1175        assert_eq!(
1176            stmt,
1177            r#"insert into "users" ("id", "username") values ($1, $2) returning *"#
1178        );
1179    }
1180
1181    #[test]
1182    fn test_insert_returning_projection_matches_update_style() {
1183        let stmt = insert_into(table)
1184            .values((id.eq(10), username.eq("hello")))
1185            .returning((id, username))
1186            .to_sql::<Postgres>();
1187
1188        assert_eq!(
1189            stmt,
1190            r#"insert into "users" ("id", "username") values ($1, $2) returning "users"."id", "users"."username""#
1191        );
1192    }
1193
1194    #[test]
1195    fn test_insert_select_sqlite() {
1196        let stmt = insert_into(table)
1197            .columns((id, username))
1198            .query(
1199                select((id, username))
1200                    .from(table)
1201                    .filter(username.eq("alice")),
1202            )
1203            .to_sql::<Sqlite>();
1204
1205        assert_eq!(
1206            stmt,
1207            r#"insert into "users" ("id", "username") select "users"."id", "users"."username" from "users" where "users"."username" = ? returning *"#
1208        );
1209    }
1210
1211    #[test]
1212    fn test_insert_select_postgres_with_filter_and_upsert() {
1213        let stmt = insert_into(table)
1214            .columns((id, username))
1215            .query(select((id, username)).from(table).filter(id.eq(10)))
1216            .upsert(["id"], ["username"])
1217            .to_sql::<Postgres>();
1218
1219        assert_eq!(
1220            stmt,
1221            r#"insert into "users" ("id", "username") select "users"."id", "users"."username" from "users" where "users"."id" = $1 on conflict ("id") do update set "username" = "excluded"."username" returning *"#
1222        );
1223    }
1224
1225    #[test]
1226    fn test_insert_select_mariadb_ignore_no_returning() {
1227        let stmt = insert_into(table)
1228            .columns((id, username))
1229            .query(select((id, username)).from(table))
1230            .ignore()
1231            .no_returning()
1232            .to_sql::<MySql>();
1233
1234        assert_eq!(
1235            stmt,
1236            "insert ignore into `users` (`id`, `username`) select `users`.`id`, `users`.`username` from `users`"
1237        );
1238    }
1239}