use mongodb::{Client, ClientSession, Database, options::ClientOptions};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use crate::config::ConnectionConfig;
use crate::error::{ConnectionError, Result};
pub struct ConnectionManager {
client: Option<Client>,
config: ConnectionConfig,
state: Arc<RwLock<ConnectionState>>,
uri: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConnectionState {
Disconnected,
Connecting,
Connected,
Failed(String),
#[allow(dead_code)]
Reconnecting,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct PoolConfig {
pub max_size: u32,
pub min_idle: u32,
pub connection_timeout: Duration,
pub idle_timeout: Duration,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct HealthStatus {
pub is_healthy: bool,
pub response_time_ms: u64,
pub server_version: Option<String>,
pub diagnostics: Option<String>,
}
impl ConnectionManager {
pub fn new(uri: String, config: ConnectionConfig) -> Self {
Self {
client: None,
config,
state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
uri,
}
}
pub async fn connect(&mut self) -> Result<()> {
info!("Connecting to MongoDB: {}", self.sanitize_uri(&self.uri));
self.set_state(ConnectionState::Connecting).await;
let options = Self::parse_uri(&self.uri).await?;
let configured_options = self.configure_pool(options);
match self.connect_with_retry(configured_options).await {
Ok(client) => {
self.client = Some(client);
self.set_state(ConnectionState::Connected).await;
info!("Successfully connected to MongoDB");
Ok(())
}
Err(e) => {
let msg = format!("Failed to connect: {}", e);
error!("{}", msg);
self.set_state(ConnectionState::Failed(msg.clone())).await;
Err(e)
}
}
}
#[allow(dead_code)]
pub async fn disconnect(&mut self) -> Result<()> {
info!("Disconnecting from MongoDB");
if self.client.is_some() {
self.client = None;
self.set_state(ConnectionState::Disconnected).await;
info!("Disconnected from MongoDB");
} else {
debug!("Already disconnected");
}
Ok(())
}
#[allow(dead_code)]
pub async fn reconnect(&mut self) -> Result<()> {
info!("Attempting to reconnect to MongoDB");
self.set_state(ConnectionState::Reconnecting).await;
if self.client.is_some() {
self.disconnect().await?;
}
self.connect().await
}
#[allow(dead_code)]
pub async fn health_check(&self) -> Result<HealthStatus> {
let client = self.get_client()?;
let start = Instant::now();
let response_time_ms = start.elapsed().as_millis() as u64;
let server_version = self.get_server_version(client).await.ok();
Ok(HealthStatus {
is_healthy: true,
response_time_ms,
server_version,
diagnostics: Some("Connected (ping skipped for secondary readPreference)".to_string()),
})
}
pub fn get_database(&self, name: &str) -> Result<Database> {
let client = self.get_client()?;
Ok(client.database(name))
}
pub fn get_client(&self) -> Result<&Client> {
self.client
.as_ref()
.ok_or_else(|| ConnectionError::NotConnected.into())
}
#[allow(dead_code)]
pub async fn get_state(&self) -> ConnectionState {
self.state.read().await.clone()
}
#[allow(dead_code)]
pub async fn is_connected(&self) -> bool {
matches!(*self.state.read().await, ConnectionState::Connected)
}
async fn parse_uri(uri: &str) -> Result<ClientOptions> {
ClientOptions::parse(uri)
.await
.map_err(|e| ConnectionError::InvalidUri(e.to_string()).into())
}
fn configure_pool(&self, mut options: ClientOptions) -> ClientOptions {
options.max_pool_size = Some(self.config.max_pool_size);
options.min_pool_size = Some(self.config.min_pool_size);
options.connect_timeout = Some(Duration::from_secs(self.config.timeout));
let server_selection_timeout = std::cmp::max(self.config.timeout, 30);
options.server_selection_timeout = Some(Duration::from_secs(server_selection_timeout));
if options.app_name.is_none() {
options.app_name = Some("mongosh-rs".to_string());
}
options.retry_reads = Some(true);
options.retry_writes = Some(true);
if options.hosts.len() == 1 {
options.direct_connection = Some(true);
debug!("Enabled direct connection for single-host connection");
}
debug!(
"Configured connection pool: max={}, min={}, readPreference={:?}, direct={:?}, server_selection_timeout={:?}s",
self.config.max_pool_size,
self.config.min_pool_size,
options.selection_criteria,
options.direct_connection,
server_selection_timeout
);
options
}
async fn set_state(&self, new_state: ConnectionState) {
*self.state.write().await = new_state;
}
async fn connect_with_retry(&self, options: ClientOptions) -> Result<Client> {
let max_retries = self.config.retry_attempts;
let base_delay_ms = 100;
let max_delay_ms = 5000;
for attempt in 1..=max_retries {
debug!("Connection attempt {}/{}", attempt, max_retries);
match Client::with_options(options.clone()) {
Ok(client) => {
debug!("Client created successfully on attempt {}", attempt);
return Ok(client);
}
Err(e) => {
if attempt == max_retries {
error!("All {} connection attempts failed", max_retries);
return Err(ConnectionError::ConnectionFailed(format!(
"Failed after {} attempts: {}",
max_retries, e
))
.into());
}
let delay_ms =
std::cmp::min(base_delay_ms * 2_u64.pow(attempt - 1), max_delay_ms);
warn!(
"Connection attempt {} failed: {}. Retrying in {}ms",
attempt, e, delay_ms
);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
}
Err(ConnectionError::ConnectionFailed("Unexpected error in retry loop".to_string()).into())
}
#[allow(dead_code)]
async fn verify_connection(&self, client: &Client) -> Result<bool> {
debug!("Verifying connection with ping");
match self.ping_internal(client).await {
Ok(_) => {
debug!("Connection verified successfully");
Ok(true)
}
Err(e) => {
warn!("Connection verification failed: {}", e);
Err(e)
}
}
}
#[allow(dead_code)]
async fn ping_internal(&self, client: &Client) -> Result<()> {
use mongodb::bson::doc;
let db = client
.default_database()
.unwrap_or_else(|| client.database("admin"));
db.run_command(doc! { "ping": 1 })
.await
.map_err(|e| ConnectionError::PingFailed(e.to_string()))?;
Ok(())
}
#[allow(dead_code)]
pub async fn get_server_version(&self, client: &Client) -> Result<String> {
use mongodb::bson::doc;
let db = client
.default_database()
.unwrap_or_else(|| client.database("admin"));
match db.run_command(doc! { "buildInfo": 1 }).await {
Ok(result) => {
if let Ok(version) = result.get_str("version") {
Ok(version.to_string())
} else {
Ok("unknown".to_string())
}
}
Err(_) => {
Ok("unknown".to_string())
}
}
}
fn sanitize_uri(&self, uri: &str) -> String {
if let Some(proto_end) = uri.find("://")
&& let Some(host_start) = uri.find('@')
{
let proto = &uri[..proto_end + 3];
let host = &uri[host_start..];
return format!("{}***{}", proto, host);
}
if uri.contains('@') {
"mongodb://***".to_string()
} else {
uri.to_string()
}
}
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_size: 10,
min_idle: 2,
connection_timeout: Duration::from_secs(30),
idle_timeout: Duration::from_secs(300),
}
}
}
impl From<&ConnectionConfig> for PoolConfig {
fn from(config: &ConnectionConfig) -> Self {
Self {
max_size: config.max_pool_size,
min_idle: config.min_pool_size,
connection_timeout: Duration::from_secs(config.timeout),
idle_timeout: Duration::from_secs(config.idle_timeout),
}
}
}
#[allow(dead_code)]
pub struct SessionManager {
client: Client,
}
impl SessionManager {
#[allow(dead_code)]
pub fn new(client: Client) -> Self {
Self { client }
}
#[allow(dead_code)]
pub async fn start_session(&self) -> Result<ClientSession> {
self.client
.start_session()
.await
.map_err(|e| ConnectionError::SessionFailed(e.to_string()).into())
}
#[allow(dead_code)]
pub async fn start_transaction(&self, session: &mut ClientSession) -> Result<()> {
session
.start_transaction()
.await
.map_err(|e| ConnectionError::TransactionFailed(e.to_string()).into())
}
#[allow(dead_code)]
pub async fn commit_transaction(&self, session: &mut ClientSession) -> Result<()> {
session
.commit_transaction()
.await
.map_err(|e| ConnectionError::TransactionFailed(e.to_string()).into())
}
#[allow(dead_code)]
pub async fn abort_transaction(&self, session: &mut ClientSession) -> Result<()> {
session
.abort_transaction()
.await
.map_err(|e| ConnectionError::TransactionFailed(e.to_string()).into())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_state() {
let state = ConnectionState::Disconnected;
assert_eq!(state, ConnectionState::Disconnected);
}
#[test]
fn test_pool_config_default() {
let config = PoolConfig::default();
assert_eq!(config.max_size, 10);
assert_eq!(config.min_idle, 2);
}
#[test]
fn test_pool_config_from_connection_config() {
let conn_config = ConnectionConfig::default();
let pool_config = PoolConfig::from(&conn_config);
assert_eq!(pool_config.max_size, conn_config.max_pool_size);
}
#[tokio::test]
async fn test_connection_manager_creation() {
let config = ConnectionConfig::default();
let manager = ConnectionManager::new("mongodb://localhost:27017".to_string(), config);
assert!(manager.client.is_none());
assert!(!manager.is_connected().await);
}
#[tokio::test]
async fn test_connection_state_transitions() {
let config = ConnectionConfig::default();
let manager = ConnectionManager::new("mongodb://localhost:27017".to_string(), config);
assert_eq!(manager.get_state().await, ConnectionState::Disconnected);
manager.set_state(ConnectionState::Connecting).await;
assert_eq!(manager.get_state().await, ConnectionState::Connecting);
}
#[test]
fn test_sanitize_uri() {
let config = ConnectionConfig::default();
let manager =
ConnectionManager::new("mongodb://user:pass@localhost:27017".to_string(), config);
let sanitized = manager.sanitize_uri("mongodb://user:pass@localhost:27017/db");
assert!(sanitized.contains("***"));
assert!(!sanitized.contains("pass"));
}
#[test]
fn test_sanitize_uri_no_credentials() {
let config = ConnectionConfig::default();
let manager = ConnectionManager::new("mongodb://localhost:27017".to_string(), config);
let sanitized = manager.sanitize_uri("mongodb://localhost:27017/db");
assert_eq!(sanitized, "mongodb://localhost:27017/db");
}
#[cfg(test)]
#[allow(dead_code)]
mod integration {
use super::*;
#[tokio::test]
#[ignore]
async fn test_connect_to_mongodb() {
let config = ConnectionConfig::default();
let mut manager =
ConnectionManager::new("mongodb://localhost:27017".to_string(), config);
let result = manager.connect().await;
assert!(result.is_ok() || matches!(result, Err(_)));
if result.is_ok() {
assert!(manager.is_connected().await);
let disconnect_result = manager.disconnect().await;
assert!(disconnect_result.is_ok());
}
}
#[tokio::test]
#[ignore]
async fn test_health_check() {
let config = ConnectionConfig::default();
let mut manager =
ConnectionManager::new("mongodb://localhost:27017".to_string(), config);
if manager.connect().await.is_ok() {
let health = manager.health_check().await;
assert!(health.is_ok());
if let Ok(status) = health {
assert!(status.is_healthy);
assert!(status.response_time_ms > 0);
}
}
}
}
}