Skip to main content

mdql_core/
query_join.rs

1//! JOIN query execution for multi-table queries.
2
3use std::collections::HashMap;
4
5use crate::errors::MdqlError;
6use crate::model::{Row, Value};
7use crate::query_ast::*;
8use crate::query_engine::evaluate;
9use crate::schema::Schema;
10
11pub fn execute_join_query(
12    query: &SelectQuery,
13    tables: &HashMap<String, (Schema, Vec<Row>)>,
14) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
15    if query.joins.is_empty() {
16        return Err(MdqlError::QueryExecution("No JOIN clause in query".into()));
17    }
18
19    let left_name = &query.table;
20    let left_alias = query.table_alias.as_deref().unwrap_or(left_name);
21
22    let (_left_schema, left_rows) = tables.get(left_name.as_str()).ok_or_else(|| {
23        MdqlError::QueryExecution(format!("Unknown table '{}'", left_name))
24    })?;
25
26    let mut current_rows: Vec<Row> = left_rows
27        .iter()
28        .map(|r| {
29            let mut prefixed = Row::new();
30            for (k, v) in r {
31                prefixed.insert(format!("{}.{}", left_alias, k), v.clone());
32            }
33            prefixed
34        })
35        .collect();
36
37    for join in &query.joins {
38        let right_name = &join.table;
39        let right_alias = join.alias.as_deref().unwrap_or(right_name);
40
41        let (_right_schema, right_rows) = tables.get(right_name.as_str()).ok_or_else(|| {
42            MdqlError::QueryExecution(format!("Unknown table '{}'", right_name))
43        })?;
44
45        let right_columns: Vec<String> = right_rows
46            .first()
47            .map(|r| r.keys().cloned().collect())
48            .unwrap_or_default();
49
50        let mut next_rows: Vec<Row> = Vec::new();
51        for lr in &current_rows {
52            let mut matched = false;
53            for rr in right_rows {
54                let mut merged = lr.clone();
55                for (k, v) in rr {
56                    merged.insert(format!("{}.{}", right_alias, k), v.clone());
57                }
58                if evaluate(&join.condition, &merged) {
59                    next_rows.push(merged);
60                    matched = true;
61                }
62            }
63            if !matched && join.join_type == JoinType::Left {
64                let mut merged = lr.clone();
65                for col in &right_columns {
66                    merged.insert(format!("{}.{}", right_alias, col), Value::Null);
67                }
68                next_rows.push(merged);
69            }
70        }
71        current_rows = next_rows;
72    }
73
74    let (mut result, columns) = super::query_engine::execute_inner(query, &current_rows, None)?;
75
76    if !result.is_empty() {
77        let mut base_counts: HashMap<String, usize> = HashMap::new();
78        for key in &columns {
79            if let Some((_prefix, base)) = key.split_once('.') {
80                *base_counts.entry(base.to_string()).or_default() += 1;
81            }
82        }
83        let unique_bases: Vec<String> = base_counts
84            .into_iter()
85            .filter(|(_, count)| *count == 1)
86            .map(|(base, _)| base)
87            .collect();
88
89        if !unique_bases.is_empty() {
90            let unique_set: std::collections::HashSet<&str> =
91                unique_bases.iter().map(|s| s.as_str()).collect();
92            for row in &mut result {
93                let additions: Vec<(String, Value)> = row
94                    .iter()
95                    .filter_map(|(k, v)| {
96                        k.split_once('.').and_then(|(_, base)| {
97                            if unique_set.contains(base) {
98                                Some((base.to_string(), v.clone()))
99                            } else {
100                                None
101                            }
102                        })
103                    })
104                    .collect();
105                for (k, v) in additions {
106                    row.insert(k, v);
107                }
108            }
109        }
110    }
111
112    Ok((result, columns))
113}