Skip to main content

mdql_core/
query_join.rs

1//! JOIN query execution for multi-table queries.
2
3use std::collections::HashMap;
4
5use crate::errors::MdqlError;
6use crate::model::{Row, Value};
7use crate::query_ast::*;
8use crate::schema::Schema;
9
10pub fn execute_join_query(
11    query: &SelectQuery,
12    tables: &HashMap<String, (Schema, Vec<Row>)>,
13) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
14    if query.joins.is_empty() {
15        return Err(MdqlError::QueryExecution("No JOIN clause in query".into()));
16    }
17
18    let left_name = &query.table;
19    let left_alias = query.table_alias.as_deref().unwrap_or(left_name);
20
21    let mut aliases: HashMap<String, String> = HashMap::new();
22    aliases.insert(left_name.clone(), left_name.clone());
23    if let Some(ref a) = query.table_alias {
24        aliases.insert(a.clone(), left_name.clone());
25    }
26    for join in &query.joins {
27        aliases.insert(join.table.clone(), join.table.clone());
28        if let Some(ref a) = join.alias {
29            aliases.insert(a.clone(), join.table.clone());
30        }
31    }
32
33    let (_left_schema, left_rows) = tables.get(left_name.as_str()).ok_or_else(|| {
34        MdqlError::QueryExecution(format!("Unknown table '{}'", left_name))
35    })?;
36
37    let mut current_rows: Vec<Row> = left_rows
38        .iter()
39        .map(|r| {
40            let mut prefixed = Row::new();
41            for (k, v) in r {
42                prefixed.insert(format!("{}.{}", left_alias, k), v.clone());
43            }
44            prefixed
45        })
46        .collect();
47
48    for join in &query.joins {
49        let right_name = &join.table;
50        let right_alias = join.alias.as_deref().unwrap_or(right_name);
51
52        let (_right_schema, right_rows) = tables.get(right_name.as_str()).ok_or_else(|| {
53            MdqlError::QueryExecution(format!("Unknown table '{}'", right_name))
54        })?;
55
56        let (on_left_table, on_left_col) = resolve_dotted(&join.left_col, &aliases);
57        let (on_right_table, on_right_col) = resolve_dotted(&join.right_col, &aliases);
58
59        let (left_key, right_key) = if on_right_table == *right_name {
60            let left_alias_for_col = reverse_alias(&on_left_table, &aliases, query, &query.joins);
61            (format!("{}.{}", left_alias_for_col, on_left_col), on_right_col)
62        } else {
63            let right_alias_for_col = reverse_alias(&on_right_table, &aliases, query, &query.joins);
64            (format!("{}.{}", right_alias_for_col, on_right_col), on_left_col)
65        };
66
67        let mut right_index: HashMap<String, Vec<&Row>> = HashMap::new();
68        for r in right_rows {
69            if let Some(key) = r.get(&right_key) {
70                let key_str = key.to_display_string();
71                right_index.entry(key_str).or_default().push(r);
72            }
73        }
74
75        let right_columns: Vec<String> = right_rows
76            .first()
77            .map(|r| r.keys().cloned().collect())
78            .unwrap_or_default();
79
80        let mut next_rows: Vec<Row> = Vec::new();
81        for lr in &current_rows {
82            let key_str = lr.get(&left_key).map(|v| v.to_display_string());
83            let matching = key_str.as_deref().and_then(|k| right_index.get(k));
84
85            if let Some(rows) = matching {
86                for rr in rows {
87                    let mut merged = lr.clone();
88                    for (k, v) in *rr {
89                        merged.insert(format!("{}.{}", right_alias, k), v.clone());
90                    }
91                    next_rows.push(merged);
92                }
93            } else if join.join_type == JoinType::Left {
94                let mut merged = lr.clone();
95                for col in &right_columns {
96                    merged.insert(format!("{}.{}", right_alias, col), Value::Null);
97                }
98                next_rows.push(merged);
99            }
100        }
101        current_rows = next_rows;
102    }
103
104    let (mut result, columns) = super::query_engine::execute_inner(query, &current_rows, None)?;
105
106    if !result.is_empty() {
107        let mut base_counts: HashMap<String, usize> = HashMap::new();
108        for key in &columns {
109            if let Some((_prefix, base)) = key.split_once('.') {
110                *base_counts.entry(base.to_string()).or_default() += 1;
111            }
112        }
113        let unique_bases: Vec<String> = base_counts
114            .into_iter()
115            .filter(|(_, count)| *count == 1)
116            .map(|(base, _)| base)
117            .collect();
118
119        if !unique_bases.is_empty() {
120            let unique_set: std::collections::HashSet<&str> =
121                unique_bases.iter().map(|s| s.as_str()).collect();
122            for row in &mut result {
123                let additions: Vec<(String, Value)> = row
124                    .iter()
125                    .filter_map(|(k, v)| {
126                        k.split_once('.').and_then(|(_, base)| {
127                            if unique_set.contains(base) {
128                                Some((base.to_string(), v.clone()))
129                            } else {
130                                None
131                            }
132                        })
133                    })
134                    .collect();
135                for (k, v) in additions {
136                    row.insert(k, v);
137                }
138            }
139        }
140    }
141
142    Ok((result, columns))
143}
144
145fn reverse_alias(
146    table_name: &str,
147    aliases: &HashMap<String, String>,
148    query: &SelectQuery,
149    joins: &[JoinClause],
150) -> String {
151    if query.table == table_name {
152        return query.table_alias.as_deref().unwrap_or(&query.table).to_string();
153    }
154    for j in joins {
155        if j.table == table_name {
156            return j.alias.as_deref().unwrap_or(&j.table).to_string();
157        }
158    }
159    if aliases.contains_key(table_name) {
160        return table_name.to_string();
161    }
162    table_name.to_string()
163}
164
165fn resolve_dotted(col: &str, aliases: &HashMap<String, String>) -> (String, String) {
166    if let Some((alias, column)) = col.split_once('.') {
167        let table = aliases.get(alias).cloned().unwrap_or_else(|| alias.to_string());
168        (table, column.to_string())
169    } else {
170        (String::new(), col.to_string())
171    }
172}