1use wasm_dbms_api::prelude::{
7 CandidColumnDef, ColumnDef, DbmsResult, JoinType, OrderDirection, Query, Value,
8};
9use wasm_dbms_memory::prelude::{AccessControl, AccessControlList, MemoryProvider};
10
11use crate::database::WasmDbmsDatabase;
12use crate::schema::DatabaseSchema;
13
14type JoinedRow = Vec<(String, Vec<(ColumnDef, Value)>)>;
16
17pub struct JoinEngine<'a, Schema: ?Sized, M, A = AccessControlList>
19where
20 Schema: DatabaseSchema<M, A>,
21 M: MemoryProvider,
22 A: AccessControl,
23{
24 schema: &'a Schema,
25 _marker: std::marker::PhantomData<(M, A)>,
26}
27
28impl<'a, Schema: ?Sized, M, A> JoinEngine<'a, Schema, M, A>
29where
30 Schema: DatabaseSchema<M, A>,
31 M: MemoryProvider,
32 A: AccessControl,
33{
34 pub fn new(schema: &'a Schema) -> Self {
35 Self {
36 schema,
37 _marker: std::marker::PhantomData,
38 }
39 }
40}
41
42impl<Schema: ?Sized, M, A> JoinEngine<'_, Schema, M, A>
43where
44 Schema: DatabaseSchema<M, A>,
45 M: MemoryProvider,
46 A: AccessControl,
47{
48 pub fn join(
50 &self,
51 dbms: &WasmDbmsDatabase<'_, M, A>,
52 from_table: &str,
53 query: Query,
54 ) -> DbmsResult<Vec<Vec<(CandidColumnDef, Value)>>> {
55 let from_rows = self
56 .schema
57 .select(dbms, from_table, Query::builder().all().build())?;
58
59 let mut joined_rows: Vec<JoinedRow> = from_rows
60 .into_iter()
61 .map(|row| vec![(from_table.to_string(), row)])
62 .collect();
63
64 for join in &query.joins {
65 let right_rows =
66 self.schema
67 .select(dbms, &join.table, Query::builder().all().build())?;
68
69 let (left_table, left_col) = self.resolve_column_ref(&join.left_column, from_table);
70 let (_right_table_ref, right_col) =
71 self.resolve_column_ref(&join.right_column, &join.table);
72
73 let (keep_unmatched_left, keep_unmatched_right) = match join.join_type {
74 JoinType::Inner => (false, false),
75 JoinType::Left => (true, false),
76 JoinType::Right => (false, true),
77 JoinType::Full => (true, true),
78 };
79
80 joined_rows = self.nested_loop_join(
81 joined_rows,
82 &right_rows,
83 &join.table,
84 &left_table,
85 left_col,
86 right_col,
87 keep_unmatched_left,
88 keep_unmatched_right,
89 );
90 }
91
92 if let Some(filter) = &query.filter {
93 joined_rows.retain(|row| {
94 let groups: Vec<(&str, Vec<(ColumnDef, Value)>)> = row
95 .iter()
96 .map(|(t, cols)| (t.as_str(), cols.clone()))
97 .collect();
98 filter.matches_joined_row(&groups).unwrap_or(false)
99 });
100 }
101
102 for (column, direction) in query.order_by.iter().rev() {
103 self.sort_joined_rows(&mut joined_rows, column, *direction);
104 }
105
106 let offset = query.offset.unwrap_or_default();
107 if offset > 0 {
108 if offset >= joined_rows.len() {
109 joined_rows.clear();
110 } else {
111 joined_rows = joined_rows.into_iter().skip(offset).collect();
112 }
113 }
114
115 if let Some(limit) = query.limit {
116 joined_rows.truncate(limit);
117 }
118
119 let results = joined_rows
120 .into_iter()
121 .map(|row| self.flatten_joined_row(row, &query))
122 .collect::<DbmsResult<Vec<_>>>()?;
123
124 Ok(results)
125 }
126
127 #[allow(clippy::too_many_arguments)]
129 fn nested_loop_join(
130 &self,
131 left_rows: Vec<JoinedRow>,
132 right_rows: &[Vec<(ColumnDef, Value)>],
133 right_table: &str,
134 left_table: &str,
135 left_col: &str,
136 right_col: &str,
137 keep_unmatched_left: bool,
138 keep_unmatched_right: bool,
139 ) -> Vec<JoinedRow> {
140 let mut results = Vec::new();
141 let mut right_matched = vec![false; right_rows.len()];
142
143 for left_row in &left_rows {
144 let left_value = self.get_column_value(left_row, left_table, left_col);
145 let mut matched = false;
146
147 for (i, right_row) in right_rows.iter().enumerate() {
148 let right_value = right_row
149 .iter()
150 .find(|(c, _)| c.name == right_col)
151 .map(|(_, v)| v);
152
153 if left_value == right_value && left_value.is_some() {
154 let mut new_row = left_row.clone();
155 new_row.push((right_table.to_string(), right_row.clone()));
156 results.push(new_row);
157 right_matched[i] = true;
158 matched = true;
159 }
160 }
161
162 if keep_unmatched_left && !matched {
163 let mut new_row = left_row.clone();
164 let null_cols = right_rows
165 .first()
166 .map(|sample| self.null_pad_columns(sample))
167 .unwrap_or_default();
168 new_row.push((right_table.to_string(), null_cols));
169 results.push(new_row);
170 }
171 }
172
173 if keep_unmatched_right {
174 for (i, right_row) in right_rows.iter().enumerate() {
175 if !right_matched[i] {
176 let mut new_row: JoinedRow = Vec::new();
177 if let Some(sample_left) = left_rows.first() {
178 for (table_name, cols) in sample_left {
179 new_row.push((table_name.clone(), self.null_pad_columns(cols)));
180 }
181 }
182 new_row.push((right_table.to_string(), right_row.clone()));
183 results.push(new_row);
184 }
185 }
186 }
187
188 results
189 }
190
191 fn resolve_column_ref<'a>(&self, field: &'a str, default_table: &'a str) -> (String, &'a str) {
193 if let Some((table, column)) = field.split_once('.') {
194 (table.to_string(), column)
195 } else {
196 (default_table.to_string(), field)
197 }
198 }
199
200 fn get_column_value<'a>(
202 &self,
203 row: &'a JoinedRow,
204 table: &str,
205 column: &str,
206 ) -> Option<&'a Value> {
207 row.iter()
208 .find(|(t, _)| t == table)
209 .and_then(|(_, cols)| cols.iter().find(|(c, _)| c.name == column).map(|(_, v)| v))
210 }
211
212 fn null_pad_columns(&self, sample_row: &[(ColumnDef, Value)]) -> Vec<(ColumnDef, Value)> {
214 sample_row
215 .iter()
216 .map(|(col, _)| (*col, Value::Null))
217 .collect()
218 }
219
220 fn sort_joined_rows(&self, rows: &mut [JoinedRow], column: &str, direction: OrderDirection) {
222 let (table, col) = if let Some((t, c)) = column.split_once('.') {
223 (Some(t), c)
224 } else {
225 (None, column)
226 };
227
228 rows.sort_by(|a, b| {
229 let a_val = self.find_value_in_joined_row(a, table, col);
230 let b_val = self.find_value_in_joined_row(b, table, col);
231
232 crate::database::sort_values_with_direction(a_val, b_val, direction)
233 });
234 }
235
236 fn find_value_in_joined_row<'a>(
238 &self,
239 row: &'a JoinedRow,
240 table: Option<&str>,
241 column: &str,
242 ) -> Option<&'a Value> {
243 if let Some(table) = table {
244 return self.get_column_value(row, table, column);
245 }
246 row.iter()
247 .flat_map(|(_, cols)| cols)
248 .find_map(|(col, value)| {
249 if col.name == column {
250 Some(value)
251 } else {
252 None
253 }
254 })
255 }
256
257 fn flatten_joined_row(
259 &self,
260 row: JoinedRow,
261 query: &Query,
262 ) -> DbmsResult<Vec<(CandidColumnDef, Value)>> {
263 let mut result = Vec::new();
264
265 for (table_name, cols) in row {
266 for (col, val) in cols {
267 let mut candid_col = CandidColumnDef::from(col);
268 candid_col.table = Some(table_name.clone());
269
270 if !query.all_selected() {
271 let selected = query.raw_columns();
272 let qualified_name = format!("{table_name}.{col}", col = candid_col.name);
273 if !selected.contains(&candid_col.name) && !selected.contains(&qualified_name) {
274 continue;
275 }
276 }
277
278 result.push((candid_col, val));
279 }
280 }
281
282 Ok(result)
283 }
284}
285
286#[cfg(test)]
287mod tests {
288
289 use wasm_dbms_api::prelude::{
290 Database as _, Filter, InsertRecord as _, Query, TableSchema as _, Text, Uint32, Value,
291 };
292 use wasm_dbms_macros::{DatabaseSchema, Table};
293 use wasm_dbms_memory::prelude::HeapMemoryProvider;
294
295 use crate::prelude::{DbmsContext, WasmDbmsDatabase};
296
297 #[derive(Debug, Table, Clone, PartialEq, Eq)]
301 #[table = "departments"]
302 pub struct Department {
303 #[primary_key]
304 pub id: Uint32,
305 pub name: Text,
306 }
307
308 #[derive(Debug, Table, Clone, PartialEq, Eq)]
309 #[table = "employees"]
310 pub struct Employee {
311 #[primary_key]
312 pub id: Uint32,
313 pub name: Text,
314 pub dept_id: Uint32,
315 }
316
317 #[derive(DatabaseSchema)]
318 #[tables(Department = "departments", Employee = "employees")]
319 pub struct TestSchema;
320
321 fn setup() -> DbmsContext<HeapMemoryProvider> {
322 let ctx = DbmsContext::new(HeapMemoryProvider::default());
323 TestSchema::register_tables(&ctx).unwrap();
324 ctx
325 }
326
327 fn insert_dept(db: &WasmDbmsDatabase<'_, HeapMemoryProvider>, id: u32, name: &str) {
328 let insert = DepartmentInsertRequest::from_values(&[
329 (Department::columns()[0], Value::Uint32(Uint32(id))),
330 (
331 Department::columns()[1],
332 Value::Text(Text(name.to_string())),
333 ),
334 ])
335 .unwrap();
336 db.insert::<Department>(insert).unwrap();
337 }
338
339 fn insert_emp(
340 db: &WasmDbmsDatabase<'_, HeapMemoryProvider>,
341 id: u32,
342 name: &str,
343 dept_id: u32,
344 ) {
345 let insert = EmployeeInsertRequest::from_values(&[
346 (Employee::columns()[0], Value::Uint32(Uint32(id))),
347 (Employee::columns()[1], Value::Text(Text(name.to_string()))),
348 (Employee::columns()[2], Value::Uint32(Uint32(dept_id))),
349 ])
350 .unwrap();
351 db.insert::<Employee>(insert).unwrap();
352 }
353
354 #[test]
355 fn test_inner_join() {
356 let ctx = setup();
357 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
358 insert_dept(&db, 1, "eng");
359 insert_dept(&db, 2, "hr");
360 insert_emp(&db, 10, "alice", 1);
361 insert_emp(&db, 11, "bob", 1);
362
363 let query = Query::builder()
364 .all()
365 .inner_join("employees", "id", "dept_id")
366 .build();
367 let results = db.select_join("departments", query).unwrap();
368 assert_eq!(results.len(), 2);
370 }
371
372 #[test]
373 fn test_left_join() {
374 let ctx = setup();
375 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
376 insert_dept(&db, 1, "eng");
377 insert_dept(&db, 2, "hr");
378 insert_emp(&db, 10, "alice", 1);
379
380 let query = Query::builder()
381 .all()
382 .left_join("employees", "id", "dept_id")
383 .build();
384 let results = db.select_join("departments", query).unwrap();
385 assert_eq!(results.len(), 2);
387
388 let hr_row = results
390 .iter()
391 .find(|row| {
392 row.iter().any(|(col, val)| {
393 col.name == "name"
394 && col.table.as_deref() == Some("departments")
395 && *val == Value::Text(Text("hr".to_string()))
396 })
397 })
398 .expect("hr should be in results");
399
400 let emp_name = hr_row
402 .iter()
403 .find(|(col, _)| col.name == "name" && col.table.as_deref() == Some("employees"))
404 .expect("employee name column should exist for hr");
405 assert_eq!(emp_name.1, Value::Null);
406 }
407
408 #[test]
409 fn test_right_join() {
410 let ctx = setup();
411 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
412 insert_dept(&db, 1, "eng");
413 insert_emp(&db, 10, "alice", 1);
414 insert_emp(&db, 11, "charlie", 999);
416
417 let query = Query::builder()
418 .all()
419 .right_join("employees", "id", "dept_id")
420 .build();
421 let results = db.select_join("departments", query).unwrap();
422 assert_eq!(results.len(), 2);
424
425 let charlie_row = results
427 .iter()
428 .find(|row| {
429 row.iter().any(|(col, val)| {
430 col.name == "name"
431 && col.table.as_deref() == Some("employees")
432 && *val == Value::Text(Text("charlie".to_string()))
433 })
434 })
435 .expect("charlie should be in results");
436
437 let dept_name = charlie_row
438 .iter()
439 .find(|(col, _)| col.name == "name" && col.table.as_deref() == Some("departments"))
440 .expect("department name column should exist for charlie");
441 assert_eq!(dept_name.1, Value::Null);
442 }
443
444 #[test]
445 fn test_full_join() {
446 let ctx = setup();
447 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
448 insert_dept(&db, 1, "eng");
449 insert_dept(&db, 2, "hr");
450 insert_emp(&db, 10, "alice", 1);
451 insert_emp(&db, 11, "charlie", 999);
453
454 let query = Query::builder()
455 .all()
456 .full_join("employees", "id", "dept_id")
457 .build();
458 let results = db.select_join("departments", query).unwrap();
459 assert_eq!(results.len(), 3);
461 }
462
463 #[test]
464 fn test_join_with_filter() {
465 let ctx = setup();
466 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
467 insert_dept(&db, 1, "eng");
468 insert_dept(&db, 2, "hr");
469 insert_emp(&db, 10, "alice", 1);
470 insert_emp(&db, 11, "bob", 2);
471
472 let query = Query::builder()
473 .all()
474 .inner_join("employees", "id", "dept_id")
475 .and_where(Filter::eq(
476 "departments.name",
477 Value::Text(Text("eng".to_string())),
478 ))
479 .build();
480 let results = db.select_join("departments", query).unwrap();
481 assert_eq!(results.len(), 1);
482 }
483
484 #[test]
485 fn test_join_with_order_by() {
486 let ctx = setup();
487 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
488 insert_dept(&db, 1, "eng");
489 insert_dept(&db, 2, "hr");
490 insert_emp(&db, 10, "zzz", 1);
491 insert_emp(&db, 11, "aaa", 2);
492
493 let query = Query::builder()
494 .all()
495 .inner_join("employees", "id", "dept_id")
496 .order_by_asc("employees.name")
497 .build();
498 let results = db.select_join("departments", query).unwrap();
499 assert_eq!(results.len(), 2);
500 let first_name = results[0]
501 .iter()
502 .find(|(col, _)| col.name == "name" && col.table.as_deref() == Some("employees"))
503 .unwrap();
504 assert_eq!(first_name.1, Value::Text(Text("aaa".to_string())));
505 }
506
507 #[test]
508 fn test_join_with_limit() {
509 let ctx = setup();
510 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
511 insert_dept(&db, 1, "eng");
512 insert_dept(&db, 2, "hr");
513 insert_emp(&db, 10, "alice", 1);
514 insert_emp(&db, 11, "bob", 2);
515
516 let query = Query::builder()
517 .all()
518 .inner_join("employees", "id", "dept_id")
519 .limit(1)
520 .build();
521 let results = db.select_join("departments", query).unwrap();
522 assert_eq!(results.len(), 1);
523 }
524
525 #[test]
526 fn test_join_with_offset() {
527 let ctx = setup();
528 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
529 insert_dept(&db, 1, "eng");
530 insert_dept(&db, 2, "hr");
531 insert_emp(&db, 10, "alice", 1);
532 insert_emp(&db, 11, "bob", 2);
533
534 let query = Query::builder()
535 .all()
536 .inner_join("employees", "id", "dept_id")
537 .offset(1)
538 .build();
539 let results = db.select_join("departments", query).unwrap();
540 assert_eq!(results.len(), 1);
541 }
542
543 #[test]
544 fn test_join_with_column_selection() {
545 let ctx = setup();
546 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
547 insert_dept(&db, 1, "eng");
548 insert_emp(&db, 10, "alice", 1);
549
550 let query = Query::builder()
551 .field("departments.name")
552 .field("employees.name")
553 .inner_join("employees", "id", "dept_id")
554 .build();
555 let results = db.select_join("departments", query).unwrap();
556 assert_eq!(results.len(), 1);
557 assert_eq!(results[0].len(), 2);
558 }
559
560 #[test]
561 fn test_inner_join_empty_result() {
562 let ctx = setup();
563 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
564 insert_dept(&db, 1, "eng");
565 let query = Query::builder()
568 .all()
569 .inner_join("employees", "id", "dept_id")
570 .build();
571 let results = db.select_join("departments", query).unwrap();
572 assert!(results.is_empty());
573 }
574
575 #[test]
576 fn test_join_offset_exceeding_results_returns_empty() {
577 let ctx = setup();
578 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
579 insert_dept(&db, 1, "eng");
580 insert_emp(&db, 10, "alice", 1);
581
582 let query = Query::builder()
583 .all()
584 .inner_join("employees", "id", "dept_id")
585 .offset(100)
586 .build();
587 let results = db.select_join("departments", query).unwrap();
588 assert!(results.is_empty());
589 }
590}