1use std::collections::HashSet;
7
8use wasm_dbms_api::prelude::{
9 ColumnDef, DbmsResult, JoinColumnDef, JoinType, OrderDirection, Query, Value,
10};
11use wasm_dbms_memory::prelude::{AccessControl, AccessControlList, MemoryProvider};
12
13use crate::database::WasmDbmsDatabase;
14use crate::schema::DatabaseSchema;
15
16type JoinedRow = Vec<(String, Vec<(ColumnDef, Value)>)>;
18
19pub struct JoinEngine<'a, Schema: ?Sized, M, A = AccessControlList>
21where
22 Schema: DatabaseSchema<M, A>,
23 M: MemoryProvider,
24 A: AccessControl,
25{
26 schema: &'a Schema,
27 _marker: std::marker::PhantomData<(M, A)>,
28}
29
30impl<'a, Schema: ?Sized, M, A> JoinEngine<'a, Schema, M, A>
31where
32 Schema: DatabaseSchema<M, A>,
33 M: MemoryProvider,
34 A: AccessControl,
35{
36 pub fn new(schema: &'a Schema) -> Self {
37 Self {
38 schema,
39 _marker: std::marker::PhantomData,
40 }
41 }
42}
43
44impl<Schema: ?Sized, M, A> JoinEngine<'_, Schema, M, A>
45where
46 Schema: DatabaseSchema<M, A>,
47 M: MemoryProvider,
48 A: AccessControl,
49{
50 pub fn join(
52 &self,
53 dbms: &WasmDbmsDatabase<'_, M, A>,
54 from_table: &str,
55 query: Query,
56 ) -> DbmsResult<Vec<Vec<(JoinColumnDef, Value)>>> {
57 let from_rows = self
58 .schema
59 .select(dbms, from_table, Query::builder().all().build())?;
60
61 let mut joined_rows: Vec<JoinedRow> = from_rows
62 .into_iter()
63 .map(|row| vec![(from_table.to_string(), row)])
64 .collect();
65
66 for join in &query.joins {
67 let (left_table, left_col) = self.resolve_column_ref(&join.left_column, from_table);
68 let (_right_table_ref, right_col) =
69 self.resolve_column_ref(&join.right_column, &join.table);
70
71 let (keep_unmatched_left, keep_unmatched_right) = match join.join_type {
72 JoinType::Inner => (false, false),
73 JoinType::Left => (true, false),
74 JoinType::Right => (false, true),
75 JoinType::Full => (true, true),
76 };
77
78 let right_rows = self.load_join_right_rows(
79 dbms,
80 &joined_rows,
81 &join.table,
82 &left_table,
83 left_col,
84 right_col,
85 keep_unmatched_right,
86 )?;
87
88 joined_rows = self.nested_loop_join(
89 joined_rows,
90 &right_rows,
91 &join.table,
92 &left_table,
93 left_col,
94 right_col,
95 keep_unmatched_left,
96 keep_unmatched_right,
97 );
98 }
99
100 if let Some(filter) = &query.filter {
101 joined_rows.retain(|row| {
102 let groups: Vec<(&str, Vec<(ColumnDef, Value)>)> = row
103 .iter()
104 .map(|(t, cols)| (t.as_str(), cols.clone()))
105 .collect();
106 filter.matches_joined_row(&groups).unwrap_or(false)
107 });
108 }
109
110 for (column, direction) in query.order_by.iter().rev() {
111 self.sort_joined_rows(&mut joined_rows, column, *direction);
112 }
113
114 let offset = query.offset.unwrap_or_default();
115 if offset > 0 {
116 if offset >= joined_rows.len() {
117 joined_rows.clear();
118 } else {
119 joined_rows = joined_rows.into_iter().skip(offset).collect();
120 }
121 }
122
123 if let Some(limit) = query.limit {
124 joined_rows.truncate(limit);
125 }
126
127 let results = joined_rows
128 .into_iter()
129 .map(|row| self.flatten_joined_row(row, &query))
130 .collect::<DbmsResult<Vec<_>>>()?;
131
132 Ok(results)
133 }
134
135 #[expect(
136 clippy::too_many_arguments,
137 reason = "arguments are necessary for loading right table rows based on join conditions"
138 )]
139 fn load_join_right_rows(
140 &self,
141 dbms: &WasmDbmsDatabase<'_, M, A>,
142 left_rows: &[JoinedRow],
143 right_table: &str,
144 left_table: &str,
145 left_col: &str,
146 right_col: &str,
147 keep_unmatched_right: bool,
148 ) -> DbmsResult<Vec<Vec<(ColumnDef, Value)>>> {
149 let unique_join_values: Vec<Value> = {
150 let mut seen = HashSet::new();
151 left_rows
152 .iter()
153 .filter_map(|row| self.get_column_value(row, left_table, left_col).cloned())
154 .filter(|value| seen.insert(value.clone()))
155 .collect()
156 };
157
158 if unique_join_values.is_empty() || keep_unmatched_right {
159 return self
160 .schema
161 .select(dbms, right_table, Query::builder().all().build());
162 }
163
164 self.schema.select(
165 dbms,
166 right_table,
167 Query::builder()
168 .all()
169 .filter(Some(wasm_dbms_api::prelude::Filter::in_list(
170 right_col,
171 unique_join_values,
172 )))
173 .build(),
174 )
175 }
176
177 #[allow(clippy::too_many_arguments)]
179 fn nested_loop_join(
180 &self,
181 left_rows: Vec<JoinedRow>,
182 right_rows: &[Vec<(ColumnDef, Value)>],
183 right_table: &str,
184 left_table: &str,
185 left_col: &str,
186 right_col: &str,
187 keep_unmatched_left: bool,
188 keep_unmatched_right: bool,
189 ) -> Vec<JoinedRow> {
190 let mut results = Vec::new();
191 let mut right_matched = vec![false; right_rows.len()];
192
193 for left_row in &left_rows {
194 let left_value = self.get_column_value(left_row, left_table, left_col);
195 let mut matched = false;
196
197 for (i, right_row) in right_rows.iter().enumerate() {
198 let right_value = right_row
199 .iter()
200 .find(|(c, _)| c.name == right_col)
201 .map(|(_, v)| v);
202
203 if left_value == right_value && left_value.is_some() {
204 let mut new_row = left_row.clone();
205 new_row.push((right_table.to_string(), right_row.clone()));
206 results.push(new_row);
207 right_matched[i] = true;
208 matched = true;
209 }
210 }
211
212 if keep_unmatched_left && !matched {
213 let mut new_row = left_row.clone();
214 let null_cols = right_rows
215 .first()
216 .map(|sample| self.null_pad_columns(sample))
217 .unwrap_or_default();
218 new_row.push((right_table.to_string(), null_cols));
219 results.push(new_row);
220 }
221 }
222
223 if keep_unmatched_right {
224 for (i, right_row) in right_rows.iter().enumerate() {
225 if !right_matched[i] {
226 let mut new_row: JoinedRow = Vec::new();
227 if let Some(sample_left) = left_rows.first() {
228 for (table_name, cols) in sample_left {
229 new_row.push((table_name.clone(), self.null_pad_columns(cols)));
230 }
231 }
232 new_row.push((right_table.to_string(), right_row.clone()));
233 results.push(new_row);
234 }
235 }
236 }
237
238 results
239 }
240
241 fn resolve_column_ref<'a>(&self, field: &'a str, default_table: &'a str) -> (String, &'a str) {
243 if let Some((table, column)) = field.split_once('.') {
244 (table.to_string(), column)
245 } else {
246 (default_table.to_string(), field)
247 }
248 }
249
250 fn get_column_value<'a>(
252 &self,
253 row: &'a JoinedRow,
254 table: &str,
255 column: &str,
256 ) -> Option<&'a Value> {
257 row.iter()
258 .find(|(t, _)| t == table)
259 .and_then(|(_, cols)| cols.iter().find(|(c, _)| c.name == column).map(|(_, v)| v))
260 }
261
262 fn null_pad_columns(&self, sample_row: &[(ColumnDef, Value)]) -> Vec<(ColumnDef, Value)> {
264 sample_row
265 .iter()
266 .map(|(col, _)| (*col, Value::Null))
267 .collect()
268 }
269
270 fn sort_joined_rows(&self, rows: &mut [JoinedRow], column: &str, direction: OrderDirection) {
272 let (table, col) = if let Some((t, c)) = column.split_once('.') {
273 (Some(t), c)
274 } else {
275 (None, column)
276 };
277
278 rows.sort_by(|a, b| {
279 let a_val = self.find_value_in_joined_row(a, table, col);
280 let b_val = self.find_value_in_joined_row(b, table, col);
281
282 crate::database::sort_values_with_direction(a_val, b_val, direction)
283 });
284 }
285
286 fn find_value_in_joined_row<'a>(
288 &self,
289 row: &'a JoinedRow,
290 table: Option<&str>,
291 column: &str,
292 ) -> Option<&'a Value> {
293 if let Some(table) = table {
294 return self.get_column_value(row, table, column);
295 }
296 row.iter()
297 .flat_map(|(_, cols)| cols)
298 .find_map(|(col, value)| {
299 if col.name == column {
300 Some(value)
301 } else {
302 None
303 }
304 })
305 }
306
307 fn flatten_joined_row(
309 &self,
310 row: JoinedRow,
311 query: &Query,
312 ) -> DbmsResult<Vec<(JoinColumnDef, Value)>> {
313 let mut result = Vec::new();
314
315 for (table_name, cols) in row {
316 for (col, val) in cols {
317 let mut candid_col = JoinColumnDef::from(col);
318 candid_col.table = Some(table_name.clone());
319
320 if !query.all_selected() {
321 let selected = query.raw_columns();
322 let qualified_name = format!("{table_name}.{col}", col = candid_col.name);
323 if !selected.contains(&candid_col.name) && !selected.contains(&qualified_name) {
324 continue;
325 }
326 }
327
328 result.push((candid_col, val));
329 }
330 }
331
332 Ok(result)
333 }
334}
335
336#[cfg(test)]
337mod tests {
338
339 use wasm_dbms_api::prelude::{
340 Database as _, Filter, InsertRecord as _, Query, TableSchema as _, Text, Uint32, Value,
341 };
342 use wasm_dbms_macros::{DatabaseSchema, Table};
343 use wasm_dbms_memory::prelude::HeapMemoryProvider;
344
345 use crate::prelude::{DbmsContext, WasmDbmsDatabase};
346
347 #[derive(Debug, Table, Clone, PartialEq, Eq)]
351 #[table = "departments"]
352 pub struct Department {
353 #[primary_key]
354 pub id: Uint32,
355 pub name: Text,
356 }
357
358 #[derive(Debug, Table, Clone, PartialEq, Eq)]
359 #[table = "employees"]
360 pub struct Employee {
361 #[primary_key]
362 pub id: Uint32,
363 pub name: Text,
364 pub dept_id: Uint32,
365 }
366
367 #[derive(DatabaseSchema)]
368 #[tables(Department = "departments", Employee = "employees")]
369 pub struct TestSchema;
370
371 #[derive(Debug, Table, Clone, PartialEq, Eq)]
372 #[table = "indexed_departments"]
373 pub struct IndexedDepartment {
374 #[primary_key]
375 pub id: Uint32,
376 pub name: Text,
377 }
378
379 #[derive(Debug, Table, Clone, PartialEq, Eq)]
380 #[table = "indexed_employees"]
381 pub struct IndexedEmployee {
382 #[primary_key]
383 pub id: Uint32,
384 pub name: Text,
385 #[index]
386 pub dept_id: Uint32,
387 }
388
389 #[derive(DatabaseSchema)]
390 #[tables(
391 IndexedDepartment = "indexed_departments",
392 IndexedEmployee = "indexed_employees"
393 )]
394 pub struct IndexedJoinSchema;
395
396 fn setup() -> DbmsContext<HeapMemoryProvider> {
397 let ctx = DbmsContext::new(HeapMemoryProvider::default());
398 TestSchema::register_tables(&ctx).unwrap();
399 ctx
400 }
401
402 fn setup_indexed() -> DbmsContext<HeapMemoryProvider> {
403 let ctx = DbmsContext::new(HeapMemoryProvider::default());
404 IndexedJoinSchema::register_tables(&ctx).unwrap();
405 ctx
406 }
407
408 fn insert_dept(db: &WasmDbmsDatabase<'_, HeapMemoryProvider>, id: u32, name: &str) {
409 let insert = DepartmentInsertRequest::from_values(&[
410 (Department::columns()[0], Value::Uint32(Uint32(id))),
411 (
412 Department::columns()[1],
413 Value::Text(Text(name.to_string())),
414 ),
415 ])
416 .unwrap();
417 db.insert::<Department>(insert).unwrap();
418 }
419
420 fn insert_emp(
421 db: &WasmDbmsDatabase<'_, HeapMemoryProvider>,
422 id: u32,
423 name: &str,
424 dept_id: u32,
425 ) {
426 let insert = EmployeeInsertRequest::from_values(&[
427 (Employee::columns()[0], Value::Uint32(Uint32(id))),
428 (Employee::columns()[1], Value::Text(Text(name.to_string()))),
429 (Employee::columns()[2], Value::Uint32(Uint32(dept_id))),
430 ])
431 .unwrap();
432 db.insert::<Employee>(insert).unwrap();
433 }
434
435 fn insert_indexed_dept(db: &WasmDbmsDatabase<'_, HeapMemoryProvider>, id: u32, name: &str) {
436 let insert = IndexedDepartmentInsertRequest::from_values(&[
437 (IndexedDepartment::columns()[0], Value::Uint32(Uint32(id))),
438 (
439 IndexedDepartment::columns()[1],
440 Value::Text(Text(name.to_string())),
441 ),
442 ])
443 .unwrap();
444 db.insert::<IndexedDepartment>(insert).unwrap();
445 }
446
447 fn insert_indexed_emp(
448 db: &WasmDbmsDatabase<'_, HeapMemoryProvider>,
449 id: u32,
450 name: &str,
451 dept_id: u32,
452 ) {
453 let insert = IndexedEmployeeInsertRequest::from_values(&[
454 (IndexedEmployee::columns()[0], Value::Uint32(Uint32(id))),
455 (
456 IndexedEmployee::columns()[1],
457 Value::Text(Text(name.to_string())),
458 ),
459 (
460 IndexedEmployee::columns()[2],
461 Value::Uint32(Uint32(dept_id)),
462 ),
463 ])
464 .unwrap();
465 db.insert::<IndexedEmployee>(insert).unwrap();
466 }
467
468 #[test]
469 fn test_inner_join() {
470 let ctx = setup();
471 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
472 insert_dept(&db, 1, "eng");
473 insert_dept(&db, 2, "hr");
474 insert_emp(&db, 10, "alice", 1);
475 insert_emp(&db, 11, "bob", 1);
476
477 let query = Query::builder()
478 .all()
479 .inner_join("employees", "id", "dept_id")
480 .build();
481 let results = db.select_join("departments", query).unwrap();
482 assert_eq!(results.len(), 2);
484 }
485
486 #[test]
487 fn test_left_join() {
488 let ctx = setup();
489 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
490 insert_dept(&db, 1, "eng");
491 insert_dept(&db, 2, "hr");
492 insert_emp(&db, 10, "alice", 1);
493
494 let query = Query::builder()
495 .all()
496 .left_join("employees", "id", "dept_id")
497 .build();
498 let results = db.select_join("departments", query).unwrap();
499 assert_eq!(results.len(), 2);
501
502 let hr_row = results
504 .iter()
505 .find(|row| {
506 row.iter().any(|(col, val)| {
507 col.name == "name"
508 && col.table.as_deref() == Some("departments")
509 && *val == Value::Text(Text("hr".to_string()))
510 })
511 })
512 .expect("hr should be in results");
513
514 let emp_name = hr_row
516 .iter()
517 .find(|(col, _)| col.name == "name" && col.table.as_deref() == Some("employees"))
518 .expect("employee name column should exist for hr");
519 assert_eq!(emp_name.1, Value::Null);
520 }
521
522 #[test]
523 fn test_right_join() {
524 let ctx = setup();
525 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
526 insert_dept(&db, 1, "eng");
527 insert_emp(&db, 10, "alice", 1);
528 insert_emp(&db, 11, "charlie", 999);
530
531 let query = Query::builder()
532 .all()
533 .right_join("employees", "id", "dept_id")
534 .build();
535 let results = db.select_join("departments", query).unwrap();
536 assert_eq!(results.len(), 2);
538
539 let charlie_row = results
541 .iter()
542 .find(|row| {
543 row.iter().any(|(col, val)| {
544 col.name == "name"
545 && col.table.as_deref() == Some("employees")
546 && *val == Value::Text(Text("charlie".to_string()))
547 })
548 })
549 .expect("charlie should be in results");
550
551 let dept_name = charlie_row
552 .iter()
553 .find(|(col, _)| col.name == "name" && col.table.as_deref() == Some("departments"))
554 .expect("department name column should exist for charlie");
555 assert_eq!(dept_name.1, Value::Null);
556 }
557
558 #[test]
559 fn test_full_join() {
560 let ctx = setup();
561 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
562 insert_dept(&db, 1, "eng");
563 insert_dept(&db, 2, "hr");
564 insert_emp(&db, 10, "alice", 1);
565 insert_emp(&db, 11, "charlie", 999);
567
568 let query = Query::builder()
569 .all()
570 .full_join("employees", "id", "dept_id")
571 .build();
572 let results = db.select_join("departments", query).unwrap();
573 assert_eq!(results.len(), 3);
575 }
576
577 #[test]
578 fn test_join_with_filter() {
579 let ctx = setup();
580 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
581 insert_dept(&db, 1, "eng");
582 insert_dept(&db, 2, "hr");
583 insert_emp(&db, 10, "alice", 1);
584 insert_emp(&db, 11, "bob", 2);
585
586 let query = Query::builder()
587 .all()
588 .inner_join("employees", "id", "dept_id")
589 .and_where(Filter::eq(
590 "departments.name",
591 Value::Text(Text("eng".to_string())),
592 ))
593 .build();
594 let results = db.select_join("departments", query).unwrap();
595 assert_eq!(results.len(), 1);
596 }
597
598 #[test]
599 fn test_join_with_order_by() {
600 let ctx = setup();
601 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
602 insert_dept(&db, 1, "eng");
603 insert_dept(&db, 2, "hr");
604 insert_emp(&db, 10, "zzz", 1);
605 insert_emp(&db, 11, "aaa", 2);
606
607 let query = Query::builder()
608 .all()
609 .inner_join("employees", "id", "dept_id")
610 .order_by_asc("employees.name")
611 .build();
612 let results = db.select_join("departments", query).unwrap();
613 assert_eq!(results.len(), 2);
614 let first_name = results[0]
615 .iter()
616 .find(|(col, _)| col.name == "name" && col.table.as_deref() == Some("employees"))
617 .unwrap();
618 assert_eq!(first_name.1, Value::Text(Text("aaa".to_string())));
619 }
620
621 #[test]
622 fn test_join_with_limit() {
623 let ctx = setup();
624 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
625 insert_dept(&db, 1, "eng");
626 insert_dept(&db, 2, "hr");
627 insert_emp(&db, 10, "alice", 1);
628 insert_emp(&db, 11, "bob", 2);
629
630 let query = Query::builder()
631 .all()
632 .inner_join("employees", "id", "dept_id")
633 .limit(1)
634 .build();
635 let results = db.select_join("departments", query).unwrap();
636 assert_eq!(results.len(), 1);
637 }
638
639 #[test]
640 fn test_join_with_offset() {
641 let ctx = setup();
642 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
643 insert_dept(&db, 1, "eng");
644 insert_dept(&db, 2, "hr");
645 insert_emp(&db, 10, "alice", 1);
646 insert_emp(&db, 11, "bob", 2);
647
648 let query = Query::builder()
649 .all()
650 .inner_join("employees", "id", "dept_id")
651 .offset(1)
652 .build();
653 let results = db.select_join("departments", query).unwrap();
654 assert_eq!(results.len(), 1);
655 }
656
657 #[test]
658 fn test_join_with_column_selection() {
659 let ctx = setup();
660 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
661 insert_dept(&db, 1, "eng");
662 insert_emp(&db, 10, "alice", 1);
663
664 let query = Query::builder()
665 .field("departments.name")
666 .field("employees.name")
667 .inner_join("employees", "id", "dept_id")
668 .build();
669 let results = db.select_join("departments", query).unwrap();
670 assert_eq!(results.len(), 1);
671 assert_eq!(results[0].len(), 2);
672 }
673
674 #[test]
675 fn test_inner_join_empty_result() {
676 let ctx = setup();
677 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
678 insert_dept(&db, 1, "eng");
679 let query = Query::builder()
682 .all()
683 .inner_join("employees", "id", "dept_id")
684 .build();
685 let results = db.select_join("departments", query).unwrap();
686 assert!(results.is_empty());
687 }
688
689 #[test]
690 fn test_join_offset_exceeding_results_returns_empty() {
691 let ctx = setup();
692 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
693 insert_dept(&db, 1, "eng");
694 insert_emp(&db, 10, "alice", 1);
695
696 let query = Query::builder()
697 .all()
698 .inner_join("employees", "id", "dept_id")
699 .offset(100)
700 .build();
701 let results = db.select_join("departments", query).unwrap();
702 assert!(results.is_empty());
703 }
704
705 #[test]
706 fn test_join_on_indexed_column() {
707 let ctx = setup_indexed();
708 let db = WasmDbmsDatabase::oneshot(&ctx, IndexedJoinSchema);
709 insert_indexed_dept(&db, 1, "eng");
710 insert_indexed_dept(&db, 2, "hr");
711 insert_indexed_emp(&db, 10, "alice", 1);
712 insert_indexed_emp(&db, 11, "bob", 2);
713
714 let query = Query::builder()
715 .all()
716 .inner_join(
717 "indexed_employees",
718 "indexed_departments.id",
719 "indexed_employees.dept_id",
720 )
721 .build();
722 let results = db.select_join("indexed_departments", query).unwrap();
723
724 assert_eq!(results.len(), 2);
725 assert!(results.iter().any(|row| {
726 row.iter().any(|(column, value)| {
727 column.name == "name"
728 && column.table.as_deref() == Some("indexed_employees")
729 && *value == Value::Text(Text("alice".to_string()))
730 })
731 }));
732 }
733}