Skip to main content

mdql_core/
query_join.rs

1//! JOIN query execution for multi-table queries.
2
3use std::collections::HashMap;
4
5use crate::errors::MdqlError;
6use crate::model::{Row, Value};
7use crate::query_ast::*;
8use crate::schema::Schema;
9
10pub fn execute_join_query(
11    query: &SelectQuery,
12    tables: &HashMap<String, (Schema, Vec<Row>)>,
13) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
14    if query.joins.is_empty() {
15        return Err(MdqlError::QueryExecution("No JOIN clause in query".into()));
16    }
17
18    let left_name = &query.table;
19    let left_alias = query.table_alias.as_deref().unwrap_or(left_name);
20
21    let mut aliases: HashMap<String, String> = HashMap::new();
22    aliases.insert(left_name.clone(), left_name.clone());
23    if let Some(ref a) = query.table_alias {
24        aliases.insert(a.clone(), left_name.clone());
25    }
26    for join in &query.joins {
27        aliases.insert(join.table.clone(), join.table.clone());
28        if let Some(ref a) = join.alias {
29            aliases.insert(a.clone(), join.table.clone());
30        }
31    }
32
33    let (_left_schema, left_rows) = tables.get(left_name.as_str()).ok_or_else(|| {
34        MdqlError::QueryExecution(format!("Unknown table '{}'", left_name))
35    })?;
36
37    let mut current_rows: Vec<Row> = left_rows
38        .iter()
39        .map(|r| {
40            let mut prefixed = Row::new();
41            for (k, v) in r {
42                prefixed.insert(format!("{}.{}", left_alias, k), v.clone());
43            }
44            prefixed
45        })
46        .collect();
47
48    for join in &query.joins {
49        let right_name = &join.table;
50        let right_alias = join.alias.as_deref().unwrap_or(right_name);
51
52        let (_right_schema, right_rows) = tables.get(right_name.as_str()).ok_or_else(|| {
53            MdqlError::QueryExecution(format!("Unknown table '{}'", right_name))
54        })?;
55
56        let (on_left_table, on_left_col) = resolve_dotted(&join.left_col, &aliases);
57        let (on_right_table, on_right_col) = resolve_dotted(&join.right_col, &aliases);
58
59        let (left_key, right_key) = if on_right_table == *right_name {
60            let left_alias_for_col = reverse_alias(&on_left_table, &aliases, query, &query.joins);
61            (format!("{}.{}", left_alias_for_col, on_left_col), on_right_col)
62        } else {
63            let right_alias_for_col = reverse_alias(&on_right_table, &aliases, query, &query.joins);
64            (format!("{}.{}", right_alias_for_col, on_right_col), on_left_col)
65        };
66
67        let mut right_index: HashMap<String, Vec<&Row>> = HashMap::new();
68        for r in right_rows {
69            if let Some(key) = r.get(&right_key) {
70                let key_str = key.to_display_string();
71                right_index.entry(key_str).or_default().push(r);
72            }
73        }
74
75        let mut next_rows: Vec<Row> = Vec::new();
76        for lr in &current_rows {
77            if let Some(key) = lr.get(&left_key) {
78                let key_str = key.to_display_string();
79                if let Some(matching) = right_index.get(&key_str) {
80                    for rr in matching {
81                        let mut merged = lr.clone();
82                        for (k, v) in *rr {
83                            merged.insert(format!("{}.{}", right_alias, k), v.clone());
84                        }
85                        next_rows.push(merged);
86                    }
87                }
88            }
89        }
90        current_rows = next_rows;
91    }
92
93    let (mut result, columns) = super::query_engine::execute_inner(query, &current_rows, None)?;
94
95    if !result.is_empty() {
96        let mut base_counts: HashMap<String, usize> = HashMap::new();
97        for key in &columns {
98            if let Some((_prefix, base)) = key.split_once('.') {
99                *base_counts.entry(base.to_string()).or_default() += 1;
100            }
101        }
102        let unique_bases: Vec<String> = base_counts
103            .into_iter()
104            .filter(|(_, count)| *count == 1)
105            .map(|(base, _)| base)
106            .collect();
107
108        if !unique_bases.is_empty() {
109            let unique_set: std::collections::HashSet<&str> =
110                unique_bases.iter().map(|s| s.as_str()).collect();
111            for row in &mut result {
112                let additions: Vec<(String, Value)> = row
113                    .iter()
114                    .filter_map(|(k, v)| {
115                        k.split_once('.').and_then(|(_, base)| {
116                            if unique_set.contains(base) {
117                                Some((base.to_string(), v.clone()))
118                            } else {
119                                None
120                            }
121                        })
122                    })
123                    .collect();
124                for (k, v) in additions {
125                    row.insert(k, v);
126                }
127            }
128        }
129    }
130
131    Ok((result, columns))
132}
133
134fn reverse_alias(
135    table_name: &str,
136    aliases: &HashMap<String, String>,
137    query: &SelectQuery,
138    joins: &[JoinClause],
139) -> String {
140    if query.table == table_name {
141        return query.table_alias.as_deref().unwrap_or(&query.table).to_string();
142    }
143    for j in joins {
144        if j.table == table_name {
145            return j.alias.as_deref().unwrap_or(&j.table).to_string();
146        }
147    }
148    if aliases.contains_key(table_name) {
149        return table_name.to_string();
150    }
151    table_name.to_string()
152}
153
154fn resolve_dotted(col: &str, aliases: &HashMap<String, String>) -> (String, String) {
155    if let Some((alias, column)) = col.split_once('.') {
156        let table = aliases.get(alias).cloned().unwrap_or_else(|| alias.to_string());
157        (table, column.to_string())
158    } else {
159        (String::new(), col.to_string())
160    }
161}