Skip to main content

rust_rel8/
lib.rs

1#![feature(unboxed_closures)]
2
3use std::{
4    borrow::Cow,
5    marker::PhantomData,
6    sync::{Arc, atomic::AtomicU32},
7};
8
9use bytemuck::TransparentWrapper as _;
10use sea_query::ExprTrait;
11
12#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq)]
13struct Binder(u32);
14
15static BINDER_COUNT: AtomicU32 = AtomicU32::new(0);
16impl Binder {
17    fn new() -> Self {
18        Self(BINDER_COUNT.fetch_add(1, std::sync::atomic::Ordering::SeqCst))
19    }
20
21    #[allow(unused)]
22    fn reset() {
23        BINDER_COUNT.store(0, std::sync::atomic::Ordering::SeqCst);
24    }
25}
26
27#[derive(Clone)]
28struct TableName {
29    binder: Binder,
30    name: String,
31}
32
33impl std::fmt::Debug for TableName {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(f, "{}", self.name)
36    }
37}
38
39impl<'a> From<TableName> for Cow<'a, str> {
40    fn from(val: TableName) -> Self {
41        format!("t{}", val.binder.0).into()
42    }
43}
44
45impl TableName {
46    fn new(binder: Binder) -> Self {
47        Self {
48            binder,
49            name: format!("t{}", binder.0),
50        }
51    }
52}
53
54impl sea_query::Iden for TableName {
55    fn unquoted(&self) -> &str {
56        &self.name
57    }
58}
59
60#[derive(Clone)]
61struct ColumnName {
62    name: String,
63    rendered: String,
64}
65
66impl std::fmt::Display for ColumnName {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        write!(f, "{}", self.rendered)
69    }
70}
71
72impl std::fmt::Debug for ColumnName {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        write!(f, "{}", self.rendered)
75    }
76}
77
78impl sea_query::Iden for ColumnName {
79    fn unquoted(&self) -> &str {
80        &self.rendered
81    }
82}
83
84impl ColumnName {
85    fn new(binder: Binder, name: String) -> Self {
86        Self {
87            rendered: format!("{}{}", name, binder.0),
88            name,
89        }
90    }
91}
92
93/// A trait abstracting the mode of a user defined table, which allows us to
94/// talk about the same table in two different modes.
95///
96/// ```rs
97/// impl<'scope, T: Column<'scope>> TableHKT<'scope> for MyTable<'scope, T> {
98///     type Of<Mode: Column<'scope>> = MyTable<'scope, Mode>;
99///
100///     type Mode = T;
101/// }
102/// ```
103pub trait TableHKT {
104    /// The current mode of this table.
105    type Mode: TableMode;
106
107    /// Replace the mode with another.
108    type Of<Mode: TableMode>;
109}
110
111#[cfg(feature = "sqlx")]
112/// Proxy module allowing us to have a supertrait conditional on feature flags.
113mod value_sqlx {
114    /// Proxy trait for [sqlx::Decode] that we can switch off if not enabled
115    pub trait Value: for<'r> sqlx::Decode<'r, sqlx::Any> {}
116    impl<T> Value for T where T: for<'r> sqlx::Decode<'r, sqlx::Any> {}
117}
118
119#[cfg(not(feature = "sqlx"))]
120/// Proxy module allowing us to have a supertrait conditional on feature flags.
121mod value_sqlx {
122    /// Proxy trait for [sqlx::Decode] that we can switch off if not enabled
123    pub trait Value {}
124    impl<T> Value for T where T: ?Sized {}
125}
126
127pub use value_sqlx::Value as SqlxValueIfEnabled;
128
129/// This trait represents values we know to encode and decode from their database type.
130///
131/// Depending on features, this will have supertraits of the encode/decode
132/// traits of the backends.
133pub trait Value: Into<sea_query::Value> + SqlxValueIfEnabled {}
134
135impl<T> Value for T where T: Into<sea_query::Value> + value_sqlx::Value {}
136
137/// This trait allows us to write a mapping function between two column modes.
138///
139/// We use a trait as we need this to work for all types of the type parameter `V`.
140pub trait ModeMapper<'scope, SrcMode: TableMode, DestMode: TableMode> {
141    fn map_mode<V>(&mut self, src: SrcMode::T<'scope, V>) -> DestMode::T<'scope, V>
142    where
143        V: Value;
144}
145
146/// This trait allows us to write a mapping function between two column modes.
147///
148/// We use a trait as we need this to work for all types of the type parameter `V`.
149pub trait ModeMapperRef<'scope, SrcMode: TableMode, DestMode: TableMode> {
150    fn map_mode_ref<V>(&mut self, src: &SrcMode::T<'scope, V>) -> DestMode::T<'scope, V>
151    where
152        V: Value;
153}
154
155/// This trait allows us to write a mapping function between two column modes.
156///
157/// We use a trait as we need this to work for all types of the type parameter `V`.
158pub trait ModeMapperMut<'scope, SrcMode: TableMode, DestMode: TableMode> {
159    fn map_mode_mut<V>(&mut self, src: &mut SrcMode::T<'scope, V>) -> DestMode::T<'scope, V>
160    where
161        V: Value;
162}
163
164/// This trait allows us to change the mode of a table by mapping all the fields
165/// with [ModeMapper], [ModeMapperRef], or [ModeMapperMut].
166pub trait MapTable<'scope>: TableHKT {
167    /// Map each field of the table
168    ///
169    /// The order and number of fields visited must always remain the same,
170    /// across: [Table::visit], [Table::visit_mut], and all methods of [MapTable].
171    fn map_modes<Mapper, DestMode>(self, mapper: &mut Mapper) -> Self::Of<DestMode>
172    where
173        Mapper: ModeMapper<'scope, Self::Mode, DestMode>,
174        DestMode: TableMode;
175
176    /// Map each field of the table, with a reference
177    ///
178    /// The order and number of fields visited must always remain the same,
179    /// across: [Table::visit], [Table::visit_mut], and all methods of [MapTable].
180    fn map_modes_ref<Mapper, DestMode>(&self, mapper: &mut Mapper) -> Self::Of<DestMode>
181    where
182        Mapper: ModeMapperRef<'scope, Self::Mode, DestMode>,
183        DestMode: TableMode;
184
185    /// Map each field of the table, with a mutable reference
186    ///
187    /// The order and number of fields visited must always remain the same,
188    /// across: [Table::visit], [Table::visit_mut], and all methods of [MapTable].
189    fn map_modes_mut<Mapper, DestMode>(&mut self, mapper: &mut Mapper) -> Self::Of<DestMode>
190    where
191        Mapper: ModeMapperMut<'scope, Self::Mode, DestMode>,
192        DestMode: TableMode;
193}
194
195/// A mapper which reads the column names of a table in [NameMode] and adds them to a select,
196/// then yields an expression referencing the column.
197struct NameToExprMapper {
198    binder: Binder,
199    query: sea_query::SelectStatement,
200}
201
202impl<'scope> ModeMapperRef<'scope, NameMode, ExprMode> for NameToExprMapper {
203    fn map_mode_ref<V>(
204        &mut self,
205        src: &<NameMode as TableMode>::T<'scope, V>,
206    ) -> <ExprMode as TableMode>::T<'scope, V> {
207        let col_name = ColumnName::new(self.binder, src.to_string());
208        self.query
209            .expr_as(sea_query::Expr::column(*src), col_name.clone());
210
211        Expr::new(ExprInner::Column(TableName::new(self.binder), col_name))
212    }
213}
214
215struct ExprCollectorMapper {
216    idx: usize,
217    table_binder: Binder,
218    columns: Vec<ColumnName>,
219    values: Vec<sea_query::Value>,
220}
221
222impl<'scope> ModeMapper<'scope, ValueMode, ExprMode> for ExprCollectorMapper {
223    fn map_mode<V>(
224        &mut self,
225        src: <ValueMode as TableMode>::T<'scope, V>,
226    ) -> <ExprMode as TableMode>::T<'scope, V>
227    where
228        V: Value,
229    {
230        let idx = self.idx;
231        self.idx += 1;
232
233        self.values.push(src.into());
234
235        let column_name = ColumnName::new(self.table_binder, format!("values_{idx}_"));
236
237        self.columns.push(column_name.clone());
238
239        Expr::new(ExprInner::Column(
240            TableName::new(self.table_binder),
241            column_name,
242        ))
243    }
244}
245
246struct ExprCollectorRemainingMapper {
247    values: Vec<sea_query::Value>,
248}
249
250impl<'scope> ModeMapper<'scope, ValueMode, EmptyMode> for ExprCollectorRemainingMapper {
251    fn map_mode<V>(
252        &mut self,
253        src: <ValueMode as TableMode>::T<'scope, V>,
254    ) -> <EmptyMode as TableMode>::T<'scope, V>
255    where
256        V: Value,
257    {
258        self.values.push(src.into());
259    }
260}
261
262struct NameCollectorMapper {
263    names: Vec<&'static str>,
264}
265
266impl<'scope> ModeMapperRef<'scope, NameMode, EmptyMode> for NameCollectorMapper {
267    fn map_mode_ref<V>(
268        &mut self,
269        src: &<NameMode as TableMode>::T<'scope, V>,
270    ) -> <EmptyMode as TableMode>::T<'scope, V> {
271        self.names.push(*src);
272    }
273}
274
275/// A mapper which visits each expression of a table in [ExprMode].
276struct VisitTableMapper<'a, F> {
277    f: &'a mut F,
278}
279
280impl<'a, 'scope, F> ModeMapperRef<'scope, ExprMode, EmptyMode> for VisitTableMapper<'a, F>
281where
282    F: FnMut(&ErasedExpr),
283{
284    fn map_mode_ref<V>(
285        &mut self,
286        src: &<ExprMode as TableMode>::T<'scope, V>,
287    ) -> <EmptyMode as TableMode>::T<'scope, V> {
288        (self.f)(src.as_erased());
289    }
290}
291
292impl<'a, 'scope, F> ModeMapperRef<'scope, ExprNullifiedMode, EmptyMode> for VisitTableMapper<'a, F>
293where
294    F: FnMut(&ErasedExpr),
295{
296    fn map_mode_ref<V>(
297        &mut self,
298        src: &<ExprNullifiedMode as TableMode>::T<'scope, V>,
299    ) -> <EmptyMode as TableMode>::T<'scope, V> {
300        (self.f)(src.as_erased());
301    }
302}
303
304/// A mapper which mutably visits each expression of a table in [ExprMode].
305struct VisitTableMapperMut<'a, F> {
306    f: &'a mut F,
307}
308
309impl<'a, 'scope, F> ModeMapperMut<'scope, ExprMode, EmptyMode> for VisitTableMapperMut<'a, F>
310where
311    F: FnMut(&mut ErasedExpr),
312{
313    fn map_mode_mut<V>(
314        &mut self,
315        src: &mut <ExprMode as TableMode>::T<'scope, V>,
316    ) -> <EmptyMode as TableMode>::T<'scope, V> {
317        (self.f)(src.as_erased_mut());
318    }
319}
320
321impl<'a, 'scope, F> ModeMapperMut<'scope, ExprNullifiedMode, EmptyMode>
322    for VisitTableMapperMut<'a, F>
323where
324    F: FnMut(&mut ErasedExpr),
325{
326    fn map_mode_mut<V>(
327        &mut self,
328        src: &mut <ExprNullifiedMode as TableMode>::T<'scope, V>,
329    ) -> <EmptyMode as TableMode>::T<'scope, V> {
330        (self.f)(src.as_erased_mut());
331    }
332}
333
334/// A mapper which nullifies a table, turning it from [ExprMode] to [ExprNullifiedMode].
335struct NullifyMapper {}
336
337impl<'scope> ModeMapper<'scope, ExprMode, ExprNullifiedMode> for NullifyMapper {
338    fn map_mode<V>(
339        &mut self,
340        src: <ExprMode as TableMode>::T<'scope, V>,
341    ) -> <ExprNullifiedMode as TableMode>::T<'scope, V> {
342        src.nullify()
343    }
344}
345
346#[cfg(feature = "sqlx")]
347/// A mapper which decodes a table from a result row, turning it from [ExprMode] to [ValueMode].
348struct LoadingMapper<'a, IT> {
349    it: &'a mut IT,
350}
351
352#[cfg(feature = "sqlx")]
353impl<'a, 'b, 'scope, IT> ModeMapperRef<'scope, ExprMode, ValueMode> for LoadingMapper<'a, IT>
354where
355    IT: Iterator<Item = sqlx::any::AnyValueRef<'b>>,
356{
357    fn map_mode_ref<V>(
358        &mut self,
359        _src: &<ExprMode as TableMode>::T<'scope, V>,
360    ) -> <ValueMode as TableMode>::T<'scope, V>
361    where
362        V: Value,
363    {
364        <_ as sqlx::Decode<sqlx::Any>>::decode(self.it.next().unwrap()).unwrap()
365    }
366}
367
368#[cfg(feature = "sqlx")]
369impl<'a, 'b, 'scope, IT> ModeMapperRef<'scope, ExprNullifiedMode, ValueNullifiedMode>
370    for LoadingMapper<'a, IT>
371where
372    IT: Iterator<Item = sqlx::any::AnyValueRef<'b>>,
373{
374    fn map_mode_ref<V>(
375        &mut self,
376        _src: &<ExprNullifiedMode as TableMode>::T<'scope, V>,
377    ) -> <ValueNullifiedMode as TableMode>::T<'scope, V>
378    where
379        V: Value,
380    {
381        use sqlx::ValueRef as _;
382
383        let v = self.it.next().unwrap();
384        if v.is_null() {
385            None
386        } else {
387            Some(<_ as sqlx::Decode<sqlx::Any>>::decode(v).unwrap())
388        }
389    }
390}
391
392#[cfg(feature = "sqlx")]
393/// A mapper which skips loading a table from a result row.
394struct SkippingMapper<'a, IT> {
395    it: &'a mut IT,
396}
397
398#[cfg(feature = "sqlx")]
399impl<'a, 'b, 'scope, IT> ModeMapperRef<'scope, ExprMode, EmptyMode> for SkippingMapper<'a, IT>
400where
401    IT: Iterator<Item = sqlx::any::AnyValueRef<'b>>,
402{
403    fn map_mode_ref<V>(
404        &mut self,
405        _src: &<ExprMode as TableMode>::T<'scope, V>,
406    ) -> <EmptyMode as TableMode>::T<'scope, V> {
407        self.it.next().unwrap();
408    }
409}
410
411/// A table's name and column names.
412pub struct TableSchema<Table> {
413    /// The name of the table.
414    pub name: &'static str,
415
416    /// The table columns, this should be some table in [NameMode].
417    pub columns: Table,
418}
419
420/// The modes a table can be in
421pub mod table_modes {
422    /// Name mode, where all columns are [`&'static str`], representing the column names.
423    pub enum NameMode {}
424
425    #[derive(Debug, PartialEq)]
426    /// Value mode, representing a table row that has been loaded from the query.
427    ///
428    /// This enum implements [Debug] and [PartialEq] so that your types can
429    /// derive them without a baseless failure from a type parameter not
430    /// implementing trait despite not appearing in the data type.
431    pub enum ValueMode {}
432
433    /// Value mode, but the value might be null.
434    pub enum ValueNullifiedMode {}
435
436    /// Expr mode, the columns are [crate::Expr]s.
437    pub enum ExprMode {}
438
439    /// Expr mode, but the value might be null, the columns are [crate::Expr]s.
440    pub enum ExprNullifiedMode {}
441
442    /// Empty mode, all fields are `()`. This is used when a mapper doesn't
443    /// want to produce a value.
444    pub enum EmptyMode {}
445}
446
447pub use table_modes::*;
448
449/// Table modes, this trait is used to switch the types of a rust structs fields.
450///
451/// You should use it in table struct like so:
452///
453/// ```rust
454/// use rust_rel8::*;
455///
456/// struct MyTable<'scope, Mode: TableMode> {
457///   id: Mode::T<'scope, i32>,
458///   name: Mode::T<'scope, String>,
459///   age: Mode::T<'scope, i32>,
460/// }
461/// ```
462pub trait TableMode {
463    /// A Gat, the resultant type may or may not incorporate `V`.
464    type T<'scope, V>;
465}
466
467impl TableMode for NameMode {
468    /// a string representing the column name.
469    type T<'scope, V> = &'static str;
470}
471
472impl TableMode for ValueMode {
473    type T<'scope, V> = V;
474}
475
476impl TableMode for ValueNullifiedMode {
477    type T<'scope, V> = Option<V>;
478}
479
480impl TableMode for ExprMode {
481    type T<'scope, V> = Expr<'scope, V>;
482}
483
484impl TableMode for ExprNullifiedMode {
485    type T<'scope, V> = Expr<'scope, Option<V>>;
486}
487
488impl TableMode for EmptyMode {
489    type T<'scope, V> = ();
490}
491
492#[derive(bytemuck::TransparentWrapper)]
493#[repr(transparent)]
494/// A wrapper which implements [Table] for any type implementing [MapTable] in [ExprMode].
495pub struct TableUsingMapper<T>(pub T);
496
497impl<T> TableUsingMapper<T> {
498    pub fn wrap(t: T) -> Self {
499        <Self as bytemuck::TransparentWrapper<T>>::wrap(t)
500    }
501
502    pub fn wrap_ref(t: &T) -> &Self {
503        <Self as bytemuck::TransparentWrapper<T>>::wrap_ref(t)
504    }
505
506    pub fn wrap_mut(t: &mut T) -> &mut Self {
507        <Self as bytemuck::TransparentWrapper<T>>::wrap_mut(t)
508    }
509}
510
511impl<'scope, T> Table<'scope> for TableUsingMapper<T>
512where
513    T: Table<'scope> + MapTable<'scope> + TableHKT<Mode = ExprMode>,
514    T::Of<ExprNullifiedMode>: Table<'scope>,
515{
516    type Nullify = T::Of<ExprNullifiedMode>;
517
518    type Result = T::Of<ValueMode>;
519
520    fn visit(&self, f: &mut impl FnMut(&ErasedExpr)) {
521        let mut mapper = VisitTableMapper { f };
522        self.0.map_modes_ref(&mut mapper);
523    }
524
525    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr)) {
526        let mut mapper = VisitTableMapperMut { f };
527        self.0.map_modes_mut(&mut mapper);
528    }
529
530    fn nullify(self) -> Self::Nullify {
531        let mut mapper = NullifyMapper {};
532        self.0.map_modes(&mut mapper)
533    }
534}
535
536#[derive(bytemuck::TransparentWrapper)]
537#[repr(transparent)]
538/// A wrapper which implements [Table] for any type implementing [MapTable] in [ExprNullifiedMode].
539pub struct TableUsingMapperNullified<T>(pub T);
540
541impl<T> TableUsingMapperNullified<T> {
542    pub fn wrap(t: T) -> Self {
543        <Self as bytemuck::TransparentWrapper<T>>::wrap(t)
544    }
545
546    pub fn wrap_ref(t: &T) -> &Self {
547        <Self as bytemuck::TransparentWrapper<T>>::wrap_ref(t)
548    }
549
550    pub fn wrap_mut(t: &mut T) -> &mut Self {
551        <Self as bytemuck::TransparentWrapper<T>>::wrap_mut(t)
552    }
553}
554
555impl<'scope, T> Table<'scope> for TableUsingMapperNullified<T>
556where
557    T: Table<'scope> + MapTable<'scope> + TableHKT<Mode = ExprNullifiedMode>,
558{
559    type Nullify = T;
560
561    type Result = T::Of<ValueNullifiedMode>;
562
563    fn visit(&self, f: &mut impl FnMut(&ErasedExpr)) {
564        let mut mapper = VisitTableMapper { f };
565        self.0.map_modes_ref(&mut mapper);
566    }
567
568    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr)) {
569        let mut mapper = VisitTableMapperMut { f };
570        self.0.map_modes_mut(&mut mapper);
571    }
572
573    fn nullify(self) -> Self::Nullify {
574        self.0
575    }
576}
577
578#[cfg(feature = "sqlx")]
579impl<T> TableLoaderSqlx for TableUsingMapper<T>
580where
581    T: Table<'static> + MapTable<'static> + TableHKT<Mode = ExprMode>,
582    T::Of<ExprNullifiedMode>: Table<'static>,
583{
584    fn load<'a>(
585        &self,
586        values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>,
587    ) -> Self::Result {
588        let mut mapper = LoadingMapper { it: values };
589        self.0.map_modes_ref(&mut mapper)
590    }
591
592    fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>) {
593        let mut mapper = SkippingMapper { it: values };
594        self.0.map_modes_ref(&mut mapper);
595    }
596}
597
598/// A trait that represents a database result row.
599///
600/// If you implement [Table] on your type, you must also implement [ForLifetimeTable].
601pub trait Table<'scope> {
602    type Nullify: Table<'scope>;
603    type Result;
604
605    /// Visit each expr in the table.
606    ///
607    /// The order and number of expressions visited must always remain the same,
608    /// across: [Table::visit], [Table::visit_mut], and all methods of [MapTable].
609    fn visit(&self, f: &mut impl FnMut(&ErasedExpr));
610
611    /// Visit each expr in the table, with a mutable reference.
612    ///
613    /// The order and number of expressions visited must always remain the same,
614    /// across: [Table::visit], [Table::visit_mut], and all methods of [MapTable].
615    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr));
616
617    fn nullify(self) -> Self::Nullify;
618}
619
620#[cfg(feature = "sqlx")]
621pub trait TableLoaderSqlx: Table<'static> {
622    /// Load the table given an iterator over a row's values
623    fn load<'a>(
624        &self,
625        values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>,
626    ) -> Self::Result;
627
628    /// discard N columns from the iterator ofer a row
629    /// This is used when this value was discarded by a [MaybeTable].
630    fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>);
631}
632
633/// attach each of the tables columns to the select, and renames the table to
634/// use the new names.
635///
636/// This embeds expressions, so after this the table contains only Column exprs.
637fn subst_table<'scope, T: Table<'scope>>(
638    table: &mut T,
639    table_name: TableName,
640    dest_select: &mut sea_query::SelectStatement,
641) {
642    table.visit_mut(&mut |ErasedExpr(inner)| {
643        let new_column_name = match inner {
644            ExprInner::Raw(..) => ColumnName::new(Binder::new(), "lit".to_owned()),
645            ExprInner::Column(_table_name, column_name) => {
646                ColumnName::new(Binder::new(), column_name.name.clone())
647            }
648            ExprInner::BinOp(..) => ColumnName::new(Binder::new(), "expr".to_owned()),
649        };
650        let r = inner.render();
651        dest_select.expr_as(r, new_column_name.clone());
652        *inner = ExprInner::Column(table_name.clone(), new_column_name);
653    })
654}
655
656// not sure if we should do this by having `q` wrap the query and `subst_table` everything, or this
657fn insert_table_name<'scope, T: Table<'scope>>(table: &mut T, new_table_name: TableName) {
658    table.visit_mut(&mut |ErasedExpr(inner)| match inner {
659        ExprInner::Raw(_) => {}
660        ExprInner::Column(table_name, _column_name) => {
661            *table_name = new_table_name.clone();
662        }
663        ExprInner::BinOp(_, expr_inner, expr_inner1) => {
664            expr_inner.visit_mut(&mut |table_name, _| *table_name = new_table_name.clone());
665            expr_inner1.visit_mut(&mut |table_name, _| *table_name = new_table_name.clone());
666        }
667    })
668}
669
670/// A helper trait that allows us to talk about a [Table] with different
671/// lifetimes. Conceptually it is a type level function of `lt -> T where T:
672/// Table<'lt>`.
673///
674/// If you implement [Table] on your type, you must also implement [ForLifetimeTable].
675pub trait ForLifetimeTable: Sized {
676    type Of<'lt>: Table<'lt> + Sized
677    where
678        Self: 'lt;
679}
680
681impl<'scope, T> Table<'scope> for Expr<'scope, T> {
682    type Nullify = Expr<'scope, Option<T>>;
683    type Result = T;
684
685    fn visit(&self, f: &mut impl FnMut(&ErasedExpr)) {
686        f(self.as_erased())
687    }
688
689    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr)) {
690        f(self.as_erased_mut())
691    }
692
693    fn nullify(self) -> Self::Nullify {
694        Expr::new(self.expr)
695    }
696}
697
698impl<T> ForLifetimeTable for Expr<'static, T> {
699    type Of<'lt>
700        = Expr<'lt, T>
701    where
702        T: 'lt;
703}
704
705impl<T> ShortenLifetime for Expr<'static, T> {
706    type Shortened<'small>
707        = Expr<'small, T>
708    where
709        Self: 'small;
710
711    fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
712    where
713        Self: 'large,
714    {
715        Expr::new(self.expr)
716    }
717}
718
719#[cfg(feature = "sqlx")]
720impl<T: for<'r> sqlx::Decode<'r, sqlx::Any>> TableLoaderSqlx for Expr<'static, T> {
721    fn load<'a>(
722        &self,
723        values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>,
724    ) -> Self::Result {
725        T::decode(values.next().unwrap()).unwrap()
726    }
727
728    fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>) {
729        let _ = values.next().unwrap();
730    }
731}
732
733impl<'scope, A: Table<'scope>> Table<'scope> for (A,) {
734    type Nullify = (A::Nullify,);
735    type Result = (A::Result,);
736
737    fn visit(&self, f: &mut impl FnMut(&ErasedExpr)) {
738        self.0.visit(f);
739    }
740
741    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr)) {
742        self.0.visit_mut(f);
743    }
744
745    fn nullify(self) -> Self::Nullify {
746        (self.0.nullify(),)
747    }
748}
749
750impl<A: ShortenLifetime> ShortenLifetime for (A,) {
751    type Shortened<'small>
752        = (A::Shortened<'small>,)
753    where
754        Self: 'small;
755
756    fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
757    where
758        Self: 'large,
759    {
760        (self.0.shorten_lifetime(),)
761    }
762}
763
764impl<A: ForLifetimeTable> ForLifetimeTable for (A,) {
765    type Of<'lt>
766        = (A::Of<'lt>,)
767    where
768        A: 'lt;
769}
770
771#[cfg(feature = "sqlx")]
772impl<A: TableLoaderSqlx> TableLoaderSqlx for (A,) {
773    fn load<'a>(
774        &self,
775        values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>,
776    ) -> Self::Result {
777        let a = self.0.load(values);
778        (a,)
779    }
780
781    fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>) {
782        self.0.skip(values);
783    }
784}
785
786impl<'scope, A: Table<'scope>, B: Table<'scope>> Table<'scope> for (A, B) {
787    type Nullify = (A::Nullify, B::Nullify);
788    type Result = (A::Result, B::Result);
789
790    fn visit(&self, f: &mut impl FnMut(&ErasedExpr)) {
791        self.0.visit(f);
792        self.1.visit(f);
793    }
794
795    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr)) {
796        self.0.visit_mut(f);
797        self.1.visit_mut(f);
798    }
799
800    fn nullify(self) -> Self::Nullify {
801        (self.0.nullify(), self.1.nullify())
802    }
803}
804
805impl<A: ForLifetimeTable, B: ForLifetimeTable> ForLifetimeTable for (A, B) {
806    type Of<'lt>
807        = (A::Of<'lt>, B::Of<'lt>)
808    where
809        A: 'lt,
810        B: 'lt;
811}
812
813impl<A: ShortenLifetime, B: ShortenLifetime> ShortenLifetime for (A, B) {
814    type Shortened<'small>
815        = (A::Shortened<'small>, B::Shortened<'small>)
816    where
817        Self: 'small;
818
819    fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
820    where
821        Self: 'large,
822    {
823        (self.0.shorten_lifetime(), self.1.shorten_lifetime())
824    }
825}
826
827#[cfg(feature = "sqlx")]
828impl<A: TableLoaderSqlx, B: TableLoaderSqlx> TableLoaderSqlx for (A, B) {
829    fn load<'a>(
830        &self,
831        values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>,
832    ) -> Self::Result {
833        let a = self.0.load(values);
834        let b = self.1.load(values);
835        (a, b)
836    }
837
838    fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>) {
839        self.0.skip(values);
840        self.1.skip(values);
841    }
842}
843
844impl<'scope, A: Table<'scope>, B: Table<'scope>, C: Table<'scope>> Table<'scope> for (A, B, C) {
845    type Nullify = (A::Nullify, B::Nullify, C::Nullify);
846    type Result = (A::Result, B::Result, C::Result);
847
848    fn visit(&self, f: &mut impl FnMut(&ErasedExpr)) {
849        self.0.visit(f);
850        self.1.visit(f);
851        self.2.visit(f);
852    }
853
854    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr)) {
855        self.0.visit_mut(f);
856        self.1.visit_mut(f);
857        self.2.visit_mut(f);
858    }
859
860    fn nullify(self) -> Self::Nullify {
861        (self.0.nullify(), self.1.nullify(), self.2.nullify())
862    }
863}
864
865impl<A: ForLifetimeTable, B: ForLifetimeTable, C: ForLifetimeTable> ForLifetimeTable for (A, B, C) {
866    type Of<'lt>
867        = (A::Of<'lt>, B::Of<'lt>, C::Of<'lt>)
868    where
869        A: 'lt,
870        B: 'lt,
871        C: 'lt;
872}
873
874impl<A: ShortenLifetime, B: ShortenLifetime, C: ShortenLifetime> ShortenLifetime for (A, B, C) {
875    type Shortened<'small>
876        = (
877        A::Shortened<'small>,
878        B::Shortened<'small>,
879        C::Shortened<'small>,
880    )
881    where
882        Self: 'small;
883
884    fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
885    where
886        Self: 'large,
887    {
888        (
889            self.0.shorten_lifetime(),
890            self.1.shorten_lifetime(),
891            self.2.shorten_lifetime(),
892        )
893    }
894}
895
896#[cfg(feature = "sqlx")]
897impl<A: TableLoaderSqlx, B: TableLoaderSqlx, C: TableLoaderSqlx> TableLoaderSqlx for (A, B, C) {
898    fn load<'a>(
899        &self,
900        values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>,
901    ) -> Self::Result {
902        let a = self.0.load(values);
903        let b = self.1.load(values);
904        let c = self.2.load(values);
905        (a, b, c)
906    }
907
908    fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>) {
909        self.0.skip(values);
910        self.1.skip(values);
911        self.2.skip(values);
912    }
913}
914
915pub struct MaybeTable<'scope, T> {
916    tag: Expr<'scope, Option<bool>>,
917    inner: T,
918}
919
920impl<'scope, T: Table<'scope>> Table<'scope> for MaybeTable<'scope, T> {
921    type Nullify = Self;
922    type Result = Option<T::Result>;
923    // type Result = <T::Nullify as Table<'scope>>::Result;
924
925    fn visit(&self, f: &mut impl FnMut(&ErasedExpr)) {
926        self.tag.visit(f);
927        self.inner.visit(f);
928    }
929
930    fn visit_mut(&mut self, f: &mut impl FnMut(&mut ErasedExpr)) {
931        self.tag.visit_mut(f);
932        self.inner.visit_mut(f);
933    }
934
935    fn nullify(self) -> Self::Nullify {
936        self
937    }
938}
939
940impl<T: ForLifetimeTable + Table<'static>> ForLifetimeTable for MaybeTable<'static, T>
941where
942    for<'lt> T::Of<'lt>: Table<'lt>,
943{
944    type Of<'lt>
945        = MaybeTable<'lt, T::Of<'lt>>
946    where
947        T: 'lt;
948}
949
950impl<T: ShortenLifetime> ShortenLifetime for MaybeTable<'static, T> {
951    type Shortened<'small>
952        = MaybeTable<'small, T::Shortened<'small>>
953    where
954        Self: 'small;
955
956    fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
957    where
958        Self: 'large,
959    {
960        MaybeTable {
961            tag: self.tag.shorten_lifetime(),
962            inner: self.inner.shorten_lifetime(),
963        }
964    }
965}
966
967#[cfg(feature = "sqlx")]
968impl<T: TableLoaderSqlx> TableLoaderSqlx for MaybeTable<'static, T> {
969    fn load<'a>(
970        &self,
971        values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>,
972    ) -> Self::Result {
973        let tag =
974            <Option<bool> as sqlx::Decode<sqlx::Any>>::decode(values.next().unwrap()).unwrap();
975
976        if tag == Some(true) {
977            Some(self.inner.load(values))
978        } else {
979            self.inner.skip(values);
980            None
981        }
982    }
983
984    fn skip<'a>(&self, values: &mut impl Iterator<Item = sqlx::any::AnyValueRef<'a>>) {
985        // the tag
986        values.next().unwrap();
987        self.inner.skip(values);
988    }
989}
990
991/// A value representing a sql select statement which produces rows of type `T`.
992#[derive(Clone)]
993pub struct Query<T> {
994    // Unique ID used to make the table name and columns unique
995    binder: Binder,
996    expr: sea_query::SelectStatement,
997    inner: T,
998    siblings_need_random: bool,
999}
1000
1001impl<'scope, T: Table<'scope>> Query<T> {
1002    fn new(binder: Binder, expr: sea_query::SelectStatement, inner: T) -> Self {
1003        Self {
1004            binder,
1005            expr,
1006            inner,
1007            siblings_need_random: false,
1008        }
1009    }
1010
1011    fn into_volatile(mut self) -> Self {
1012        self.siblings_need_random = true;
1013        self
1014    }
1015
1016    fn erased(self) -> (ErasedQuery, T) {
1017        let r = ErasedQuery {
1018            expr: self.expr,
1019            siblings_need_random: self.siblings_need_random,
1020        };
1021
1022        (r, self.inner)
1023    }
1024
1025    /// Lift a table into a query which ensures side effects happen and are not shared.
1026    pub fn evaluate(mut table: T) -> Self {
1027        let binder = Binder::new();
1028        let mut expr = sea_query::SelectStatement::new();
1029        subst_table(&mut table, TableName::new(binder), &mut expr);
1030        Self::new(binder, expr, table).into_volatile()
1031    }
1032
1033    /// Transform this query into one which produces rows of either [`Some<T>`] or [None].
1034    ///
1035    /// That is, this turns the query into a left join.
1036    pub fn optional(mut self) -> Query<MaybeTable<'scope, T>> {
1037        let binder = Binder::new();
1038
1039        let filler_col = ColumnName::new(binder, "filler".to_owned());
1040
1041        let mut filler = sea_query::Query::select()
1042            .expr_as(sea_query::Expr::column("column1"), filler_col.clone())
1043            .from_values(
1044                vec![sea_query::ValueTuple::One(true.into())],
1045                TableName::new(binder),
1046            )
1047            .to_owned();
1048
1049        let tag = ColumnName::new(binder, "tag".to_owned());
1050
1051        let mut expr = self.expr;
1052        expr.expr_as(sea_query::Value::Bool(Some(true)), tag.clone());
1053
1054        let table_name = TableName::new(self.binder);
1055
1056        filler.join_subquery(
1057            sea_query::JoinType::LeftJoin,
1058            expr,
1059            table_name.clone(),
1060            sea_query::Condition::all(),
1061        );
1062
1063        // important: the tag must be the first column
1064        filler.expr_as(
1065            sea_query::Expr::column((table_name.clone(), tag.clone())),
1066            tag.clone(),
1067        );
1068
1069        // the rest can come later
1070        subst_table(&mut self.inner, table_name, &mut filler);
1071
1072        let maybe_table = MaybeTable {
1073            inner: self.inner,
1074            tag: Expr {
1075                expr: ExprInner::Column(TableName::new(binder), tag),
1076                _phantom: PhantomData,
1077            },
1078        };
1079
1080        Query {
1081            binder,
1082            expr: filler,
1083            inner: maybe_table,
1084            siblings_need_random: self.siblings_need_random,
1085        }
1086    }
1087
1088    /// Add an order by clause to the query, the given function should return a
1089    /// table, the query will be ordered by each column of the table.
1090    pub fn order_by<U, F>(self, f: F) -> Query<T>
1091    where
1092        U: Table<'scope>,
1093        F: FnOnce(&T) -> (U, sea_query::Order),
1094    {
1095        let binder = Binder::new();
1096
1097        let mut outer = sea_query::Query::select();
1098        outer.from_subquery(self.expr, TableName::new(self.binder));
1099
1100        let (order_expr, order) = f(&self.inner);
1101
1102        order_expr.visit(&mut |ErasedExpr(e)| {
1103            outer.order_by_expr(e.render(), order.clone());
1104        });
1105
1106        let mut e = self.inner;
1107        subst_table(&mut e, TableName::new(binder), &mut outer);
1108
1109        Query {
1110            binder,
1111            expr: outer,
1112            inner: e,
1113            siblings_need_random: self.siblings_need_random,
1114        }
1115    }
1116}
1117
1118impl<'scope, T> Query<T>
1119where
1120    T: MapTable<'scope> + TableHKT<Mode = NameMode>,
1121    T::Of<ExprMode>: Table<'scope>,
1122{
1123    /// Given a [TableSchema], build a query that selects all columns of every row.
1124    pub fn each(schema: &TableSchema<T>) -> Query<T::Of<ExprMode>> {
1125        let binder = Binder::new();
1126        let mut query = sea_query::Query::select();
1127        query.from(schema.name);
1128
1129        let mut mapper = NameToExprMapper { binder, query };
1130
1131        let expr = schema.columns.map_modes_ref(&mut mapper);
1132
1133        Query::new(binder, mapper.query, expr)
1134    }
1135}
1136
1137impl<'scope, T: TableHKT<Mode = ValueMode>> Query<T> {
1138    /// Construct a query yielding the given rows
1139    pub fn values(vals: impl IntoIterator<Item = T>) -> Query<T::Of<ExprMode>>
1140    where
1141        T: MapTable<'scope>,
1142        T::Of<ExprMode>: Table<'scope>,
1143    {
1144        let binder = Binder::new();
1145        let mut iter = vals.into_iter();
1146
1147        let Some(first) = iter.next() else {
1148            panic!("Don't do that");
1149        };
1150
1151        let mut mapper = ExprCollectorMapper {
1152            idx: 0,
1153            table_binder: binder,
1154            columns: Vec::new(),
1155            values: Vec::new(),
1156        };
1157
1158        let result_expr = first.map_modes(&mut mapper);
1159
1160        let mut all_values = vec![mapper.values];
1161
1162        for v in iter {
1163            let mut mapper = ExprCollectorRemainingMapper { values: Vec::new() };
1164            v.map_modes(&mut mapper);
1165            all_values.push(mapper.values);
1166        }
1167
1168        let mut select = sea_query::Query::select();
1169        for (idx, col) in mapper.columns.into_iter().enumerate() {
1170            select.expr_as(sea_query::Expr::column(format!("column{}", idx + 1)), col);
1171        }
1172
1173        select.from_values(
1174            all_values
1175                .into_iter()
1176                .map(|v| sea_query::ValueTuple::Many(v)),
1177            TableName::new(binder),
1178        );
1179
1180        Query::new(binder, select, result_expr)
1181    }
1182}
1183
1184#[cfg(feature = "sqlx")]
1185impl<T: TableLoaderSqlx> Query<T> {
1186    pub async fn all(&self, pool: &mut sqlx::AnyConnection) -> sqlx::Result<Vec<T::Result>> {
1187        use sea_query::PostgresQueryBuilder;
1188        use sea_query_sqlx::SqlxBinder as _;
1189        use sqlx::Row as _;
1190
1191        let (sql, values) = self.expr.build_sqlx(PostgresQueryBuilder);
1192
1193        let all = sqlx::query_with(&sql, values).fetch_all(pool).await?;
1194
1195        Ok(all
1196            .into_iter()
1197            .map(|row| {
1198                let len = row.len();
1199                let mut it = (0..len).map(|x| row.try_get_raw(x).unwrap());
1200                self.inner.load(&mut it)
1201            })
1202            .collect::<Vec<_>>())
1203    }
1204}
1205
1206#[derive(Debug)]
1207struct ErasedQuery {
1208    expr: sea_query::SelectStatement,
1209    siblings_need_random: bool,
1210}
1211
1212/// A publicly exposed opaque type that is used by [Table::visit] and
1213/// [Table::visit_mut]. Its purpose is to allow you to store [Expr]s in your
1214/// types which implement the [Table] trait.
1215#[derive(bytemuck::TransparentWrapper)]
1216#[repr(transparent)]
1217pub struct ErasedExpr(ExprInner);
1218
1219#[derive(Clone)]
1220enum ExprInner {
1221    Raw(sea_query::Expr),
1222    Column(TableName, ColumnName),
1223    BinOp(
1224        Arc<dyn Fn(sea_query::SimpleExpr, sea_query::SimpleExpr) -> sea_query::SimpleExpr>,
1225        Box<ExprInner>,
1226        Box<ExprInner>,
1227    ),
1228}
1229
1230impl ExprInner {
1231    fn visit_mut(&mut self, f: &mut impl FnMut(&mut TableName, &mut ColumnName)) {
1232        match self {
1233            ExprInner::Column(table_name, column_name) => f(table_name, column_name),
1234            ExprInner::BinOp(_, expr_inner, expr_inner1) => {
1235                expr_inner.visit_mut(f);
1236                expr_inner1.visit_mut(f);
1237            }
1238            ExprInner::Raw(_) => {}
1239        }
1240    }
1241
1242    fn render(&self) -> sea_query::SimpleExpr {
1243        match self {
1244            ExprInner::Raw(value) => value.clone(),
1245            ExprInner::Column(table_name, column_name) => {
1246                sea_query::Expr::column((table_name.clone(), column_name.clone()))
1247            }
1248            ExprInner::BinOp(cb, expr_inner, expr_inner1) => {
1249                cb(expr_inner.render(), expr_inner1.render())
1250            }
1251        }
1252    }
1253}
1254
1255/// A type representing an expression in the query, can be passed around on the
1256/// rust side to wire things up
1257#[derive(Clone)]
1258pub struct Expr<'scope, T> {
1259    expr: ExprInner,
1260    _phantom: PhantomData<(&'scope (), T)>,
1261}
1262
1263impl<'scope, T> Expr<'scope, T> {
1264    fn new(expr: ExprInner) -> Self {
1265        Self {
1266            expr,
1267            _phantom: PhantomData,
1268        }
1269    }
1270
1271    /// Construct a literal value from any value from any value that can be encoded.
1272    pub fn lit(value: T) -> Self
1273    where
1274        T: Into<sea_query::Value>,
1275    {
1276        Self::new(ExprInner::Raw(sea_query::Expr::value(value.into())))
1277    }
1278
1279    fn binop<U>(
1280        self,
1281        other: Self,
1282        binop: Arc<dyn Fn(sea_query::SimpleExpr, sea_query::SimpleExpr) -> sea_query::SimpleExpr>,
1283    ) -> Expr<'scope, U> {
1284        Expr::new(ExprInner::BinOp(
1285            binop,
1286            Box::new(self.expr),
1287            Box::new(other.expr),
1288        ))
1289    }
1290
1291    pub fn equals(self, other: Self) -> Expr<'scope, bool> {
1292        self.binop(
1293            other,
1294            Arc::new(|a, b| a.binary(sea_query::BinOper::Equal, b)),
1295        )
1296    }
1297
1298    fn as_erased(&self) -> &ErasedExpr {
1299        ErasedExpr::wrap_ref(&self.expr)
1300    }
1301
1302    fn as_erased_mut(&mut self) -> &mut ErasedExpr {
1303        ErasedExpr::wrap_mut(&mut self.expr)
1304    }
1305}
1306
1307// TODO: num trait
1308impl<'scope> Expr<'scope, i32> {
1309    pub fn add(self, other: Self) -> Self {
1310        self.binop(other, Arc::new(|a, b| a.binary(sea_query::BinOper::Add, b)))
1311    }
1312
1313    /// generate `nextval('name')`, this must be used within [`Query::evaluate`]
1314    /// for it to behave properly.
1315    ///
1316    /// # Example
1317    ///
1318    /// ```rust
1319    /// use rust_rel8::{helper_tables::One, *};
1320    ///
1321    /// query::<(Expr<i32>, Expr<i32>)>(|q| {
1322    ///   let id = q.q(Query::evaluate(Expr::nextval("table_id_seq")));
1323    ///   let v = q.q(Query::values([1, 2, 3].map(|a| One { a })));
1324    ///   (id, v.a)
1325    /// });
1326    /// ```
1327    pub fn nextval(name: &str) -> Self {
1328        Self::new(ExprInner::Raw(
1329            sea_query::Func::cust("nextval").arg(name.to_owned()).into(),
1330        ))
1331    }
1332}
1333
1334/// An opaque value you can use to compose together queries.
1335///
1336/// To get a value of this type, use [query].
1337pub struct Q<'scope> {
1338    queries: Vec<(TableName, ErasedQuery)>,
1339    filters: Vec<ExprInner>,
1340    binder: Binder,
1341    _phantom: PhantomData<&'scope ()>,
1342}
1343
1344impl<'scope> Q<'scope> {
1345    /// Bind a query and give you a value representing each row it produces.
1346    ///
1347    /// The `'scope` lifetime prevents this value leaking out of its context,
1348    /// which would result in invalid queries.
1349    pub fn q<T: Table<'scope>>(&mut self, query: Query<T>) -> T {
1350        let binder = Binder::new();
1351        let name = TableName::new(binder);
1352        let (erased, mut inner) = query.erased();
1353        self.queries.push((name.clone(), erased));
1354        insert_table_name(&mut inner, name);
1355        inner
1356    }
1357
1358    /// Introduce a where clause for this query.
1359    ///
1360    /// If you introduce a clause that looks like `a.id = b.a_id` then you
1361    /// effectively create an inner join.
1362    pub fn where_<'a>(&mut self, expr: Expr<'a, bool>)
1363    where
1364        'scope: 'a,
1365    {
1366        self.filters.push(expr.expr);
1367    }
1368}
1369
1370/// Open a context allowing you to manipulate a query.
1371///
1372/// Inside you can use `q.q(...)` on as many [`Query<T>`] values as you wish, the
1373/// result of each call can be thought of as each value the query yields.
1374///
1375/// You can think of this as cross joining each query together, to create inner
1376/// joins or left joins, simply use [Q::where_] and [`Query<T>::optional`].
1377///
1378/// Unfortunately, rustc isn't able to infer the return type of this function as
1379/// there seems to be no good way to express that `T::Of<'a>` is the same type
1380/// as `T::Of<'b>` modulo lifetimes.
1381pub fn query<'outer, T: ForLifetimeTable>(
1382    f: impl for<'scope> FnOnce(&mut Q<'scope>) -> T::Of<'scope>,
1383) -> Query<T::Of<'outer>> {
1384    let mut q = Q {
1385        binder: Binder::new(),
1386        filters: Vec::new(),
1387        queries: Vec::new(),
1388        _phantom: PhantomData,
1389    };
1390
1391    let mut e = f(&mut q);
1392
1393    // if one of the selects needs the parents to have dummy columns to prevent
1394    // postgres evaluating it only once, we add to each query a `random() as dummy`.
1395    // Then, in those selects needing them,
1396    let needs_random = q.queries.iter().any(|(_, q)| q.siblings_need_random);
1397
1398    let mut random_binders: Vec<sea_query::Expr> = Vec::new();
1399
1400    let mut insert_dummy = |mut stmt: sea_query::SelectStatement, table: &TableName| {
1401        if needs_random {
1402            stmt.expr_as(sea_query::Func::random(), "dummy");
1403
1404            for binder in &random_binders {
1405                stmt.and_where(binder.clone());
1406            }
1407
1408            random_binders
1409                .push(sea_query::Expr::column((table.clone(), "dummy".to_string())).is_not_null())
1410        }
1411        stmt
1412    };
1413
1414    let mut iter = q.queries.into_iter();
1415    let mut table = sea_query::Query::select();
1416
1417    if let Some((first_table_name, first)) = iter.next() {
1418        let expr = insert_dummy(first.expr, &first_table_name);
1419        table.from_subquery(expr, first_table_name);
1420    };
1421
1422    for (table_name, q) in iter {
1423        let expr = insert_dummy(q.expr, &table_name);
1424        table.join_lateral(
1425            // normally a cross join, but sea_query doesn't support omitting the `ON` for cross joins (:
1426            // and CROSS JOIN is INNER JOIN ON TRUE
1427            sea_query::JoinType::InnerJoin,
1428            expr,
1429            table_name,
1430            sea_query::Condition::all(),
1431        );
1432    }
1433
1434    for filter in q.filters {
1435        table.and_where(filter.render());
1436    }
1437
1438    subst_table(&mut e, TableName::new(q.binder), &mut table);
1439
1440    // if we needed to add random calls, so do parents
1441    let mut q = Query::new(q.binder, table.to_owned(), e);
1442    q.siblings_need_random = needs_random;
1443    q
1444}
1445
1446/// Construct an insert statement, the result of the query `rows` will be inserted into `into`.
1447pub struct Insert<T: TableHKT<Mode = NameMode>> {
1448    pub into: TableSchema<T>,
1449    pub rows: Query<T::Of<ExprMode>>,
1450}
1451
1452#[cfg(feature = "sqlx")]
1453impl<T: TableHKT<Mode = NameMode>> Insert<T>
1454where
1455    T: MapTable<'static>, // where T::Of<ExprMode>: TableLoaderSqlx,
1456{
1457    /// Run the insert statement
1458    pub async fn run(
1459        &self,
1460        pool: &mut sqlx::AnyConnection,
1461    ) -> sqlx::Result<sqlx::any::AnyQueryResult> {
1462        use sea_query::PostgresQueryBuilder;
1463        use sea_query_sqlx::SqlxBinder as _;
1464
1465        let mut insert = sea_query::Query::insert()
1466            .into_table(self.into.name)
1467            .to_owned();
1468
1469        let mut mapper = NameCollectorMapper { names: Vec::new() };
1470        self.into.columns.map_modes_ref(&mut mapper);
1471
1472        insert.columns(mapper.names);
1473        insert.select_from(self.rows.expr.clone()).unwrap();
1474
1475        let (sql, values) = insert.build_sqlx(PostgresQueryBuilder);
1476
1477        let all = sqlx::query_with(&sql, values).execute(pool).await?;
1478
1479        Ok(all)
1480    }
1481}
1482
1483/// A set of helper tables that are equivalent to tuples.
1484///
1485/// We need these as an alternative to just tuples in order to be parameterised
1486/// by the [TableMode].
1487pub mod helper_tables {
1488    use super::*;
1489    use rust_rel8_derive::TableStruct;
1490
1491    #[derive(TableStruct)]
1492    #[table(crate = "crate")]
1493    /// A helper table with one field.
1494    pub struct One<'scope, Mode: TableMode, #[table(proxy)] A: Value> {
1495        pub a: Mode::T<'scope, A>,
1496    }
1497
1498    #[derive(TableStruct)]
1499    #[table(crate = "crate")]
1500    /// A helper table with two fields.
1501    pub struct Two<'scope, Mode: TableMode, #[table(proxy)] A: Value, #[table(proxy)] B: Value> {
1502        pub a: Mode::T<'scope, A>,
1503        pub b: Mode::T<'scope, B>,
1504    }
1505
1506    #[derive(TableStruct)]
1507    #[table(crate = "crate")]
1508    /// A helper table with three fields.
1509    pub struct Three<
1510        'scope,
1511        Mode: TableMode,
1512        #[table(proxy)] A: Value,
1513        #[table(proxy)] B: Value,
1514        #[table(proxy)] C: Value,
1515    > {
1516        pub a: Mode::T<'scope, A>,
1517        pub b: Mode::T<'scope, B>,
1518        pub c: Mode::T<'scope, C>,
1519    }
1520}
1521
1522pub mod helper_utilities {
1523    use std::{collections::HashMap, hash::Hash};
1524
1525    /// A helper trait which implements `shorten_lifetime` for some common wrapper types.
1526    pub trait ShortenLifetime {
1527        type Shortened<'small>
1528        where
1529            Self: 'small;
1530
1531        /// Shorten a lifetime, normally rust does this automatically, but if
1532        /// the lifetime is invariant due to being used in a Gat or trait, we
1533        /// need to do it manually.
1534        ///
1535        /// If rust complains about a lifetime being invariant, you should call
1536        /// this method at the use site where the lifetime error is generated.
1537        fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
1538        where
1539            Self: 'large;
1540    }
1541
1542    impl<T: ShortenLifetime, const N: usize> ShortenLifetime for [T; N] {
1543        type Shortened<'small>
1544            = [T::Shortened<'small>; N]
1545        where
1546            Self: 'small;
1547
1548        fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
1549        where
1550            Self: 'large,
1551        {
1552            self.map(ShortenLifetime::shorten_lifetime)
1553        }
1554    }
1555
1556    impl<T: ShortenLifetime> ShortenLifetime for Vec<T> {
1557        type Shortened<'small>
1558            = Vec<T::Shortened<'small>>
1559        where
1560            Self: 'small;
1561
1562        fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
1563        where
1564            Self: 'large,
1565        {
1566            self.into_iter()
1567                .map(ShortenLifetime::shorten_lifetime)
1568                .collect::<Vec<_>>()
1569        }
1570    }
1571
1572    impl<K: Hash + Eq, T: ShortenLifetime> ShortenLifetime for HashMap<K, T> {
1573        type Shortened<'small>
1574            = HashMap<K, T::Shortened<'small>>
1575        where
1576            Self: 'small;
1577
1578        fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
1579        where
1580            Self: 'large,
1581        {
1582            self.into_iter()
1583                .map(|(k, v)| (k, v.shorten_lifetime()))
1584                .collect::<HashMap<_, _>>()
1585        }
1586    }
1587}
1588
1589pub use helper_utilities::ShortenLifetime;
1590
1591#[cfg(feature = "derive")]
1592pub use rust_rel8_derive::TableStruct;