1use serde::{Deserialize, Serialize};
23
24use crate::error::{QueryError, QueryResult};
25use crate::sql::DatabaseType;
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
33pub struct LateralJoin {
34 pub subquery: String,
36 pub alias: String,
38 pub join_type: LateralJoinType,
40 pub condition: Option<String>,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum LateralJoinType {
47 Cross,
49 Left,
51}
52
53impl LateralJoin {
54 pub fn new(subquery: impl Into<String>, alias: impl Into<String>) -> LateralJoinBuilder {
56 LateralJoinBuilder::new(subquery, alias)
57 }
58
59 pub fn to_postgres_sql(&self) -> String {
61 match self.join_type {
62 LateralJoinType::Cross => {
63 format!("CROSS JOIN LATERAL ({}) AS {}", self.subquery, self.alias)
64 }
65 LateralJoinType::Left => {
66 let cond = self.condition.as_deref().unwrap_or("TRUE");
67 format!(
68 "LEFT JOIN LATERAL ({}) AS {} ON {}",
69 self.subquery, self.alias, cond
70 )
71 }
72 }
73 }
74
75 pub fn to_mysql_sql(&self) -> String {
77 match self.join_type {
78 LateralJoinType::Cross => {
79 format!("CROSS JOIN LATERAL ({}) AS {}", self.subquery, self.alias)
80 }
81 LateralJoinType::Left => {
82 let cond = self.condition.as_deref().unwrap_or("TRUE");
83 format!(
84 "LEFT JOIN LATERAL ({}) AS {} ON {}",
85 self.subquery, self.alias, cond
86 )
87 }
88 }
89 }
90
91 pub fn to_mssql_sql(&self) -> String {
93 match self.join_type {
94 LateralJoinType::Cross => {
95 format!("CROSS APPLY ({}) AS {}", self.subquery, self.alias)
96 }
97 LateralJoinType::Left => {
98 format!("OUTER APPLY ({}) AS {}", self.subquery, self.alias)
99 }
100 }
101 }
102
103 pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
105 match db_type {
106 DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
107 DatabaseType::MySQL => Ok(self.to_mysql_sql()),
108 DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
109 DatabaseType::SQLite => Err(QueryError::unsupported(
110 "LATERAL joins are not supported in SQLite",
111 )),
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct LateralJoinBuilder {
119 subquery: String,
120 alias: String,
121 join_type: LateralJoinType,
122 condition: Option<String>,
123}
124
125impl LateralJoinBuilder {
126 pub fn new(subquery: impl Into<String>, alias: impl Into<String>) -> Self {
128 Self {
129 subquery: subquery.into(),
130 alias: alias.into(),
131 join_type: LateralJoinType::Cross,
132 condition: None,
133 }
134 }
135
136 pub fn left(mut self) -> Self {
138 self.join_type = LateralJoinType::Left;
139 self
140 }
141
142 pub fn cross(mut self) -> Self {
144 self.join_type = LateralJoinType::Cross;
145 self
146 }
147
148 pub fn on(mut self, condition: impl Into<String>) -> Self {
150 self.condition = Some(condition.into());
151 self
152 }
153
154 pub fn build(self) -> LateralJoin {
156 LateralJoin {
157 subquery: self.subquery,
158 alias: self.alias,
159 join_type: self.join_type,
160 condition: self.condition,
161 }
162 }
163}
164
165#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
171pub struct DistinctOn {
172 pub columns: Vec<String>,
174}
175
176impl DistinctOn {
177 pub fn new<I, S>(columns: I) -> Self
179 where
180 I: IntoIterator<Item = S>,
181 S: Into<String>,
182 {
183 Self {
184 columns: columns.into_iter().map(Into::into).collect(),
185 }
186 }
187
188 pub fn to_postgres_sql(&self) -> String {
190 format!("DISTINCT ON ({})", self.columns.join(", "))
191 }
192
193 pub fn to_mysql_workaround(&self) -> String {
196 format!(
197 "-- MySQL workaround: Use GROUP BY {} with appropriate aggregates",
198 self.columns.join(", ")
199 )
200 }
201}
202
203pub mod mongodb_distinct {
205 use serde_json::Value as JsonValue;
206
207 pub fn distinct_on_stage(group_fields: &[&str], first_fields: &[&str]) -> JsonValue {
209 let mut group_id = serde_json::Map::new();
210 for field in group_fields {
211 group_id.insert(field.to_string(), serde_json::json!(format!("${}", field)));
212 }
213
214 let mut group_spec = serde_json::Map::new();
215 group_spec.insert("_id".to_string(), serde_json::json!(group_id));
216
217 for field in first_fields {
218 group_spec.insert(
219 field.to_string(),
220 serde_json::json!({ "$first": format!("${}", field) }),
221 );
222 }
223
224 serde_json::json!({ "$group": group_spec })
225 }
226}
227
228#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
234pub struct Returning {
235 pub columns: Vec<ReturningColumn>,
237 pub operation: ReturnOperation,
239}
240
241#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
243pub enum ReturningColumn {
244 All,
246 Column(String),
248 Expression { expr: String, alias: String },
250 Inserted(String),
252 Deleted(String),
254}
255
256#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
258pub enum ReturnOperation {
259 Insert,
260 Update,
261 Delete,
262}
263
264impl Returning {
265 pub fn all(operation: ReturnOperation) -> Self {
267 Self {
268 columns: vec![ReturningColumn::All],
269 operation,
270 }
271 }
272
273 pub fn columns<I, S>(operation: ReturnOperation, columns: I) -> Self
275 where
276 I: IntoIterator<Item = S>,
277 S: Into<String>,
278 {
279 Self {
280 columns: columns
281 .into_iter()
282 .map(|c| ReturningColumn::Column(c.into()))
283 .collect(),
284 operation,
285 }
286 }
287
288 pub fn to_postgres_sql(&self) -> String {
290 let cols = self.format_columns(DatabaseType::PostgreSQL);
291 format!("RETURNING {}", cols)
292 }
293
294 pub fn to_sqlite_sql(&self) -> String {
296 let cols = self.format_columns(DatabaseType::SQLite);
297 format!("RETURNING {}", cols)
298 }
299
300 pub fn to_mssql_sql(&self) -> String {
302 let cols = self.format_columns(DatabaseType::MSSQL);
303 format!("OUTPUT {}", cols)
304 }
305
306 fn format_columns(&self, db_type: DatabaseType) -> String {
308 self.columns
309 .iter()
310 .map(|col| match col {
311 ReturningColumn::All => {
312 if db_type == DatabaseType::MSSQL {
313 match self.operation {
314 ReturnOperation::Insert => "INSERTED.*".to_string(),
315 ReturnOperation::Delete => "DELETED.*".to_string(),
316 ReturnOperation::Update => "INSERTED.*".to_string(),
317 }
318 } else {
319 "*".to_string()
320 }
321 }
322 ReturningColumn::Column(name) => {
323 if db_type == DatabaseType::MSSQL {
324 match self.operation {
325 ReturnOperation::Insert => format!("INSERTED.{}", name),
326 ReturnOperation::Delete => format!("DELETED.{}", name),
327 ReturnOperation::Update => format!("INSERTED.{}", name),
328 }
329 } else {
330 name.clone()
331 }
332 }
333 ReturningColumn::Expression { expr, alias } => format!("{} AS {}", expr, alias),
334 ReturningColumn::Inserted(name) => format!("INSERTED.{}", name),
335 ReturningColumn::Deleted(name) => format!("DELETED.{}", name),
336 })
337 .collect::<Vec<_>>()
338 .join(", ")
339 }
340
341 pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
343 match db_type {
344 DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
345 DatabaseType::SQLite => Ok(self.to_sqlite_sql()),
346 DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
347 DatabaseType::MySQL => Err(QueryError::unsupported(
348 "RETURNING clause is not supported in MySQL. Consider using LAST_INSERT_ID() or separate SELECT.",
349 )),
350 }
351 }
352}
353
354#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
360pub struct RowLock {
361 pub strength: LockStrength,
363 pub of_tables: Vec<String>,
365 pub wait: LockWait,
367}
368
369#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
371pub enum LockStrength {
372 Update,
374 NoKeyUpdate,
376 Share,
378 KeyShare,
380}
381
382#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
384pub enum LockWait {
385 Wait,
387 NoWait,
389 SkipLocked,
391}
392
393impl RowLock {
394 pub fn for_update() -> RowLockBuilder {
396 RowLockBuilder::new(LockStrength::Update)
397 }
398
399 pub fn for_share() -> RowLockBuilder {
401 RowLockBuilder::new(LockStrength::Share)
402 }
403
404 pub fn for_no_key_update() -> RowLockBuilder {
406 RowLockBuilder::new(LockStrength::NoKeyUpdate)
407 }
408
409 pub fn for_key_share() -> RowLockBuilder {
411 RowLockBuilder::new(LockStrength::KeyShare)
412 }
413
414 pub fn to_postgres_sql(&self) -> String {
416 let strength = match self.strength {
417 LockStrength::Update => "FOR UPDATE",
418 LockStrength::NoKeyUpdate => "FOR NO KEY UPDATE",
419 LockStrength::Share => "FOR SHARE",
420 LockStrength::KeyShare => "FOR KEY SHARE",
421 };
422
423 let mut sql = strength.to_string();
424
425 if !self.of_tables.is_empty() {
426 sql.push_str(&format!(" OF {}", self.of_tables.join(", ")));
427 }
428
429 match self.wait {
430 LockWait::Wait => {}
431 LockWait::NoWait => sql.push_str(" NOWAIT"),
432 LockWait::SkipLocked => sql.push_str(" SKIP LOCKED"),
433 }
434
435 sql
436 }
437
438 pub fn to_mysql_sql(&self) -> String {
440 let strength = match self.strength {
441 LockStrength::Update | LockStrength::NoKeyUpdate => "FOR UPDATE",
442 LockStrength::Share | LockStrength::KeyShare => "FOR SHARE",
443 };
444
445 let mut sql = strength.to_string();
446
447 if !self.of_tables.is_empty() {
448 sql.push_str(&format!(" OF {}", self.of_tables.join(", ")));
449 }
450
451 match self.wait {
452 LockWait::Wait => {}
453 LockWait::NoWait => sql.push_str(" NOWAIT"),
454 LockWait::SkipLocked => sql.push_str(" SKIP LOCKED"),
455 }
456
457 sql
458 }
459
460 pub fn to_mssql_hint(&self) -> String {
462 let hint = match self.strength {
463 LockStrength::Update | LockStrength::NoKeyUpdate => "UPDLOCK, ROWLOCK",
464 LockStrength::Share | LockStrength::KeyShare => "HOLDLOCK, ROWLOCK",
465 };
466
467 let wait_hint = match self.wait {
468 LockWait::Wait => "",
469 LockWait::NoWait => ", NOWAIT",
470 LockWait::SkipLocked => ", READPAST",
471 };
472
473 format!("WITH ({}{})", hint, wait_hint)
474 }
475
476 pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
478 match db_type {
479 DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
480 DatabaseType::MySQL => Ok(self.to_mysql_sql()),
481 DatabaseType::MSSQL => Ok(self.to_mssql_hint()),
482 DatabaseType::SQLite => Err(QueryError::unsupported(
483 "Row locking is not supported in SQLite",
484 )),
485 }
486 }
487}
488
489#[derive(Debug, Clone)]
491pub struct RowLockBuilder {
492 strength: LockStrength,
493 of_tables: Vec<String>,
494 wait: LockWait,
495}
496
497impl RowLockBuilder {
498 pub fn new(strength: LockStrength) -> Self {
500 Self {
501 strength,
502 of_tables: Vec::new(),
503 wait: LockWait::Wait,
504 }
505 }
506
507 pub fn of<I, S>(mut self, tables: I) -> Self
509 where
510 I: IntoIterator<Item = S>,
511 S: Into<String>,
512 {
513 self.of_tables = tables.into_iter().map(Into::into).collect();
514 self
515 }
516
517 pub fn nowait(mut self) -> Self {
519 self.wait = LockWait::NoWait;
520 self
521 }
522
523 pub fn skip_locked(mut self) -> Self {
525 self.wait = LockWait::SkipLocked;
526 self
527 }
528
529 pub fn build(self) -> RowLock {
531 RowLock {
532 strength: self.strength,
533 of_tables: self.of_tables,
534 wait: self.wait,
535 }
536 }
537}
538
539#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
545pub struct TableSample {
546 pub method: SampleMethod,
548 pub size: SampleSize,
550 pub seed: Option<i64>,
552}
553
554#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
556pub enum SampleMethod {
557 Bernoulli,
559 System,
561}
562
563#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
565pub enum SampleSize {
566 Percent(f64),
568 Rows(usize),
570}
571
572impl TableSample {
573 pub fn percent(percent: f64) -> TableSampleBuilder {
575 TableSampleBuilder::new(SampleMethod::Bernoulli, SampleSize::Percent(percent))
576 }
577
578 pub fn rows(count: usize) -> TableSampleBuilder {
580 TableSampleBuilder::new(SampleMethod::System, SampleSize::Rows(count))
581 }
582
583 pub fn to_postgres_sql(&self) -> String {
585 let method = match self.method {
586 SampleMethod::Bernoulli => "BERNOULLI",
587 SampleMethod::System => "SYSTEM",
588 };
589
590 let size = match self.size {
591 SampleSize::Percent(p) => format!("{}", p),
592 SampleSize::Rows(_) => {
593 "10".to_string() }
596 };
597
598 let mut sql = format!("TABLESAMPLE {} ({})", method, size);
599
600 if let Some(seed) = self.seed {
601 sql.push_str(&format!(" REPEATABLE ({})", seed));
602 }
603
604 sql
605 }
606
607 pub fn to_mssql_sql(&self) -> String {
609 let size_clause = match self.size {
610 SampleSize::Percent(p) => format!("{} PERCENT", p),
611 SampleSize::Rows(n) => format!("{} ROWS", n),
612 };
613
614 let mut sql = format!("TABLESAMPLE ({})", size_clause);
615
616 if let Some(seed) = self.seed {
617 sql.push_str(&format!(" REPEATABLE ({})", seed));
618 }
619
620 sql
621 }
622
623 pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
625 match db_type {
626 DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
627 DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
628 DatabaseType::MySQL | DatabaseType::SQLite => Err(QueryError::unsupported(
629 "TABLESAMPLE is not supported in this database. Use ORDER BY RANDOM() LIMIT instead.",
630 )),
631 }
632 }
633}
634
635#[derive(Debug, Clone)]
637pub struct TableSampleBuilder {
638 method: SampleMethod,
639 size: SampleSize,
640 seed: Option<i64>,
641}
642
643impl TableSampleBuilder {
644 pub fn new(method: SampleMethod, size: SampleSize) -> Self {
646 Self {
647 method,
648 size,
649 seed: None,
650 }
651 }
652
653 pub fn bernoulli(mut self) -> Self {
655 self.method = SampleMethod::Bernoulli;
656 self
657 }
658
659 pub fn system(mut self) -> Self {
661 self.method = SampleMethod::System;
662 self
663 }
664
665 pub fn seed(mut self, seed: i64) -> Self {
667 self.seed = Some(seed);
668 self
669 }
670
671 pub fn build(self) -> TableSample {
673 TableSample {
674 method: self.method,
675 size: self.size,
676 seed: self.seed,
677 }
678 }
679}
680
681pub mod random_sample {
683 use super::*;
684
685 pub fn order_by_random_sql(limit: usize, db_type: DatabaseType) -> String {
687 let random_func = match db_type {
688 DatabaseType::PostgreSQL => "RANDOM()",
689 DatabaseType::MySQL => "RAND()",
690 DatabaseType::SQLite => "RANDOM()",
691 DatabaseType::MSSQL => "NEWID()",
692 };
693
694 format!("ORDER BY {} LIMIT {}", random_func, limit)
695 }
696
697 pub fn where_random_sql(threshold: f64, db_type: DatabaseType) -> String {
699 match db_type {
700 DatabaseType::PostgreSQL | DatabaseType::SQLite => {
701 format!("WHERE RANDOM() < {}", threshold)
702 }
703 DatabaseType::MySQL => format!("WHERE RAND() < {}", threshold),
704 DatabaseType::MSSQL => {
705 format!(
706 "WHERE ABS(CHECKSUM(NEWID())) % 100 < {}",
707 (threshold * 100.0) as i32
708 )
709 }
710 }
711 }
712}
713
714#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
720pub struct BulkOperation<T> {
721 pub items: Vec<T>,
723 pub batch_size: usize,
725 pub ordered: bool,
727}
728
729impl<T> BulkOperation<T> {
730 pub fn new(items: Vec<T>) -> Self {
732 Self {
733 items,
734 batch_size: 1000,
735 ordered: true,
736 }
737 }
738
739 pub fn batch_size(mut self, size: usize) -> Self {
741 self.batch_size = size;
742 self
743 }
744
745 pub fn unordered(mut self) -> Self {
747 self.ordered = false;
748 self
749 }
750
751 pub fn batches(&self) -> impl Iterator<Item = &[T]> {
753 self.items.chunks(self.batch_size)
754 }
755
756 pub fn batch_count(&self) -> usize {
758 (self.items.len() + self.batch_size - 1) / self.batch_size
759 }
760}
761
762pub mod mongodb {
764 use serde::{Deserialize, Serialize};
765 use serde_json::Value as JsonValue;
766
767 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
769 pub enum BulkWriteOp {
770 InsertOne { document: JsonValue },
772 UpdateOne {
774 filter: JsonValue,
775 update: JsonValue,
776 upsert: bool,
777 },
778 UpdateMany {
780 filter: JsonValue,
781 update: JsonValue,
782 upsert: bool,
783 },
784 ReplaceOne {
786 filter: JsonValue,
787 replacement: JsonValue,
788 upsert: bool,
789 },
790 DeleteOne { filter: JsonValue },
792 DeleteMany { filter: JsonValue },
794 }
795
796 impl BulkWriteOp {
797 pub fn insert_one(document: JsonValue) -> Self {
799 Self::InsertOne { document }
800 }
801
802 pub fn update_one(filter: JsonValue, update: JsonValue) -> Self {
804 Self::UpdateOne {
805 filter,
806 update,
807 upsert: false,
808 }
809 }
810
811 pub fn upsert_one(filter: JsonValue, update: JsonValue) -> Self {
813 Self::UpdateOne {
814 filter,
815 update,
816 upsert: true,
817 }
818 }
819
820 pub fn delete_one(filter: JsonValue) -> Self {
822 Self::DeleteOne { filter }
823 }
824
825 pub fn to_command(&self) -> JsonValue {
827 match self {
828 Self::InsertOne { document } => {
829 serde_json::json!({ "insertOne": { "document": document } })
830 }
831 Self::UpdateOne {
832 filter,
833 update,
834 upsert,
835 } => {
836 serde_json::json!({
837 "updateOne": {
838 "filter": filter,
839 "update": update,
840 "upsert": upsert
841 }
842 })
843 }
844 Self::UpdateMany {
845 filter,
846 update,
847 upsert,
848 } => {
849 serde_json::json!({
850 "updateMany": {
851 "filter": filter,
852 "update": update,
853 "upsert": upsert
854 }
855 })
856 }
857 Self::ReplaceOne {
858 filter,
859 replacement,
860 upsert,
861 } => {
862 serde_json::json!({
863 "replaceOne": {
864 "filter": filter,
865 "replacement": replacement,
866 "upsert": upsert
867 }
868 })
869 }
870 Self::DeleteOne { filter } => {
871 serde_json::json!({ "deleteOne": { "filter": filter } })
872 }
873 Self::DeleteMany { filter } => {
874 serde_json::json!({ "deleteMany": { "filter": filter } })
875 }
876 }
877 }
878 }
879
880 #[derive(Debug, Clone, Default)]
882 pub struct BulkWriteBuilder {
883 operations: Vec<BulkWriteOp>,
884 ordered: bool,
885 bypass_validation: bool,
886 }
887
888 impl BulkWriteBuilder {
889 pub fn new() -> Self {
891 Self {
892 operations: Vec::new(),
893 ordered: true,
894 bypass_validation: false,
895 }
896 }
897
898 pub fn add(mut self, op: BulkWriteOp) -> Self {
900 self.operations.push(op);
901 self
902 }
903
904 pub fn add_many<I>(mut self, ops: I) -> Self
906 where
907 I: IntoIterator<Item = BulkWriteOp>,
908 {
909 self.operations.extend(ops);
910 self
911 }
912
913 pub fn insert_one(self, document: JsonValue) -> Self {
915 self.add(BulkWriteOp::insert_one(document))
916 }
917
918 pub fn update_one(self, filter: JsonValue, update: JsonValue) -> Self {
920 self.add(BulkWriteOp::update_one(filter, update))
921 }
922
923 pub fn upsert_one(self, filter: JsonValue, update: JsonValue) -> Self {
925 self.add(BulkWriteOp::upsert_one(filter, update))
926 }
927
928 pub fn delete_one(self, filter: JsonValue) -> Self {
930 self.add(BulkWriteOp::delete_one(filter))
931 }
932
933 pub fn unordered(mut self) -> Self {
935 self.ordered = false;
936 self
937 }
938
939 pub fn bypass_validation(mut self) -> Self {
941 self.bypass_validation = true;
942 self
943 }
944
945 pub fn build(&self, collection: &str) -> JsonValue {
947 let ops: Vec<JsonValue> = self.operations.iter().map(|op| op.to_command()).collect();
948
949 serde_json::json!({
950 "bulkWrite": collection,
951 "operations": ops,
952 "ordered": self.ordered,
953 "bypassDocumentValidation": self.bypass_validation
954 })
955 }
956 }
957
958 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
960 pub struct Sample {
961 pub size: usize,
963 }
964
965 impl Sample {
966 pub fn new(size: usize) -> Self {
968 Self { size }
969 }
970
971 pub fn to_stage(&self) -> JsonValue {
973 serde_json::json!({ "$sample": { "size": self.size } })
974 }
975 }
976}
977
978#[cfg(test)]
979mod tests {
980 use super::*;
981
982 #[test]
983 fn test_lateral_join_postgres() {
984 let lateral = LateralJoin::new(
985 "SELECT * FROM orders WHERE orders.user_id = users.id LIMIT 3",
986 "recent_orders",
987 )
988 .build();
989
990 let sql = lateral.to_postgres_sql();
991 assert!(sql.contains("CROSS JOIN LATERAL"));
992 assert!(sql.contains("AS recent_orders"));
993 }
994
995 #[test]
996 fn test_lateral_join_mssql() {
997 let lateral = LateralJoin::new(
998 "SELECT TOP 3 * FROM orders WHERE orders.user_id = users.id",
999 "recent_orders",
1000 )
1001 .left()
1002 .build();
1003
1004 let sql = lateral.to_mssql_sql();
1005 assert!(sql.contains("OUTER APPLY"));
1006 }
1007
1008 #[test]
1009 fn test_distinct_on() {
1010 let distinct = DistinctOn::new(["department", "date"]);
1011 let sql = distinct.to_postgres_sql();
1012
1013 assert!(sql.contains("DISTINCT ON (department, date)"));
1014 }
1015
1016 #[test]
1017 fn test_returning_postgres() {
1018 let ret = Returning::all(ReturnOperation::Insert);
1019 let sql = ret.to_postgres_sql();
1020
1021 assert_eq!(sql, "RETURNING *");
1022 }
1023
1024 #[test]
1025 fn test_returning_mssql() {
1026 let ret = Returning::columns(ReturnOperation::Insert, ["id", "name"]);
1027 let sql = ret.to_mssql_sql();
1028
1029 assert!(sql.contains("OUTPUT INSERTED.id, INSERTED.name"));
1030 }
1031
1032 #[test]
1033 fn test_for_update() {
1034 let lock = RowLock::for_update().nowait().build();
1035 let sql = lock.to_postgres_sql();
1036
1037 assert!(sql.contains("FOR UPDATE"));
1038 assert!(sql.contains("NOWAIT"));
1039 }
1040
1041 #[test]
1042 fn test_for_share_skip_locked() {
1043 let lock = RowLock::for_share().skip_locked().build();
1044 let sql = lock.to_postgres_sql();
1045
1046 assert!(sql.contains("FOR SHARE"));
1047 assert!(sql.contains("SKIP LOCKED"));
1048 }
1049
1050 #[test]
1051 fn test_row_lock_mssql() {
1052 let lock = RowLock::for_update().nowait().build();
1053 let sql = lock.to_mssql_hint();
1054
1055 assert!(sql.contains("UPDLOCK"));
1056 assert!(sql.contains("NOWAIT"));
1057 }
1058
1059 #[test]
1060 fn test_tablesample_postgres() {
1061 let sample = TableSample::percent(10.0).seed(42).build();
1062 let sql = sample.to_postgres_sql();
1063
1064 assert!(sql.contains("TABLESAMPLE BERNOULLI (10)"));
1065 assert!(sql.contains("REPEATABLE (42)"));
1066 }
1067
1068 #[test]
1069 fn test_tablesample_mssql() {
1070 let sample = TableSample::rows(1000).build();
1071 let sql = sample.to_mssql_sql();
1072
1073 assert!(sql.contains("TABLESAMPLE (1000 ROWS)"));
1074 }
1075
1076 #[test]
1077 fn test_bulk_operation_batches() {
1078 let bulk: BulkOperation<i32> = BulkOperation::new(vec![1, 2, 3, 4, 5]).batch_size(2);
1079
1080 assert_eq!(bulk.batch_count(), 3);
1081 let batches: Vec<_> = bulk.batches().collect();
1082 assert_eq!(batches.len(), 3);
1083 assert_eq!(batches[0], &[1, 2]);
1084 assert_eq!(batches[1], &[3, 4]);
1085 assert_eq!(batches[2], &[5]);
1086 }
1087
1088 mod mongodb_tests {
1089 use super::super::mongodb::*;
1090
1091 #[test]
1092 fn test_bulk_write_builder() {
1093 let bulk = BulkWriteBuilder::new()
1094 .insert_one(serde_json::json!({ "name": "Alice" }))
1095 .update_one(
1096 serde_json::json!({ "_id": 1 }),
1097 serde_json::json!({ "$set": { "status": "active" } }),
1098 )
1099 .delete_one(serde_json::json!({ "_id": 2 }))
1100 .unordered()
1101 .build("users");
1102
1103 assert_eq!(bulk["bulkWrite"], "users");
1104 assert_eq!(bulk["ordered"], false);
1105 assert!(bulk["operations"].is_array());
1106 assert_eq!(bulk["operations"].as_array().unwrap().len(), 3);
1107 }
1108
1109 #[test]
1110 fn test_sample_stage() {
1111 let sample = Sample::new(100);
1112 let stage = sample.to_stage();
1113
1114 assert_eq!(stage["$sample"]["size"], 100);
1115 }
1116
1117 #[test]
1118 fn test_bulk_write_upsert() {
1119 let op = BulkWriteOp::upsert_one(
1120 serde_json::json!({ "email": "test@example.com" }),
1121 serde_json::json!({ "$set": { "name": "Test" } }),
1122 );
1123
1124 let cmd = op.to_command();
1125 assert!(cmd["updateOne"]["upsert"].as_bool().unwrap());
1126 }
1127 }
1128
1129 mod distinct_on_tests {
1130 use super::super::mongodb_distinct::*;
1131
1132 #[test]
1133 fn test_distinct_on_stage() {
1134 let stage = distinct_on_stage(&["department"], &["name", "salary"]);
1135
1136 assert!(stage["$group"]["_id"]["department"].is_string());
1137 assert!(stage["$group"]["name"]["$first"].is_string());
1138 assert!(stage["$group"]["salary"]["$first"].is_string());
1139 }
1140 }
1141}