Skip to main content

wasm_dbms/
join.rs

1// Rust guideline compliant 2026-03-01
2// X-WHERE-CLAUSE, M-CANONICAL-DOCS
3
4//! Join execution engine for cross-table queries.
5
6use wasm_dbms_api::prelude::{
7    CandidColumnDef, ColumnDef, DbmsResult, JoinType, OrderDirection, Query, Value,
8};
9use wasm_dbms_memory::prelude::{AccessControl, AccessControlList, MemoryProvider};
10
11use crate::database::WasmDbmsDatabase;
12use crate::schema::DatabaseSchema;
13
14/// A row in the joined result, organized by source table.
15type JoinedRow = Vec<(String, Vec<(ColumnDef, Value)>)>;
16
17/// Engine that executes join queries using nested-loop join.
18pub struct JoinEngine<'a, Schema: ?Sized, M, A = AccessControlList>
19where
20    Schema: DatabaseSchema<M, A>,
21    M: MemoryProvider,
22    A: AccessControl,
23{
24    schema: &'a Schema,
25    _marker: std::marker::PhantomData<(M, A)>,
26}
27
28impl<'a, Schema: ?Sized, M, A> JoinEngine<'a, Schema, M, A>
29where
30    Schema: DatabaseSchema<M, A>,
31    M: MemoryProvider,
32    A: AccessControl,
33{
34    pub fn new(schema: &'a Schema) -> Self {
35        Self {
36            schema,
37            _marker: std::marker::PhantomData,
38        }
39    }
40}
41
42impl<Schema: ?Sized, M, A> JoinEngine<'_, Schema, M, A>
43where
44    Schema: DatabaseSchema<M, A>,
45    M: MemoryProvider,
46    A: AccessControl,
47{
48    /// Executes a join query using nested-loop join.
49    pub fn join(
50        &self,
51        dbms: &WasmDbmsDatabase<'_, M, A>,
52        from_table: &str,
53        query: Query,
54    ) -> DbmsResult<Vec<Vec<(CandidColumnDef, Value)>>> {
55        let from_rows = self
56            .schema
57            .select(dbms, from_table, Query::builder().all().build())?;
58
59        let mut joined_rows: Vec<JoinedRow> = from_rows
60            .into_iter()
61            .map(|row| vec![(from_table.to_string(), row)])
62            .collect();
63
64        for join in &query.joins {
65            let right_rows =
66                self.schema
67                    .select(dbms, &join.table, Query::builder().all().build())?;
68
69            let (left_table, left_col) = self.resolve_column_ref(&join.left_column, from_table);
70            let (_right_table_ref, right_col) =
71                self.resolve_column_ref(&join.right_column, &join.table);
72
73            let (keep_unmatched_left, keep_unmatched_right) = match join.join_type {
74                JoinType::Inner => (false, false),
75                JoinType::Left => (true, false),
76                JoinType::Right => (false, true),
77                JoinType::Full => (true, true),
78            };
79
80            joined_rows = self.nested_loop_join(
81                joined_rows,
82                &right_rows,
83                &join.table,
84                &left_table,
85                left_col,
86                right_col,
87                keep_unmatched_left,
88                keep_unmatched_right,
89            );
90        }
91
92        if let Some(filter) = &query.filter {
93            joined_rows.retain(|row| {
94                let groups: Vec<(&str, Vec<(ColumnDef, Value)>)> = row
95                    .iter()
96                    .map(|(t, cols)| (t.as_str(), cols.clone()))
97                    .collect();
98                filter.matches_joined_row(&groups).unwrap_or(false)
99            });
100        }
101
102        for (column, direction) in query.order_by.iter().rev() {
103            self.sort_joined_rows(&mut joined_rows, column, *direction);
104        }
105
106        let offset = query.offset.unwrap_or_default();
107        if offset > 0 {
108            if offset >= joined_rows.len() {
109                joined_rows.clear();
110            } else {
111                joined_rows = joined_rows.into_iter().skip(offset).collect();
112            }
113        }
114
115        if let Some(limit) = query.limit {
116            joined_rows.truncate(limit);
117        }
118
119        let results = joined_rows
120            .into_iter()
121            .map(|row| self.flatten_joined_row(row, &query))
122            .collect::<DbmsResult<Vec<_>>>()?;
123
124        Ok(results)
125    }
126
127    /// Unified nested-loop join.
128    #[allow(clippy::too_many_arguments)]
129    fn nested_loop_join(
130        &self,
131        left_rows: Vec<JoinedRow>,
132        right_rows: &[Vec<(ColumnDef, Value)>],
133        right_table: &str,
134        left_table: &str,
135        left_col: &str,
136        right_col: &str,
137        keep_unmatched_left: bool,
138        keep_unmatched_right: bool,
139    ) -> Vec<JoinedRow> {
140        let mut results = Vec::new();
141        let mut right_matched = vec![false; right_rows.len()];
142
143        for left_row in &left_rows {
144            let left_value = self.get_column_value(left_row, left_table, left_col);
145            let mut matched = false;
146
147            for (i, right_row) in right_rows.iter().enumerate() {
148                let right_value = right_row
149                    .iter()
150                    .find(|(c, _)| c.name == right_col)
151                    .map(|(_, v)| v);
152
153                if left_value == right_value && left_value.is_some() {
154                    let mut new_row = left_row.clone();
155                    new_row.push((right_table.to_string(), right_row.clone()));
156                    results.push(new_row);
157                    right_matched[i] = true;
158                    matched = true;
159                }
160            }
161
162            if keep_unmatched_left && !matched {
163                let mut new_row = left_row.clone();
164                let null_cols = right_rows
165                    .first()
166                    .map(|sample| self.null_pad_columns(sample))
167                    .unwrap_or_default();
168                new_row.push((right_table.to_string(), null_cols));
169                results.push(new_row);
170            }
171        }
172
173        if keep_unmatched_right {
174            for (i, right_row) in right_rows.iter().enumerate() {
175                if !right_matched[i] {
176                    let mut new_row: JoinedRow = Vec::new();
177                    if let Some(sample_left) = left_rows.first() {
178                        for (table_name, cols) in sample_left {
179                            new_row.push((table_name.clone(), self.null_pad_columns(cols)));
180                        }
181                    }
182                    new_row.push((right_table.to_string(), right_row.clone()));
183                    results.push(new_row);
184                }
185            }
186        }
187
188        results
189    }
190
191    /// Resolves a column reference to (table_name, column_name).
192    fn resolve_column_ref<'a>(&self, field: &'a str, default_table: &'a str) -> (String, &'a str) {
193        if let Some((table, column)) = field.split_once('.') {
194            (table.to_string(), column)
195        } else {
196            (default_table.to_string(), field)
197        }
198    }
199
200    /// Finds a column value in a joined row.
201    fn get_column_value<'a>(
202        &self,
203        row: &'a JoinedRow,
204        table: &str,
205        column: &str,
206    ) -> Option<&'a Value> {
207        row.iter()
208            .find(|(t, _)| t == table)
209            .and_then(|(_, cols)| cols.iter().find(|(c, _)| c.name == column).map(|(_, v)| v))
210    }
211
212    /// Creates a NULL-padded row.
213    fn null_pad_columns(&self, sample_row: &[(ColumnDef, Value)]) -> Vec<(ColumnDef, Value)> {
214        sample_row
215            .iter()
216            .map(|(col, _)| (*col, Value::Null))
217            .collect()
218    }
219
220    /// Sorts joined rows by a column.
221    fn sort_joined_rows(&self, rows: &mut [JoinedRow], column: &str, direction: OrderDirection) {
222        let (table, col) = if let Some((t, c)) = column.split_once('.') {
223            (Some(t), c)
224        } else {
225            (None, column)
226        };
227
228        rows.sort_by(|a, b| {
229            let a_val = self.find_value_in_joined_row(a, table, col);
230            let b_val = self.find_value_in_joined_row(b, table, col);
231
232            crate::database::sort_values_with_direction(a_val, b_val, direction)
233        });
234    }
235
236    /// Finds a column value in a joined row, optionally scoped to a table.
237    fn find_value_in_joined_row<'a>(
238        &self,
239        row: &'a JoinedRow,
240        table: Option<&str>,
241        column: &str,
242    ) -> Option<&'a Value> {
243        if let Some(table) = table {
244            return self.get_column_value(row, table, column);
245        }
246        row.iter()
247            .flat_map(|(_, cols)| cols)
248            .find_map(|(col, value)| {
249                if col.name == column {
250                    Some(value)
251                } else {
252                    None
253                }
254            })
255    }
256
257    /// Flattens a joined row into the output format.
258    fn flatten_joined_row(
259        &self,
260        row: JoinedRow,
261        query: &Query,
262    ) -> DbmsResult<Vec<(CandidColumnDef, Value)>> {
263        let mut result = Vec::new();
264
265        for (table_name, cols) in row {
266            for (col, val) in cols {
267                let mut candid_col = CandidColumnDef::from(col);
268                candid_col.table = Some(table_name.clone());
269
270                if !query.all_selected() {
271                    let selected = query.raw_columns();
272                    let qualified_name = format!("{table_name}.{col}", col = candid_col.name);
273                    if !selected.contains(&candid_col.name) && !selected.contains(&qualified_name) {
274                        continue;
275                    }
276                }
277
278                result.push((candid_col, val));
279            }
280        }
281
282        Ok(result)
283    }
284}
285
286#[cfg(test)]
287mod tests {
288
289    use wasm_dbms_api::prelude::{
290        Database as _, Filter, InsertRecord as _, Query, TableSchema as _, Text, Uint32, Value,
291    };
292    use wasm_dbms_macros::{DatabaseSchema, Table};
293    use wasm_dbms_memory::prelude::HeapMemoryProvider;
294
295    use crate::prelude::{DbmsContext, WasmDbmsDatabase};
296
297    // Use tables WITHOUT foreign key constraints so we can test all join
298    // types including unmatched rows without FK validation failures.
299
300    #[derive(Debug, Table, Clone, PartialEq, Eq)]
301    #[table = "departments"]
302    pub struct Department {
303        #[primary_key]
304        pub id: Uint32,
305        pub name: Text,
306    }
307
308    #[derive(Debug, Table, Clone, PartialEq, Eq)]
309    #[table = "employees"]
310    pub struct Employee {
311        #[primary_key]
312        pub id: Uint32,
313        pub name: Text,
314        pub dept_id: Uint32,
315    }
316
317    #[derive(DatabaseSchema)]
318    #[tables(Department = "departments", Employee = "employees")]
319    pub struct TestSchema;
320
321    fn setup() -> DbmsContext<HeapMemoryProvider> {
322        let ctx = DbmsContext::new(HeapMemoryProvider::default());
323        TestSchema::register_tables(&ctx).unwrap();
324        ctx
325    }
326
327    fn insert_dept(db: &WasmDbmsDatabase<'_, HeapMemoryProvider>, id: u32, name: &str) {
328        let insert = DepartmentInsertRequest::from_values(&[
329            (Department::columns()[0], Value::Uint32(Uint32(id))),
330            (
331                Department::columns()[1],
332                Value::Text(Text(name.to_string())),
333            ),
334        ])
335        .unwrap();
336        db.insert::<Department>(insert).unwrap();
337    }
338
339    fn insert_emp(
340        db: &WasmDbmsDatabase<'_, HeapMemoryProvider>,
341        id: u32,
342        name: &str,
343        dept_id: u32,
344    ) {
345        let insert = EmployeeInsertRequest::from_values(&[
346            (Employee::columns()[0], Value::Uint32(Uint32(id))),
347            (Employee::columns()[1], Value::Text(Text(name.to_string()))),
348            (Employee::columns()[2], Value::Uint32(Uint32(dept_id))),
349        ])
350        .unwrap();
351        db.insert::<Employee>(insert).unwrap();
352    }
353
354    #[test]
355    fn test_inner_join() {
356        let ctx = setup();
357        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
358        insert_dept(&db, 1, "eng");
359        insert_dept(&db, 2, "hr");
360        insert_emp(&db, 10, "alice", 1);
361        insert_emp(&db, 11, "bob", 1);
362
363        let query = Query::builder()
364            .all()
365            .inner_join("employees", "id", "dept_id")
366            .build();
367        let results = db.select_join("departments", query).unwrap();
368        // eng has 2 employees, hr has 0 → 2 rows
369        assert_eq!(results.len(), 2);
370    }
371
372    #[test]
373    fn test_left_join() {
374        let ctx = setup();
375        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
376        insert_dept(&db, 1, "eng");
377        insert_dept(&db, 2, "hr");
378        insert_emp(&db, 10, "alice", 1);
379
380        let query = Query::builder()
381            .all()
382            .left_join("employees", "id", "dept_id")
383            .build();
384        let results = db.select_join("departments", query).unwrap();
385        // eng has 1 employee, hr has 0 but LEFT keeps unmatched left → 2 rows
386        assert_eq!(results.len(), 2);
387
388        // Find hr's row: employee columns should be Null
389        let hr_row = results
390            .iter()
391            .find(|row| {
392                row.iter().any(|(col, val)| {
393                    col.name == "name"
394                        && col.table.as_deref() == Some("departments")
395                        && *val == Value::Text(Text("hr".to_string()))
396                })
397            })
398            .expect("hr should be in results");
399
400        // hr's employee name should be Null
401        let emp_name = hr_row
402            .iter()
403            .find(|(col, _)| col.name == "name" && col.table.as_deref() == Some("employees"))
404            .expect("employee name column should exist for hr");
405        assert_eq!(emp_name.1, Value::Null);
406    }
407
408    #[test]
409    fn test_right_join() {
410        let ctx = setup();
411        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
412        insert_dept(&db, 1, "eng");
413        insert_emp(&db, 10, "alice", 1);
414        // charlie references dept 999 which doesn't exist (no FK constraint)
415        insert_emp(&db, 11, "charlie", 999);
416
417        let query = Query::builder()
418            .all()
419            .right_join("employees", "id", "dept_id")
420            .build();
421        let results = db.select_join("departments", query).unwrap();
422        // alice matches eng, charlie (dept_id=999) is unmatched right → 2 rows
423        assert_eq!(results.len(), 2);
424
425        // charlie should have null department columns
426        let charlie_row = results
427            .iter()
428            .find(|row| {
429                row.iter().any(|(col, val)| {
430                    col.name == "name"
431                        && col.table.as_deref() == Some("employees")
432                        && *val == Value::Text(Text("charlie".to_string()))
433                })
434            })
435            .expect("charlie should be in results");
436
437        let dept_name = charlie_row
438            .iter()
439            .find(|(col, _)| col.name == "name" && col.table.as_deref() == Some("departments"))
440            .expect("department name column should exist for charlie");
441        assert_eq!(dept_name.1, Value::Null);
442    }
443
444    #[test]
445    fn test_full_join() {
446        let ctx = setup();
447        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
448        insert_dept(&db, 1, "eng");
449        insert_dept(&db, 2, "hr");
450        insert_emp(&db, 10, "alice", 1);
451        // charlie references dept 999 which doesn't exist
452        insert_emp(&db, 11, "charlie", 999);
453
454        let query = Query::builder()
455            .all()
456            .full_join("employees", "id", "dept_id")
457            .build();
458        let results = db.select_join("departments", query).unwrap();
459        // eng-alice matched (1), hr unmatched left (1), charlie unmatched right (1) = 3
460        assert_eq!(results.len(), 3);
461    }
462
463    #[test]
464    fn test_join_with_filter() {
465        let ctx = setup();
466        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
467        insert_dept(&db, 1, "eng");
468        insert_dept(&db, 2, "hr");
469        insert_emp(&db, 10, "alice", 1);
470        insert_emp(&db, 11, "bob", 2);
471
472        let query = Query::builder()
473            .all()
474            .inner_join("employees", "id", "dept_id")
475            .and_where(Filter::eq(
476                "departments.name",
477                Value::Text(Text("eng".to_string())),
478            ))
479            .build();
480        let results = db.select_join("departments", query).unwrap();
481        assert_eq!(results.len(), 1);
482    }
483
484    #[test]
485    fn test_join_with_order_by() {
486        let ctx = setup();
487        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
488        insert_dept(&db, 1, "eng");
489        insert_dept(&db, 2, "hr");
490        insert_emp(&db, 10, "zzz", 1);
491        insert_emp(&db, 11, "aaa", 2);
492
493        let query = Query::builder()
494            .all()
495            .inner_join("employees", "id", "dept_id")
496            .order_by_asc("employees.name")
497            .build();
498        let results = db.select_join("departments", query).unwrap();
499        assert_eq!(results.len(), 2);
500        let first_name = results[0]
501            .iter()
502            .find(|(col, _)| col.name == "name" && col.table.as_deref() == Some("employees"))
503            .unwrap();
504        assert_eq!(first_name.1, Value::Text(Text("aaa".to_string())));
505    }
506
507    #[test]
508    fn test_join_with_limit() {
509        let ctx = setup();
510        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
511        insert_dept(&db, 1, "eng");
512        insert_dept(&db, 2, "hr");
513        insert_emp(&db, 10, "alice", 1);
514        insert_emp(&db, 11, "bob", 2);
515
516        let query = Query::builder()
517            .all()
518            .inner_join("employees", "id", "dept_id")
519            .limit(1)
520            .build();
521        let results = db.select_join("departments", query).unwrap();
522        assert_eq!(results.len(), 1);
523    }
524
525    #[test]
526    fn test_join_with_offset() {
527        let ctx = setup();
528        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
529        insert_dept(&db, 1, "eng");
530        insert_dept(&db, 2, "hr");
531        insert_emp(&db, 10, "alice", 1);
532        insert_emp(&db, 11, "bob", 2);
533
534        let query = Query::builder()
535            .all()
536            .inner_join("employees", "id", "dept_id")
537            .offset(1)
538            .build();
539        let results = db.select_join("departments", query).unwrap();
540        assert_eq!(results.len(), 1);
541    }
542
543    #[test]
544    fn test_join_with_column_selection() {
545        let ctx = setup();
546        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
547        insert_dept(&db, 1, "eng");
548        insert_emp(&db, 10, "alice", 1);
549
550        let query = Query::builder()
551            .field("departments.name")
552            .field("employees.name")
553            .inner_join("employees", "id", "dept_id")
554            .build();
555        let results = db.select_join("departments", query).unwrap();
556        assert_eq!(results.len(), 1);
557        assert_eq!(results[0].len(), 2);
558    }
559
560    #[test]
561    fn test_inner_join_empty_result() {
562        let ctx = setup();
563        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
564        insert_dept(&db, 1, "eng");
565        // No employees
566
567        let query = Query::builder()
568            .all()
569            .inner_join("employees", "id", "dept_id")
570            .build();
571        let results = db.select_join("departments", query).unwrap();
572        assert!(results.is_empty());
573    }
574
575    #[test]
576    fn test_join_offset_exceeding_results_returns_empty() {
577        let ctx = setup();
578        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
579        insert_dept(&db, 1, "eng");
580        insert_emp(&db, 10, "alice", 1);
581
582        let query = Query::builder()
583            .all()
584            .inner_join("employees", "id", "dept_id")
585            .offset(100)
586            .build();
587        let results = db.select_join("departments", query).unwrap();
588        assert!(results.is_empty());
589    }
590}