prax_query/
advanced.rs

1//! Advanced query features.
2//!
3//! This module provides advanced SQL query capabilities including:
4//! - LATERAL joins (correlated subqueries)
5//! - DISTINCT ON
6//! - RETURNING/OUTPUT clauses
7//! - Row locking (FOR UPDATE/SHARE)
8//! - TABLESAMPLE
9//! - Bulk operations
10//!
11//! # Database Support
12//!
13//! | Feature           | PostgreSQL     | MySQL    | SQLite | MSSQL           | MongoDB      |
14//! |-------------------|----------------|----------|--------|-----------------|--------------|
15//! | LATERAL joins     | ✅             | ✅       | ❌     | ✅ CROSS APPLY  | ✅ $lookup   |
16//! | DISTINCT ON       | ✅             | ❌       | ❌     | ❌              | ✅ $first    |
17//! | RETURNING/OUTPUT  | ✅             | ❌       | ✅     | ✅ OUTPUT       | ✅           |
18//! | FOR UPDATE/SHARE  | ✅             | ✅       | ❌     | ✅ WITH UPDLOCK | ❌           |
19//! | TABLESAMPLE       | ✅             | ❌       | ❌     | ✅              | ✅ $sample   |
20//! | Bulk operations   | ✅             | ✅       | ✅     | ✅              | ✅ bulkWrite |
21
22use serde::{Deserialize, Serialize};
23
24use crate::error::{QueryError, QueryResult};
25use crate::sql::DatabaseType;
26
27// ============================================================================
28// LATERAL Joins
29// ============================================================================
30
31/// A LATERAL join specification.
32#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
33pub struct LateralJoin {
34    /// The subquery or function call.
35    pub subquery: String,
36    /// Alias for the lateral result.
37    pub alias: String,
38    /// Join type.
39    pub join_type: LateralJoinType,
40    /// Optional ON condition (for LEFT LATERAL).
41    pub condition: Option<String>,
42}
43
44/// LATERAL join type.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum LateralJoinType {
47    /// CROSS JOIN LATERAL / CROSS APPLY.
48    Cross,
49    /// LEFT JOIN LATERAL / OUTER APPLY.
50    Left,
51}
52
53impl LateralJoin {
54    /// Create a new LATERAL join.
55    pub fn new(subquery: impl Into<String>, alias: impl Into<String>) -> LateralJoinBuilder {
56        LateralJoinBuilder::new(subquery, alias)
57    }
58
59    /// Generate PostgreSQL LATERAL join.
60    pub fn to_postgres_sql(&self) -> String {
61        match self.join_type {
62            LateralJoinType::Cross => {
63                format!("CROSS JOIN LATERAL ({}) AS {}", self.subquery, self.alias)
64            }
65            LateralJoinType::Left => {
66                let cond = self.condition.as_deref().unwrap_or("TRUE");
67                format!(
68                    "LEFT JOIN LATERAL ({}) AS {} ON {}",
69                    self.subquery, self.alias, cond
70                )
71            }
72        }
73    }
74
75    /// Generate MySQL LATERAL join.
76    pub fn to_mysql_sql(&self) -> String {
77        match self.join_type {
78            LateralJoinType::Cross => {
79                format!("CROSS JOIN LATERAL ({}) AS {}", self.subquery, self.alias)
80            }
81            LateralJoinType::Left => {
82                let cond = self.condition.as_deref().unwrap_or("TRUE");
83                format!(
84                    "LEFT JOIN LATERAL ({}) AS {} ON {}",
85                    self.subquery, self.alias, cond
86                )
87            }
88        }
89    }
90
91    /// Generate MSSQL APPLY join.
92    pub fn to_mssql_sql(&self) -> String {
93        match self.join_type {
94            LateralJoinType::Cross => {
95                format!("CROSS APPLY ({}) AS {}", self.subquery, self.alias)
96            }
97            LateralJoinType::Left => {
98                format!("OUTER APPLY ({}) AS {}", self.subquery, self.alias)
99            }
100        }
101    }
102
103    /// Generate SQL for database type.
104    pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
105        match db_type {
106            DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
107            DatabaseType::MySQL => Ok(self.to_mysql_sql()),
108            DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
109            DatabaseType::SQLite => Err(QueryError::unsupported(
110                "LATERAL joins are not supported in SQLite",
111            )),
112        }
113    }
114}
115
116/// Builder for LATERAL joins.
117#[derive(Debug, Clone)]
118pub struct LateralJoinBuilder {
119    subquery: String,
120    alias: String,
121    join_type: LateralJoinType,
122    condition: Option<String>,
123}
124
125impl LateralJoinBuilder {
126    /// Create a new builder.
127    pub fn new(subquery: impl Into<String>, alias: impl Into<String>) -> Self {
128        Self {
129            subquery: subquery.into(),
130            alias: alias.into(),
131            join_type: LateralJoinType::Cross,
132            condition: None,
133        }
134    }
135
136    /// Make this a LEFT LATERAL join.
137    pub fn left(mut self) -> Self {
138        self.join_type = LateralJoinType::Left;
139        self
140    }
141
142    /// Make this a CROSS LATERAL join (default).
143    pub fn cross(mut self) -> Self {
144        self.join_type = LateralJoinType::Cross;
145        self
146    }
147
148    /// Set the ON condition.
149    pub fn on(mut self, condition: impl Into<String>) -> Self {
150        self.condition = Some(condition.into());
151        self
152    }
153
154    /// Build the LATERAL join.
155    pub fn build(self) -> LateralJoin {
156        LateralJoin {
157            subquery: self.subquery,
158            alias: self.alias,
159            join_type: self.join_type,
160            condition: self.condition,
161        }
162    }
163}
164
165// ============================================================================
166// DISTINCT ON
167// ============================================================================
168
169/// DISTINCT ON clause (PostgreSQL specific).
170#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
171pub struct DistinctOn {
172    /// Columns to distinct on.
173    pub columns: Vec<String>,
174}
175
176impl DistinctOn {
177    /// Create a new DISTINCT ON clause.
178    pub fn new<I, S>(columns: I) -> Self
179    where
180        I: IntoIterator<Item = S>,
181        S: Into<String>,
182    {
183        Self {
184            columns: columns.into_iter().map(Into::into).collect(),
185        }
186    }
187
188    /// Generate PostgreSQL DISTINCT ON clause.
189    pub fn to_postgres_sql(&self) -> String {
190        format!("DISTINCT ON ({})", self.columns.join(", "))
191    }
192
193    /// Generate MySQL workaround using GROUP BY.
194    /// Note: This is not exactly equivalent to DISTINCT ON.
195    pub fn to_mysql_workaround(&self) -> String {
196        format!(
197            "-- MySQL workaround: Use GROUP BY {} with appropriate aggregates",
198            self.columns.join(", ")
199        )
200    }
201}
202
203/// MongoDB $first aggregation helper for DISTINCT ON behavior.
204pub mod mongodb_distinct {
205    use serde_json::Value as JsonValue;
206
207    /// Generate $group stage that mimics DISTINCT ON.
208    pub fn distinct_on_stage(
209        group_fields: &[&str],
210        first_fields: &[&str],
211    ) -> JsonValue {
212        let mut group_id = serde_json::Map::new();
213        for field in group_fields {
214            group_id.insert(field.to_string(), serde_json::json!(format!("${}", field)));
215        }
216
217        let mut group_spec = serde_json::Map::new();
218        group_spec.insert("_id".to_string(), serde_json::json!(group_id));
219
220        for field in first_fields {
221            group_spec.insert(
222                field.to_string(),
223                serde_json::json!({ "$first": format!("${}", field) }),
224            );
225        }
226
227        serde_json::json!({ "$group": group_spec })
228    }
229}
230
231// ============================================================================
232// RETURNING / OUTPUT Clause
233// ============================================================================
234
235/// RETURNING clause specification.
236#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
237pub struct Returning {
238    /// Columns to return.
239    pub columns: Vec<ReturningColumn>,
240    /// Operation type (for MSSQL OUTPUT).
241    pub operation: ReturnOperation,
242}
243
244/// A column in the RETURNING clause.
245#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
246pub enum ReturningColumn {
247    /// All columns (*).
248    All,
249    /// Specific column name.
250    Column(String),
251    /// Expression with alias.
252    Expression { expr: String, alias: String },
253    /// MSSQL INSERTED.column.
254    Inserted(String),
255    /// MSSQL DELETED.column.
256    Deleted(String),
257}
258
259/// Operation type for RETURNING/OUTPUT.
260#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
261pub enum ReturnOperation {
262    Insert,
263    Update,
264    Delete,
265}
266
267impl Returning {
268    /// Create RETURNING all columns.
269    pub fn all(operation: ReturnOperation) -> Self {
270        Self {
271            columns: vec![ReturningColumn::All],
272            operation,
273        }
274    }
275
276    /// Create RETURNING specific columns.
277    pub fn columns<I, S>(operation: ReturnOperation, columns: I) -> Self
278    where
279        I: IntoIterator<Item = S>,
280        S: Into<String>,
281    {
282        Self {
283            columns: columns.into_iter().map(|c| ReturningColumn::Column(c.into())).collect(),
284            operation,
285        }
286    }
287
288    /// Generate PostgreSQL RETURNING clause.
289    pub fn to_postgres_sql(&self) -> String {
290        let cols = self.format_columns(DatabaseType::PostgreSQL);
291        format!("RETURNING {}", cols)
292    }
293
294    /// Generate SQLite RETURNING clause.
295    pub fn to_sqlite_sql(&self) -> String {
296        let cols = self.format_columns(DatabaseType::SQLite);
297        format!("RETURNING {}", cols)
298    }
299
300    /// Generate MSSQL OUTPUT clause.
301    pub fn to_mssql_sql(&self) -> String {
302        let cols = self.format_columns(DatabaseType::MSSQL);
303        format!("OUTPUT {}", cols)
304    }
305
306    /// Format columns for database.
307    fn format_columns(&self, db_type: DatabaseType) -> String {
308        self.columns
309            .iter()
310            .map(|col| match col {
311                ReturningColumn::All => {
312                    if db_type == DatabaseType::MSSQL {
313                        match self.operation {
314                            ReturnOperation::Insert => "INSERTED.*".to_string(),
315                            ReturnOperation::Delete => "DELETED.*".to_string(),
316                            ReturnOperation::Update => "INSERTED.*".to_string(),
317                        }
318                    } else {
319                        "*".to_string()
320                    }
321                }
322                ReturningColumn::Column(name) => {
323                    if db_type == DatabaseType::MSSQL {
324                        match self.operation {
325                            ReturnOperation::Insert => format!("INSERTED.{}", name),
326                            ReturnOperation::Delete => format!("DELETED.{}", name),
327                            ReturnOperation::Update => format!("INSERTED.{}", name),
328                        }
329                    } else {
330                        name.clone()
331                    }
332                }
333                ReturningColumn::Expression { expr, alias } => format!("{} AS {}", expr, alias),
334                ReturningColumn::Inserted(name) => format!("INSERTED.{}", name),
335                ReturningColumn::Deleted(name) => format!("DELETED.{}", name),
336            })
337            .collect::<Vec<_>>()
338            .join(", ")
339    }
340
341    /// Generate SQL for database type.
342    pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
343        match db_type {
344            DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
345            DatabaseType::SQLite => Ok(self.to_sqlite_sql()),
346            DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
347            DatabaseType::MySQL => Err(QueryError::unsupported(
348                "RETURNING clause is not supported in MySQL. Consider using LAST_INSERT_ID() or separate SELECT.",
349            )),
350        }
351    }
352}
353
354// ============================================================================
355// Row Locking (FOR UPDATE / FOR SHARE)
356// ============================================================================
357
358/// Row locking mode.
359#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
360pub struct RowLock {
361    /// Lock strength.
362    pub strength: LockStrength,
363    /// Tables to lock (optional).
364    pub of_tables: Vec<String>,
365    /// Wait behavior.
366    pub wait: LockWait,
367}
368
369/// Lock strength.
370#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
371pub enum LockStrength {
372    /// FOR UPDATE - exclusive lock.
373    Update,
374    /// FOR NO KEY UPDATE - exclusive but allows key reads.
375    NoKeyUpdate,
376    /// FOR SHARE - shared lock.
377    Share,
378    /// FOR KEY SHARE - shared key lock.
379    KeyShare,
380}
381
382/// Lock wait behavior.
383#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
384pub enum LockWait {
385    /// Wait for lock (default).
386    Wait,
387    /// Don't wait, error if locked.
388    NoWait,
389    /// Skip locked rows.
390    SkipLocked,
391}
392
393impl RowLock {
394    /// Create FOR UPDATE lock.
395    pub fn for_update() -> RowLockBuilder {
396        RowLockBuilder::new(LockStrength::Update)
397    }
398
399    /// Create FOR SHARE lock.
400    pub fn for_share() -> RowLockBuilder {
401        RowLockBuilder::new(LockStrength::Share)
402    }
403
404    /// Create FOR NO KEY UPDATE lock.
405    pub fn for_no_key_update() -> RowLockBuilder {
406        RowLockBuilder::new(LockStrength::NoKeyUpdate)
407    }
408
409    /// Create FOR KEY SHARE lock.
410    pub fn for_key_share() -> RowLockBuilder {
411        RowLockBuilder::new(LockStrength::KeyShare)
412    }
413
414    /// Generate PostgreSQL FOR clause.
415    pub fn to_postgres_sql(&self) -> String {
416        let strength = match self.strength {
417            LockStrength::Update => "FOR UPDATE",
418            LockStrength::NoKeyUpdate => "FOR NO KEY UPDATE",
419            LockStrength::Share => "FOR SHARE",
420            LockStrength::KeyShare => "FOR KEY SHARE",
421        };
422
423        let mut sql = strength.to_string();
424
425        if !self.of_tables.is_empty() {
426            sql.push_str(&format!(" OF {}", self.of_tables.join(", ")));
427        }
428
429        match self.wait {
430            LockWait::Wait => {}
431            LockWait::NoWait => sql.push_str(" NOWAIT"),
432            LockWait::SkipLocked => sql.push_str(" SKIP LOCKED"),
433        }
434
435        sql
436    }
437
438    /// Generate MySQL FOR clause.
439    pub fn to_mysql_sql(&self) -> String {
440        let strength = match self.strength {
441            LockStrength::Update | LockStrength::NoKeyUpdate => "FOR UPDATE",
442            LockStrength::Share | LockStrength::KeyShare => "FOR SHARE",
443        };
444
445        let mut sql = strength.to_string();
446
447        if !self.of_tables.is_empty() {
448            sql.push_str(&format!(" OF {}", self.of_tables.join(", ")));
449        }
450
451        match self.wait {
452            LockWait::Wait => {}
453            LockWait::NoWait => sql.push_str(" NOWAIT"),
454            LockWait::SkipLocked => sql.push_str(" SKIP LOCKED"),
455        }
456
457        sql
458    }
459
460    /// Generate MSSQL table hint.
461    pub fn to_mssql_hint(&self) -> String {
462        let hint = match self.strength {
463            LockStrength::Update | LockStrength::NoKeyUpdate => "UPDLOCK, ROWLOCK",
464            LockStrength::Share | LockStrength::KeyShare => "HOLDLOCK, ROWLOCK",
465        };
466
467        let wait_hint = match self.wait {
468            LockWait::Wait => "",
469            LockWait::NoWait => ", NOWAIT",
470            LockWait::SkipLocked => ", READPAST",
471        };
472
473        format!("WITH ({}{})", hint, wait_hint)
474    }
475
476    /// Generate SQL for database type.
477    pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
478        match db_type {
479            DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
480            DatabaseType::MySQL => Ok(self.to_mysql_sql()),
481            DatabaseType::MSSQL => Ok(self.to_mssql_hint()),
482            DatabaseType::SQLite => Err(QueryError::unsupported(
483                "Row locking is not supported in SQLite",
484            )),
485        }
486    }
487}
488
489/// Builder for row locks.
490#[derive(Debug, Clone)]
491pub struct RowLockBuilder {
492    strength: LockStrength,
493    of_tables: Vec<String>,
494    wait: LockWait,
495}
496
497impl RowLockBuilder {
498    /// Create a new builder.
499    pub fn new(strength: LockStrength) -> Self {
500        Self {
501            strength,
502            of_tables: Vec::new(),
503            wait: LockWait::Wait,
504        }
505    }
506
507    /// Lock specific tables.
508    pub fn of<I, S>(mut self, tables: I) -> Self
509    where
510        I: IntoIterator<Item = S>,
511        S: Into<String>,
512    {
513        self.of_tables = tables.into_iter().map(Into::into).collect();
514        self
515    }
516
517    /// NOWAIT - error immediately if locked.
518    pub fn nowait(mut self) -> Self {
519        self.wait = LockWait::NoWait;
520        self
521    }
522
523    /// SKIP LOCKED - skip locked rows.
524    pub fn skip_locked(mut self) -> Self {
525        self.wait = LockWait::SkipLocked;
526        self
527    }
528
529    /// Build the row lock.
530    pub fn build(self) -> RowLock {
531        RowLock {
532            strength: self.strength,
533            of_tables: self.of_tables,
534            wait: self.wait,
535        }
536    }
537}
538
539// ============================================================================
540// TABLESAMPLE
541// ============================================================================
542
543/// Table sampling configuration.
544#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
545pub struct TableSample {
546    /// Sampling method.
547    pub method: SampleMethod,
548    /// Sample size (percentage or rows).
549    pub size: SampleSize,
550    /// Optional seed for reproducibility.
551    pub seed: Option<i64>,
552}
553
554/// Sampling method.
555#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
556pub enum SampleMethod {
557    /// BERNOULLI - row-level random sampling.
558    Bernoulli,
559    /// SYSTEM - page-level random sampling (faster, less accurate).
560    System,
561}
562
563/// Sample size specification.
564#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
565pub enum SampleSize {
566    /// Percentage of rows (0-100).
567    Percent(f64),
568    /// Approximate number of rows.
569    Rows(usize),
570}
571
572impl TableSample {
573    /// Create a percentage sample using BERNOULLI.
574    pub fn percent(percent: f64) -> TableSampleBuilder {
575        TableSampleBuilder::new(SampleMethod::Bernoulli, SampleSize::Percent(percent))
576    }
577
578    /// Create a row count sample.
579    pub fn rows(count: usize) -> TableSampleBuilder {
580        TableSampleBuilder::new(SampleMethod::System, SampleSize::Rows(count))
581    }
582
583    /// Generate PostgreSQL TABLESAMPLE clause.
584    pub fn to_postgres_sql(&self) -> String {
585        let method = match self.method {
586            SampleMethod::Bernoulli => "BERNOULLI",
587            SampleMethod::System => "SYSTEM",
588        };
589
590        let size = match self.size {
591            SampleSize::Percent(p) => format!("{}", p),
592            SampleSize::Rows(_) => {
593                // PostgreSQL doesn't support row counts directly
594                "10".to_string() // Default to 10%
595            }
596        };
597
598        let mut sql = format!("TABLESAMPLE {} ({})", method, size);
599
600        if let Some(seed) = self.seed {
601            sql.push_str(&format!(" REPEATABLE ({})", seed));
602        }
603
604        sql
605    }
606
607    /// Generate MSSQL TABLESAMPLE clause.
608    pub fn to_mssql_sql(&self) -> String {
609        let size_clause = match self.size {
610            SampleSize::Percent(p) => format!("{} PERCENT", p),
611            SampleSize::Rows(n) => format!("{} ROWS", n),
612        };
613
614        let mut sql = format!("TABLESAMPLE ({})", size_clause);
615
616        if let Some(seed) = self.seed {
617            sql.push_str(&format!(" REPEATABLE ({})", seed));
618        }
619
620        sql
621    }
622
623    /// Generate SQL for database type.
624    pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
625        match db_type {
626            DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
627            DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
628            DatabaseType::MySQL | DatabaseType::SQLite => Err(QueryError::unsupported(
629                "TABLESAMPLE is not supported in this database. Use ORDER BY RANDOM() LIMIT instead.",
630            )),
631        }
632    }
633}
634
635/// Builder for table sampling.
636#[derive(Debug, Clone)]
637pub struct TableSampleBuilder {
638    method: SampleMethod,
639    size: SampleSize,
640    seed: Option<i64>,
641}
642
643impl TableSampleBuilder {
644    /// Create a new builder.
645    pub fn new(method: SampleMethod, size: SampleSize) -> Self {
646        Self {
647            method,
648            size,
649            seed: None,
650        }
651    }
652
653    /// Use BERNOULLI sampling.
654    pub fn bernoulli(mut self) -> Self {
655        self.method = SampleMethod::Bernoulli;
656        self
657    }
658
659    /// Use SYSTEM sampling.
660    pub fn system(mut self) -> Self {
661        self.method = SampleMethod::System;
662        self
663    }
664
665    /// Set seed for reproducibility.
666    pub fn seed(mut self, seed: i64) -> Self {
667        self.seed = Some(seed);
668        self
669    }
670
671    /// Build the sample configuration.
672    pub fn build(self) -> TableSample {
673        TableSample {
674            method: self.method,
675            size: self.size,
676            seed: self.seed,
677        }
678    }
679}
680
681/// Random sampling alternatives for unsupported databases.
682pub mod random_sample {
683    use super::*;
684
685    /// Generate ORDER BY RANDOM() LIMIT for databases without TABLESAMPLE.
686    pub fn order_by_random_sql(limit: usize, db_type: DatabaseType) -> String {
687        let random_func = match db_type {
688            DatabaseType::PostgreSQL => "RANDOM()",
689            DatabaseType::MySQL => "RAND()",
690            DatabaseType::SQLite => "RANDOM()",
691            DatabaseType::MSSQL => "NEWID()",
692        };
693
694        format!("ORDER BY {} LIMIT {}", random_func, limit)
695    }
696
697    /// Generate WHERE RANDOM() < threshold for row sampling.
698    pub fn where_random_sql(threshold: f64, db_type: DatabaseType) -> String {
699        match db_type {
700            DatabaseType::PostgreSQL | DatabaseType::SQLite => {
701                format!("WHERE RANDOM() < {}", threshold)
702            }
703            DatabaseType::MySQL => format!("WHERE RAND() < {}", threshold),
704            DatabaseType::MSSQL => {
705                format!("WHERE ABS(CHECKSUM(NEWID())) % 100 < {}", (threshold * 100.0) as i32)
706            }
707        }
708    }
709}
710
711// ============================================================================
712// Bulk Operations
713// ============================================================================
714
715/// Bulk operation configuration.
716#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
717pub struct BulkOperation<T> {
718    /// Items to process.
719    pub items: Vec<T>,
720    /// Batch size.
721    pub batch_size: usize,
722    /// Whether to continue on error.
723    pub ordered: bool,
724}
725
726impl<T> BulkOperation<T> {
727    /// Create a new bulk operation.
728    pub fn new(items: Vec<T>) -> Self {
729        Self {
730            items,
731            batch_size: 1000,
732            ordered: true,
733        }
734    }
735
736    /// Set batch size.
737    pub fn batch_size(mut self, size: usize) -> Self {
738        self.batch_size = size;
739        self
740    }
741
742    /// Allow unordered execution (continue on errors).
743    pub fn unordered(mut self) -> Self {
744        self.ordered = false;
745        self
746    }
747
748    /// Get batches.
749    pub fn batches(&self) -> impl Iterator<Item = &[T]> {
750        self.items.chunks(self.batch_size)
751    }
752
753    /// Get number of batches.
754    pub fn batch_count(&self) -> usize {
755        (self.items.len() + self.batch_size - 1) / self.batch_size
756    }
757}
758
759/// MongoDB bulkWrite operations.
760pub mod mongodb {
761    use serde::{Deserialize, Serialize};
762    use serde_json::Value as JsonValue;
763
764    /// A single bulk write operation.
765    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
766    pub enum BulkWriteOp {
767        /// Insert one document.
768        InsertOne { document: JsonValue },
769        /// Update one document.
770        UpdateOne {
771            filter: JsonValue,
772            update: JsonValue,
773            upsert: bool,
774        },
775        /// Update many documents.
776        UpdateMany {
777            filter: JsonValue,
778            update: JsonValue,
779            upsert: bool,
780        },
781        /// Replace one document.
782        ReplaceOne {
783            filter: JsonValue,
784            replacement: JsonValue,
785            upsert: bool,
786        },
787        /// Delete one document.
788        DeleteOne { filter: JsonValue },
789        /// Delete many documents.
790        DeleteMany { filter: JsonValue },
791    }
792
793    impl BulkWriteOp {
794        /// Create an insert operation.
795        pub fn insert_one(document: JsonValue) -> Self {
796            Self::InsertOne { document }
797        }
798
799        /// Create an update one operation.
800        pub fn update_one(filter: JsonValue, update: JsonValue) -> Self {
801            Self::UpdateOne {
802                filter,
803                update,
804                upsert: false,
805            }
806        }
807
808        /// Create an upsert operation.
809        pub fn upsert_one(filter: JsonValue, update: JsonValue) -> Self {
810            Self::UpdateOne {
811                filter,
812                update,
813                upsert: true,
814            }
815        }
816
817        /// Create a delete one operation.
818        pub fn delete_one(filter: JsonValue) -> Self {
819            Self::DeleteOne { filter }
820        }
821
822        /// Convert to MongoDB format.
823        pub fn to_command(&self) -> JsonValue {
824            match self {
825                Self::InsertOne { document } => {
826                    serde_json::json!({ "insertOne": { "document": document } })
827                }
828                Self::UpdateOne { filter, update, upsert } => {
829                    serde_json::json!({
830                        "updateOne": {
831                            "filter": filter,
832                            "update": update,
833                            "upsert": upsert
834                        }
835                    })
836                }
837                Self::UpdateMany { filter, update, upsert } => {
838                    serde_json::json!({
839                        "updateMany": {
840                            "filter": filter,
841                            "update": update,
842                            "upsert": upsert
843                        }
844                    })
845                }
846                Self::ReplaceOne { filter, replacement, upsert } => {
847                    serde_json::json!({
848                        "replaceOne": {
849                            "filter": filter,
850                            "replacement": replacement,
851                            "upsert": upsert
852                        }
853                    })
854                }
855                Self::DeleteOne { filter } => {
856                    serde_json::json!({ "deleteOne": { "filter": filter } })
857                }
858                Self::DeleteMany { filter } => {
859                    serde_json::json!({ "deleteMany": { "filter": filter } })
860                }
861            }
862        }
863    }
864
865    /// Bulk write builder.
866    #[derive(Debug, Clone, Default)]
867    pub struct BulkWriteBuilder {
868        operations: Vec<BulkWriteOp>,
869        ordered: bool,
870        bypass_validation: bool,
871    }
872
873    impl BulkWriteBuilder {
874        /// Create a new builder.
875        pub fn new() -> Self {
876            Self {
877                operations: Vec::new(),
878                ordered: true,
879                bypass_validation: false,
880            }
881        }
882
883        /// Add an operation.
884        pub fn add(mut self, op: BulkWriteOp) -> Self {
885            self.operations.push(op);
886            self
887        }
888
889        /// Add multiple operations.
890        pub fn add_many<I>(mut self, ops: I) -> Self
891        where
892            I: IntoIterator<Item = BulkWriteOp>,
893        {
894            self.operations.extend(ops);
895            self
896        }
897
898        /// Insert one document.
899        pub fn insert_one(self, document: JsonValue) -> Self {
900            self.add(BulkWriteOp::insert_one(document))
901        }
902
903        /// Update one document.
904        pub fn update_one(self, filter: JsonValue, update: JsonValue) -> Self {
905            self.add(BulkWriteOp::update_one(filter, update))
906        }
907
908        /// Upsert one document.
909        pub fn upsert_one(self, filter: JsonValue, update: JsonValue) -> Self {
910            self.add(BulkWriteOp::upsert_one(filter, update))
911        }
912
913        /// Delete one document.
914        pub fn delete_one(self, filter: JsonValue) -> Self {
915            self.add(BulkWriteOp::delete_one(filter))
916        }
917
918        /// Set unordered execution.
919        pub fn unordered(mut self) -> Self {
920            self.ordered = false;
921            self
922        }
923
924        /// Bypass document validation.
925        pub fn bypass_validation(mut self) -> Self {
926            self.bypass_validation = true;
927            self
928        }
929
930        /// Build the bulkWrite command.
931        pub fn build(&self, collection: &str) -> JsonValue {
932            let ops: Vec<JsonValue> = self.operations.iter().map(|op| op.to_command()).collect();
933
934            serde_json::json!({
935                "bulkWrite": collection,
936                "operations": ops,
937                "ordered": self.ordered,
938                "bypassDocumentValidation": self.bypass_validation
939            })
940        }
941    }
942
943    /// MongoDB $sample aggregation stage.
944    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
945    pub struct Sample {
946        /// Number of documents to sample.
947        pub size: usize,
948    }
949
950    impl Sample {
951        /// Create a new sample stage.
952        pub fn new(size: usize) -> Self {
953            Self { size }
954        }
955
956        /// Convert to aggregation stage.
957        pub fn to_stage(&self) -> JsonValue {
958            serde_json::json!({ "$sample": { "size": self.size } })
959        }
960    }
961}
962
963#[cfg(test)]
964mod tests {
965    use super::*;
966
967    #[test]
968    fn test_lateral_join_postgres() {
969        let lateral = LateralJoin::new(
970            "SELECT * FROM orders WHERE orders.user_id = users.id LIMIT 3",
971            "recent_orders",
972        )
973        .build();
974
975        let sql = lateral.to_postgres_sql();
976        assert!(sql.contains("CROSS JOIN LATERAL"));
977        assert!(sql.contains("AS recent_orders"));
978    }
979
980    #[test]
981    fn test_lateral_join_mssql() {
982        let lateral = LateralJoin::new(
983            "SELECT TOP 3 * FROM orders WHERE orders.user_id = users.id",
984            "recent_orders",
985        )
986        .left()
987        .build();
988
989        let sql = lateral.to_mssql_sql();
990        assert!(sql.contains("OUTER APPLY"));
991    }
992
993    #[test]
994    fn test_distinct_on() {
995        let distinct = DistinctOn::new(["department", "date"]);
996        let sql = distinct.to_postgres_sql();
997
998        assert!(sql.contains("DISTINCT ON (department, date)"));
999    }
1000
1001    #[test]
1002    fn test_returning_postgres() {
1003        let ret = Returning::all(ReturnOperation::Insert);
1004        let sql = ret.to_postgres_sql();
1005
1006        assert_eq!(sql, "RETURNING *");
1007    }
1008
1009    #[test]
1010    fn test_returning_mssql() {
1011        let ret = Returning::columns(ReturnOperation::Insert, ["id", "name"]);
1012        let sql = ret.to_mssql_sql();
1013
1014        assert!(sql.contains("OUTPUT INSERTED.id, INSERTED.name"));
1015    }
1016
1017    #[test]
1018    fn test_for_update() {
1019        let lock = RowLock::for_update().nowait().build();
1020        let sql = lock.to_postgres_sql();
1021
1022        assert!(sql.contains("FOR UPDATE"));
1023        assert!(sql.contains("NOWAIT"));
1024    }
1025
1026    #[test]
1027    fn test_for_share_skip_locked() {
1028        let lock = RowLock::for_share().skip_locked().build();
1029        let sql = lock.to_postgres_sql();
1030
1031        assert!(sql.contains("FOR SHARE"));
1032        assert!(sql.contains("SKIP LOCKED"));
1033    }
1034
1035    #[test]
1036    fn test_row_lock_mssql() {
1037        let lock = RowLock::for_update().nowait().build();
1038        let sql = lock.to_mssql_hint();
1039
1040        assert!(sql.contains("UPDLOCK"));
1041        assert!(sql.contains("NOWAIT"));
1042    }
1043
1044    #[test]
1045    fn test_tablesample_postgres() {
1046        let sample = TableSample::percent(10.0).seed(42).build();
1047        let sql = sample.to_postgres_sql();
1048
1049        assert!(sql.contains("TABLESAMPLE BERNOULLI (10)"));
1050        assert!(sql.contains("REPEATABLE (42)"));
1051    }
1052
1053    #[test]
1054    fn test_tablesample_mssql() {
1055        let sample = TableSample::rows(1000).build();
1056        let sql = sample.to_mssql_sql();
1057
1058        assert!(sql.contains("TABLESAMPLE (1000 ROWS)"));
1059    }
1060
1061    #[test]
1062    fn test_bulk_operation_batches() {
1063        let bulk: BulkOperation<i32> = BulkOperation::new(vec![1, 2, 3, 4, 5]).batch_size(2);
1064
1065        assert_eq!(bulk.batch_count(), 3);
1066        let batches: Vec<_> = bulk.batches().collect();
1067        assert_eq!(batches.len(), 3);
1068        assert_eq!(batches[0], &[1, 2]);
1069        assert_eq!(batches[1], &[3, 4]);
1070        assert_eq!(batches[2], &[5]);
1071    }
1072
1073    mod mongodb_tests {
1074        use super::super::mongodb::*;
1075
1076        #[test]
1077        fn test_bulk_write_builder() {
1078            let bulk = BulkWriteBuilder::new()
1079                .insert_one(serde_json::json!({ "name": "Alice" }))
1080                .update_one(
1081                    serde_json::json!({ "_id": 1 }),
1082                    serde_json::json!({ "$set": { "status": "active" } }),
1083                )
1084                .delete_one(serde_json::json!({ "_id": 2 }))
1085                .unordered()
1086                .build("users");
1087
1088            assert_eq!(bulk["bulkWrite"], "users");
1089            assert_eq!(bulk["ordered"], false);
1090            assert!(bulk["operations"].is_array());
1091            assert_eq!(bulk["operations"].as_array().unwrap().len(), 3);
1092        }
1093
1094        #[test]
1095        fn test_sample_stage() {
1096            let sample = Sample::new(100);
1097            let stage = sample.to_stage();
1098
1099            assert_eq!(stage["$sample"]["size"], 100);
1100        }
1101
1102        #[test]
1103        fn test_bulk_write_upsert() {
1104            let op = BulkWriteOp::upsert_one(
1105                serde_json::json!({ "email": "test@example.com" }),
1106                serde_json::json!({ "$set": { "name": "Test" } }),
1107            );
1108
1109            let cmd = op.to_command();
1110            assert!(cmd["updateOne"]["upsert"].as_bool().unwrap());
1111        }
1112    }
1113
1114    mod distinct_on_tests {
1115        use super::super::mongodb_distinct::*;
1116
1117        #[test]
1118        fn test_distinct_on_stage() {
1119            let stage = distinct_on_stage(&["department"], &["name", "salary"]);
1120
1121            assert!(stage["$group"]["_id"]["department"].is_string());
1122            assert!(stage["$group"]["name"]["$first"].is_string());
1123            assert!(stage["$group"]["salary"]["$first"].is_string());
1124        }
1125    }
1126}
1127
1128
1129
1130