Skip to main content

spacetimedb_query_builder/
join.rs

1use crate::TableNameStr;
2
3use super::{
4    expr::{format_expr, BoolExpr},
5    table::{CanBeLookupTable, ColumnRef, HasCols, HasIxCols, Table},
6    Query, RawQuery,
7};
8use std::marker::PhantomData;
9
10/// Indexed columns for joins
11///
12/// Joins are performed on indexed columns, Tables that implement `HasIxCols`
13/// provide access to their indexed columns.
14pub struct IxCol<T, V> {
15    pub(super) col: ColumnRef<T>,
16    _marker: PhantomData<V>,
17}
18
19impl<T, V> IxCol<T, V> {
20    pub fn new(table_name: TableNameStr, column: &'static str) -> Self {
21        Self {
22            col: ColumnRef::new(table_name, column),
23            _marker: PhantomData,
24        }
25    }
26}
27
28impl<T, V> Copy for IxCol<T, V> {}
29impl<T, V> Clone for IxCol<T, V> {
30    fn clone(&self) -> Self {
31        *self
32    }
33}
34
35pub struct IxJoinEq<L, R, V> {
36    pub(super) lhs_col: ColumnRef<L>,
37    pub(super) rhs_col: ColumnRef<R>,
38    _marker: PhantomData<V>,
39}
40
41impl<T, V> IxCol<T, V> {
42    pub fn eq<R: HasIxCols>(self, rhs: IxCol<R, V>) -> IxJoinEq<T, R, V> {
43        IxJoinEq {
44            lhs_col: self.col,
45            rhs_col: rhs.col,
46            _marker: PhantomData,
47        }
48    }
49}
50
51// Left semijoin: filters and returns left table rows
52pub struct LeftSemiJoin<L> {
53    pub(super) left_col: ColumnRef<L>,
54    pub(super) right_table: &'static str,
55    pub(super) right_col: &'static str,
56    pub(super) where_expr: Option<BoolExpr<L>>,
57}
58
59// Right semijoin: returns right table rows, but remembers left conditions
60pub struct RightSemiJoin<R, L> {
61    pub(super) left_col: ColumnRef<L>,
62    pub(super) right_col: ColumnRef<R>,
63    pub(super) left_where_expr: Option<BoolExpr<L>>,
64    pub(super) right_where_expr: Option<BoolExpr<R>>,
65    _left_marker: PhantomData<L>,
66}
67
68impl<L: HasIxCols> Table<L> {
69    pub fn left_semijoin<R: CanBeLookupTable, V>(
70        self,
71        right: Table<R>,
72        on: impl Fn(&L::IxCols, &R::IxCols) -> IxJoinEq<L, R, V>,
73    ) -> LeftSemiJoin<L> {
74        let join = on(&L::ix_cols(self.name()), &R::ix_cols(right.name()));
75        LeftSemiJoin {
76            left_col: join.lhs_col,
77            right_table: right.name(),
78            right_col: join.rhs_col.column_name(),
79            where_expr: None,
80        }
81    }
82
83    pub fn right_semijoin<R: CanBeLookupTable, V>(
84        self,
85        right: Table<R>,
86        on: impl Fn(&L::IxCols, &R::IxCols) -> IxJoinEq<L, R, V>,
87    ) -> RightSemiJoin<R, L> {
88        let join = on(&L::ix_cols(self.name()), &R::ix_cols(right.name()));
89        RightSemiJoin {
90            left_col: join.lhs_col,
91            right_col: join.rhs_col,
92            left_where_expr: None,
93            right_where_expr: None,
94            _left_marker: PhantomData,
95        }
96    }
97}
98
99impl<L: HasIxCols> super::FromWhere<L> {
100    pub fn left_semijoin<R: CanBeLookupTable, V>(
101        self,
102        right: Table<R>,
103        on: impl Fn(&L::IxCols, &R::IxCols) -> IxJoinEq<L, R, V>,
104    ) -> LeftSemiJoin<L> {
105        let join = on(&L::ix_cols(self.table_name), &R::ix_cols(right.name()));
106        LeftSemiJoin {
107            left_col: join.lhs_col,
108            right_table: right.name(),
109            right_col: join.rhs_col.column_name(),
110            where_expr: Some(self.expr),
111        }
112    }
113
114    pub fn right_semijoin<R: CanBeLookupTable, V>(
115        self,
116        right: Table<R>,
117        on: impl Fn(&L::IxCols, &R::IxCols) -> IxJoinEq<L, R, V>,
118    ) -> RightSemiJoin<R, L> {
119        let join = on(&L::ix_cols(self.table_name), &R::ix_cols(right.name()));
120        RightSemiJoin {
121            left_col: join.lhs_col,
122            right_col: join.rhs_col,
123            left_where_expr: Some(self.expr),
124            right_where_expr: None,
125            _left_marker: PhantomData,
126        }
127    }
128}
129
130impl<L: HasCols> Query<L> for LeftSemiJoin<L> {
131    fn into_sql(self) -> String {
132        self.build().into_sql()
133    }
134}
135
136impl<R: HasCols, L: HasCols> Query<R> for RightSemiJoin<R, L> {
137    fn into_sql(self) -> String {
138        self.build().into_sql()
139    }
140}
141
142// LeftSemiJoin where() operates on L
143impl<L: HasCols> LeftSemiJoin<L> {
144    pub fn r#where<F>(self, f: F) -> Self
145    where
146        F: Fn(&L::Cols) -> BoolExpr<L>,
147    {
148        let extra = f(&L::cols(self.left_col.table_name()));
149        let new = match self.where_expr {
150            Some(existing) => Some(existing.and(extra)),
151            None => Some(extra),
152        };
153        Self {
154            left_col: self.left_col,
155            right_table: self.right_table,
156            right_col: self.right_col,
157            where_expr: new,
158        }
159    }
160
161    // Filter is an alias for where
162    pub fn filter<F>(self, f: F) -> Self
163    where
164        F: Fn(&L::Cols) -> BoolExpr<L>,
165    {
166        self.r#where(f)
167    }
168
169    pub fn build(self) -> RawQuery<L> {
170        let where_clause = self
171            .where_expr
172            .map(|e| format!(" WHERE {}", format_expr(&e)))
173            .unwrap_or_default();
174
175        let sql = format!(
176            r#"SELECT "{}".* FROM "{}" JOIN "{}" ON "{}"."{}" = "{}"."{}"{}"#,
177            self.left_col.table_name(),
178            self.left_col.table_name(),
179            self.right_table,
180            self.left_col.table_name(),
181            self.left_col.column_name(),
182            self.right_table,
183            self.right_col,
184            where_clause
185        );
186        RawQuery::new(sql)
187    }
188}
189
190// RightSemiJoin where() operates on R
191impl<R: HasCols, L: HasCols> RightSemiJoin<R, L> {
192    pub fn r#where<F>(self, f: F) -> Self
193    where
194        F: Fn(&R::Cols) -> BoolExpr<R>,
195    {
196        let extra = f(&R::cols(self.right_col.table_name()));
197        let new = match self.right_where_expr {
198            Some(existing) => Some(existing.and(extra)),
199            None => Some(extra),
200        };
201        Self {
202            left_col: self.left_col,
203            right_col: self.right_col,
204            left_where_expr: self.left_where_expr,
205            right_where_expr: new,
206            _left_marker: PhantomData,
207        }
208    }
209
210    // Filter is an alias for where
211    pub fn filter<F>(self, f: F) -> Self
212    where
213        F: Fn(&R::Cols) -> BoolExpr<R>,
214    {
215        self.r#where(f)
216    }
217
218    pub fn build(self) -> RawQuery<R> {
219        let mut where_parts = Vec::new();
220
221        if let Some(left_expr) = self.left_where_expr {
222            where_parts.push(format_expr(&left_expr));
223        }
224
225        if let Some(right_expr) = self.right_where_expr {
226            where_parts.push(format_expr(&right_expr));
227        }
228
229        let where_clause = if !where_parts.is_empty() {
230            format!(" WHERE {}", where_parts.join(" AND "))
231        } else {
232            String::new()
233        };
234
235        let sql = format!(
236            r#"SELECT "{}".* FROM "{}" JOIN "{}" ON "{}"."{}" = "{}"."{}"{}"#,
237            self.right_col.table_name(),
238            self.left_col.table_name(),
239            self.right_col.table_name(),
240            self.left_col.table_name(),
241            self.left_col.column_name(),
242            self.right_col.table_name(),
243            self.right_col.column_name(),
244            where_clause
245        );
246        RawQuery::new(sql)
247    }
248}