1use std::fmt;
37
38use indexmap::IndexMap;
39use serde::Deserialize;
40
41use crate::config::{Aggregate, FilterTerm, Scalar, Sort, SortDir, ViewConfig};
42use crate::proto::{ColumnType, ViewPort};
43
44#[derive(Debug, Clone)]
46pub enum GenericSQLError {
47 ColumnNotFound(String),
49 InvalidConfig(String),
51 UnsupportedOperation(String),
53}
54
55impl fmt::Display for GenericSQLError {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 match self {
58 Self::ColumnNotFound(col) => write!(f, "Column not found: {}", col),
59 Self::InvalidConfig(msg) => write!(f, "Invalid configuration: {}", msg),
60 Self::UnsupportedOperation(msg) => write!(f, "Unsupported operation: {}", msg),
61 }
62 }
63}
64
65impl std::error::Error for GenericSQLError {}
66
67pub type GenericSQLResult<T> = Result<T, GenericSQLError>;
69
70#[derive(Clone, Debug, Deserialize, Default)]
71pub struct GenericSQLVirtualServerModelArgs {
72 create_entity: Option<String>,
73 grouping_fn: Option<String>,
74}
75
76#[derive(Debug, Default, Clone)]
81pub struct GenericSQLVirtualServerModel(GenericSQLVirtualServerModelArgs);
82
83impl GenericSQLVirtualServerModel {
84 pub fn new(args: GenericSQLVirtualServerModelArgs) -> Self {
86 tracing::error!("{:?}", args);
87 Self(args)
88 }
89
90 pub fn get_hosted_tables(&self) -> GenericSQLResult<String> {
95 Ok("SHOW ALL TABLES".to_string())
96 }
97
98 pub fn table_schema(&self, table_id: &str) -> GenericSQLResult<String> {
106 Ok(format!("DESCRIBE {}", table_id))
107 }
108
109 pub fn table_size(&self, table_id: &str) -> GenericSQLResult<String> {
117 Ok(format!("SELECT COUNT(*) FROM {}", table_id))
118 }
119
120 pub fn view_column_size(&self, view_id: &str) -> GenericSQLResult<String> {
128 Ok(format!("SELECT COUNT(*) FROM (DESCRIBE {})", view_id))
129 }
130
131 pub fn table_validate_expression(
140 &self,
141 table_id: &str,
142 expression: &str,
143 ) -> GenericSQLResult<String> {
144 Ok(format!(
145 "DESCRIBE (SELECT {} FROM {})",
146 expression, table_id
147 ))
148 }
149
150 pub fn view_delete(&self, view_id: &str) -> GenericSQLResult<String> {
158 Ok(format!("DROP TABLE IF EXISTS {}", view_id))
159 }
160
161 pub fn table_make_view(
173 &self,
174 table_id: &str,
175 view_id: &str,
176 config: &ViewConfig,
177 ) -> GenericSQLResult<String> {
178 let columns = &config.columns;
179 let group_by = &config.group_by;
180 let split_by = &config.split_by;
181 let aggregates = &config.aggregates;
182 let sort = &config.sort;
183 let expressions = &config.expressions.0;
184 let filter = &config.filter;
185
186 let col_name = |col: &str| -> String {
187 expressions
188 .get(col)
189 .cloned()
190 .unwrap_or_else(|| format!("\"{}\"", col))
191 };
192
193 let get_aggregate = |col: &str| -> Option<&Aggregate> { aggregates.get(col) };
194 let generate_select_clauses = || -> Vec<String> {
195 let mut clauses = Vec::new();
196
197 if !group_by.is_empty() {
198 for col in columns.iter().flatten() {
199 let agg = get_aggregate(col)
200 .map(Self::aggregate_to_string)
201 .unwrap_or_else(|| "any_value".to_string());
202 clauses.push(format!(
203 "{}({}) as \"{}\"",
204 agg,
205 col_name(col),
206 col.replace('"', "\"\"").replace("_", "-")
207 ));
208 }
209
210 if split_by.is_empty() {
211 for (idx, gb_col) in group_by.iter().enumerate() {
212 clauses.push(format!("{} as __ROW_PATH_{}__", col_name(gb_col), idx));
213 }
214
215 let groups = group_by.iter().map(|c| col_name(c)).collect::<Vec<_>>();
216 let grouping_fn = self.0.grouping_fn.as_deref().unwrap_or("GROUPING_ID");
217 clauses.push(format!(
218 "{}({}) AS __GROUPING_ID__",
219 grouping_fn,
220 groups.join(", ")
221 ));
222 }
223 } else if !columns.is_empty() {
224 for col in columns.iter().flatten() {
225 let escaped_col = col.replace('"', "\"\"").replace("_", "-");
226 clauses.push(format!("{} as \"{}\"", col_name(col), escaped_col));
227 }
228 }
229
230 clauses
231 };
232
233 let mut order_by_clauses: Vec<String> = Vec::new();
234 let mut window_clauses: Vec<String> = Vec::new();
235 let mut where_clauses: Vec<String> = Vec::new();
236
237 if !group_by.is_empty() {
238 for gidx in 0..group_by.len() {
239 let groups = group_by[..=gidx]
240 .iter()
241 .map(|c| col_name(c))
242 .collect::<Vec<_>>()
243 .join(", ");
244
245 if split_by.is_empty() {
246 let grouping_fn = self.0.grouping_fn.as_deref().unwrap_or("GROUPING_ID");
247 order_by_clauses.push(format!("{}({}) DESC", grouping_fn, groups));
248 }
249
250 for Sort(sort_col, sort_dir) in sort {
251 if *sort_dir != SortDir::None {
252 let agg = get_aggregate(sort_col)
253 .map(Self::aggregate_to_string)
254 .unwrap_or_else(|| "any_value".to_string());
255 let dir_str = Self::sort_dir_to_string(sort_dir);
256
257 if gidx >= group_by.len() - 1 {
258 order_by_clauses.push(format!(
259 "{}({}) {}",
260 agg,
261 col_name(sort_col),
262 dir_str
263 ));
264 } else {
265 order_by_clauses.push(format!(
266 "first({}({})) OVER __WINDOW_{}__ {}",
267 agg,
268 col_name(sort_col),
269 gidx,
270 dir_str
271 ));
272 }
273 }
274 }
275
276 order_by_clauses.push(format!("__ROW_PATH_{}__ ASC", gidx));
277 }
278 } else {
279 for Sort(sort_col, sort_dir) in sort {
280 if *sort_dir != SortDir::None {
281 let dir_str = Self::sort_dir_to_string(sort_dir);
282 order_by_clauses.push(format!("{} {}", col_name(sort_col), dir_str));
283 }
284 }
285 }
286
287 if !sort.is_empty() && group_by.len() > 1 {
288 for gidx in 0..(group_by.len() - 1) {
289 let partition = (0..=gidx)
290 .map(|i| format!("__ROW_PATH_{}__", i))
291 .collect::<Vec<_>>()
292 .join(", ");
293
294 let sub_groups = group_by[..=gidx]
295 .iter()
296 .map(|c| col_name(c))
297 .collect::<Vec<_>>()
298 .join(", ");
299
300 let groups = group_by.iter().map(|c| col_name(c)).collect::<Vec<_>>();
301 let grouping_fn = self.0.grouping_fn.as_deref().unwrap_or("GROUPING_ID");
302 window_clauses.push(format!(
303 "__WINDOW_{}__ AS (PARTITION BY {}({}), {} ORDER BY {})",
304 gidx,
305 grouping_fn,
306 sub_groups,
307 partition,
308 groups.join(", ")
309 ));
310 }
311 }
312
313 for flt in filter {
314 let term = Self::filter_term_to_sql(flt.term());
315 if let Some(term_lit) = term {
316 where_clauses.push(format!(
317 "{} {} {}",
318 col_name(flt.column()),
319 flt.op(),
320 term_lit
321 ));
322 }
323 }
324
325 let mut query = if !split_by.is_empty() {
326 format!("SELECT * FROM {}", table_id)
327 } else {
328 let select_clauses = generate_select_clauses();
329 format!("SELECT {} FROM {}", select_clauses.join(", "), table_id)
330 };
331
332 if !where_clauses.is_empty() {
333 query = format!("{} WHERE {}", query, where_clauses.join(" AND "));
334 }
335
336 if !split_by.is_empty() {
337 let groups = group_by.iter().map(|c| col_name(c)).collect::<Vec<_>>();
338 let group_aliases = group_by
339 .iter()
340 .enumerate()
341 .map(|(i, c)| format!("{} AS __ROW_PATH_{}__", col_name(c), i))
342 .collect::<Vec<_>>()
343 .join(", ");
344 let pivot_on = split_by
345 .iter()
346 .map(|c| format!("\"{}\"", c))
347 .collect::<Vec<_>>()
348 .join(", ");
349 let pivot_using = generate_select_clauses().join(", ");
350
351 query = format!(
352 "SELECT * EXCLUDE ({}) , {} FROM (PIVOT ({}) ON {} USING {} GROUP BY {})",
353 groups.join(", "),
354 group_aliases,
355 query,
356 pivot_on,
357 pivot_using,
358 groups.join(", ")
359 );
360 } else if !group_by.is_empty() {
361 let groups = group_by.iter().map(|c| col_name(c)).collect::<Vec<_>>();
362 query = format!("{} GROUP BY ROLLUP({})", query, groups.join(", "));
363 }
364
365 if !window_clauses.is_empty() {
366 query = format!("{} WINDOW {}", query, window_clauses.join(", "));
367 }
368
369 if !order_by_clauses.is_empty() {
370 query = format!("{} ORDER BY {}", query, order_by_clauses.join(", "));
371 }
372
373 let template = self.0.create_entity.as_deref().unwrap_or("TABLE");
374 Ok(format!("CREATE {} {} AS ({})", template, view_id, query))
375 }
376
377 pub fn view_get_data(
388 &self,
389 view_id: &str,
390 config: &ViewConfig,
391 viewport: &ViewPort,
392 schema: &IndexMap<String, ColumnType>,
393 ) -> GenericSQLResult<String> {
394 let group_by = &config.group_by;
395 let split_by = &config.split_by;
396 let start_col = viewport.start_col.unwrap_or(0) as usize;
397 let end_col = viewport.end_col.map(|x| x as usize);
398 let start_row = viewport.start_row.unwrap_or(0);
399 let end_row = viewport.end_row;
400 let limit_clause = if let Some(end) = end_row {
401 format!("LIMIT {} OFFSET {}", end - start_row, start_row)
402 } else {
403 String::new()
404 };
405
406 let data_columns: Vec<&String> = schema
407 .keys()
408 .filter(|col_name| !col_name.starts_with("__"))
409 .skip(start_col)
410 .take(end_col.map(|e| e - start_col).unwrap_or(usize::MAX))
411 .collect();
412
413 let mut group_by_cols: Vec<String> = Vec::new();
414 if !group_by.is_empty() {
415 if split_by.is_empty() {
416 group_by_cols.push("\"__GROUPING_ID__\"".to_string());
417 }
418
419 for idx in 0..group_by.len() {
420 group_by_cols.push(format!("\"__ROW_PATH_{}__\"", idx));
421 }
422 }
423
424 let all_columns: Vec<String> = group_by_cols
425 .into_iter()
426 .chain(data_columns.iter().map(|col| format!("\"{}\"", col)))
427 .collect();
428
429 Ok(format!(
430 "SELECT {} FROM {} {}",
431 all_columns.join(", "),
432 view_id,
433 limit_clause
434 )
435 .trim()
436 .to_string())
437 }
438
439 pub fn view_schema(&self, view_id: &str) -> GenericSQLResult<String> {
447 Ok(format!("DESCRIBE {}", view_id))
448 }
449
450 pub fn view_size(&self, view_id: &str) -> GenericSQLResult<String> {
458 Ok(format!("SELECT COUNT(*) FROM {}", view_id))
459 }
460
461 fn aggregate_to_string(agg: &Aggregate) -> String {
462 match agg {
463 Aggregate::SingleAggregate(name) => name.clone(),
464 Aggregate::MultiAggregate(name, _args) => name.clone(),
465 }
466 }
467
468 fn sort_dir_to_string(dir: &SortDir) -> &'static str {
469 match dir {
470 SortDir::None => "",
471 SortDir::Asc | SortDir::ColAsc | SortDir::AscAbs | SortDir::ColAscAbs => "ASC",
472 SortDir::Desc | SortDir::ColDesc | SortDir::DescAbs | SortDir::ColDescAbs => "DESC",
473 }
474 }
475
476 fn filter_term_to_sql(term: &FilterTerm) -> Option<String> {
477 match term {
478 FilterTerm::Scalar(scalar) => Self::scalar_to_sql(scalar),
479 FilterTerm::Array(scalars) => {
480 let values: Vec<String> = scalars.iter().filter_map(Self::scalar_to_sql).collect();
481 if values.is_empty() {
482 None
483 } else {
484 Some(format!("({})", values.join(", ")))
485 }
486 },
487 }
488 }
489
490 fn scalar_to_sql(scalar: &Scalar) -> Option<String> {
491 match scalar {
492 Scalar::Null => None,
493 Scalar::Bool(b) => Some(if *b { "TRUE" } else { "FALSE" }.to_string()),
494 Scalar::Float(f) => Some(f.to_string()),
495 Scalar::String(s) => Some(format!("'{}'", s.replace('\'', "''"))),
496 }
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503
504 #[test]
505 fn test_get_hosted_tables() {
506 let builder =
507 GenericSQLVirtualServerModel::new(GenericSQLVirtualServerModelArgs::default());
508 assert_eq!(builder.get_hosted_tables().unwrap(), "SHOW ALL TABLES");
509 }
510
511 #[test]
512 fn test_table_schema() {
513 let builder =
514 GenericSQLVirtualServerModel::new(GenericSQLVirtualServerModelArgs::default());
515 assert_eq!(
516 builder.table_schema("my_table").unwrap(),
517 "DESCRIBE my_table"
518 );
519 }
520
521 #[test]
522 fn test_table_size() {
523 let builder =
524 GenericSQLVirtualServerModel::new(GenericSQLVirtualServerModelArgs::default());
525 assert_eq!(
526 builder.table_size("my_table").unwrap(),
527 "SELECT COUNT(*) FROM my_table"
528 );
529 }
530
531 #[test]
532 fn test_view_delete() {
533 let builder =
534 GenericSQLVirtualServerModel::new(GenericSQLVirtualServerModelArgs::default());
535 assert_eq!(
536 builder.view_delete("my_view").unwrap(),
537 "DROP TABLE IF EXISTS my_view"
538 );
539 }
540
541 #[test]
542 fn test_table_make_view_simple() {
543 let builder =
544 GenericSQLVirtualServerModel::new(GenericSQLVirtualServerModelArgs::default());
545 let mut config = ViewConfig::default();
546 config.columns = vec![Some("col1".to_string()), Some("col2".to_string())];
547 let sql = builder
548 .table_make_view("source_table", "dest_view", &config)
549 .unwrap();
550
551 assert!(sql.starts_with("CREATE TABLE dest_view AS"));
552 assert!(sql.contains("\"col1\""));
553 assert!(sql.contains("\"col2\""));
554 }
555
556 #[test]
557 fn test_table_make_view_with_group_by() {
558 let builder =
559 GenericSQLVirtualServerModel::new(GenericSQLVirtualServerModelArgs::default());
560 let mut config = ViewConfig::default();
561 config.columns = vec![Some("value".to_string())];
562 config.group_by = vec!["category".to_string()];
563 let sql = builder
564 .table_make_view("source_table", "dest_view", &config)
565 .unwrap();
566
567 assert!(sql.contains("GROUP BY ROLLUP"));
568 assert!(sql.contains("__ROW_PATH_0__"));
569 assert!(sql.contains("__GROUPING_ID__"));
570 }
571
572 #[test]
573 fn test_view_get_data() {
574 let builder =
575 GenericSQLVirtualServerModel::new(GenericSQLVirtualServerModelArgs::default());
576 let config = ViewConfig::default();
577 let viewport = ViewPort {
578 start_row: Some(0),
579 end_row: Some(100),
580 start_col: Some(0),
581 end_col: Some(5),
582 };
583
584 let mut schema = IndexMap::new();
585 schema.insert("col1".to_string(), ColumnType::String);
586 schema.insert("col2".to_string(), ColumnType::Integer);
587 let sql = builder
588 .view_get_data("my_view", &config, &viewport, &schema)
589 .unwrap();
590
591 assert!(sql.contains("SELECT"));
592 assert!(sql.contains("FROM my_view"));
593 assert!(sql.contains("LIMIT 100 OFFSET 0"));
594 }
595}