Skip to main content

prax_mongodb/
client.rs

1//! MongoDB client wrapper with built-in connection pooling.
2
3use std::sync::Arc;
4
5use bson::{Document, doc};
6use mongodb::{Client, Collection, Database};
7use tracing::{debug, info};
8
9use crate::config::MongoConfig;
10use crate::error::{MongoError, MongoResult};
11
12/// A MongoDB client with connection pooling.
13///
14/// The MongoDB driver handles connection pooling internally,
15/// so this client wraps the driver's Client with additional
16/// Prax-specific functionality.
17#[derive(Clone)]
18pub struct MongoClient {
19    client: Client,
20    database: Database,
21    config: Arc<MongoConfig>,
22}
23
24impl MongoClient {
25    /// Create a new client from configuration.
26    pub async fn new(config: MongoConfig) -> MongoResult<Self> {
27        let options = config.to_client_options().await?;
28
29        let client = Client::with_options(options)
30            .map_err(|e| MongoError::connection(format!("failed to create client: {}", e)))?;
31
32        let database = client.database(&config.database);
33
34        info!(
35            uri = %config.uri,
36            database = %config.database,
37            "MongoDB client created"
38        );
39
40        Ok(Self {
41            client,
42            database,
43            config: Arc::new(config),
44        })
45    }
46
47    /// Create a builder for the client.
48    pub fn builder() -> MongoClientBuilder {
49        MongoClientBuilder::new()
50    }
51
52    /// Get a typed collection.
53    pub fn collection<T>(&self, name: &str) -> Collection<T>
54    where
55        T: Send + Sync,
56    {
57        self.database.collection(name)
58    }
59
60    /// Get a collection with BSON documents.
61    pub fn collection_doc(&self, name: &str) -> Collection<Document> {
62        self.database.collection(name)
63    }
64
65    /// Get the underlying database.
66    pub fn database(&self) -> &Database {
67        &self.database
68    }
69
70    /// Get a different database from the same client.
71    pub fn get_database(&self, name: &str) -> Database {
72        self.client.database(name)
73    }
74
75    /// Get the underlying MongoDB client.
76    pub fn inner(&self) -> &Client {
77        &self.client
78    }
79
80    /// Get the configuration.
81    pub fn config(&self) -> &MongoConfig {
82        &self.config
83    }
84
85    /// Check if the client is healthy by pinging the server.
86    pub async fn is_healthy(&self) -> bool {
87        self.database
88            .run_command(doc! { "ping": 1 }, None)
89            .await
90            .is_ok()
91    }
92
93    /// List all collection names in the database.
94    pub async fn list_collections(&self) -> MongoResult<Vec<String>> {
95        let names = self
96            .database
97            .list_collection_names(None)
98            .await
99            .map_err(MongoError::from)?;
100        Ok(names)
101    }
102
103    /// Drop a collection.
104    pub async fn drop_collection(&self, name: &str) -> MongoResult<()> {
105        debug!(collection = %name, "Dropping collection");
106        self.database
107            .collection::<Document>(name)
108            .drop(None)
109            .await
110            .map_err(MongoError::from)?;
111        Ok(())
112    }
113
114    /// Create an index on a collection.
115    pub async fn create_index(
116        &self,
117        collection: &str,
118        keys: Document,
119        unique: bool,
120    ) -> MongoResult<String> {
121        use mongodb::IndexModel;
122        use mongodb::options::IndexOptions;
123
124        let options = IndexOptions::builder().unique(unique).build();
125        let model = IndexModel::builder().keys(keys).options(options).build();
126
127        let result = self
128            .database
129            .collection::<Document>(collection)
130            .create_index(model, None)
131            .await
132            .map_err(MongoError::from)?;
133
134        Ok(result.index_name)
135    }
136
137    /// Run a database command.
138    pub async fn run_command(&self, command: Document) -> MongoResult<Document> {
139        let result = self
140            .database
141            .run_command(command, None)
142            .await
143            .map_err(MongoError::from)?;
144        Ok(result)
145    }
146
147    /// Start a client session for transactions.
148    pub async fn start_session(&self) -> MongoResult<mongodb::ClientSession> {
149        let session = self
150            .client
151            .start_session(None)
152            .await
153            .map_err(MongoError::from)?;
154        Ok(session)
155    }
156}
157
158/// Builder for MongoClient.
159#[derive(Debug, Default)]
160pub struct MongoClientBuilder {
161    uri: Option<String>,
162    database: Option<String>,
163    app_name: Option<String>,
164    max_pool_size: Option<u32>,
165    min_pool_size: Option<u32>,
166    connect_timeout: Option<std::time::Duration>,
167    direct_connection: Option<bool>,
168}
169
170impl MongoClientBuilder {
171    /// Create a new builder.
172    pub fn new() -> Self {
173        Self::default()
174    }
175
176    /// Set the MongoDB URI.
177    pub fn uri(mut self, uri: impl Into<String>) -> Self {
178        self.uri = Some(uri.into());
179        self
180    }
181
182    /// Set the database name.
183    pub fn database(mut self, database: impl Into<String>) -> Self {
184        self.database = Some(database.into());
185        self
186    }
187
188    /// Set the application name.
189    pub fn app_name(mut self, name: impl Into<String>) -> Self {
190        self.app_name = Some(name.into());
191        self
192    }
193
194    /// Set the maximum pool size.
195    pub fn max_pool_size(mut self, size: u32) -> Self {
196        self.max_pool_size = Some(size);
197        self
198    }
199
200    /// Set the minimum pool size.
201    pub fn min_pool_size(mut self, size: u32) -> Self {
202        self.min_pool_size = Some(size);
203        self
204    }
205
206    /// Set the connection timeout.
207    pub fn connect_timeout(mut self, duration: std::time::Duration) -> Self {
208        self.connect_timeout = Some(duration);
209        self
210    }
211
212    /// Enable direct connection (bypass replica set discovery).
213    pub fn direct_connection(mut self, enabled: bool) -> Self {
214        self.direct_connection = Some(enabled);
215        self
216    }
217
218    /// Build the client.
219    pub async fn build(self) -> MongoResult<MongoClient> {
220        let mut config_builder = MongoConfig::builder();
221
222        if let Some(uri) = self.uri {
223            config_builder = config_builder.uri(uri);
224        }
225
226        if let Some(database) = self.database {
227            config_builder = config_builder.database(database);
228        }
229
230        if let Some(app_name) = self.app_name {
231            config_builder = config_builder.app_name(app_name);
232        }
233
234        if let Some(max_pool) = self.max_pool_size {
235            config_builder = config_builder.max_pool_size(max_pool);
236        }
237
238        if let Some(min_pool) = self.min_pool_size {
239            config_builder = config_builder.min_pool_size(min_pool);
240        }
241
242        if let Some(timeout) = self.connect_timeout {
243            config_builder = config_builder.connect_timeout(timeout);
244        }
245
246        if let Some(direct) = self.direct_connection {
247            config_builder = config_builder.direct_connection(direct);
248        }
249
250        let config = config_builder.build()?;
251        MongoClient::new(config).await
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_client_builder() {
261        let builder = MongoClientBuilder::new()
262            .uri("mongodb://localhost:27017")
263            .database("test")
264            .max_pool_size(20);
265
266        assert_eq!(builder.uri, Some("mongodb://localhost:27017".to_string()));
267        assert_eq!(builder.database, Some("test".to_string()));
268        assert_eq!(builder.max_pool_size, Some(20));
269    }
270}