1#![deny(clippy::all)]
27#[cfg(feature = "serde")]
28use serde::{Deserialize, Serialize};
29mod buffer;
30mod conversion;
31pub mod prelude;
32
33use crate::buffer::*;
34
35use conversion::Wrap;
36use polars::export::rayon::prelude::*;
37use polars::{frame::row::*, prelude::*};
38use polars_core::POOL;
39
40use mongodb::{
41 bson::{Bson, Document},
42 options::{ClientOptions, FindOptions},
43 sync::{Client, Collection, Cursor},
44};
45use polars_core::utils::accumulate_dataframes_vertical;
46
47pub struct MongoScan {
48 client_options: ClientOptions,
49 db: String,
50 collection_name: String,
51 pub collection: Option<Collection<Document>>,
52 pub n_threads: Option<usize>,
53 pub batch_size: Option<usize>,
54 pub rechunk: bool,
55}
56
57impl MongoScan {
58 pub fn with_rechunk(mut self, rechunk: bool) -> Self {
59 self.rechunk = rechunk;
60 self
61 }
62 pub fn with_batch_size(mut self, batch_size: Option<usize>) -> Self {
63 self.batch_size = batch_size;
64 self
65 }
66
67 pub fn new(connection_str: String, db: String, collection: String) -> PolarsResult<Self> {
68 let client_options = ClientOptions::parse(connection_str).map_err(|e| {
69 PolarsError::InvalidOperation(format!("unable to connect to mongodb: {}", e).into())
70 })?;
71
72 Ok(MongoScan {
73 client_options,
74 db,
75 collection_name: collection,
76 collection: None,
77 n_threads: None,
78 rechunk: false,
79 batch_size: None,
80 })
81 }
82
83 fn get_collection(&self) -> Collection<Document> {
84 let client = Client::with_options(self.client_options.clone()).unwrap();
85
86 let database = client.database(&self.db);
87 database.collection::<Document>(&self.collection_name)
88 }
89
90 fn parse_lines<'a>(
91 &self,
92 mut cursor: Cursor<Document>,
93 buffers: &mut PlIndexMap<String, Buffer<'a>>,
94 ) -> mongodb::error::Result<()> {
95 while let Some(Ok(doc)) = cursor.next() {
96 buffers.iter_mut().for_each(|(s, inner)| match doc.get(s) {
97 Some(v) => inner.add(v).expect("was not able to add to buffer."),
98 None => inner.add_null(),
99 });
100 }
101 Ok(())
102 }
103}
104
105impl AnonymousScan for MongoScan {
106 fn scan(&self, scan_opts: AnonymousScanOptions) -> PolarsResult<DataFrame> {
107 let collection = &self.get_collection();
108
109 let projection = scan_opts.output_schema.clone().map(|schema| {
110 let prj = schema
111 .iter_names()
112 .map(|name| (name.clone(), Bson::Int64(1)));
113
114 Document::from_iter(prj)
115 });
116
117 let mut find_options = FindOptions::default();
118 find_options.projection = projection;
119 find_options.batch_size = self.batch_size.map(|b| b as u32);
120
121 let schema = scan_opts.output_schema.unwrap_or(scan_opts.schema);
122
123 let n_rows = scan_opts
125 .n_rows
126 .unwrap_or_else(|| collection.estimated_document_count(None).unwrap() as usize);
127
128 let mut n_threads = self.n_threads.unwrap_or_else(|| POOL.current_num_threads());
129
130 if n_rows < 128 {
131 n_threads = 1
132 }
133
134 let rows_per_thread = n_rows / n_threads;
135
136 let dfs = POOL.install(|| {
137 (0..n_threads)
138 .into_par_iter()
139 .map(|idx| {
140 let mut find_options = find_options.clone();
141
142 let start = idx * rows_per_thread;
143
144 find_options.skip = Some(start as u64);
145 find_options.limit = Some(rows_per_thread as i64);
146 let cursor = collection.find(None, Some(find_options));
147 let mut buffers = init_buffers(schema.as_ref(), rows_per_thread)?;
148
149 self.parse_lines(cursor.unwrap(), &mut buffers)
150 .map_err(|err| PolarsError::ComputeError(format!("{:#?}", err).into()))?;
151
152 DataFrame::new(
153 buffers
154 .into_values()
155 .map(|buf| buf.into_series())
156 .collect::<PolarsResult<_>>()?,
157 )
158 })
159 .collect::<PolarsResult<Vec<_>>>()
160 })?;
161 let mut df = accumulate_dataframes_vertical(dfs)?;
162
163 if self.rechunk {
164 df.rechunk();
165 }
166 Ok(df)
167 }
168
169 fn schema(&self, infer_schema_length: Option<usize>) -> PolarsResult<Schema> {
170 let collection = self.get_collection();
171
172 let infer_options = FindOptions::builder()
173 .limit(infer_schema_length.map(|i| i as i64))
174 .build();
175
176 let res = collection
177 .find(None, Some(infer_options))
178 .map_err(|err| PolarsError::ComputeError(format!("{:#?}", err).into()))?;
179 let iter = res.map(|doc| {
180 let val = doc.unwrap();
181 val.into_iter()
182 .map(|(key, value)| {
183 let dtype = Wrap::<DataType>::from(&value);
184 (key, dtype.0)
185 })
186 .collect()
187 });
188 let schema = infer_schema(iter, infer_schema_length.unwrap_or(100));
189 Ok(schema)
190 }
191
192 fn allows_predicate_pushdown(&self) -> bool {
193 true
194 }
195 fn allows_projection_pushdown(&self) -> bool {
196 true
197 }
198 fn allows_slice_pushdown(&self) -> bool {
199 true
200 }
201}
202
203#[derive(Debug, Clone)]
204#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
205pub struct MongoScanOptions {
206 pub connection_str: String,
208 pub db: String,
210 pub collection: String,
212 pub infer_schema_length: Option<usize>,
214 pub n_rows: Option<usize>,
216 pub batch_size: Option<usize>,
218}
219
220pub trait MongoLazyReader {
221 fn scan_mongo_collection(options: MongoScanOptions) -> PolarsResult<LazyFrame> {
222 let f = MongoScan::new(options.connection_str, options.db, options.collection)?;
223
224 let args = ScanArgsAnonymous {
225 name: "MONGO SCAN",
226 infer_schema_length: options.infer_schema_length,
227 n_rows: options.n_rows,
228 ..ScanArgsAnonymous::default()
229 };
230
231 LazyFrame::anonymous_scan(Arc::new(f), args)
232 }
233}
234
235impl MongoLazyReader for LazyFrame {}