1use std::collections::HashMap;
4
5use crate::errors::MdqlError;
6use crate::model::{Row, Value};
7use crate::query_ast::*;
8use crate::schema::Schema;
9
10pub fn execute_join_query(
11 query: &SelectQuery,
12 tables: &HashMap<String, (Schema, Vec<Row>)>,
13) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
14 if query.joins.is_empty() {
15 return Err(MdqlError::QueryExecution("No JOIN clause in query".into()));
16 }
17
18 let left_name = &query.table;
19 let left_alias = query.table_alias.as_deref().unwrap_or(left_name);
20
21 let mut aliases: HashMap<String, String> = HashMap::new();
22 aliases.insert(left_name.clone(), left_name.clone());
23 if let Some(ref a) = query.table_alias {
24 aliases.insert(a.clone(), left_name.clone());
25 }
26 for join in &query.joins {
27 aliases.insert(join.table.clone(), join.table.clone());
28 if let Some(ref a) = join.alias {
29 aliases.insert(a.clone(), join.table.clone());
30 }
31 }
32
33 let (_left_schema, left_rows) = tables.get(left_name.as_str()).ok_or_else(|| {
34 MdqlError::QueryExecution(format!("Unknown table '{}'", left_name))
35 })?;
36
37 let mut current_rows: Vec<Row> = left_rows
38 .iter()
39 .map(|r| {
40 let mut prefixed = Row::new();
41 for (k, v) in r {
42 prefixed.insert(format!("{}.{}", left_alias, k), v.clone());
43 }
44 prefixed
45 })
46 .collect();
47
48 for join in &query.joins {
49 let right_name = &join.table;
50 let right_alias = join.alias.as_deref().unwrap_or(right_name);
51
52 let (_right_schema, right_rows) = tables.get(right_name.as_str()).ok_or_else(|| {
53 MdqlError::QueryExecution(format!("Unknown table '{}'", right_name))
54 })?;
55
56 let (on_left_table, on_left_col) = resolve_dotted(&join.left_col, &aliases);
57 let (on_right_table, on_right_col) = resolve_dotted(&join.right_col, &aliases);
58
59 let (left_key, right_key) = if on_right_table == *right_name {
60 let left_alias_for_col = reverse_alias(&on_left_table, &aliases, query, &query.joins);
61 (format!("{}.{}", left_alias_for_col, on_left_col), on_right_col)
62 } else {
63 let right_alias_for_col = reverse_alias(&on_right_table, &aliases, query, &query.joins);
64 (format!("{}.{}", right_alias_for_col, on_right_col), on_left_col)
65 };
66
67 let mut right_index: HashMap<String, Vec<&Row>> = HashMap::new();
68 for r in right_rows {
69 if let Some(key) = r.get(&right_key) {
70 let key_str = key.to_display_string();
71 right_index.entry(key_str).or_default().push(r);
72 }
73 }
74
75 let mut next_rows: Vec<Row> = Vec::new();
76 for lr in ¤t_rows {
77 if let Some(key) = lr.get(&left_key) {
78 let key_str = key.to_display_string();
79 if let Some(matching) = right_index.get(&key_str) {
80 for rr in matching {
81 let mut merged = lr.clone();
82 for (k, v) in *rr {
83 merged.insert(format!("{}.{}", right_alias, k), v.clone());
84 }
85 next_rows.push(merged);
86 }
87 }
88 }
89 }
90 current_rows = next_rows;
91 }
92
93 let (mut result, columns) = super::query_engine::execute_inner(query, ¤t_rows, None)?;
94
95 if !result.is_empty() {
96 let mut base_counts: HashMap<String, usize> = HashMap::new();
97 for key in &columns {
98 if let Some((_prefix, base)) = key.split_once('.') {
99 *base_counts.entry(base.to_string()).or_default() += 1;
100 }
101 }
102 let unique_bases: Vec<String> = base_counts
103 .into_iter()
104 .filter(|(_, count)| *count == 1)
105 .map(|(base, _)| base)
106 .collect();
107
108 if !unique_bases.is_empty() {
109 let unique_set: std::collections::HashSet<&str> =
110 unique_bases.iter().map(|s| s.as_str()).collect();
111 for row in &mut result {
112 let additions: Vec<(String, Value)> = row
113 .iter()
114 .filter_map(|(k, v)| {
115 k.split_once('.').and_then(|(_, base)| {
116 if unique_set.contains(base) {
117 Some((base.to_string(), v.clone()))
118 } else {
119 None
120 }
121 })
122 })
123 .collect();
124 for (k, v) in additions {
125 row.insert(k, v);
126 }
127 }
128 }
129 }
130
131 Ok((result, columns))
132}
133
134fn reverse_alias(
135 table_name: &str,
136 aliases: &HashMap<String, String>,
137 query: &SelectQuery,
138 joins: &[JoinClause],
139) -> String {
140 if query.table == table_name {
141 return query.table_alias.as_deref().unwrap_or(&query.table).to_string();
142 }
143 for j in joins {
144 if j.table == table_name {
145 return j.alias.as_deref().unwrap_or(&j.table).to_string();
146 }
147 }
148 if aliases.contains_key(table_name) {
149 return table_name.to_string();
150 }
151 table_name.to_string()
152}
153
154fn resolve_dotted(col: &str, aliases: &HashMap<String, String>) -> (String, String) {
155 if let Some((alias, column)) = col.split_once('.') {
156 let table = aliases.get(alias).cloned().unwrap_or_else(|| alias.to_string());
157 (table, column.to_string())
158 } else {
159 (String::new(), col.to_string())
160 }
161}