1use crate::error::{Result, TermError};
26use crate::sources::DataSource;
27use crate::telemetry::TermTelemetry;
28use arrow::datatypes::Schema;
29use async_trait::async_trait;
30use datafusion::prelude::*;
31use serde::{Deserialize, Serialize};
32use std::sync::Arc;
33use tracing::{debug, info, instrument};
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
37pub enum JoinType {
38 Inner,
40 Left,
42 Right,
44 Full,
46}
47
48impl JoinType {
49 pub fn to_sql(&self) -> &'static str {
51 match self {
52 JoinType::Inner => "INNER JOIN",
53 JoinType::Left => "LEFT JOIN",
54 JoinType::Right => "RIGHT JOIN",
55 JoinType::Full => "FULL OUTER JOIN",
56 }
57 }
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
62pub struct JoinCondition {
63 pub left_column: String,
65 pub right_column: String,
67 pub join_type: JoinType,
69}
70
71impl JoinCondition {
72 pub fn new(
74 left_column: impl Into<String>,
75 right_column: impl Into<String>,
76 join_type: JoinType,
77 ) -> Self {
78 Self {
79 left_column: left_column.into(),
80 right_column: right_column.into(),
81 join_type,
82 }
83 }
84
85 pub fn to_sql(&self, left_alias: &str, right_alias: &str) -> String {
87 format!(
88 "{} ON {left_alias}.{} = {right_alias}.{}",
89 self.join_type.to_sql(),
90 self.left_column,
91 self.right_column
92 )
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct JoinedSource {
99 left_source: Arc<dyn DataSource>,
101 left_alias: String,
103 right_source: Arc<dyn DataSource>,
105 right_alias: String,
107 join_condition: JoinCondition,
109 where_clause: Option<String>,
111 additional_joins: Vec<AdditionalJoin>,
113}
114
115#[derive(Debug, Clone)]
117struct AdditionalJoin {
118 source: Arc<dyn DataSource>,
119 alias: String,
120 condition: JoinCondition,
121}
122
123impl JoinedSource {
124 pub fn builder() -> JoinedSourceBuilder {
126 JoinedSourceBuilder::new()
127 }
128
129 pub fn left_alias(&self) -> &str {
131 &self.left_alias
132 }
133
134 pub fn right_alias(&self) -> &str {
136 &self.right_alias
137 }
138
139 pub fn join_condition(&self) -> &JoinCondition {
141 &self.join_condition
142 }
143
144 #[instrument(skip(self))]
146 pub fn generate_sql(&self, table_name: &str) -> String {
147 let join_type_sql = self.join_condition.join_type.to_sql();
148
149 let left_col = if self.join_condition.left_column.contains('.') {
153 self.join_condition.left_column.clone()
154 } else {
155 format!("{}.{}", self.left_alias, self.join_condition.left_column)
156 };
157
158 let right_col = if self.join_condition.right_column.contains('.') {
159 self.join_condition.right_column.clone()
160 } else {
161 format!("{}.{}", self.right_alias, self.join_condition.right_column)
162 };
163
164 let on_clause = format!("ON {left_col} = {right_col}");
165
166 let mut sql = format!(
167 "CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM {} {join_type_sql} {} {on_clause}",
168 self.left_alias,
169 self.right_alias
170 );
171
172 for additional in &self.additional_joins {
174 sql.push(' ');
175 sql.push_str(
176 &additional
177 .condition
178 .to_sql(&self.left_alias, &additional.alias),
179 );
180 sql.push(' ');
181 sql.push_str(&additional.alias);
182 }
183
184 if let Some(where_clause) = &self.where_clause {
186 sql.push_str(" WHERE ");
187 sql.push_str(where_clause);
188 }
189
190 debug!("Generated SQL for joined source: {}", sql);
191 sql
192 }
193}
194
195#[async_trait]
196impl DataSource for JoinedSource {
197 #[instrument(skip(self, ctx))]
198 async fn register_with_telemetry(
199 &self,
200 ctx: &SessionContext,
201 table_name: &str,
202 telemetry: Option<&Arc<TermTelemetry>>,
203 ) -> Result<()> {
204 info!("Registering joined source as table: {}", table_name);
205
206 self.left_source
208 .register_with_telemetry(ctx, &self.left_alias, telemetry)
209 .await
210 .map_err(|e| {
211 TermError::data_source(
212 "joined",
213 format!("Failed to register left source '{}': {e}", self.left_alias),
214 )
215 })?;
216
217 self.right_source
218 .register_with_telemetry(ctx, &self.right_alias, telemetry)
219 .await
220 .map_err(|e| {
221 TermError::data_source(
222 "joined",
223 format!(
224 "Failed to register right source '{}': {e}",
225 self.right_alias
226 ),
227 )
228 })?;
229
230 for additional in &self.additional_joins {
232 additional
233 .source
234 .register_with_telemetry(ctx, &additional.alias, telemetry)
235 .await
236 .map_err(|e| {
237 TermError::data_source(
238 "joined",
239 format!(
240 "Failed to register additional source '{}': {e}",
241 additional.alias
242 ),
243 )
244 })?;
245 }
246
247 let sql = self.generate_sql(table_name);
249 ctx.sql(&sql).await.map_err(|e| {
250 TermError::data_source("joined", format!("Failed to create joined view: {e}"))
251 })?;
252
253 info!("Successfully registered joined source: {}", table_name);
254 Ok(())
255 }
256
257 fn schema(&self) -> Option<&Arc<Schema>> {
258 None
261 }
262
263 fn description(&self) -> String {
264 format!(
265 "Joined source: {} {} {} ON {}.{} = {}.{}",
266 self.left_alias,
267 self.join_condition.join_type.to_sql(),
268 self.right_alias,
269 self.left_alias,
270 self.join_condition.left_column,
271 self.right_alias,
272 self.join_condition.right_column
273 )
274 }
275}
276
277pub struct JoinedSourceBuilder {
279 left_source: Option<(Arc<dyn DataSource>, String)>,
280 right_source: Option<(Arc<dyn DataSource>, String)>,
281 join_condition: Option<JoinCondition>,
282 where_clause: Option<String>,
283 additional_joins: Vec<AdditionalJoin>,
284}
285
286impl JoinedSourceBuilder {
287 fn new() -> Self {
288 Self {
289 left_source: None,
290 right_source: None,
291 join_condition: None,
292 where_clause: None,
293 additional_joins: Vec::new(),
294 }
295 }
296
297 pub fn left_source<S: DataSource + 'static>(
299 mut self,
300 source: S,
301 alias: impl Into<String>,
302 ) -> Self {
303 self.left_source = Some((Arc::new(source), alias.into()));
304 self
305 }
306
307 pub fn right_source<S: DataSource + 'static>(
309 mut self,
310 source: S,
311 alias: impl Into<String>,
312 ) -> Self {
313 self.right_source = Some((Arc::new(source), alias.into()));
314 self
315 }
316
317 pub fn on(mut self, left_column: impl Into<String>, right_column: impl Into<String>) -> Self {
319 self.join_condition = Some(JoinCondition::new(
320 left_column,
321 right_column,
322 JoinType::Inner,
323 ));
324 self
325 }
326
327 pub fn join_on(
329 mut self,
330 left_column: impl Into<String>,
331 right_column: impl Into<String>,
332 join_type: JoinType,
333 ) -> Self {
334 self.join_condition = Some(JoinCondition::new(left_column, right_column, join_type));
335 self
336 }
337
338 pub fn join_type(mut self, join_type: JoinType) -> Self {
340 if let Some(ref mut condition) = self.join_condition {
341 condition.join_type = join_type;
342 }
343 self
344 }
345
346 pub fn where_clause(mut self, clause: impl Into<String>) -> Self {
348 self.where_clause = Some(clause.into());
349 self
350 }
351
352 pub fn additional_join<S: DataSource + 'static>(
354 mut self,
355 source: S,
356 alias: impl Into<String>,
357 left_column: impl Into<String>,
358 right_column: impl Into<String>,
359 join_type: JoinType,
360 ) -> Self {
361 self.additional_joins.push(AdditionalJoin {
362 source: Arc::new(source),
363 alias: alias.into(),
364 condition: JoinCondition::new(left_column, right_column, join_type),
365 });
366 self
367 }
368
369 pub fn build(self) -> Result<JoinedSource> {
371 let left_source = self
372 .left_source
373 .ok_or_else(|| TermError::data_source("joined", "Left source is required"))?;
374 let right_source = self
375 .right_source
376 .ok_or_else(|| TermError::data_source("joined", "Right source is required"))?;
377 let join_condition = self
378 .join_condition
379 .ok_or_else(|| TermError::data_source("joined", "Join condition is required"))?;
380
381 Ok(JoinedSource {
382 left_source: left_source.0,
383 left_alias: left_source.1,
384 right_source: right_source.0,
385 right_alias: right_source.1,
386 join_condition,
387 where_clause: self.where_clause,
388 additional_joins: self.additional_joins,
389 })
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use crate::sources::CsvSource;
397 use std::io::Write;
398 use tempfile::NamedTempFile;
399
400 fn create_test_csv(data: &str) -> Result<NamedTempFile> {
401 let mut temp_file = NamedTempFile::with_suffix(".csv")?;
402 write!(temp_file, "{data}")?;
403 temp_file.flush()?;
404 Ok(temp_file)
405 }
406
407 #[test]
408 fn test_join_type_sql() {
409 assert_eq!(JoinType::Inner.to_sql(), "INNER JOIN");
410 assert_eq!(JoinType::Left.to_sql(), "LEFT JOIN");
411 assert_eq!(JoinType::Right.to_sql(), "RIGHT JOIN");
412 assert_eq!(JoinType::Full.to_sql(), "FULL OUTER JOIN");
413 }
414
415 #[test]
416 fn test_join_condition_sql() {
417 let condition = JoinCondition::new("customer_id", "id", JoinType::Inner);
418 assert_eq!(
419 condition.to_sql("orders", "customers"),
420 "INNER JOIN ON orders.customer_id = customers.id"
421 );
422 }
423
424 #[tokio::test]
425 async fn test_joined_source_builder() -> Result<()> {
426 let orders_data = "order_id,customer_id,amount\n1,1,100.0\n2,2,200.0";
427 let customers_data = "id,name\n1,Alice\n2,Bob";
428
429 let orders_file = create_test_csv(orders_data)?;
430 let customers_file = create_test_csv(customers_data)?;
431
432 let orders_source = CsvSource::new(orders_file.path().to_string_lossy().to_string())?;
433 let customers_source = CsvSource::new(customers_file.path().to_string_lossy().to_string())?;
434
435 let joined_source = JoinedSource::builder()
436 .left_source(orders_source, "orders")
437 .right_source(customers_source, "customers")
438 .on("customer_id", "id")
439 .build()?;
440
441 assert_eq!(joined_source.left_alias(), "orders");
442 assert_eq!(joined_source.right_alias(), "customers");
443 assert_eq!(joined_source.join_condition().join_type, JoinType::Inner);
444
445 Ok(())
446 }
447
448 #[tokio::test]
449 async fn test_joined_source_registration() -> Result<()> {
450 let orders_data = "order_id,customer_id,amount\n1,1,100.0\n2,2,200.0\n3,999,300.0";
451 let customers_data = "id,name\n1,Alice\n2,Bob";
452
453 let orders_file = create_test_csv(orders_data)?;
454 let customers_file = create_test_csv(customers_data)?;
455
456 let orders_source = CsvSource::new(orders_file.path().to_string_lossy().to_string())?;
457 let customers_source = CsvSource::new(customers_file.path().to_string_lossy().to_string())?;
458
459 let joined_source = JoinedSource::builder()
460 .left_source(orders_source, "orders")
461 .right_source(customers_source, "customers")
462 .join_on("customer_id", "id", JoinType::Left)
463 .build()?;
464
465 let ctx = SessionContext::new();
466 joined_source
467 .register(&ctx, "orders_with_customers")
468 .await?;
469
470 let df = ctx
472 .sql("SELECT COUNT(*) as count FROM orders_with_customers")
473 .await?;
474 let results = df.collect().await?;
475
476 assert_eq!(results.len(), 1);
478
479 Ok(())
480 }
481
482 #[test]
483 fn test_joined_source_sql_generation() -> Result<()> {
484 let orders_data = "id,customer_id,amount\n1,1,100.0";
485 let customers_data = "id,name\n1,Alice";
486
487 let orders_file = create_test_csv(orders_data)?;
488 let customers_file = create_test_csv(customers_data)?;
489
490 let orders_source = CsvSource::new(orders_file.path().to_string_lossy().to_string())?;
491 let customers_source = CsvSource::new(customers_file.path().to_string_lossy().to_string())?;
492
493 let joined_source = JoinedSource::builder()
494 .left_source(orders_source, "orders")
495 .right_source(customers_source, "customers")
496 .on("customer_id", "id")
497 .where_clause("orders.amount > 50")
498 .build()?;
499
500 let sql = joined_source.generate_sql("test_view");
501
502 assert!(sql.contains("CREATE OR REPLACE VIEW test_view"));
503 assert!(sql.contains("INNER JOIN"));
504 assert!(sql.contains("orders.customer_id = customers.id"));
505 assert!(sql.contains("WHERE orders.amount > 50"));
506
507 Ok(())
508 }
509
510 #[test]
511 fn test_joined_source_description() -> Result<()> {
512 let orders_data = "id,customer_id,amount\n1,1,100.0";
513 let customers_data = "id,name\n1,Alice";
514
515 let orders_file = create_test_csv(orders_data)?;
516 let customers_file = create_test_csv(customers_data)?;
517
518 let orders_source = CsvSource::new(orders_file.path().to_string_lossy().to_string())?;
519 let customers_source = CsvSource::new(customers_file.path().to_string_lossy().to_string())?;
520
521 let joined_source = JoinedSource::builder()
522 .left_source(orders_source, "orders")
523 .right_source(customers_source, "customers")
524 .join_on("customer_id", "id", JoinType::Left)
525 .build()?;
526
527 let description = joined_source.description();
528 assert!(description.contains("orders"));
529 assert!(description.contains("LEFT JOIN"));
530 assert!(description.contains("customers"));
531 assert!(description.contains("customer_id"));
532
533 Ok(())
534 }
535}