use super::{
manager::ReplicationConfig,
task::{ModelTask, ReplicatorTask, Task},
types::DatabasePair,
};
use crate::replication::task::TaskConfig;
use crate::{Mask, TuxedoResult};
use async_trait::async_trait;
use bson::{Document, RawDocumentBuf};
use indicatif::ProgressBar;
use serde::{de::DeserializeOwned, Serialize};
use std::marker::PhantomData;
use std::sync::Arc;
use tokio::sync::mpsc;
#[async_trait]
pub(crate) trait Processor: Send + Sync {
async fn run(
&self,
dbs: Arc<DatabasePair>,
task_sender: mpsc::Sender<Box<dyn Task>>,
default_config: ReplicationConfig,
progress_bar: ProgressBar,
);
async fn get_total_documents(
&self,
dbs: &Arc<DatabasePair>,
query: Document,
) -> TuxedoResult<usize> {
match dbs
.read_total_documents::<RawDocumentBuf>(self.collection_name(), query)
.await
{
Ok(total_documents) => Ok(total_documents),
Err(e) => {
println!(
"Could not get total number of documents for collection: `{}`. Collection will be skipped. Encountered error: {e}",
self.collection_name(),
);
Err(e)
}
}
}
fn setup_progress_bar(
&self,
progress_bar: ProgressBar,
total_documents: usize,
entity_name: &str,
) -> Arc<ProgressBar> {
let progress_bar = Arc::new(progress_bar);
progress_bar.set_length(total_documents as u64);
progress_bar.set_message(format!("{} ({})", self.collection_name(), entity_name,));
progress_bar
}
async fn setup_adaptive_batching(&self, dbs: &Arc<DatabasePair>) -> TuxedoResult<u64> {
let average_document_size = dbs
.get_average_document_size(self.collection_name())
.await?;
let target_bytes = calculate_optimal_target_bytes(average_document_size);
let optimal_document_count = target_bytes / average_document_size;
let batch_size = optimal_document_count.max(1);
Ok(batch_size)
}
async fn copy_indexes(&self, dbs: &Arc<DatabasePair>) {
if let Err(e) = dbs.copy_indexes(self.collection_name()).await {
println!(
"Error when copying indexes for collection `{}` from source to target - Error: {:?}",
self.collection_name(),
e
)
}
}
fn collection_name(&self) -> &str;
}
pub(crate) struct ModelProcessor<T: Mask + Serialize + DeserializeOwned + Send + Sync + Unpin> {
config: ProcessorConfig,
collection_name: String,
_phantom_data: PhantomData<T>,
}
impl<T: Mask + Serialize + DeserializeOwned + Send + Sync + Unpin> ModelProcessor<T> {
pub(crate) fn new(collection_name: impl Into<String>, config: ProcessorConfig) -> Self {
Self {
config,
collection_name: collection_name.into(),
_phantom_data: PhantomData,
}
}
}
pub(crate) struct ReplicatorProcessor<T: Send + Sync> {
config: ReplicatorConfig,
collection_name: String,
_phantom_data: PhantomData<T>,
}
impl<T: Send + Sync> ReplicatorProcessor<T> {
pub(crate) fn new(config: ReplicatorConfig, collection_name: String) -> Self {
Self {
config,
collection_name,
_phantom_data: PhantomData,
}
}
}
#[async_trait]
impl<T: Mask + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static> Processor
for ModelProcessor<T>
{
async fn run(
&self,
dbs: Arc<DatabasePair>,
task_sender: mpsc::Sender<Box<dyn Task>>,
default_config: ReplicationConfig,
progress_bar: ProgressBar,
) {
let mut batch_size = self.config.batch_size.unwrap_or(default_config.batch_size);
let write_batch_size = self
.config
.write_batch_size
.unwrap_or(default_config.write_batch_size);
let total_documents = match self
.get_total_documents(&dbs, self.config.query.clone())
.await
{
Ok(total_documents) => total_documents,
Err(_) => return,
};
let progress_bar = self.setup_progress_bar(
progress_bar,
total_documents,
std::any::type_name::<T>()
.split("::")
.last()
.expect("Expected to get model name for progress bar"),
);
if total_documents == 0 {
progress_bar.finish_and_clear();
return;
}
if self.config.adaptive_batching == Some(true) || default_config.adaptive_batching {
if let Ok(adaptive_batch_size) = self.setup_adaptive_batching(&dbs).await {
batch_size = adaptive_batch_size;
}
}
let batch_count = total_documents.div_ceil(batch_size as usize);
let strategy = default_config.strategy;
let write_options = default_config.write_options;
for batch_index in 0..batch_count {
let skip = batch_index * batch_size as usize;
let remaining_documents = total_documents.saturating_sub(skip);
let limit = batch_size.min(remaining_documents as u64) as i64;
if limit == 0 {
break;
}
let dbs = Arc::clone(&dbs);
let query = self.config.query.clone();
let strategy = strategy.clone();
let progress_bar = Arc::clone(&progress_bar);
let mut read_options = default_config.read_options.clone();
read_options.skip = (skip as u64).into();
read_options.limit = limit.into();
read_options.batch_size = Some(limit as u32);
let task = Box::new(ModelTask::<T>::new(
dbs,
self.collection_name.clone(),
TaskConfig {
query,
write_batch_size,
read_options,
write_options: write_options.clone(),
},
strategy,
progress_bar,
));
if task_sender.send(task).await.is_err() {
println!(
"Failed to send task to worker pool for collection '{}' (batch {}/{}). Channel closed, stopping processor.",
&self.collection_name,
batch_index + 1,
batch_count
);
break;
}
}
}
fn collection_name(&self) -> &str {
&self.collection_name
}
}
fn calculate_optimal_target_bytes(average_document_size: u64) -> u64 {
if average_document_size < 1024 {
return to_mb(75);
}
if average_document_size < 10 * 1024 {
return to_mb(50);
}
if average_document_size < 100 * 1024 {
return to_mb(30);
}
if average_document_size < 500 * 1024 {
return to_mb(15);
}
to_mb(5)
}
fn to_mb(size: u64) -> u64 {
size * 1024 * 1024
}
#[async_trait]
impl<T: Send + Sync + 'static> Processor for ReplicatorProcessor<T> {
async fn run(
&self,
dbs: Arc<DatabasePair>,
task_sender: mpsc::Sender<Box<dyn Task>>,
default_config: ReplicationConfig,
progress_bar: ProgressBar,
) {
let mut batch_size = self.config.batch_size.unwrap_or(default_config.batch_size);
let write_batch_size = self
.config
.write_batch_size
.unwrap_or(default_config.write_batch_size);
let total_documents = match self
.get_total_documents(&dbs, self.config.query.clone())
.await
{
Ok(total_documents) => total_documents,
Err(_) => return,
};
let progress_bar = self.setup_progress_bar(progress_bar, total_documents, "Document");
if total_documents == 0 {
progress_bar.finish_and_clear();
return;
}
if self.config.adaptive_batching == Some(true) || default_config.adaptive_batching {
if let Ok(adaptive_batch_size) = self.setup_adaptive_batching(&dbs).await {
batch_size = adaptive_batch_size;
}
}
let batch_count = total_documents.div_ceil(batch_size as usize);
let write_options = default_config.write_options;
for batch_index in 0..batch_count {
let skip = batch_index * batch_size as usize;
let remaining_documents = total_documents.saturating_sub(skip);
let limit = batch_size.min(remaining_documents as u64) as i64;
if limit == 0 {
break;
}
let dbs = Arc::clone(&dbs);
let query = self.config.query.clone();
let progress_bar = Arc::clone(&progress_bar);
let mut read_options = default_config.read_options.clone();
read_options.skip = (skip as u64).into();
read_options.limit = limit.into();
read_options.batch_size = Some(limit as u32);
let task = Box::new(ReplicatorTask::<T>::new(
dbs,
self.collection_name.clone(),
TaskConfig {
query,
write_batch_size,
read_options,
write_options: write_options.clone(),
},
self.config.lambda.clone(),
progress_bar,
));
if task_sender.send(task).await.is_err() {
println!(
"Failed to send task to worker pool for collection '{}' (batch {}/{}). Channel closed, stopping processor.",
&self.collection_name,
batch_index + 1,
batch_count
);
break;
}
}
}
fn collection_name(&self) -> &str {
&self.collection_name
}
}
#[derive(Debug, Default)]
pub struct ProcessorConfig {
adaptive_batching: Option<bool>,
batch_size: Option<u64>,
write_batch_size: Option<u64>,
query: Document,
}
#[derive(Debug, Default)]
pub struct ProcessorConfigBuilder {
config: ProcessorConfig,
}
impl ProcessorConfig {
pub fn builder() -> ProcessorConfigBuilder {
ProcessorConfigBuilder::new()
}
}
impl ProcessorConfigBuilder {
pub fn new() -> Self {
Default::default()
}
pub fn batch_size(mut self, size: impl Into<u64>) -> Self {
self.config.batch_size = Some(size.into());
self
}
pub fn write_batch_size(mut self, size: impl Into<u64>) -> Self {
self.config.write_batch_size = Some(size.into());
self
}
pub fn query<Q: Into<Document>>(mut self, query: Q) -> Self {
self.config.query = query.into();
self
}
pub fn adaptive_batching(mut self, enabled: bool) -> Self {
self.config.adaptive_batching = Some(enabled);
self
}
pub fn build(self) -> ProcessorConfig {
self.config
}
}
#[derive(Default)]
pub struct ReplicatorConfig {
adaptive_batching: Option<bool>,
batch_size: Option<u64>,
write_batch_size: Option<u64>,
query: Document,
lambda: Option<Arc<dyn Fn(&mut Document) + Send + Sync>>,
}
impl ReplicatorConfig {
fn new(
batch_size: Option<u64>,
write_batch_size: Option<u64>,
query: Document,
adaptive_batching: Option<bool>,
lambda: Option<Arc<dyn Fn(&mut Document) + Send + Sync>>,
) -> Self {
Self {
batch_size,
write_batch_size,
query,
adaptive_batching,
lambda,
}
}
pub fn builder() -> ReplicationConfigBuilder {
ReplicationConfigBuilder::new()
}
}
#[derive(Default)]
pub struct ReplicationConfigBuilder {
batch_size: Option<u64>,
write_batch_size: Option<u64>,
query: Document,
adaptive_batching: Option<bool>,
lambda: Option<Arc<dyn Fn(&mut Document) + Send + Sync>>,
}
impl ReplicationConfigBuilder {
pub fn new() -> Self {
Default::default()
}
pub fn batch_size(mut self, size: impl Into<Option<u64>>) -> Self {
self.batch_size = size.into();
self
}
pub fn write_batch_size(mut self, size: impl Into<Option<u64>>) -> Self {
self.write_batch_size = size.into();
self
}
pub fn query(mut self, query: impl Into<Document>) -> Self {
self.query = query.into();
self
}
pub fn adaptive_batching(mut self, enabled: impl Into<bool>) -> Self {
self.adaptive_batching = Some(enabled.into());
self
}
pub fn mask<F>(mut self, lambda: F) -> Self
where
F: Fn(&mut Document) + Send + Sync + 'static,
{
self.lambda = Some(Arc::new(lambda));
self
}
pub fn build(self) -> ReplicatorConfig {
ReplicatorConfig::new(
self.batch_size,
self.write_batch_size,
self.query,
self.adaptive_batching,
self.lambda,
)
}
}