polars_mongo/
lib.rs

1//! Polars mongo is a connector to read from a mongodb collection into a Polars dataframe.
2//! Usage:
3//! ```rust
4//! use polars::prelude::*;
5//! use polars_mongo::prelude::*;
6//!
7//! pub fn main() -> PolarsResult<()> {
8//!     let connection_str = std::env::var("POLARS_MONGO_CONNECTION_URI").unwrap();
9//!     let db = std::env::var("POLARS_MONGO_DB").unwrap();
10//!     let collection = std::env::var("POLARS_MONGO_COLLECTION").unwrap();
11//!
12//!     let df = LazyFrame::scan_mongo_collection(MongoScanOptions {
13//!         batch_size: None,
14//!         connection_str,
15//!         db,
16//!         collection,
17//!         infer_schema_length: Some(1000),
18//!         n_rows: None,
19//!     })?
20//!     .collect()?;
21//!
22//!     dbg!(df);
23//!     Ok(())
24//! }
25//!
26#![deny(clippy::all)]
27#[cfg(feature = "serde")]
28use serde::{Deserialize, Serialize};
29mod buffer;
30mod conversion;
31pub mod prelude;
32
33use crate::buffer::*;
34
35use conversion::Wrap;
36use polars::export::rayon::prelude::*;
37use polars::{frame::row::*, prelude::*};
38use polars_core::POOL;
39
40use mongodb::{
41    bson::{Bson, Document},
42    options::{ClientOptions, FindOptions},
43    sync::{Client, Collection, Cursor},
44};
45use polars_core::utils::accumulate_dataframes_vertical;
46
47pub struct MongoScan {
48    client_options: ClientOptions,
49    db: String,
50    collection_name: String,
51    pub collection: Option<Collection<Document>>,
52    pub n_threads: Option<usize>,
53    pub batch_size: Option<usize>,
54    pub rechunk: bool,
55}
56
57impl MongoScan {
58    pub fn with_rechunk(mut self, rechunk: bool) -> Self {
59        self.rechunk = rechunk;
60        self
61    }
62    pub fn with_batch_size(mut self, batch_size: Option<usize>) -> Self {
63        self.batch_size = batch_size;
64        self
65    }
66
67    pub fn new(connection_str: String, db: String, collection: String) -> PolarsResult<Self> {
68        let client_options = ClientOptions::parse(connection_str).map_err(|e| {
69            PolarsError::InvalidOperation(format!("unable to connect to mongodb: {}", e).into())
70        })?;
71
72        Ok(MongoScan {
73            client_options,
74            db,
75            collection_name: collection,
76            collection: None,
77            n_threads: None,
78            rechunk: false,
79            batch_size: None,
80        })
81    }
82
83    fn get_collection(&self) -> Collection<Document> {
84        let client = Client::with_options(self.client_options.clone()).unwrap();
85
86        let database = client.database(&self.db);
87        database.collection::<Document>(&self.collection_name)
88    }
89
90    fn parse_lines<'a>(
91        &self,
92        mut cursor: Cursor<Document>,
93        buffers: &mut PlIndexMap<String, Buffer<'a>>,
94    ) -> mongodb::error::Result<()> {
95        while let Some(Ok(doc)) = cursor.next() {
96            buffers.iter_mut().for_each(|(s, inner)| match doc.get(s) {
97                Some(v) => inner.add(v).expect("was not able to add to buffer."),
98                None => inner.add_null(),
99            });
100        }
101        Ok(())
102    }
103}
104
105impl AnonymousScan for MongoScan {
106    fn scan(&self, scan_opts: AnonymousScanOptions) -> PolarsResult<DataFrame> {
107        let collection = &self.get_collection();
108
109        let projection = scan_opts.output_schema.clone().map(|schema| {
110            let prj = schema
111                .iter_names()
112                .map(|name| (name.clone(), Bson::Int64(1)));
113
114            Document::from_iter(prj)
115        });
116
117        let mut find_options = FindOptions::default();
118        find_options.projection = projection;
119        find_options.batch_size = self.batch_size.map(|b| b as u32);
120
121        let schema = scan_opts.output_schema.unwrap_or(scan_opts.schema);
122
123        // if no n_rows we need to get the count from mongo.
124        let n_rows = scan_opts
125            .n_rows
126            .unwrap_or_else(|| collection.estimated_document_count(None).unwrap() as usize);
127
128        let mut n_threads = self.n_threads.unwrap_or_else(|| POOL.current_num_threads());
129
130        if n_rows < 128 {
131            n_threads = 1
132        }
133
134        let rows_per_thread = n_rows / n_threads;
135
136        let dfs = POOL.install(|| {
137            (0..n_threads)
138                .into_par_iter()
139                .map(|idx| {
140                    let mut find_options = find_options.clone();
141
142                    let start = idx * rows_per_thread;
143
144                    find_options.skip = Some(start as u64);
145                    find_options.limit = Some(rows_per_thread as i64);
146                    let cursor = collection.find(None, Some(find_options));
147                    let mut buffers = init_buffers(schema.as_ref(), rows_per_thread)?;
148
149                    self.parse_lines(cursor.unwrap(), &mut buffers)
150                        .map_err(|err| PolarsError::ComputeError(format!("{:#?}", err).into()))?;
151
152                    DataFrame::new(
153                        buffers
154                            .into_values()
155                            .map(|buf| buf.into_series())
156                            .collect::<PolarsResult<_>>()?,
157                    )
158                })
159                .collect::<PolarsResult<Vec<_>>>()
160        })?;
161        let mut df = accumulate_dataframes_vertical(dfs)?;
162
163        if self.rechunk {
164            df.rechunk();
165        }
166        Ok(df)
167    }
168
169    fn schema(&self, infer_schema_length: Option<usize>) -> PolarsResult<Schema> {
170        let collection = self.get_collection();
171
172        let infer_options = FindOptions::builder()
173            .limit(infer_schema_length.map(|i| i as i64))
174            .build();
175
176        let res = collection
177            .find(None, Some(infer_options))
178            .map_err(|err| PolarsError::ComputeError(format!("{:#?}", err).into()))?;
179        let iter = res.map(|doc| {
180            let val = doc.unwrap();
181            val.into_iter()
182                .map(|(key, value)| {
183                    let dtype = Wrap::<DataType>::from(&value);
184                    (key, dtype.0)
185                })
186                .collect()
187        });
188        let schema = infer_schema(iter, infer_schema_length.unwrap_or(100));
189        Ok(schema)
190    }
191
192    fn allows_predicate_pushdown(&self) -> bool {
193        true
194    }
195    fn allows_projection_pushdown(&self) -> bool {
196        true
197    }
198    fn allows_slice_pushdown(&self) -> bool {
199        true
200    }
201}
202
203#[derive(Debug, Clone)]
204#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
205pub struct MongoScanOptions {
206    /// mongodb style connection string. `mongodb://<user>:<password>@host.domain`
207    pub connection_str: String,
208    /// the name of the mongodb database
209    pub db: String,
210    /// the name of the mongodb collection
211    pub collection: String,
212    // Number of rows used to infer the schema. Defaults to `100` if not provided.
213    pub infer_schema_length: Option<usize>,
214    /// Number of rows to return from mongodb collection. If not provided, it will fetch all rows from collection.
215    pub n_rows: Option<usize>,
216    /// determines the number of records to return from a single request to mongodb
217    pub batch_size: Option<usize>,
218}
219
220pub trait MongoLazyReader {
221    fn scan_mongo_collection(options: MongoScanOptions) -> PolarsResult<LazyFrame> {
222        let f = MongoScan::new(options.connection_str, options.db, options.collection)?;
223
224        let args = ScanArgsAnonymous {
225            name: "MONGO SCAN",
226            infer_schema_length: options.infer_schema_length,
227            n_rows: options.n_rows,
228            ..ScanArgsAnonymous::default()
229        };
230
231        LazyFrame::anonymous_scan(Arc::new(f), args)
232    }
233}
234
235impl MongoLazyReader for LazyFrame {}