1use std::collections::HashMap;
4
5use crate::errors::MdqlError;
6use crate::model::{Row, Value};
7use crate::query_ast::*;
8use crate::schema::Schema;
9
10pub fn execute_join_query(
11 query: &SelectQuery,
12 tables: &HashMap<String, (Schema, Vec<Row>)>,
13) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
14 if query.joins.is_empty() {
15 return Err(MdqlError::QueryExecution("No JOIN clause in query".into()));
16 }
17
18 let left_name = &query.table;
19 let left_alias = query.table_alias.as_deref().unwrap_or(left_name);
20
21 let mut aliases: HashMap<String, String> = HashMap::new();
22 aliases.insert(left_name.clone(), left_name.clone());
23 if let Some(ref a) = query.table_alias {
24 aliases.insert(a.clone(), left_name.clone());
25 }
26 for join in &query.joins {
27 aliases.insert(join.table.clone(), join.table.clone());
28 if let Some(ref a) = join.alias {
29 aliases.insert(a.clone(), join.table.clone());
30 }
31 }
32
33 let (_left_schema, left_rows) = tables.get(left_name.as_str()).ok_or_else(|| {
34 MdqlError::QueryExecution(format!("Unknown table '{}'", left_name))
35 })?;
36
37 let mut current_rows: Vec<Row> = left_rows
38 .iter()
39 .map(|r| {
40 let mut prefixed = Row::new();
41 for (k, v) in r {
42 prefixed.insert(format!("{}.{}", left_alias, k), v.clone());
43 }
44 prefixed
45 })
46 .collect();
47
48 for join in &query.joins {
49 let right_name = &join.table;
50 let right_alias = join.alias.as_deref().unwrap_or(right_name);
51
52 let (_right_schema, right_rows) = tables.get(right_name.as_str()).ok_or_else(|| {
53 MdqlError::QueryExecution(format!("Unknown table '{}'", right_name))
54 })?;
55
56 let (on_left_table, on_left_col) = resolve_dotted(&join.left_col, &aliases);
57 let (on_right_table, on_right_col) = resolve_dotted(&join.right_col, &aliases);
58
59 let (left_key, right_key) = if on_right_table == *right_name {
60 let left_alias_for_col = reverse_alias(&on_left_table, &aliases, query, &query.joins);
61 (format!("{}.{}", left_alias_for_col, on_left_col), on_right_col)
62 } else {
63 let right_alias_for_col = reverse_alias(&on_right_table, &aliases, query, &query.joins);
64 (format!("{}.{}", right_alias_for_col, on_right_col), on_left_col)
65 };
66
67 let mut right_index: HashMap<String, Vec<&Row>> = HashMap::new();
68 for r in right_rows {
69 if let Some(key) = r.get(&right_key) {
70 let key_str = key.to_display_string();
71 right_index.entry(key_str).or_default().push(r);
72 }
73 }
74
75 let right_columns: Vec<String> = right_rows
76 .first()
77 .map(|r| r.keys().cloned().collect())
78 .unwrap_or_default();
79
80 let mut next_rows: Vec<Row> = Vec::new();
81 for lr in ¤t_rows {
82 let key_str = lr.get(&left_key).map(|v| v.to_display_string());
83 let matching = key_str.as_deref().and_then(|k| right_index.get(k));
84
85 if let Some(rows) = matching {
86 for rr in rows {
87 let mut merged = lr.clone();
88 for (k, v) in *rr {
89 merged.insert(format!("{}.{}", right_alias, k), v.clone());
90 }
91 next_rows.push(merged);
92 }
93 } else if join.join_type == JoinType::Left {
94 let mut merged = lr.clone();
95 for col in &right_columns {
96 merged.insert(format!("{}.{}", right_alias, col), Value::Null);
97 }
98 next_rows.push(merged);
99 }
100 }
101 current_rows = next_rows;
102 }
103
104 let (mut result, columns) = super::query_engine::execute_inner(query, ¤t_rows, None)?;
105
106 if !result.is_empty() {
107 let mut base_counts: HashMap<String, usize> = HashMap::new();
108 for key in &columns {
109 if let Some((_prefix, base)) = key.split_once('.') {
110 *base_counts.entry(base.to_string()).or_default() += 1;
111 }
112 }
113 let unique_bases: Vec<String> = base_counts
114 .into_iter()
115 .filter(|(_, count)| *count == 1)
116 .map(|(base, _)| base)
117 .collect();
118
119 if !unique_bases.is_empty() {
120 let unique_set: std::collections::HashSet<&str> =
121 unique_bases.iter().map(|s| s.as_str()).collect();
122 for row in &mut result {
123 let additions: Vec<(String, Value)> = row
124 .iter()
125 .filter_map(|(k, v)| {
126 k.split_once('.').and_then(|(_, base)| {
127 if unique_set.contains(base) {
128 Some((base.to_string(), v.clone()))
129 } else {
130 None
131 }
132 })
133 })
134 .collect();
135 for (k, v) in additions {
136 row.insert(k, v);
137 }
138 }
139 }
140 }
141
142 Ok((result, columns))
143}
144
145fn reverse_alias(
146 table_name: &str,
147 aliases: &HashMap<String, String>,
148 query: &SelectQuery,
149 joins: &[JoinClause],
150) -> String {
151 if query.table == table_name {
152 return query.table_alias.as_deref().unwrap_or(&query.table).to_string();
153 }
154 for j in joins {
155 if j.table == table_name {
156 return j.alias.as_deref().unwrap_or(&j.table).to_string();
157 }
158 }
159 if aliases.contains_key(table_name) {
160 return table_name.to_string();
161 }
162 table_name.to_string()
163}
164
165fn resolve_dotted(col: &str, aliases: &HashMap<String, String>) -> (String, String) {
166 if let Some((alias, column)) = col.split_once('.') {
167 let table = aliases.get(alias).cloned().unwrap_or_else(|| alias.to_string());
168 (table, column.to_string())
169 } else {
170 (String::new(), col.to_string())
171 }
172}