1use std::time::Duration;
34
35use bson::{Bson, Document, doc};
36use serde::{Deserialize, Serialize};
37
38use crate::client::MongoClient;
39use crate::error::{MongoError, MongoResult};
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct AggregationView {
47 pub name: String,
49 pub source_collection: String,
51 pub pipeline: Vec<Document>,
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub collation: Option<Document>,
56}
57
58impl AggregationView {
59 pub fn new(
61 name: impl Into<String>,
62 source_collection: impl Into<String>,
63 pipeline: Vec<Document>,
64 ) -> Self {
65 Self {
66 name: name.into(),
67 source_collection: source_collection.into(),
68 pipeline,
69 collation: None,
70 }
71 }
72
73 pub fn builder(name: impl Into<String>) -> AggregationViewBuilder {
75 AggregationViewBuilder::new(name)
76 }
77
78 pub fn with_collation(mut self, collation: Document) -> Self {
80 self.collation = Some(collation);
81 self
82 }
83
84 pub fn to_create_command(&self, _database: &str) -> Document {
86 let mut cmd = doc! {
87 "create": &self.name,
88 "viewOn": &self.source_collection,
89 "pipeline": self.pipeline.iter().cloned().map(Bson::Document).collect::<Vec<_>>(),
90 };
91
92 if let Some(ref collation) = self.collation {
93 cmd.insert("collation", collation.clone());
94 }
95
96 cmd
97 }
98}
99
100#[derive(Debug, Default)]
102pub struct AggregationViewBuilder {
103 name: String,
104 source_collection: Option<String>,
105 pipeline: Vec<Document>,
106 collation: Option<Document>,
107}
108
109impl AggregationViewBuilder {
110 pub fn new(name: impl Into<String>) -> Self {
112 Self {
113 name: name.into(),
114 ..Default::default()
115 }
116 }
117
118 pub fn source_collection(mut self, collection: impl Into<String>) -> Self {
120 self.source_collection = Some(collection.into());
121 self
122 }
123
124 pub fn pipeline(mut self, pipeline: Vec<Document>) -> Self {
126 self.pipeline = pipeline;
127 self
128 }
129
130 pub fn add_stage(mut self, stage: Document) -> Self {
132 self.pipeline.push(stage);
133 self
134 }
135
136 pub fn match_stage(mut self, filter: Document) -> Self {
138 self.pipeline.push(doc! { "$match": filter });
139 self
140 }
141
142 pub fn project_stage(mut self, projection: Document) -> Self {
144 self.pipeline.push(doc! { "$project": projection });
145 self
146 }
147
148 pub fn group_stage(mut self, group: Document) -> Self {
150 self.pipeline.push(doc! { "$group": group });
151 self
152 }
153
154 pub fn sort_stage(mut self, sort: Document) -> Self {
156 self.pipeline.push(doc! { "$sort": sort });
157 self
158 }
159
160 pub fn limit_stage(mut self, limit: i64) -> Self {
162 self.pipeline.push(doc! { "$limit": limit });
163 self
164 }
165
166 pub fn skip_stage(mut self, skip: i64) -> Self {
168 self.pipeline.push(doc! { "$skip": skip });
169 self
170 }
171
172 pub fn lookup_stage(
174 mut self,
175 from: impl Into<String>,
176 local_field: impl Into<String>,
177 foreign_field: impl Into<String>,
178 as_field: impl Into<String>,
179 ) -> Self {
180 self.pipeline.push(doc! {
181 "$lookup": {
182 "from": from.into(),
183 "localField": local_field.into(),
184 "foreignField": foreign_field.into(),
185 "as": as_field.into(),
186 }
187 });
188 self
189 }
190
191 pub fn unwind_stage(mut self, path: impl Into<String>) -> Self {
193 self.pipeline.push(doc! { "$unwind": path.into() });
194 self
195 }
196
197 pub fn count_stage(mut self, field: impl Into<String>) -> Self {
199 self.pipeline.push(doc! { "$count": field.into() });
200 self
201 }
202
203 pub fn add_fields_stage(mut self, fields: Document) -> Self {
205 self.pipeline.push(doc! { "$addFields": fields });
206 self
207 }
208
209 pub fn collation(mut self, collation: Document) -> Self {
211 self.collation = Some(collation);
212 self
213 }
214
215 pub fn build(self) -> AggregationView {
217 AggregationView {
218 name: self.name,
219 source_collection: self.source_collection.unwrap_or_default(),
220 pipeline: self.pipeline,
221 collation: self.collation,
222 }
223 }
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct MaterializedAggregationView {
232 pub name: String,
234 pub source_collection: String,
236 pub pipeline: Vec<Document>,
238 pub use_merge: bool,
240 #[serde(skip_serializing_if = "Option::is_none")]
242 pub merge_options: Option<MergeOptions>,
243 #[serde(skip_serializing_if = "Option::is_none")]
245 pub refresh_interval: Option<Duration>,
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct MergeOptions {
251 pub on: Vec<String>,
253 pub when_matched: MergeAction,
255 pub when_not_matched: MergeNotMatchedAction,
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize)]
261#[serde(rename_all = "camelCase")]
262pub enum MergeAction {
263 Replace,
265 KeepExisting,
267 Merge,
269 Fail,
271 Pipeline(Vec<Document>),
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
277#[serde(rename_all = "camelCase")]
278pub enum MergeNotMatchedAction {
279 Insert,
281 Discard,
283 Fail,
285}
286
287impl MaterializedAggregationView {
288 pub fn with_out(
290 name: impl Into<String>,
291 source_collection: impl Into<String>,
292 pipeline: Vec<Document>,
293 ) -> Self {
294 Self {
295 name: name.into(),
296 source_collection: source_collection.into(),
297 pipeline,
298 use_merge: false,
299 merge_options: None,
300 refresh_interval: None,
301 }
302 }
303
304 pub fn with_merge(
306 name: impl Into<String>,
307 source_collection: impl Into<String>,
308 pipeline: Vec<Document>,
309 merge_options: MergeOptions,
310 ) -> Self {
311 Self {
312 name: name.into(),
313 source_collection: source_collection.into(),
314 pipeline,
315 use_merge: true,
316 merge_options: Some(merge_options),
317 refresh_interval: None,
318 }
319 }
320
321 pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
323 self.refresh_interval = Some(interval);
324 self
325 }
326
327 pub fn to_pipeline(&self) -> Vec<Document> {
329 let mut pipeline = self.pipeline.clone();
330
331 if self.use_merge {
332 let merge_opts = self.merge_options.as_ref().unwrap();
333 let when_matched = match &merge_opts.when_matched {
334 MergeAction::Replace => Bson::String("replace".to_string()),
335 MergeAction::KeepExisting => Bson::String("keepExisting".to_string()),
336 MergeAction::Merge => Bson::String("merge".to_string()),
337 MergeAction::Fail => Bson::String("fail".to_string()),
338 MergeAction::Pipeline(p) => {
339 Bson::Array(p.iter().cloned().map(Bson::Document).collect())
340 }
341 };
342
343 let when_not_matched = match merge_opts.when_not_matched {
344 MergeNotMatchedAction::Insert => "insert",
345 MergeNotMatchedAction::Discard => "discard",
346 MergeNotMatchedAction::Fail => "fail",
347 };
348
349 pipeline.push(doc! {
350 "$merge": {
351 "into": &self.name,
352 "on": &merge_opts.on,
353 "whenMatched": when_matched,
354 "whenNotMatched": when_not_matched,
355 }
356 });
357 } else {
358 pipeline.push(doc! { "$out": &self.name });
359 }
360
361 pipeline
362 }
363}
364
365impl MongoClient {
366 pub async fn create_view(&self, view: &AggregationView) -> MongoResult<()> {
368 let cmd = view.to_create_command(&self.config().database);
369 self.run_command(cmd).await?;
370 Ok(())
371 }
372
373 pub async fn drop_view(&self, name: &str) -> MongoResult<()> {
375 self.drop_collection(name).await
376 }
377
378 pub async fn list_views(&self) -> MongoResult<Vec<String>> {
380 let result = self
381 .run_command(doc! {
382 "listCollections": 1,
383 "filter": { "type": "view" }
384 })
385 .await?;
386
387 let cursor = result
388 .get_document("cursor")
389 .map_err(|e| MongoError::query(format!("invalid response: {}", e)))?;
390
391 let first_batch = cursor
392 .get_array("firstBatch")
393 .map_err(|e| MongoError::query(format!("invalid response: {}", e)))?;
394
395 let views = first_batch
396 .iter()
397 .filter_map(|doc| {
398 doc.as_document()
399 .and_then(|d| d.get_str("name").ok())
400 .map(String::from)
401 })
402 .collect();
403
404 Ok(views)
405 }
406
407 pub async fn get_view_definition(&self, name: &str) -> MongoResult<Option<AggregationView>> {
409 let result = self
410 .run_command(doc! {
411 "listCollections": 1,
412 "filter": { "name": name, "type": "view" }
413 })
414 .await?;
415
416 let cursor = result
417 .get_document("cursor")
418 .map_err(|e| MongoError::query(format!("invalid response: {}", e)))?;
419
420 let first_batch = cursor
421 .get_array("firstBatch")
422 .map_err(|e| MongoError::query(format!("invalid response: {}", e)))?;
423
424 if first_batch.is_empty() {
425 return Ok(None);
426 }
427
428 let doc = first_batch[0]
429 .as_document()
430 .ok_or_else(|| MongoError::query("invalid view definition"))?;
431
432 let options = doc
433 .get_document("options")
434 .map_err(|e| MongoError::query(format!("missing options: {}", e)))?;
435
436 let view_on = options
437 .get_str("viewOn")
438 .map_err(|e| MongoError::query(format!("missing viewOn: {}", e)))?;
439
440 let pipeline = options
441 .get_array("pipeline")
442 .map_err(|e| MongoError::query(format!("missing pipeline: {}", e)))?
443 .iter()
444 .filter_map(|b| b.as_document().cloned())
445 .collect();
446
447 Ok(Some(AggregationView {
448 name: name.to_string(),
449 source_collection: view_on.to_string(),
450 pipeline,
451 collation: options.get_document("collation").ok().cloned(),
452 }))
453 }
454
455 pub async fn refresh_materialized_view(
459 &self,
460 view: &MaterializedAggregationView,
461 ) -> MongoResult<u64> {
462 use futures::TryStreamExt;
463
464 let collection = self.collection_doc(&view.source_collection);
465 let pipeline = view.to_pipeline();
466
467 let cursor = collection
468 .aggregate(pipeline, None)
469 .await
470 .map_err(MongoError::from)?;
471
472 let docs: Vec<Document> = cursor.try_collect().await.map_err(MongoError::from)?;
474
475 Ok(docs.len() as u64)
476 }
477}
478
479pub mod stages {
481 use bson::{Bson, Document, doc};
482
483 pub fn match_stage(filter: Document) -> Document {
485 doc! { "$match": filter }
486 }
487
488 pub fn project(fields: Document) -> Document {
490 doc! { "$project": fields }
491 }
492
493 pub fn group(id: impl Into<Bson>, accumulators: Document) -> Document {
495 let mut group_doc = doc! { "_id": id.into() };
496 group_doc.extend(accumulators);
497 doc! { "$group": group_doc }
498 }
499
500 pub fn sort(fields: Document) -> Document {
502 doc! { "$sort": fields }
503 }
504
505 pub fn limit(n: i64) -> Document {
507 doc! { "$limit": n }
508 }
509
510 pub fn skip(n: i64) -> Document {
512 doc! { "$skip": n }
513 }
514
515 pub fn lookup(
517 from: impl Into<String>,
518 local_field: impl Into<String>,
519 foreign_field: impl Into<String>,
520 as_field: impl Into<String>,
521 ) -> Document {
522 doc! {
523 "$lookup": {
524 "from": from.into(),
525 "localField": local_field.into(),
526 "foreignField": foreign_field.into(),
527 "as": as_field.into(),
528 }
529 }
530 }
531
532 pub fn unwind(path: impl Into<String>) -> Document {
534 doc! { "$unwind": path.into() }
535 }
536
537 pub fn unwind_with_options(
539 path: impl Into<String>,
540 preserve_null: bool,
541 include_array_index: Option<&str>,
542 ) -> Document {
543 let mut unwind_doc = doc! { "path": path.into() };
544 unwind_doc.insert("preserveNullAndEmptyArrays", preserve_null);
545 if let Some(index_field) = include_array_index {
546 unwind_doc.insert("includeArrayIndex", index_field);
547 }
548 doc! { "$unwind": unwind_doc }
549 }
550
551 pub fn count(field: impl Into<String>) -> Document {
553 doc! { "$count": field.into() }
554 }
555
556 pub fn add_fields(fields: Document) -> Document {
558 doc! { "$addFields": fields }
559 }
560
561 pub fn set(fields: Document) -> Document {
563 doc! { "$set": fields }
564 }
565
566 pub fn unset(fields: Vec<&str>) -> Document {
568 if fields.len() == 1 {
569 doc! { "$unset": fields[0] }
570 } else {
571 doc! { "$unset": fields }
572 }
573 }
574
575 pub fn replace_root(new_root: impl Into<Bson>) -> Document {
577 doc! { "$replaceRoot": { "newRoot": new_root.into() } }
578 }
579
580 pub fn facet(facets: Document) -> Document {
582 doc! { "$facet": facets }
583 }
584
585 pub fn bucket(
587 group_by: impl Into<Bson>,
588 boundaries: Vec<impl Into<Bson>>,
589 default_bucket: impl Into<Bson>,
590 output: Document,
591 ) -> Document {
592 doc! {
593 "$bucket": {
594 "groupBy": group_by.into(),
595 "boundaries": boundaries.into_iter().map(|b| b.into()).collect::<Vec<_>>(),
596 "default": default_bucket.into(),
597 "output": output,
598 }
599 }
600 }
601
602 pub fn bucket_auto(group_by: impl Into<Bson>, buckets: i32, output: Document) -> Document {
604 doc! {
605 "$bucketAuto": {
606 "groupBy": group_by.into(),
607 "buckets": buckets,
608 "output": output,
609 }
610 }
611 }
612
613 pub fn sample(size: i64) -> Document {
615 doc! { "$sample": { "size": size } }
616 }
617}
618
619pub mod accumulators {
621 use bson::{Bson, doc};
622
623 pub fn sum(expr: impl Into<Bson>) -> Bson {
625 Bson::Document(doc! { "$sum": expr.into() })
626 }
627
628 pub fn avg(expr: impl Into<Bson>) -> Bson {
630 Bson::Document(doc! { "$avg": expr.into() })
631 }
632
633 pub fn min(expr: impl Into<Bson>) -> Bson {
635 Bson::Document(doc! { "$min": expr.into() })
636 }
637
638 pub fn max(expr: impl Into<Bson>) -> Bson {
640 Bson::Document(doc! { "$max": expr.into() })
641 }
642
643 pub fn first(expr: impl Into<Bson>) -> Bson {
645 Bson::Document(doc! { "$first": expr.into() })
646 }
647
648 pub fn last(expr: impl Into<Bson>) -> Bson {
650 Bson::Document(doc! { "$last": expr.into() })
651 }
652
653 pub fn push(expr: impl Into<Bson>) -> Bson {
655 Bson::Document(doc! { "$push": expr.into() })
656 }
657
658 pub fn add_to_set(expr: impl Into<Bson>) -> Bson {
660 Bson::Document(doc! { "$addToSet": expr.into() })
661 }
662
663 pub fn count() -> Bson {
665 Bson::Document(doc! { "$sum": 1 })
666 }
667
668 pub fn std_dev_pop(expr: impl Into<Bson>) -> Bson {
670 Bson::Document(doc! { "$stdDevPop": expr.into() })
671 }
672
673 pub fn std_dev_samp(expr: impl Into<Bson>) -> Bson {
675 Bson::Document(doc! { "$stdDevSamp": expr.into() })
676 }
677}
678
679#[cfg(test)]
680mod tests {
681 use super::*;
682
683 #[test]
684 fn test_aggregation_view_builder() {
685 let view = AggregationView::builder("active_users")
686 .source_collection("users")
687 .match_stage(doc! { "status": "active" })
688 .project_stage(doc! { "name": 1, "email": 1 })
689 .build();
690
691 assert_eq!(view.name, "active_users");
692 assert_eq!(view.source_collection, "users");
693 assert_eq!(view.pipeline.len(), 2);
694 }
695
696 #[test]
697 fn test_view_create_command() {
698 let view = AggregationView::new(
699 "test_view",
700 "source_col",
701 vec![doc! { "$match": { "active": true } }],
702 );
703
704 let cmd = view.to_create_command("testdb");
705 assert_eq!(cmd.get_str("create").unwrap(), "test_view");
706 assert_eq!(cmd.get_str("viewOn").unwrap(), "source_col");
707 }
708
709 #[test]
710 fn test_materialized_view_out() {
711 let view = MaterializedAggregationView::with_out(
712 "user_stats",
713 "users",
714 vec![
715 doc! { "$match": { "status": "active" } },
716 doc! { "$group": { "_id": "$department", "count": { "$sum": 1 } } },
717 ],
718 );
719
720 let pipeline = view.to_pipeline();
721 assert_eq!(pipeline.len(), 3);
722 assert!(pipeline.last().unwrap().contains_key("$out"));
723 }
724
725 #[test]
726 fn test_materialized_view_merge() {
727 let view = MaterializedAggregationView::with_merge(
728 "user_stats",
729 "users",
730 vec![doc! { "$group": { "_id": "$department", "count": { "$sum": 1 } } }],
731 MergeOptions {
732 on: vec!["_id".to_string()],
733 when_matched: MergeAction::Replace,
734 when_not_matched: MergeNotMatchedAction::Insert,
735 },
736 );
737
738 let pipeline = view.to_pipeline();
739 assert_eq!(pipeline.len(), 2);
740 assert!(pipeline.last().unwrap().contains_key("$merge"));
741 }
742
743 #[test]
744 fn test_stages_helpers() {
745 let match_doc = stages::match_stage(doc! { "status": "active" });
746 assert!(match_doc.contains_key("$match"));
747
748 let group_doc = stages::group("$department", doc! { "count": accumulators::count() });
749 assert!(group_doc.contains_key("$group"));
750
751 let lookup_doc = stages::lookup("orders", "user_id", "_id", "user_orders");
752 assert!(lookup_doc.contains_key("$lookup"));
753 }
754
755 #[test]
756 fn test_accumulators() {
757 let sum = accumulators::sum("$amount");
758 assert!(sum.as_document().unwrap().contains_key("$sum"));
759
760 let avg = accumulators::avg("$price");
761 assert!(avg.as_document().unwrap().contains_key("$avg"));
762
763 let count = accumulators::count();
764 assert_eq!(count.as_document().unwrap().get_i32("$sum").unwrap(), 1);
765 }
766}