1use crate::expr::{Dialect, Expr};
32use sqlmodel_core::Value;
33
34#[derive(Debug, Clone)]
36pub struct Cte {
37 name: String,
39 columns: Vec<String>,
41 recursive: bool,
43 query_sql: String,
45 query_params: Vec<Value>,
47 union_sql: Option<String>,
49 union_params: Vec<Value>,
51}
52
53impl Cte {
54 pub fn new(name: impl Into<String>) -> Self {
67 Self {
68 name: name.into(),
69 columns: Vec::new(),
70 recursive: false,
71 query_sql: String::new(),
72 query_params: Vec::new(),
73 union_sql: None,
74 union_params: Vec::new(),
75 }
76 }
77
78 pub fn recursive(name: impl Into<String>) -> Self {
93 Self {
94 name: name.into(),
95 columns: Vec::new(),
96 recursive: true,
97 query_sql: String::new(),
98 query_params: Vec::new(),
99 union_sql: None,
100 union_params: Vec::new(),
101 }
102 }
103
104 pub fn columns(mut self, cols: &[&str]) -> Self {
114 self.columns = cols.iter().map(|&s| s.to_string()).collect();
115 self
116 }
117
118 pub fn as_select(mut self, sql: impl Into<String>) -> Self {
124 self.query_sql = sql.into();
125 self
126 }
127
128 pub fn as_select_with_params(mut self, sql: impl Into<String>, params: Vec<Value>) -> Self {
135 self.query_sql = sql.into();
136 self.query_params = params;
137 self
138 }
139
140 pub fn union_all(mut self, sql: impl Into<String>) -> Self {
146 self.union_sql = Some(sql.into());
147 self
148 }
149
150 pub fn union_all_with_params(mut self, sql: impl Into<String>, params: Vec<Value>) -> Self {
152 self.union_sql = Some(sql.into());
153 self.union_params = params;
154 self
155 }
156
157 pub fn name(&self) -> &str {
159 &self.name
160 }
161
162 pub fn is_recursive(&self) -> bool {
164 self.recursive
165 }
166
167 pub fn as_ref(&self) -> CteRef {
179 CteRef {
180 name: self.name.clone(),
181 }
182 }
183
184 pub fn build(&self, dialect: Dialect) -> (String, Vec<Value>) {
188 let mut sql = String::new();
189 let mut params = Vec::new();
190
191 sql.push_str(&dialect.quote_identifier(&self.name));
193
194 if !self.columns.is_empty() {
195 sql.push_str(" (");
196 let quoted_cols: Vec<_> = self
197 .columns
198 .iter()
199 .map(|c| dialect.quote_identifier(c))
200 .collect();
201 sql.push_str("ed_cols.join(", "));
202 sql.push(')');
203 }
204
205 sql.push_str(" AS (");
206
207 sql.push_str(&self.query_sql);
209 params.extend(self.query_params.clone());
210
211 if let Some(union) = &self.union_sql {
213 sql.push_str(" UNION ALL ");
214 sql.push_str(union);
215 params.extend(self.union_params.clone());
216 }
217
218 sql.push(')');
219
220 (sql, params)
221 }
222}
223
224#[derive(Debug, Clone)]
226pub struct CteRef {
227 name: String,
228}
229
230impl CteRef {
231 pub fn new(name: impl Into<String>) -> Self {
233 Self { name: name.into() }
234 }
235
236 pub fn col(&self, column: impl Into<String>) -> Expr {
245 Expr::qualified(&self.name, column)
246 }
247
248 pub fn name(&self) -> &str {
250 &self.name
251 }
252}
253
254#[derive(Debug, Clone)]
256pub struct WithQuery {
257 ctes: Vec<Cte>,
259 main_sql: String,
261 main_params: Vec<Value>,
263}
264
265impl WithQuery {
266 pub fn new() -> Self {
268 Self {
269 ctes: Vec::new(),
270 main_sql: String::new(),
271 main_params: Vec::new(),
272 }
273 }
274
275 pub fn with_cte(mut self, cte: Cte) -> Self {
279 self.ctes.push(cte);
280 self
281 }
282
283 pub fn with_ctes(mut self, ctes: Vec<Cte>) -> Self {
285 self.ctes.extend(ctes);
286 self
287 }
288
289 pub fn select(mut self, sql: impl Into<String>) -> Self {
291 self.main_sql = sql.into();
292 self
293 }
294
295 pub fn select_with_params(mut self, sql: impl Into<String>, params: Vec<Value>) -> Self {
297 self.main_sql = sql.into();
298 self.main_params = params;
299 self
300 }
301
302 pub fn build(&self) -> (String, Vec<Value>) {
304 self.build_with_dialect(Dialect::Postgres)
305 }
306
307 pub fn build_with_dialect(&self, dialect: Dialect) -> (String, Vec<Value>) {
309 let mut sql = String::new();
310 let mut params = Vec::new();
311
312 if !self.ctes.is_empty() {
313 let has_recursive = self.ctes.iter().any(|c| c.recursive);
315
316 if has_recursive {
317 sql.push_str("WITH RECURSIVE ");
318 } else {
319 sql.push_str("WITH ");
320 }
321
322 let cte_sqls: Vec<String> = self
324 .ctes
325 .iter()
326 .map(|cte| {
327 let (cte_sql, cte_params) = cte.build(dialect);
328 params.extend(cte_params);
329 cte_sql
330 })
331 .collect();
332
333 sql.push_str(&cte_sqls.join(", "));
334 sql.push(' ');
335 }
336
337 sql.push_str(&self.main_sql);
339 params.extend(self.main_params.clone());
340
341 (sql, params)
342 }
343}
344
345impl Default for WithQuery {
346 fn default() -> Self {
347 Self::new()
348 }
349}
350
351#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
358 fn test_basic_cte() {
359 let cte = Cte::new("active_users").as_select("SELECT * FROM users WHERE active = true");
360
361 let (sql, params) = cte.build(Dialect::Postgres);
362 assert_eq!(
363 sql,
364 "\"active_users\" AS (SELECT * FROM users WHERE active = true)"
365 );
366 assert!(params.is_empty());
367 }
368
369 #[test]
370 fn test_cte_with_columns() {
371 let cte = Cte::new("user_totals")
372 .columns(&["user_id", "total"])
373 .as_select("SELECT user_id, SUM(amount) FROM orders GROUP BY user_id");
374
375 let (sql, params) = cte.build(Dialect::Postgres);
376 assert_eq!(
377 sql,
378 "\"user_totals\" (\"user_id\", \"total\") AS (SELECT user_id, SUM(amount) FROM orders GROUP BY user_id)"
379 );
380 assert!(params.is_empty());
381 }
382
383 #[test]
384 fn test_cte_with_params() {
385 let cte = Cte::new("recent_orders").as_select_with_params(
386 "SELECT * FROM orders WHERE amount > $1",
387 vec![Value::Int(100)],
388 );
389
390 let (sql, params) = cte.build(Dialect::Postgres);
391 assert_eq!(
392 sql,
393 "\"recent_orders\" AS (SELECT * FROM orders WHERE amount > $1)"
394 );
395 assert_eq!(params, vec![Value::Int(100)]);
396 }
397
398 #[test]
399 fn test_recursive_cte() {
400 let cte = Cte::recursive("hierarchy")
401 .columns(&["id", "name", "level"])
402 .as_select("SELECT id, name, 0 FROM employees WHERE manager_id IS NULL")
403 .union_all("SELECT e.id, e.name, h.level + 1 FROM employees e JOIN hierarchy h ON e.manager_id = h.id");
404
405 let (sql, _) = cte.build(Dialect::Postgres);
406 assert!(sql.contains("UNION ALL"));
407 assert!(cte.is_recursive());
408 }
409
410 #[test]
411 fn test_cte_ref_column() {
412 let cte_ref = CteRef::new("my_cte");
413 let expr = cte_ref.col("name");
414
415 let mut params = Vec::new();
416 let sql = expr.build(&mut params, 0);
417 assert_eq!(sql, "\"my_cte\".\"name\"");
418 }
419
420 #[test]
421 fn test_with_query_single_cte() {
422 let cte = Cte::new("active_users").as_select("SELECT * FROM users WHERE active = true");
423
424 let query = WithQuery::new()
425 .with_cte(cte)
426 .select("SELECT * FROM active_users");
427
428 let (sql, params) = query.build();
429 assert_eq!(
430 sql,
431 "WITH \"active_users\" AS (SELECT * FROM users WHERE active = true) SELECT * FROM active_users"
432 );
433 assert!(params.is_empty());
434 }
435
436 #[test]
437 fn test_with_query_multiple_ctes() {
438 let cte1 = Cte::new("active_users").as_select("SELECT * FROM users WHERE active = true");
439
440 let cte2 = Cte::new("user_orders")
441 .as_select("SELECT u.id, COUNT(*) as order_count FROM active_users u JOIN orders o ON u.id = o.user_id GROUP BY u.id");
442
443 let query = WithQuery::new()
444 .with_cte(cte1)
445 .with_cte(cte2)
446 .select("SELECT * FROM user_orders WHERE order_count > 5");
447
448 let (sql, _) = query.build();
449 assert!(sql.starts_with("WITH "));
450 assert!(sql.contains("\"active_users\" AS"));
451 assert!(sql.contains("\"user_orders\" AS"));
452 }
453
454 #[test]
455 fn test_with_query_recursive() {
456 let cte = Cte::recursive("numbers")
457 .columns(&["n"])
458 .as_select("SELECT 1")
459 .union_all("SELECT n + 1 FROM numbers WHERE n < 10");
460
461 let query = WithQuery::new()
462 .with_cte(cte)
463 .select("SELECT * FROM numbers");
464
465 let (sql, _) = query.build();
466 assert!(sql.starts_with("WITH RECURSIVE "));
467 }
468
469 #[test]
470 fn test_cte_mysql_dialect() {
471 let cte = Cte::new("temp")
472 .columns(&["col1", "col2"])
473 .as_select("SELECT a, b FROM t");
474
475 let (sql, _) = cte.build(Dialect::Mysql);
476 assert_eq!(sql, "`temp` (`col1`, `col2`) AS (SELECT a, b FROM t)");
477 }
478
479 #[test]
480 fn test_cte_sqlite_dialect() {
481 let cte = Cte::new("temp").as_select("SELECT 1");
482
483 let (sql, _) = cte.build(Dialect::Sqlite);
484 assert_eq!(sql, "\"temp\" AS (SELECT 1)");
485 }
486
487 #[test]
488 fn test_with_query_params_aggregation() {
489 let cte = Cte::new("filtered")
490 .as_select_with_params("SELECT * FROM items WHERE price > $1", vec![Value::Int(50)]);
491
492 let query = WithQuery::new().with_cte(cte).select_with_params(
493 "SELECT * FROM filtered WHERE category = $2",
494 vec![Value::Text("electronics".to_string())],
495 );
496
497 let (sql, params) = query.build();
498 assert_eq!(params.len(), 2);
499 assert_eq!(params[0], Value::Int(50));
500 assert_eq!(params[1], Value::Text("electronics".to_string()));
501 assert!(sql.contains("$1"));
502 assert!(sql.contains("$2"));
503 }
504
505 #[test]
506 fn test_recursive_cte_hierarchy_example() {
507 let cte = Cte::recursive("org_chart")
509 .columns(&["id", "name", "manager_id", "level"])
510 .as_select("SELECT id, name, manager_id, 0 AS level FROM employees WHERE manager_id IS NULL")
511 .union_all("SELECT e.id, e.name, e.manager_id, oc.level + 1 FROM employees e INNER JOIN org_chart oc ON e.manager_id = oc.id");
512
513 let query = WithQuery::new()
514 .with_cte(cte)
515 .select("SELECT * FROM org_chart ORDER BY level, name");
516
517 let (sql, _) = query.build();
518
519 assert!(sql.starts_with("WITH RECURSIVE "));
520 assert!(sql.contains("\"org_chart\""));
521 assert!(sql.contains("UNION ALL"));
522 assert!(sql.contains("ORDER BY level, name"));
523 }
524
525 #[test]
526 fn test_cte_chained_references() {
527 let cte1 =
529 Cte::new("base_data").as_select("SELECT id, value FROM raw_data WHERE valid = true");
530
531 let cte2 = Cte::new("aggregated")
532 .as_select("SELECT COUNT(*) as cnt, SUM(value) as total FROM base_data");
533
534 let query = WithQuery::new()
535 .with_cte(cte1)
536 .with_cte(cte2)
537 .select("SELECT * FROM aggregated");
538
539 let (sql, _) = query.build();
540
541 let base_pos = sql.find("\"base_data\"").unwrap();
543 let agg_pos = sql.find("\"aggregated\"").unwrap();
544 assert!(
545 base_pos < agg_pos,
546 "base_data should come before aggregated"
547 );
548 }
549}