Skip to main content

spacetimedb_query_builder/
lib.rs

1pub mod expr;
2pub mod join;
3pub mod table;
4
5pub use expr::*;
6pub use join::*;
7use spacetimedb_lib::{sats::impl_st, AlgebraicType, SpacetimeType};
8pub use table::*;
9
10/// Trait implemented by all query builder types. Use `impl Query<T>` as a
11/// return type for view functions and helpers.
12pub trait Query<T> {
13    fn into_sql(self) -> String;
14}
15
16/// The concrete SQL query produced by calling `.build()` on a builder.
17pub struct RawQuery<T> {
18    pub(crate) sql: String,
19    _marker: std::marker::PhantomData<T>,
20}
21
22impl<T> RawQuery<T> {
23    pub fn new(sql: String) -> Self {
24        Self {
25            sql,
26            _marker: std::marker::PhantomData,
27        }
28    }
29
30    pub fn sql(&self) -> &str {
31        &self.sql
32    }
33}
34
35impl<T> Query<T> for RawQuery<T> {
36    fn into_sql(self) -> String {
37        self.sql
38    }
39}
40
41impl_st!([T: SpacetimeType] RawQuery<T>, ts => AlgebraicType::option(T::make_type(ts)));
42
43#[cfg(test)]
44mod tests {
45    use spacetimedb_lib::{sats::i256, TimeDuration};
46
47    use super::*;
48    struct User;
49    #[derive(Clone)]
50    struct UserCols {
51        pub id: Col<User, i32>,
52        pub name: Col<User, String>,
53        pub age: Col<User, i32>,
54    }
55    impl UserCols {
56        fn new(table_name: &'static str) -> Self {
57            Self {
58                id: Col::new(table_name, "id"),
59                name: Col::new(table_name, "name"),
60                age: Col::new(table_name, "age"),
61            }
62        }
63    }
64    impl HasCols for User {
65        type Cols = UserCols;
66        fn cols(table_name: &'static str) -> Self::Cols {
67            UserCols::new(table_name)
68        }
69    }
70    fn users() -> Table<User> {
71        Table::new("users")
72    }
73    fn other() -> Table<Other> {
74        Table::new("other")
75    }
76    struct OtherCols {
77        pub uid: Col<Other, i32>,
78    }
79
80    impl HasCols for Other {
81        type Cols = OtherCols;
82        fn cols(table: &'static str) -> Self::Cols {
83            OtherCols {
84                uid: Col::new(table, "uid"),
85            }
86        }
87    }
88    struct IxUserCols {
89        pub id: IxCol<User, i32>,
90    }
91    impl HasIxCols for User {
92        type IxCols = IxUserCols;
93        fn ix_cols(table_name: &'static str) -> Self::IxCols {
94            IxUserCols {
95                id: IxCol::new(table_name, "id"),
96            }
97        }
98    }
99    struct Other;
100    #[derive(Clone)]
101    struct IxOtherCols {
102        pub uid: IxCol<Other, i32>,
103    }
104    impl HasIxCols for Other {
105        type IxCols = IxOtherCols;
106        fn ix_cols(table_name: &'static str) -> Self::IxCols {
107            IxOtherCols {
108                uid: IxCol::new(table_name, "uid"),
109            }
110        }
111    }
112    impl CanBeLookupTable for User {}
113    impl CanBeLookupTable for Other {}
114    fn norm(s: &str) -> String {
115        s.split_whitespace().collect::<Vec<_>>().join(" ")
116    }
117    #[test]
118    fn test_simple_select() {
119        let q = users().build();
120        assert_eq!(q.sql(), r#"SELECT * FROM "users""#);
121    }
122    #[test]
123    fn test_where_literal() {
124        let q = users().r#where(|c| c.id.eq(10)).build();
125        let expected = r#"SELECT * FROM "users" WHERE ("users"."id" = 10)"#;
126        assert_eq!(norm(q.sql()), norm(expected));
127    }
128    #[test]
129    fn test_where_multiple_predicates() {
130        let q = users().r#where(|c| c.id.eq(10)).r#where(|c| c.age.gt(18)).build();
131        let expected = r#"SELECT * FROM "users" WHERE (("users"."id" = 10) AND ("users"."age" > 18))"#;
132        assert_eq!(norm(q.sql()), norm(expected));
133    }
134
135    #[test]
136    fn test_where_gte_lte() {
137        let q = users().r#where(|c| c.age.gte(18)).r#where(|c| c.age.lte(30)).build();
138        let expected = r#"SELECT * FROM "users" WHERE (("users"."age" >= 18) AND ("users"."age" <= 30))"#;
139        assert_eq!(norm(q.sql()), norm(expected));
140    }
141
142    #[test]
143    fn test_column_column_comparison() {
144        let q = users().r#where(|c| c.age.gt(c.id)).build();
145        let expected = r#"SELECT * FROM "users" WHERE ("users"."age" > "users"."id")"#;
146        assert_eq!(norm(q.sql()), norm(expected));
147    }
148    #[test]
149    fn test_ne_comparison() {
150        let q = users().r#where(|c| c.name.ne("Shub".to_string())).build();
151        assert!(q.sql().contains("name"), "Expected a name comparison");
152        assert!(q.sql().contains("<>"));
153    }
154
155    #[test]
156    fn test_not_comparison() {
157        let q = users().r#where(|c| c.name.eq("Alice".to_string()).not()).build();
158        let expected = r#"SELECT * FROM "users" WHERE (NOT ("users"."name" = 'Alice'))"#;
159        assert_eq!(norm(q.sql()), norm(expected));
160    }
161
162    #[test]
163    fn test_not_with_and() {
164        let q = users()
165            .r#where(|c| c.name.eq("Alice".to_string()).not().and(c.age.gt(18)))
166            .build();
167        let expected = r#"SELECT * FROM "users" WHERE ((NOT ("users"."name" = 'Alice')) AND ("users"."age" > 18))"#;
168        assert_eq!(norm(q.sql()), norm(expected));
169    }
170
171    #[test]
172    fn test_filter_alias() {
173        let q = users().filter(|c| c.id.eq(5)).filter(|c| c.age.lt(30)).build();
174        let expected = r#"SELECT * FROM "users" WHERE (("users"."id" = 5) AND ("users"."age" < 30))"#;
175        assert_eq!(norm(q.sql()), norm(expected));
176    }
177
178    #[test]
179    fn test_or_comparison() {
180        let q = users()
181            .r#where(|c| c.name.ne("Shub".to_string()).or(c.name.ne("Pop".to_string())))
182            .build();
183
184        let expected = r#"SELECT * FROM "users" WHERE (("users"."name" <> 'Shub') OR ("users"."name" <> 'Pop'))"#;
185        assert_eq!(q.sql, expected);
186    }
187
188    #[test]
189    fn test_format_expr_column_literal() {
190        let expr = BoolExpr::Eq(
191            Operand::Column(ColumnRef::<User>::new("user", "id")),
192            Operand::Literal(LiteralValue::new("42".to_string())),
193        );
194        let sql = format_expr(&expr);
195        assert!(sql.contains("id"), "Missing col");
196        assert!(sql.contains("42"), "Missing literal");
197    }
198
199    #[test]
200    fn test_format_semi_join_expr() {
201        let user = users();
202        let other = other();
203        let sql = user.left_semijoin(other, |u, o| u.id.eq(o.uid)).build().sql;
204        let expected = r#"SELECT "users".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid""#;
205        assert_eq!(sql, expected);
206    }
207
208    #[test]
209    fn test_left_semijoin_with_where_expr() {
210        let user = users();
211        let o = other();
212        let sql = user
213            .left_semijoin(o, |u, o| u.id.eq(o.uid))
214            .r#where(|u| u.id.eq(1i32))
215            .r#where(|u| u.id.gt(10))
216            .build()
217            .sql;
218        let expected = r#"SELECT "users".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid" WHERE (("users"."id" = 1) AND ("users"."id" > 10))"#;
219        assert_eq!(sql, expected);
220        let user = users();
221        let other = other();
222        let sql2 = user
223            .r#where(|u| u.id.eq(1))
224            .r#where(|u| u.id.gt(10))
225            .left_semijoin(other, |u, o| u.id.eq(o.uid))
226            .build()
227            .sql;
228        assert_eq!(sql2, expected);
229    }
230    #[test]
231    fn test_right_semijoin_with_where_expr() {
232        let user = users();
233        let o = other();
234        let sql = user
235            .right_semijoin(o, |u, o| u.id.eq(o.uid))
236            .r#where(|o| o.uid.eq(1))
237            .r#where(|o| o.uid.gt(10))
238            .build()
239            .sql;
240        let expected = r#"SELECT "other".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid" WHERE (("other"."uid" = 1) AND ("other"."uid" > 10))"#;
241        assert_eq!(sql, expected);
242    }
243
244    #[test]
245    fn test_right_semijoin_with_left_and_right_where_expr() {
246        let user = users();
247        let o = other();
248        let sql = user
249            .r#where(|u| u.id.eq(1))
250            .right_semijoin(o, |u, o| u.id.eq(o.uid))
251            .r#where(|o| o.uid.gt(10))
252            .build()
253            .sql;
254        let expected = r#"SELECT "other".* FROM "users" JOIN "other" ON "users"."id" = "other"."uid" WHERE ("users"."id" = 1) AND ("other"."uid" > 10)"#;
255        assert_eq!(sql, expected);
256    }
257
258    #[test]
259    fn test_literals() {
260        use spacetimedb_lib::{ConnectionId, Identity};
261
262        struct Player;
263        struct PlayerCols {
264            score: Col<Player, i32>,
265            name: Col<Player, String>,
266            active: Col<Player, bool>,
267            connection_id: Col<Player, ConnectionId>,
268            cells: Col<Player, i256>,
269            identity: Col<Player, Identity>,
270            ts: Col<Player, spacetimedb_lib::Timestamp>,
271            bytes: Col<Player, Vec<u8>>,
272        }
273
274        impl HasCols for Player {
275            type Cols = PlayerCols;
276            fn cols(table_name: &'static str) -> Self::Cols {
277                PlayerCols {
278                    score: Col::new(table_name, "score"),
279                    name: Col::new(table_name, "name"),
280                    active: Col::new(table_name, "active"),
281                    connection_id: Col::new(table_name, "connection_id"),
282                    cells: Col::new(table_name, "cells"),
283                    identity: Col::new(table_name, "identity"),
284                    ts: Col::new(table_name, "ts"),
285                    bytes: Col::new(table_name, "bytes"),
286                }
287            }
288        }
289
290        let table = Table::<Player>::new("player");
291        let q = table.r#where(|c| c.score.eq(100)).build();
292
293        assert_eq!(q.sql, r#"SELECT * FROM "player" WHERE ("player"."score" = 100)"#);
294
295        let table = Table::<Player>::new("player");
296        let q = table.r#where(|c| c.name.ne("Alice".to_string())).build();
297
298        assert_eq!(q.sql, r#"SELECT * FROM "player" WHERE ("player"."name" <> 'Alice')"#);
299
300        let table = Table::<Player>::new("player");
301        let q = table.r#where(|c| c.active.eq(true)).build();
302
303        assert_eq!(q.sql, r#"SELECT * FROM "player" WHERE ("player"."active" = TRUE)"#);
304
305        let table = Table::<Player>::new("player");
306        let q = table.r#where(|c| c.connection_id.eq(ConnectionId::ZERO)).build();
307
308        assert_eq!(
309            q.sql,
310            r#"SELECT * FROM "player" WHERE ("player"."connection_id" = 0x00000000000000000000000000000000)"#
311        );
312
313        let big_int: i256 = (i256::ONE << 120) * i256::from(-1);
314
315        let table = Table::<Player>::new("player");
316        let q = table.r#where(|c| c.cells.gt(big_int)).build();
317
318        assert_eq!(
319            q.sql,
320            r#"SELECT * FROM "player" WHERE ("player"."cells" > -1329227995784915872903807060280344576)"#,
321        );
322
323        let table = Table::<Player>::new("player");
324        let q = table.r#where(|c| c.identity.ne(Identity::ONE)).build();
325
326        assert_eq!(
327            q.sql,
328            r#"SELECT * FROM "player" WHERE ("player"."identity" <> 0x0000000000000000000000000000000000000000000000000000000000000001)"#
329        );
330
331        let ts = spacetimedb_lib::Timestamp::UNIX_EPOCH + TimeDuration::from_micros(1000);
332
333        let table = Table::<Player>::new("player");
334        let q = table.r#where(|c| c.ts.eq(ts)).build();
335        assert_eq!(
336            q.sql,
337            r#"SELECT * FROM "player" WHERE ("player"."ts" = '1970-01-01T00:00:00.001+00:00')"#
338        );
339
340        let table = Table::<Player>::new("player");
341        let q = table.r#where(|c| c.bytes.eq(vec![1, 2, 3, 4, 255])).build();
342
343        assert_eq!(
344            q.sql,
345            r#"SELECT * FROM "player" WHERE ("player"."bytes" = 0x01020304ff)"#
346        );
347    }
348}