use super::manager::{ReplicationConfig, ReplicationManager};
use super::processor::{Processor, ProcessorConfig, ReplicatorConfig};
use crate::replication::processor::{ModelProcessor, ReplicatorProcessor};
use crate::replication::types::{DatabasePair, ReplicationStrategy};
use crate::{Mask, TuxedoError, TuxedoResult};
use bson::Document;
use mongodb::options::FindOptions;
use mongodb::{
options::{ClientOptions, Compressor, InsertManyOptions, ReadConcern},
Client,
};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
use tokio::sync::mpsc;
use url::Url;
pub struct ReplicationManagerBuilder {
source_uri: Option<String>,
target_uri: Option<String>,
source_db: Option<String>,
target_db: Option<String>,
thread_count: usize,
config: ReplicationConfig,
compressors: Option<Vec<Compressor>>,
processors: Vec<Box<dyn Processor>>,
}
impl Default for ReplicationManagerBuilder {
fn default() -> Self {
Self::new()
}
}
impl ReplicationManagerBuilder {
pub fn new() -> Self {
let available_threads = num_cpus::get();
Self {
source_uri: None,
target_uri: None,
source_db: None,
target_db: None,
thread_count: available_threads,
config: ReplicationConfig::default(),
compressors: None,
processors: Vec::new(),
}
}
pub fn source_uri<S: Into<String>>(mut self, uri: S) -> Self {
self.source_uri = Some(uri.into());
self
}
pub fn target_uri<S: Into<String>>(mut self, uri: S) -> Self {
self.target_uri = Some(uri.into());
self
}
pub fn source_db<S: Into<String>>(mut self, db_name: S) -> Self {
self.source_db = Some(db_name.into());
self
}
pub fn target_db<S: Into<String>>(mut self, db_name: S) -> Self {
self.target_db = Some(db_name.into());
self
}
pub fn thread_count<S: Into<usize>>(mut self, count: S) -> Self {
self.thread_count = count.into();
self
}
pub fn strategy<S: Into<ReplicationStrategy>>(mut self, strategy: S) -> Self {
self.config.strategy = strategy.into();
self
}
pub fn batch_size(mut self, size: impl Into<u64>) -> Self {
self.config.batch_size = size.into();
self
}
pub fn write_batch_size(mut self, size: impl Into<u64>) -> Self {
self.config.write_batch_size = size.into();
self
}
pub fn write_options(mut self, options: impl Into<InsertManyOptions>) -> Self {
self.config.write_options = options.into();
self
}
pub fn read_options(mut self, options: impl Into<FindOptions>) -> Self {
self.config.read_options = options.into();
self
}
pub fn add_compression(mut self) -> Self {
self.compressors = Some(vec![
Compressor::Zstd { level: None },
Compressor::Zlib { level: None },
Compressor::Snappy,
]);
self
}
pub fn adaptive_batching(mut self) -> Self {
self.config.adaptive_batching = true;
self
}
pub fn copy_views(mut self, enabled: bool) -> Self {
self.config.copy_views = enabled;
self
}
pub fn optimize_for_performance(self, compression: bool) -> Self {
let mut builder = self;
builder.config.write_options.ordered = false.into();
builder.config.write_options.bypass_document_validation = true.into();
if compression {
builder = builder.add_compression();
}
builder
}
pub fn add_processor<T: Mask + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static>(
self,
collection_name: impl Into<String>,
) -> Self {
let config = ProcessorConfig::default();
self.add_processor_with_config::<T>(collection_name, config)
}
pub fn add_processor_with_config<
T: Mask + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static,
>(
mut self,
collection_name: impl Into<String>,
config: ProcessorConfig,
) -> Self {
self.processors
.push(Box::new(ModelProcessor::<T>::new(collection_name, config)));
self
}
pub fn add_replicator(self, collection_name: impl Into<String>) -> Self {
let config = ReplicatorConfig::default();
self.add_replicator_with_config(collection_name.into(), config)
}
pub fn add_replicator_with_config(
mut self,
collection_name: impl Into<String>,
config: ReplicatorConfig,
) -> Self {
self.processors
.push(Box::new(ReplicatorProcessor::<Document>::new(
config,
collection_name.into(),
)));
self
}
pub async fn build(self) -> TuxedoResult<ReplicationManager> {
let source_uri = self
.source_uri
.clone()
.ok_or(TuxedoError::ConfigError("No source_uri provided.".into()))?;
let target_uri = self
.target_uri
.clone()
.ok_or(TuxedoError::ConfigError("No target_uri provided.".into()))?;
let compressors = self.compressors.clone();
let max_pool_size = (self.config.thread_count * 2) as u32;
let min_pool_size = self.config.thread_count as u32;
let max_connecting = self.config.thread_count as u32;
let mut source_client_options = ClientOptions::parse(&source_uri).await?;
source_client_options.max_pool_size = max_pool_size.into();
source_client_options.min_pool_size = min_pool_size.into();
source_client_options.max_connecting = max_connecting.into();
source_client_options.compressors = compressors.clone();
source_client_options.read_concern = ReadConcern::majority().into();
let source_client = Client::with_options(source_client_options)?;
let mut target_client_options = ClientOptions::parse(&target_uri).await?;
target_client_options.max_pool_size = max_pool_size.into();
target_client_options.min_pool_size = min_pool_size.into();
target_client_options.max_connecting = max_connecting.into();
target_client_options.compressors = compressors;
let target_client = Client::with_options(target_client_options)?;
let source_db_name = self.get_db_name(&source_uri, self.source_db.clone())?;
let target_db_name = self.get_db_name(&target_uri, self.target_db.clone())?;
let dbs = Arc::new(DatabasePair::new(
source_client.database(&source_db_name),
target_client.database(&target_db_name),
));
dbs.test_database_collection_source()
.await
.expect("Could not create test connection to source database");
dbs.test_database_collection_target()
.await
.expect("Could not create test connection to target database");
println!("Dropping collections and views from target database before beginning...");
let mut items_to_drop: Vec<String> = self
.processors
.iter()
.map(|p| p.collection_name().to_string())
.collect();
if self.config.copy_views {
let view_names = dbs
.get_source_view_names()
.await
.expect("Expected to successfully get source view names");
items_to_drop.extend(view_names);
}
dbs.clear_target_collections(&items_to_drop)
.await
.expect("Expected to successfully drop target database collections and views before replication");
let (task_sender, task_receiver) = mpsc::channel(self.config.thread_count);
let manager = ReplicationManager {
dbs,
processors: self.processors.into_iter().map(|p| p.into()).collect(),
config: self.config,
task_receiver,
task_sender,
};
Ok(manager)
}
fn get_db_name(&self, uri: &str, db_name: Option<String>) -> TuxedoResult<String> {
let parsed_db_name = self.parse_db_name_from_uri(uri)?;
match parsed_db_name {
Some(db_name) => Ok(db_name),
None => match db_name {
Some(db_name) => Ok(db_name),
None => Err(TuxedoError::ConfigError(
"Could not parse database name from URI and no database name provided.".into(),
)),
},
}
}
fn parse_db_name_from_uri(&self, url: &str) -> TuxedoResult<Option<String>> {
let url = Url::parse(url).map_err(|e| TuxedoError::ConfigError(e.to_string()))?;
let path = url.path();
if path.is_empty() || path == "/" {
return Ok(None);
}
let db_name = path.trim_start_matches('/').split('?').next().unwrap();
if db_name.is_empty() {
Ok(None)
} else {
Ok(Some(db_name.to_string()))
}
}
}