Skip to main content

prax_query/
advanced.rs

1//! Advanced query features.
2//!
3//! This module provides advanced SQL query capabilities including:
4//! - LATERAL joins (correlated subqueries)
5//! - DISTINCT ON
6//! - RETURNING/OUTPUT clauses
7//! - Row locking (FOR UPDATE/SHARE)
8//! - TABLESAMPLE
9//! - Bulk operations
10//!
11//! # Database Support
12//!
13//! | Feature           | PostgreSQL     | MySQL    | SQLite | MSSQL           | MongoDB      |
14//! |-------------------|----------------|----------|--------|-----------------|--------------|
15//! | LATERAL joins     | ✅             | ✅       | ❌     | ✅ CROSS APPLY  | ✅ $lookup   |
16//! | DISTINCT ON       | ✅             | ❌       | ❌     | ❌              | ✅ $first    |
17//! | RETURNING/OUTPUT  | ✅             | ❌       | ✅     | ✅ OUTPUT       | ✅           |
18//! | FOR UPDATE/SHARE  | ✅             | ✅       | ❌     | ✅ WITH UPDLOCK | ❌           |
19//! | TABLESAMPLE       | ✅             | ❌       | ❌     | ✅              | ✅ $sample   |
20//! | Bulk operations   | ✅             | ✅       | ✅     | ✅              | ✅ bulkWrite |
21
22use serde::{Deserialize, Serialize};
23
24use crate::error::{QueryError, QueryResult};
25use crate::sql::DatabaseType;
26
27// ============================================================================
28// LATERAL Joins
29// ============================================================================
30
31/// A LATERAL join specification.
32#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
33pub struct LateralJoin {
34    /// The subquery or function call.
35    pub subquery: String,
36    /// Alias for the lateral result.
37    pub alias: String,
38    /// Join type.
39    pub join_type: LateralJoinType,
40    /// Optional ON condition (for LEFT LATERAL).
41    pub condition: Option<String>,
42}
43
44/// LATERAL join type.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum LateralJoinType {
47    /// CROSS JOIN LATERAL / CROSS APPLY.
48    Cross,
49    /// LEFT JOIN LATERAL / OUTER APPLY.
50    Left,
51}
52
53impl LateralJoin {
54    /// Create a new LATERAL join.
55    pub fn new(subquery: impl Into<String>, alias: impl Into<String>) -> LateralJoinBuilder {
56        LateralJoinBuilder::new(subquery, alias)
57    }
58
59    /// Generate PostgreSQL LATERAL join.
60    pub fn to_postgres_sql(&self) -> String {
61        match self.join_type {
62            LateralJoinType::Cross => {
63                format!("CROSS JOIN LATERAL ({}) AS {}", self.subquery, self.alias)
64            }
65            LateralJoinType::Left => {
66                let cond = self.condition.as_deref().unwrap_or("TRUE");
67                format!(
68                    "LEFT JOIN LATERAL ({}) AS {} ON {}",
69                    self.subquery, self.alias, cond
70                )
71            }
72        }
73    }
74
75    /// Generate MySQL LATERAL join.
76    pub fn to_mysql_sql(&self) -> String {
77        match self.join_type {
78            LateralJoinType::Cross => {
79                format!("CROSS JOIN LATERAL ({}) AS {}", self.subquery, self.alias)
80            }
81            LateralJoinType::Left => {
82                let cond = self.condition.as_deref().unwrap_or("TRUE");
83                format!(
84                    "LEFT JOIN LATERAL ({}) AS {} ON {}",
85                    self.subquery, self.alias, cond
86                )
87            }
88        }
89    }
90
91    /// Generate MSSQL APPLY join.
92    pub fn to_mssql_sql(&self) -> String {
93        match self.join_type {
94            LateralJoinType::Cross => {
95                format!("CROSS APPLY ({}) AS {}", self.subquery, self.alias)
96            }
97            LateralJoinType::Left => {
98                format!("OUTER APPLY ({}) AS {}", self.subquery, self.alias)
99            }
100        }
101    }
102
103    /// Generate SQL for database type.
104    pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
105        match db_type {
106            DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
107            DatabaseType::MySQL => Ok(self.to_mysql_sql()),
108            DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
109            DatabaseType::SQLite => Err(QueryError::unsupported(
110                "LATERAL joins are not supported in SQLite",
111            )),
112        }
113    }
114}
115
116/// Builder for LATERAL joins.
117#[derive(Debug, Clone)]
118pub struct LateralJoinBuilder {
119    subquery: String,
120    alias: String,
121    join_type: LateralJoinType,
122    condition: Option<String>,
123}
124
125impl LateralJoinBuilder {
126    /// Create a new builder.
127    pub fn new(subquery: impl Into<String>, alias: impl Into<String>) -> Self {
128        Self {
129            subquery: subquery.into(),
130            alias: alias.into(),
131            join_type: LateralJoinType::Cross,
132            condition: None,
133        }
134    }
135
136    /// Make this a LEFT LATERAL join.
137    pub fn left(mut self) -> Self {
138        self.join_type = LateralJoinType::Left;
139        self
140    }
141
142    /// Make this a CROSS LATERAL join (default).
143    pub fn cross(mut self) -> Self {
144        self.join_type = LateralJoinType::Cross;
145        self
146    }
147
148    /// Set the ON condition.
149    pub fn on(mut self, condition: impl Into<String>) -> Self {
150        self.condition = Some(condition.into());
151        self
152    }
153
154    /// Build the LATERAL join.
155    pub fn build(self) -> LateralJoin {
156        LateralJoin {
157            subquery: self.subquery,
158            alias: self.alias,
159            join_type: self.join_type,
160            condition: self.condition,
161        }
162    }
163}
164
165// ============================================================================
166// DISTINCT ON
167// ============================================================================
168
169/// DISTINCT ON clause (PostgreSQL specific).
170#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
171pub struct DistinctOn {
172    /// Columns to distinct on.
173    pub columns: Vec<String>,
174}
175
176impl DistinctOn {
177    /// Create a new DISTINCT ON clause.
178    pub fn new<I, S>(columns: I) -> Self
179    where
180        I: IntoIterator<Item = S>,
181        S: Into<String>,
182    {
183        Self {
184            columns: columns.into_iter().map(Into::into).collect(),
185        }
186    }
187
188    /// Generate PostgreSQL DISTINCT ON clause.
189    pub fn to_postgres_sql(&self) -> String {
190        format!("DISTINCT ON ({})", self.columns.join(", "))
191    }
192
193    /// Generate MySQL workaround using GROUP BY.
194    /// Note: This is not exactly equivalent to DISTINCT ON.
195    pub fn to_mysql_workaround(&self) -> String {
196        format!(
197            "-- MySQL workaround: Use GROUP BY {} with appropriate aggregates",
198            self.columns.join(", ")
199        )
200    }
201}
202
203/// MongoDB $first aggregation helper for DISTINCT ON behavior.
204pub mod mongodb_distinct {
205    use serde_json::Value as JsonValue;
206
207    /// Generate $group stage that mimics DISTINCT ON.
208    pub fn distinct_on_stage(group_fields: &[&str], first_fields: &[&str]) -> JsonValue {
209        let mut group_id = serde_json::Map::new();
210        for field in group_fields {
211            group_id.insert(field.to_string(), serde_json::json!(format!("${}", field)));
212        }
213
214        let mut group_spec = serde_json::Map::new();
215        group_spec.insert("_id".to_string(), serde_json::json!(group_id));
216
217        for field in first_fields {
218            group_spec.insert(
219                field.to_string(),
220                serde_json::json!({ "$first": format!("${}", field) }),
221            );
222        }
223
224        serde_json::json!({ "$group": group_spec })
225    }
226}
227
228// ============================================================================
229// RETURNING / OUTPUT Clause
230// ============================================================================
231
232/// RETURNING clause specification.
233#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
234pub struct Returning {
235    /// Columns to return.
236    pub columns: Vec<ReturningColumn>,
237    /// Operation type (for MSSQL OUTPUT).
238    pub operation: ReturnOperation,
239}
240
241/// A column in the RETURNING clause.
242#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
243pub enum ReturningColumn {
244    /// All columns (*).
245    All,
246    /// Specific column name.
247    Column(String),
248    /// Expression with alias.
249    Expression { expr: String, alias: String },
250    /// MSSQL INSERTED.column.
251    Inserted(String),
252    /// MSSQL DELETED.column.
253    Deleted(String),
254}
255
256/// Operation type for RETURNING/OUTPUT.
257#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
258pub enum ReturnOperation {
259    Insert,
260    Update,
261    Delete,
262}
263
264impl Returning {
265    /// Create RETURNING all columns.
266    pub fn all(operation: ReturnOperation) -> Self {
267        Self {
268            columns: vec![ReturningColumn::All],
269            operation,
270        }
271    }
272
273    /// Create RETURNING specific columns.
274    pub fn columns<I, S>(operation: ReturnOperation, columns: I) -> Self
275    where
276        I: IntoIterator<Item = S>,
277        S: Into<String>,
278    {
279        Self {
280            columns: columns
281                .into_iter()
282                .map(|c| ReturningColumn::Column(c.into()))
283                .collect(),
284            operation,
285        }
286    }
287
288    /// Generate PostgreSQL RETURNING clause.
289    pub fn to_postgres_sql(&self) -> String {
290        let cols = self.format_columns(DatabaseType::PostgreSQL);
291        format!("RETURNING {}", cols)
292    }
293
294    /// Generate SQLite RETURNING clause.
295    pub fn to_sqlite_sql(&self) -> String {
296        let cols = self.format_columns(DatabaseType::SQLite);
297        format!("RETURNING {}", cols)
298    }
299
300    /// Generate MSSQL OUTPUT clause.
301    pub fn to_mssql_sql(&self) -> String {
302        let cols = self.format_columns(DatabaseType::MSSQL);
303        format!("OUTPUT {}", cols)
304    }
305
306    /// Format columns for database.
307    fn format_columns(&self, db_type: DatabaseType) -> String {
308        self.columns
309            .iter()
310            .map(|col| match col {
311                ReturningColumn::All => {
312                    if db_type == DatabaseType::MSSQL {
313                        match self.operation {
314                            ReturnOperation::Insert => "INSERTED.*".to_string(),
315                            ReturnOperation::Delete => "DELETED.*".to_string(),
316                            ReturnOperation::Update => "INSERTED.*".to_string(),
317                        }
318                    } else {
319                        "*".to_string()
320                    }
321                }
322                ReturningColumn::Column(name) => {
323                    if db_type == DatabaseType::MSSQL {
324                        match self.operation {
325                            ReturnOperation::Insert => format!("INSERTED.{}", name),
326                            ReturnOperation::Delete => format!("DELETED.{}", name),
327                            ReturnOperation::Update => format!("INSERTED.{}", name),
328                        }
329                    } else {
330                        name.clone()
331                    }
332                }
333                ReturningColumn::Expression { expr, alias } => format!("{} AS {}", expr, alias),
334                ReturningColumn::Inserted(name) => format!("INSERTED.{}", name),
335                ReturningColumn::Deleted(name) => format!("DELETED.{}", name),
336            })
337            .collect::<Vec<_>>()
338            .join(", ")
339    }
340
341    /// Generate SQL for database type.
342    pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
343        match db_type {
344            DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
345            DatabaseType::SQLite => Ok(self.to_sqlite_sql()),
346            DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
347            DatabaseType::MySQL => Err(QueryError::unsupported(
348                "RETURNING clause is not supported in MySQL. Consider using LAST_INSERT_ID() or separate SELECT.",
349            )),
350        }
351    }
352}
353
354// ============================================================================
355// Row Locking (FOR UPDATE / FOR SHARE)
356// ============================================================================
357
358/// Row locking mode.
359#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
360pub struct RowLock {
361    /// Lock strength.
362    pub strength: LockStrength,
363    /// Tables to lock (optional).
364    pub of_tables: Vec<String>,
365    /// Wait behavior.
366    pub wait: LockWait,
367}
368
369/// Lock strength.
370#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
371pub enum LockStrength {
372    /// FOR UPDATE - exclusive lock.
373    Update,
374    /// FOR NO KEY UPDATE - exclusive but allows key reads.
375    NoKeyUpdate,
376    /// FOR SHARE - shared lock.
377    Share,
378    /// FOR KEY SHARE - shared key lock.
379    KeyShare,
380}
381
382/// Lock wait behavior.
383#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
384pub enum LockWait {
385    /// Wait for lock (default).
386    Wait,
387    /// Don't wait, error if locked.
388    NoWait,
389    /// Skip locked rows.
390    SkipLocked,
391}
392
393impl RowLock {
394    /// Create FOR UPDATE lock.
395    pub fn for_update() -> RowLockBuilder {
396        RowLockBuilder::new(LockStrength::Update)
397    }
398
399    /// Create FOR SHARE lock.
400    pub fn for_share() -> RowLockBuilder {
401        RowLockBuilder::new(LockStrength::Share)
402    }
403
404    /// Create FOR NO KEY UPDATE lock.
405    pub fn for_no_key_update() -> RowLockBuilder {
406        RowLockBuilder::new(LockStrength::NoKeyUpdate)
407    }
408
409    /// Create FOR KEY SHARE lock.
410    pub fn for_key_share() -> RowLockBuilder {
411        RowLockBuilder::new(LockStrength::KeyShare)
412    }
413
414    /// Generate PostgreSQL FOR clause.
415    pub fn to_postgres_sql(&self) -> String {
416        let strength = match self.strength {
417            LockStrength::Update => "FOR UPDATE",
418            LockStrength::NoKeyUpdate => "FOR NO KEY UPDATE",
419            LockStrength::Share => "FOR SHARE",
420            LockStrength::KeyShare => "FOR KEY SHARE",
421        };
422
423        let mut sql = strength.to_string();
424
425        if !self.of_tables.is_empty() {
426            sql.push_str(&format!(" OF {}", self.of_tables.join(", ")));
427        }
428
429        match self.wait {
430            LockWait::Wait => {}
431            LockWait::NoWait => sql.push_str(" NOWAIT"),
432            LockWait::SkipLocked => sql.push_str(" SKIP LOCKED"),
433        }
434
435        sql
436    }
437
438    /// Generate MySQL FOR clause.
439    pub fn to_mysql_sql(&self) -> String {
440        let strength = match self.strength {
441            LockStrength::Update | LockStrength::NoKeyUpdate => "FOR UPDATE",
442            LockStrength::Share | LockStrength::KeyShare => "FOR SHARE",
443        };
444
445        let mut sql = strength.to_string();
446
447        if !self.of_tables.is_empty() {
448            sql.push_str(&format!(" OF {}", self.of_tables.join(", ")));
449        }
450
451        match self.wait {
452            LockWait::Wait => {}
453            LockWait::NoWait => sql.push_str(" NOWAIT"),
454            LockWait::SkipLocked => sql.push_str(" SKIP LOCKED"),
455        }
456
457        sql
458    }
459
460    /// Generate MSSQL table hint.
461    pub fn to_mssql_hint(&self) -> String {
462        let hint = match self.strength {
463            LockStrength::Update | LockStrength::NoKeyUpdate => "UPDLOCK, ROWLOCK",
464            LockStrength::Share | LockStrength::KeyShare => "HOLDLOCK, ROWLOCK",
465        };
466
467        let wait_hint = match self.wait {
468            LockWait::Wait => "",
469            LockWait::NoWait => ", NOWAIT",
470            LockWait::SkipLocked => ", READPAST",
471        };
472
473        format!("WITH ({}{})", hint, wait_hint)
474    }
475
476    /// Generate SQL for database type.
477    pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
478        match db_type {
479            DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
480            DatabaseType::MySQL => Ok(self.to_mysql_sql()),
481            DatabaseType::MSSQL => Ok(self.to_mssql_hint()),
482            DatabaseType::SQLite => Err(QueryError::unsupported(
483                "Row locking is not supported in SQLite",
484            )),
485        }
486    }
487}
488
489/// Builder for row locks.
490#[derive(Debug, Clone)]
491pub struct RowLockBuilder {
492    strength: LockStrength,
493    of_tables: Vec<String>,
494    wait: LockWait,
495}
496
497impl RowLockBuilder {
498    /// Create a new builder.
499    pub fn new(strength: LockStrength) -> Self {
500        Self {
501            strength,
502            of_tables: Vec::new(),
503            wait: LockWait::Wait,
504        }
505    }
506
507    /// Lock specific tables.
508    pub fn of<I, S>(mut self, tables: I) -> Self
509    where
510        I: IntoIterator<Item = S>,
511        S: Into<String>,
512    {
513        self.of_tables = tables.into_iter().map(Into::into).collect();
514        self
515    }
516
517    /// NOWAIT - error immediately if locked.
518    pub fn nowait(mut self) -> Self {
519        self.wait = LockWait::NoWait;
520        self
521    }
522
523    /// SKIP LOCKED - skip locked rows.
524    pub fn skip_locked(mut self) -> Self {
525        self.wait = LockWait::SkipLocked;
526        self
527    }
528
529    /// Build the row lock.
530    pub fn build(self) -> RowLock {
531        RowLock {
532            strength: self.strength,
533            of_tables: self.of_tables,
534            wait: self.wait,
535        }
536    }
537}
538
539// ============================================================================
540// TABLESAMPLE
541// ============================================================================
542
543/// Table sampling configuration.
544#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
545pub struct TableSample {
546    /// Sampling method.
547    pub method: SampleMethod,
548    /// Sample size (percentage or rows).
549    pub size: SampleSize,
550    /// Optional seed for reproducibility.
551    pub seed: Option<i64>,
552}
553
554/// Sampling method.
555#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
556pub enum SampleMethod {
557    /// BERNOULLI - row-level random sampling.
558    Bernoulli,
559    /// SYSTEM - page-level random sampling (faster, less accurate).
560    System,
561}
562
563/// Sample size specification.
564#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
565pub enum SampleSize {
566    /// Percentage of rows (0-100).
567    Percent(f64),
568    /// Approximate number of rows.
569    Rows(usize),
570}
571
572impl TableSample {
573    /// Create a percentage sample using BERNOULLI.
574    pub fn percent(percent: f64) -> TableSampleBuilder {
575        TableSampleBuilder::new(SampleMethod::Bernoulli, SampleSize::Percent(percent))
576    }
577
578    /// Create a row count sample.
579    pub fn rows(count: usize) -> TableSampleBuilder {
580        TableSampleBuilder::new(SampleMethod::System, SampleSize::Rows(count))
581    }
582
583    /// Generate PostgreSQL TABLESAMPLE clause.
584    pub fn to_postgres_sql(&self) -> String {
585        let method = match self.method {
586            SampleMethod::Bernoulli => "BERNOULLI",
587            SampleMethod::System => "SYSTEM",
588        };
589
590        let size = match self.size {
591            SampleSize::Percent(p) => format!("{}", p),
592            SampleSize::Rows(_) => {
593                // PostgreSQL doesn't support row counts directly
594                "10".to_string() // Default to 10%
595            }
596        };
597
598        let mut sql = format!("TABLESAMPLE {} ({})", method, size);
599
600        if let Some(seed) = self.seed {
601            sql.push_str(&format!(" REPEATABLE ({})", seed));
602        }
603
604        sql
605    }
606
607    /// Generate MSSQL TABLESAMPLE clause.
608    pub fn to_mssql_sql(&self) -> String {
609        let size_clause = match self.size {
610            SampleSize::Percent(p) => format!("{} PERCENT", p),
611            SampleSize::Rows(n) => format!("{} ROWS", n),
612        };
613
614        let mut sql = format!("TABLESAMPLE ({})", size_clause);
615
616        if let Some(seed) = self.seed {
617            sql.push_str(&format!(" REPEATABLE ({})", seed));
618        }
619
620        sql
621    }
622
623    /// Generate SQL for database type.
624    pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
625        match db_type {
626            DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
627            DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
628            DatabaseType::MySQL | DatabaseType::SQLite => Err(QueryError::unsupported(
629                "TABLESAMPLE is not supported in this database. Use ORDER BY RANDOM() LIMIT instead.",
630            )),
631        }
632    }
633}
634
635/// Builder for table sampling.
636#[derive(Debug, Clone)]
637pub struct TableSampleBuilder {
638    method: SampleMethod,
639    size: SampleSize,
640    seed: Option<i64>,
641}
642
643impl TableSampleBuilder {
644    /// Create a new builder.
645    pub fn new(method: SampleMethod, size: SampleSize) -> Self {
646        Self {
647            method,
648            size,
649            seed: None,
650        }
651    }
652
653    /// Use BERNOULLI sampling.
654    pub fn bernoulli(mut self) -> Self {
655        self.method = SampleMethod::Bernoulli;
656        self
657    }
658
659    /// Use SYSTEM sampling.
660    pub fn system(mut self) -> Self {
661        self.method = SampleMethod::System;
662        self
663    }
664
665    /// Set seed for reproducibility.
666    pub fn seed(mut self, seed: i64) -> Self {
667        self.seed = Some(seed);
668        self
669    }
670
671    /// Build the sample configuration.
672    pub fn build(self) -> TableSample {
673        TableSample {
674            method: self.method,
675            size: self.size,
676            seed: self.seed,
677        }
678    }
679}
680
681/// Random sampling alternatives for unsupported databases.
682pub mod random_sample {
683    use super::*;
684
685    /// Generate ORDER BY RANDOM() LIMIT for databases without TABLESAMPLE.
686    pub fn order_by_random_sql(limit: usize, db_type: DatabaseType) -> String {
687        let random_func = match db_type {
688            DatabaseType::PostgreSQL => "RANDOM()",
689            DatabaseType::MySQL => "RAND()",
690            DatabaseType::SQLite => "RANDOM()",
691            DatabaseType::MSSQL => "NEWID()",
692        };
693
694        format!("ORDER BY {} LIMIT {}", random_func, limit)
695    }
696
697    /// Generate WHERE RANDOM() < threshold for row sampling.
698    pub fn where_random_sql(threshold: f64, db_type: DatabaseType) -> String {
699        match db_type {
700            DatabaseType::PostgreSQL | DatabaseType::SQLite => {
701                format!("WHERE RANDOM() < {}", threshold)
702            }
703            DatabaseType::MySQL => format!("WHERE RAND() < {}", threshold),
704            DatabaseType::MSSQL => {
705                format!(
706                    "WHERE ABS(CHECKSUM(NEWID())) % 100 < {}",
707                    (threshold * 100.0) as i32
708                )
709            }
710        }
711    }
712}
713
714// ============================================================================
715// Bulk Operations
716// ============================================================================
717
718/// Bulk operation configuration.
719#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
720pub struct BulkOperation<T> {
721    /// Items to process.
722    pub items: Vec<T>,
723    /// Batch size.
724    pub batch_size: usize,
725    /// Whether to continue on error.
726    pub ordered: bool,
727}
728
729impl<T> BulkOperation<T> {
730    /// Create a new bulk operation.
731    pub fn new(items: Vec<T>) -> Self {
732        Self {
733            items,
734            batch_size: 1000,
735            ordered: true,
736        }
737    }
738
739    /// Set batch size.
740    pub fn batch_size(mut self, size: usize) -> Self {
741        self.batch_size = size;
742        self
743    }
744
745    /// Allow unordered execution (continue on errors).
746    pub fn unordered(mut self) -> Self {
747        self.ordered = false;
748        self
749    }
750
751    /// Get batches.
752    pub fn batches(&self) -> impl Iterator<Item = &[T]> {
753        self.items.chunks(self.batch_size)
754    }
755
756    /// Get number of batches.
757    pub fn batch_count(&self) -> usize {
758        (self.items.len() + self.batch_size - 1) / self.batch_size
759    }
760}
761
762/// MongoDB bulkWrite operations.
763pub mod mongodb {
764    use serde::{Deserialize, Serialize};
765    use serde_json::Value as JsonValue;
766
767    /// A single bulk write operation.
768    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
769    pub enum BulkWriteOp {
770        /// Insert one document.
771        InsertOne { document: JsonValue },
772        /// Update one document.
773        UpdateOne {
774            filter: JsonValue,
775            update: JsonValue,
776            upsert: bool,
777        },
778        /// Update many documents.
779        UpdateMany {
780            filter: JsonValue,
781            update: JsonValue,
782            upsert: bool,
783        },
784        /// Replace one document.
785        ReplaceOne {
786            filter: JsonValue,
787            replacement: JsonValue,
788            upsert: bool,
789        },
790        /// Delete one document.
791        DeleteOne { filter: JsonValue },
792        /// Delete many documents.
793        DeleteMany { filter: JsonValue },
794    }
795
796    impl BulkWriteOp {
797        /// Create an insert operation.
798        pub fn insert_one(document: JsonValue) -> Self {
799            Self::InsertOne { document }
800        }
801
802        /// Create an update one operation.
803        pub fn update_one(filter: JsonValue, update: JsonValue) -> Self {
804            Self::UpdateOne {
805                filter,
806                update,
807                upsert: false,
808            }
809        }
810
811        /// Create an upsert operation.
812        pub fn upsert_one(filter: JsonValue, update: JsonValue) -> Self {
813            Self::UpdateOne {
814                filter,
815                update,
816                upsert: true,
817            }
818        }
819
820        /// Create a delete one operation.
821        pub fn delete_one(filter: JsonValue) -> Self {
822            Self::DeleteOne { filter }
823        }
824
825        /// Convert to MongoDB format.
826        pub fn to_command(&self) -> JsonValue {
827            match self {
828                Self::InsertOne { document } => {
829                    serde_json::json!({ "insertOne": { "document": document } })
830                }
831                Self::UpdateOne {
832                    filter,
833                    update,
834                    upsert,
835                } => {
836                    serde_json::json!({
837                        "updateOne": {
838                            "filter": filter,
839                            "update": update,
840                            "upsert": upsert
841                        }
842                    })
843                }
844                Self::UpdateMany {
845                    filter,
846                    update,
847                    upsert,
848                } => {
849                    serde_json::json!({
850                        "updateMany": {
851                            "filter": filter,
852                            "update": update,
853                            "upsert": upsert
854                        }
855                    })
856                }
857                Self::ReplaceOne {
858                    filter,
859                    replacement,
860                    upsert,
861                } => {
862                    serde_json::json!({
863                        "replaceOne": {
864                            "filter": filter,
865                            "replacement": replacement,
866                            "upsert": upsert
867                        }
868                    })
869                }
870                Self::DeleteOne { filter } => {
871                    serde_json::json!({ "deleteOne": { "filter": filter } })
872                }
873                Self::DeleteMany { filter } => {
874                    serde_json::json!({ "deleteMany": { "filter": filter } })
875                }
876            }
877        }
878    }
879
880    /// Bulk write builder.
881    #[derive(Debug, Clone, Default)]
882    pub struct BulkWriteBuilder {
883        operations: Vec<BulkWriteOp>,
884        ordered: bool,
885        bypass_validation: bool,
886    }
887
888    impl BulkWriteBuilder {
889        /// Create a new builder.
890        pub fn new() -> Self {
891            Self {
892                operations: Vec::new(),
893                ordered: true,
894                bypass_validation: false,
895            }
896        }
897
898        /// Add an operation.
899        pub fn add(mut self, op: BulkWriteOp) -> Self {
900            self.operations.push(op);
901            self
902        }
903
904        /// Add multiple operations.
905        pub fn add_many<I>(mut self, ops: I) -> Self
906        where
907            I: IntoIterator<Item = BulkWriteOp>,
908        {
909            self.operations.extend(ops);
910            self
911        }
912
913        /// Insert one document.
914        pub fn insert_one(self, document: JsonValue) -> Self {
915            self.add(BulkWriteOp::insert_one(document))
916        }
917
918        /// Update one document.
919        pub fn update_one(self, filter: JsonValue, update: JsonValue) -> Self {
920            self.add(BulkWriteOp::update_one(filter, update))
921        }
922
923        /// Upsert one document.
924        pub fn upsert_one(self, filter: JsonValue, update: JsonValue) -> Self {
925            self.add(BulkWriteOp::upsert_one(filter, update))
926        }
927
928        /// Delete one document.
929        pub fn delete_one(self, filter: JsonValue) -> Self {
930            self.add(BulkWriteOp::delete_one(filter))
931        }
932
933        /// Set unordered execution.
934        pub fn unordered(mut self) -> Self {
935            self.ordered = false;
936            self
937        }
938
939        /// Bypass document validation.
940        pub fn bypass_validation(mut self) -> Self {
941            self.bypass_validation = true;
942            self
943        }
944
945        /// Build the bulkWrite command.
946        pub fn build(&self, collection: &str) -> JsonValue {
947            let ops: Vec<JsonValue> = self.operations.iter().map(|op| op.to_command()).collect();
948
949            serde_json::json!({
950                "bulkWrite": collection,
951                "operations": ops,
952                "ordered": self.ordered,
953                "bypassDocumentValidation": self.bypass_validation
954            })
955        }
956    }
957
958    /// MongoDB $sample aggregation stage.
959    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
960    pub struct Sample {
961        /// Number of documents to sample.
962        pub size: usize,
963    }
964
965    impl Sample {
966        /// Create a new sample stage.
967        pub fn new(size: usize) -> Self {
968            Self { size }
969        }
970
971        /// Convert to aggregation stage.
972        pub fn to_stage(&self) -> JsonValue {
973            serde_json::json!({ "$sample": { "size": self.size } })
974        }
975    }
976}
977
978#[cfg(test)]
979mod tests {
980    use super::*;
981
982    #[test]
983    fn test_lateral_join_postgres() {
984        let lateral = LateralJoin::new(
985            "SELECT * FROM orders WHERE orders.user_id = users.id LIMIT 3",
986            "recent_orders",
987        )
988        .build();
989
990        let sql = lateral.to_postgres_sql();
991        assert!(sql.contains("CROSS JOIN LATERAL"));
992        assert!(sql.contains("AS recent_orders"));
993    }
994
995    #[test]
996    fn test_lateral_join_mssql() {
997        let lateral = LateralJoin::new(
998            "SELECT TOP 3 * FROM orders WHERE orders.user_id = users.id",
999            "recent_orders",
1000        )
1001        .left()
1002        .build();
1003
1004        let sql = lateral.to_mssql_sql();
1005        assert!(sql.contains("OUTER APPLY"));
1006    }
1007
1008    #[test]
1009    fn test_distinct_on() {
1010        let distinct = DistinctOn::new(["department", "date"]);
1011        let sql = distinct.to_postgres_sql();
1012
1013        assert!(sql.contains("DISTINCT ON (department, date)"));
1014    }
1015
1016    #[test]
1017    fn test_returning_postgres() {
1018        let ret = Returning::all(ReturnOperation::Insert);
1019        let sql = ret.to_postgres_sql();
1020
1021        assert_eq!(sql, "RETURNING *");
1022    }
1023
1024    #[test]
1025    fn test_returning_mssql() {
1026        let ret = Returning::columns(ReturnOperation::Insert, ["id", "name"]);
1027        let sql = ret.to_mssql_sql();
1028
1029        assert!(sql.contains("OUTPUT INSERTED.id, INSERTED.name"));
1030    }
1031
1032    #[test]
1033    fn test_for_update() {
1034        let lock = RowLock::for_update().nowait().build();
1035        let sql = lock.to_postgres_sql();
1036
1037        assert!(sql.contains("FOR UPDATE"));
1038        assert!(sql.contains("NOWAIT"));
1039    }
1040
1041    #[test]
1042    fn test_for_share_skip_locked() {
1043        let lock = RowLock::for_share().skip_locked().build();
1044        let sql = lock.to_postgres_sql();
1045
1046        assert!(sql.contains("FOR SHARE"));
1047        assert!(sql.contains("SKIP LOCKED"));
1048    }
1049
1050    #[test]
1051    fn test_row_lock_mssql() {
1052        let lock = RowLock::for_update().nowait().build();
1053        let sql = lock.to_mssql_hint();
1054
1055        assert!(sql.contains("UPDLOCK"));
1056        assert!(sql.contains("NOWAIT"));
1057    }
1058
1059    #[test]
1060    fn test_tablesample_postgres() {
1061        let sample = TableSample::percent(10.0).seed(42).build();
1062        let sql = sample.to_postgres_sql();
1063
1064        assert!(sql.contains("TABLESAMPLE BERNOULLI (10)"));
1065        assert!(sql.contains("REPEATABLE (42)"));
1066    }
1067
1068    #[test]
1069    fn test_tablesample_mssql() {
1070        let sample = TableSample::rows(1000).build();
1071        let sql = sample.to_mssql_sql();
1072
1073        assert!(sql.contains("TABLESAMPLE (1000 ROWS)"));
1074    }
1075
1076    #[test]
1077    fn test_bulk_operation_batches() {
1078        let bulk: BulkOperation<i32> = BulkOperation::new(vec![1, 2, 3, 4, 5]).batch_size(2);
1079
1080        assert_eq!(bulk.batch_count(), 3);
1081        let batches: Vec<_> = bulk.batches().collect();
1082        assert_eq!(batches.len(), 3);
1083        assert_eq!(batches[0], &[1, 2]);
1084        assert_eq!(batches[1], &[3, 4]);
1085        assert_eq!(batches[2], &[5]);
1086    }
1087
1088    mod mongodb_tests {
1089        use super::super::mongodb::*;
1090
1091        #[test]
1092        fn test_bulk_write_builder() {
1093            let bulk = BulkWriteBuilder::new()
1094                .insert_one(serde_json::json!({ "name": "Alice" }))
1095                .update_one(
1096                    serde_json::json!({ "_id": 1 }),
1097                    serde_json::json!({ "$set": { "status": "active" } }),
1098                )
1099                .delete_one(serde_json::json!({ "_id": 2 }))
1100                .unordered()
1101                .build("users");
1102
1103            assert_eq!(bulk["bulkWrite"], "users");
1104            assert_eq!(bulk["ordered"], false);
1105            assert!(bulk["operations"].is_array());
1106            assert_eq!(bulk["operations"].as_array().unwrap().len(), 3);
1107        }
1108
1109        #[test]
1110        fn test_sample_stage() {
1111            let sample = Sample::new(100);
1112            let stage = sample.to_stage();
1113
1114            assert_eq!(stage["$sample"]["size"], 100);
1115        }
1116
1117        #[test]
1118        fn test_bulk_write_upsert() {
1119            let op = BulkWriteOp::upsert_one(
1120                serde_json::json!({ "email": "test@example.com" }),
1121                serde_json::json!({ "$set": { "name": "Test" } }),
1122            );
1123
1124            let cmd = op.to_command();
1125            assert!(cmd["updateOne"]["upsert"].as_bool().unwrap());
1126        }
1127    }
1128
1129    mod distinct_on_tests {
1130        use super::super::mongodb_distinct::*;
1131
1132        #[test]
1133        fn test_distinct_on_stage() {
1134            let stage = distinct_on_stage(&["department"], &["name", "salary"]);
1135
1136            assert!(stage["$group"]["_id"]["department"].is_string());
1137            assert!(stage["$group"]["name"]["$first"].is_string());
1138            assert!(stage["$group"]["salary"]["$first"].is_string());
1139        }
1140    }
1141}