polars_python/
batched_csv.rs

1use std::path::PathBuf;
2use std::sync::Mutex;
3
4use polars::io::csv::read::OwnedBatchedCsvReader;
5use polars::io::mmap::MmapBytesReader;
6use polars::io::RowIndex;
7use polars::prelude::*;
8use polars_utils::open_file;
9use pyo3::prelude::*;
10use pyo3::pybacked::PyBackedStr;
11
12use crate::error::PyPolarsErr;
13use crate::{PyDataFrame, Wrap};
14
15#[pyclass]
16#[repr(transparent)]
17pub struct PyBatchedCsv {
18    reader: Mutex<OwnedBatchedCsvReader>,
19}
20
21#[pymethods]
22#[allow(clippy::wrong_self_convention, clippy::should_implement_trait)]
23impl PyBatchedCsv {
24    #[staticmethod]
25    #[pyo3(signature = (
26        infer_schema_length, chunk_size, has_header, ignore_errors, n_rows, skip_rows, skip_lines,
27        projection, separator, rechunk, columns, encoding, n_threads, path, schema_overrides,
28        overwrite_dtype_slice, low_memory, comment_prefix, quote_char, null_values,
29        missing_utf8_is_empty_string, try_parse_dates, skip_rows_after_header, row_index,
30        eol_char, raise_if_empty, truncate_ragged_lines, decimal_comma)
31    )]
32    fn new(
33        infer_schema_length: Option<usize>,
34        chunk_size: usize,
35        has_header: bool,
36        ignore_errors: bool,
37        n_rows: Option<usize>,
38        skip_rows: usize,
39        skip_lines: usize,
40        projection: Option<Vec<usize>>,
41        separator: &str,
42        rechunk: bool,
43        columns: Option<Vec<String>>,
44        encoding: Wrap<CsvEncoding>,
45        n_threads: Option<usize>,
46        path: PathBuf,
47        schema_overrides: Option<Vec<(PyBackedStr, Wrap<DataType>)>>,
48        overwrite_dtype_slice: Option<Vec<Wrap<DataType>>>,
49        low_memory: bool,
50        comment_prefix: Option<&str>,
51        quote_char: Option<&str>,
52        null_values: Option<Wrap<NullValues>>,
53        missing_utf8_is_empty_string: bool,
54        try_parse_dates: bool,
55        skip_rows_after_header: usize,
56        row_index: Option<(String, IdxSize)>,
57        eol_char: &str,
58        raise_if_empty: bool,
59        truncate_ragged_lines: bool,
60        decimal_comma: bool,
61    ) -> PyResult<PyBatchedCsv> {
62        let null_values = null_values.map(|w| w.0);
63        let eol_char = eol_char.as_bytes()[0];
64        let row_index = row_index.map(|(name, offset)| RowIndex {
65            name: name.into(),
66            offset,
67        });
68        let quote_char = if let Some(s) = quote_char {
69            if s.is_empty() {
70                None
71            } else {
72                Some(s.as_bytes()[0])
73            }
74        } else {
75            None
76        };
77
78        let schema_overrides = schema_overrides.map(|overwrite_dtype| {
79            overwrite_dtype
80                .iter()
81                .map(|(name, dtype)| {
82                    let dtype = dtype.0.clone();
83                    Field::new((&**name).into(), dtype)
84                })
85                .collect::<Schema>()
86        });
87
88        let overwrite_dtype_slice = overwrite_dtype_slice.map(|overwrite_dtype| {
89            overwrite_dtype
90                .iter()
91                .map(|dt| dt.0.clone())
92                .collect::<Vec<_>>()
93        });
94
95        let file = open_file(&path).map_err(PyPolarsErr::from)?;
96        let reader = Box::new(file) as Box<dyn MmapBytesReader>;
97        let reader = CsvReadOptions::default()
98            .with_infer_schema_length(infer_schema_length)
99            .with_has_header(has_header)
100            .with_n_rows(n_rows)
101            .with_skip_rows(skip_rows)
102            .with_skip_rows(skip_lines)
103            .with_ignore_errors(ignore_errors)
104            .with_projection(projection.map(Arc::new))
105            .with_rechunk(rechunk)
106            .with_chunk_size(chunk_size)
107            .with_columns(columns.map(|x| x.into_iter().map(PlSmallStr::from_string).collect()))
108            .with_n_threads(n_threads)
109            .with_dtype_overwrite(overwrite_dtype_slice.map(Arc::new))
110            .with_low_memory(low_memory)
111            .with_schema_overwrite(schema_overrides.map(Arc::new))
112            .with_skip_rows_after_header(skip_rows_after_header)
113            .with_row_index(row_index)
114            .with_raise_if_empty(raise_if_empty)
115            .with_parse_options(
116                CsvParseOptions::default()
117                    .with_separator(separator.as_bytes()[0])
118                    .with_encoding(encoding.0)
119                    .with_missing_is_null(!missing_utf8_is_empty_string)
120                    .with_comment_prefix(comment_prefix)
121                    .with_null_values(null_values)
122                    .with_try_parse_dates(try_parse_dates)
123                    .with_quote_char(quote_char)
124                    .with_eol_char(eol_char)
125                    .with_truncate_ragged_lines(truncate_ragged_lines)
126                    .with_decimal_comma(decimal_comma),
127            )
128            .into_reader_with_file_handle(reader);
129
130        let reader = reader.batched(None).map_err(PyPolarsErr::from)?;
131
132        Ok(PyBatchedCsv {
133            reader: Mutex::new(reader),
134        })
135    }
136
137    fn next_batches(&self, py: Python, n: usize) -> PyResult<Option<Vec<PyDataFrame>>> {
138        let reader = &self.reader;
139        let batches = py.allow_threads(move || {
140            reader
141                .lock()
142                .map_err(|e| PyPolarsErr::Other(e.to_string()))?
143                .next_batches(n)
144                .map_err(PyPolarsErr::from)
145        })?;
146
147        // SAFETY: same memory layout
148        let batches = unsafe {
149            std::mem::transmute::<Option<Vec<DataFrame>>, Option<Vec<PyDataFrame>>>(batches)
150        };
151        Ok(batches)
152    }
153}