use super::types::{DatabasePair, ReplicationStrategy};
use crate::Mask;
use async_trait::async_trait;
use bson::Document;
use indicatif::ProgressBar;
use mongodb::options::{FindOptions, InsertManyOptions};
use serde::{de::DeserializeOwned, Serialize};
use std::marker::PhantomData;
use std::sync::Arc;
#[async_trait]
pub(crate) trait Task: Send + Sync {
async fn run(&self);
fn update_progress_bar(&self, progress_bar: &ProgressBar, num_records: usize) {
progress_bar.inc(num_records as u64);
if progress_bar.is_finished() {
progress_bar.finish_and_clear();
progress_bar.set_message("Complete");
}
}
}
#[derive(Debug)]
pub(crate) struct ModelTask<T: Mask + Serialize + DeserializeOwned + Send + Sync + 'static> {
dbs: Arc<DatabasePair>,
collection_name: String,
config: TaskConfig,
progress_bar: Arc<ProgressBar>,
strategy: ReplicationStrategy,
_phantom_data: PhantomData<T>,
}
pub(crate) struct ReplicatorTask<T: Send> {
dbs: Arc<DatabasePair>,
collection_name: String,
config: TaskConfig,
masking_lambda: Option<Arc<dyn Fn(&mut Document) + Send + Sync>>,
progress_bar: Arc<ProgressBar>,
_phantom_data: PhantomData<T>,
}
#[derive(Debug)]
pub(crate) struct TaskConfig {
pub(crate) query: Document,
pub(crate) write_batch_size: u64,
pub(crate) read_options: FindOptions,
pub(crate) write_options: InsertManyOptions,
}
impl<T: Mask + Serialize + DeserializeOwned + Send + Sync + 'static> ModelTask<T> {
pub(crate) fn new(
dbs: Arc<DatabasePair>,
collection_name: impl Into<String>,
config: TaskConfig,
strategy: ReplicationStrategy,
progress_bar: Arc<ProgressBar>,
) -> Self {
Self {
dbs,
collection_name: collection_name.into(),
config,
strategy,
progress_bar,
_phantom_data: PhantomData,
}
}
}
impl<T: Send> ReplicatorTask<T> {
pub(crate) fn new(
dbs: Arc<DatabasePair>,
collection_name: impl Into<String>,
config: TaskConfig,
masking_lambda: Option<Arc<dyn Fn(&mut Document) + Send + Sync>>,
progress_bar: Arc<ProgressBar>,
) -> Self {
Self {
dbs,
collection_name: collection_name.into(),
config,
masking_lambda,
progress_bar,
_phantom_data: PhantomData,
}
}
}
#[async_trait]
impl<T: Send + Sync> Task for ReplicatorTask<T> {
async fn run(&self) {
let mut cursor = match self
.dbs
.read_documents(
&self.collection_name,
self.config.query.clone(),
self.config.read_options.clone().into(),
)
.await
{
Ok(cursor) => cursor,
Err(e) => {
println!(
"Failed to retrieve cursor for collection: `{}` using Query: {:?} with read options: {:?}. Encountered error: {}",
&self.collection_name,
&self.config.query,
&self.config.read_options,
e
);
return;
}
};
let mut write_batch: Vec<Document> =
Vec::with_capacity(self.config.write_batch_size as usize);
let mut total_processed = 0;
while match cursor.advance().await {
Ok(true) => true, Ok(false) => false, Err(e) => {
println!(
"Error advancing cursor for collection: `{}`. Stopping task. Error: {}",
&self.collection_name, e
);
false }
} {
let mut doc = match cursor.deserialize_current() {
Ok(d) => d,
Err(e) => {
println!(
"Failed to deserialize document for collection: `{}`. Skipping document. Error: {}",
&self.collection_name, e
);
continue; }
};
if let Some(masking_fn) = self.masking_lambda.as_ref() {
(masking_fn)(&mut doc);
}
write_batch.push(doc);
total_processed += 1;
if write_batch.len() >= self.config.write_batch_size as usize {
if let Err(e) = self
.dbs
.write::<Document>(
&self.collection_name,
&write_batch,
self.config.write_options.clone().into(),
)
.await
{
println!(
"Failed to insert batch of {} records into collection: `{}`. Error: {}",
write_batch.len(),
&self.collection_name,
e
);
} else {
self.update_progress_bar(&self.progress_bar, write_batch.len());
}
write_batch.clear();
}
}
if !write_batch.is_empty() {
if let Err(e) = self
.dbs
.write::<Document>(
&self.collection_name,
&write_batch,
self.config.write_options.clone().into(),
)
.await
{
println!(
"Failed to insert final batch of {} records into collection: `{}`. Error: {}",
write_batch.len(),
&self.collection_name,
e
);
} else {
self.update_progress_bar(&self.progress_bar, write_batch.len());
}
}
if total_processed == 0 {
println!(
"No records found or processed for batch. Query: {:?} with read options: {:?}",
&self.config.query, &self.config.read_options,
);
}
}
}
#[async_trait]
impl<T: Mask + Serialize + DeserializeOwned + Send + Sync + Unpin> Task for ModelTask<T> {
async fn run(&self) {
let mut cursor = match self
.dbs
.read::<T>(
&self.collection_name,
self.config.query.clone(),
self.config.read_options.clone().into(),
)
.await
{
Ok(cursor) => cursor,
Err(e) => {
println!(
"Failed to retrieve cursor for collection: `{}` using Query: {:?} with read options: {:?}. Encountered error: {}",
&self.collection_name,
&self.config.query,
&self.config.read_options,
e
);
return;
}
};
let mut write_batch: Vec<T> = Vec::with_capacity(self.config.write_batch_size as usize);
let mut total_processed = 0;
let use_masking = matches!(self.strategy, ReplicationStrategy::Mask);
while match cursor.advance().await {
Ok(true) => true, Ok(false) => false, Err(e) => {
println!(
"Error advancing cursor for collection: `{}`. Stopping task. Error: {}",
&self.collection_name, e
);
false }
} {
let mut record = match cursor.deserialize_current() {
Ok(d) => d,
Err(e) => {
println!(
"Failed to deserialize document for collection: `{}`. Skipping document. Error: {}",
&self.collection_name, e
);
continue; }
};
if use_masking {
record.mask();
}
write_batch.push(record);
total_processed += 1;
if write_batch.len() >= self.config.write_batch_size as usize {
if let Err(e) = self
.dbs
.write::<T>(
&self.collection_name,
&write_batch,
self.config.write_options.clone().into(),
)
.await
{
println!(
"Failed to insert batch of {} records into collection: `{}`. Error: {}",
write_batch.len(),
&self.collection_name,
e
);
} else {
self.update_progress_bar(&self.progress_bar, write_batch.len());
}
write_batch.clear();
}
}
if !write_batch.is_empty() {
if let Err(e) = self
.dbs
.write::<T>(
&self.collection_name,
&write_batch,
self.config.write_options.clone().into(),
)
.await
{
println!(
"Failed to insert final batch of {} records into collection: `{}`. Error: {}",
write_batch.len(),
&self.collection_name,
e
);
} else {
self.update_progress_bar(&self.progress_bar, write_batch.len());
}
}
if total_processed == 0 {
println!(
"No records found or processed for batch. Query: {:?} with read options: {:?}",
&self.config.query, &self.config.read_options,
);
}
}
}