Skip to main content

polars_io/csv/read/
schema_inference.rs

1use polars_buffer::Buffer;
2use polars_core::prelude::*;
3#[cfg(feature = "polars-time")]
4use polars_time::chunkedarray::string::infer as date_infer;
5#[cfg(feature = "polars-time")]
6use polars_time::prelude::string::Pattern;
7use polars_utils::format_pl_smallstr;
8
9use super::splitfields::SplitFields;
10use super::{CsvParseOptions, NullValues};
11use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};
12
13/// Low-level CSV schema inference function.
14///
15/// Use `read_until_start_and_infer_schema` instead.
16#[allow(clippy::too_many_arguments)]
17pub(super) fn infer_file_schema_impl(
18    header_line: &Option<Buffer<u8>>,
19    content_lines: &[Buffer<u8>],
20    infer_all_as_str: bool,
21    parse_options: &CsvParseOptions,
22    schema_overwrite: Option<&Schema>,
23) -> Schema {
24    let mut headers = header_line
25        .as_ref()
26        .map(|line| infer_headers(line, parse_options))
27        .unwrap_or_else(|| Vec::with_capacity(8));
28
29    let extend_header_with_unknown_column = header_line.is_none();
30
31    let mut column_types = vec![PlHashSet::<DataType>::with_capacity(4); headers.len()];
32    let mut nulls = vec![false; headers.len()];
33
34    for content_line in content_lines {
35        infer_types_from_line(
36            content_line,
37            infer_all_as_str,
38            &mut headers,
39            extend_header_with_unknown_column,
40            parse_options,
41            &mut column_types,
42            &mut nulls,
43        );
44    }
45
46    build_schema(&headers, &column_types, schema_overwrite)
47}
48
49fn infer_headers(mut header_line: &[u8], parse_options: &CsvParseOptions) -> Vec<PlSmallStr> {
50    let len = header_line.len();
51
52    if header_line.last().copied() == Some(b'\r') {
53        header_line = &header_line[..len - 1];
54    }
55
56    let byterecord = SplitFields::new(
57        header_line,
58        parse_options.separator,
59        parse_options.quote_char,
60        parse_options.eol_char,
61    );
62
63    let headers = byterecord
64        .map(|(slice, needs_escaping)| {
65            let slice_escaped = if needs_escaping && (slice.len() >= 2) {
66                &slice[1..(slice.len() - 1)]
67            } else {
68                slice
69            };
70            String::from_utf8_lossy(slice_escaped)
71        })
72        .collect::<Vec<_>>();
73
74    let mut deduplicated_headers = Vec::with_capacity(headers.len());
75    let mut header_names = PlHashMap::with_capacity(headers.len());
76
77    for name in &headers {
78        let count = header_names.entry(name.as_ref()).or_insert(0usize);
79        if *count != 0 {
80            deduplicated_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))
81        } else {
82            deduplicated_headers.push(PlSmallStr::from_str(name))
83        }
84        *count += 1;
85    }
86
87    deduplicated_headers
88}
89
90fn infer_types_from_line(
91    mut line: &[u8],
92    infer_all_as_str: bool,
93    headers: &mut Vec<PlSmallStr>,
94    extend_header_with_unknown_column: bool,
95    parse_options: &CsvParseOptions,
96    column_types: &mut Vec<PlHashSet<DataType>>,
97    nulls: &mut Vec<bool>,
98) {
99    let line_len = line.len();
100    if line.last().copied() == Some(b'\r') {
101        line = &line[..line_len - 1];
102    }
103
104    let record = SplitFields::new(
105        line,
106        parse_options.separator,
107        parse_options.quote_char,
108        parse_options.eol_char,
109    );
110
111    for (i, (slice, needs_escaping)) in record.enumerate() {
112        if i >= headers.len() {
113            if extend_header_with_unknown_column {
114                headers.push(column_name(i));
115                column_types.push(Default::default());
116                nulls.push(false);
117            } else {
118                break;
119            }
120        }
121
122        if infer_all_as_str {
123            column_types[i].insert(DataType::String);
124            continue;
125        }
126
127        if slice.is_empty() {
128            nulls[i] = true;
129        } else {
130            let slice_escaped = if needs_escaping && (slice.len() >= 2) {
131                &slice[1..(slice.len() - 1)]
132            } else {
133                slice
134            };
135            let s = String::from_utf8_lossy(slice_escaped);
136            let dtype = match &parse_options.null_values {
137                None => Some(infer_field_schema(
138                    &s,
139                    parse_options.try_parse_dates,
140                    parse_options.decimal_comma,
141                )),
142                Some(NullValues::AllColumns(names)) => {
143                    if !names.iter().any(|nv| nv == s.as_ref()) {
144                        Some(infer_field_schema(
145                            &s,
146                            parse_options.try_parse_dates,
147                            parse_options.decimal_comma,
148                        ))
149                    } else {
150                        None
151                    }
152                },
153                Some(NullValues::AllColumnsSingle(name)) => {
154                    if s.as_ref() != name.as_str() {
155                        Some(infer_field_schema(
156                            &s,
157                            parse_options.try_parse_dates,
158                            parse_options.decimal_comma,
159                        ))
160                    } else {
161                        None
162                    }
163                },
164                Some(NullValues::Named(names)) => {
165                    let current_name = &headers[i];
166                    let null_name = &names.iter().find(|name| name.0 == current_name);
167
168                    if let Some(null_name) = null_name {
169                        if null_name.1.as_str() != s.as_ref() {
170                            Some(infer_field_schema(
171                                &s,
172                                parse_options.try_parse_dates,
173                                parse_options.decimal_comma,
174                            ))
175                        } else {
176                            None
177                        }
178                    } else {
179                        Some(infer_field_schema(
180                            &s,
181                            parse_options.try_parse_dates,
182                            parse_options.decimal_comma,
183                        ))
184                    }
185                },
186            };
187            if let Some(dtype) = dtype {
188                column_types[i].insert(dtype);
189            }
190        }
191    }
192}
193
194fn build_schema(
195    headers: &[PlSmallStr],
196    column_types: &[PlHashSet<DataType>],
197    schema_overwrite: Option<&Schema>,
198) -> Schema {
199    assert!(headers.len() == column_types.len());
200
201    let get_schema_overwrite = |field_name| {
202        if let Some(schema_overwrite) = schema_overwrite {
203            // Apply schema_overwrite by column name only. Positional overrides are handled
204            // separately via dtype_overwrite.
205            if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {
206                return Some((name.clone(), dtype.clone()));
207            }
208        }
209
210        None
211    };
212
213    Schema::from_iter(
214        headers
215            .iter()
216            .zip(column_types)
217            .map(|(field_name, type_possibilities)| {
218                let (name, dtype) = get_schema_overwrite(field_name).unwrap_or_else(|| {
219                    (
220                        field_name.clone(),
221                        finish_infer_field_schema(type_possibilities),
222                    )
223                });
224
225                Field::new(name, dtype)
226            }),
227    )
228}
229
230pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
231    // determine data type based on possible types
232    // if there are incompatible types, use DataType::String
233    match possibilities.len() {
234        1 => possibilities.iter().next().unwrap().clone(),
235        2 if possibilities.contains(&DataType::Int64)
236            && possibilities.contains(&DataType::Float64) =>
237        {
238            // we have an integer and double, fall down to double
239            DataType::Float64
240        },
241        // default to String for conflicting datatypes (e.g bool and int)
242        _ => DataType::String,
243    }
244}
245
246/// Infer the data type of a record
247pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
248    // when quoting is enabled in the reader, these quotes aren't escaped, we default to
249    // String for them
250    let bytes = string.as_bytes();
251    if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' {
252        if try_parse_dates {
253            #[cfg(feature = "polars-time")]
254            {
255                match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {
256                    Some(pattern_with_offset) => match pattern_with_offset {
257                        Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
258                            DataType::Datetime(TimeUnit::Microseconds, None)
259                        },
260                        Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
261                        Pattern::DatetimeYMDZ => {
262                            DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
263                        },
264                        Pattern::Time => DataType::Time,
265                    },
266                    None => DataType::String,
267                }
268            }
269            #[cfg(not(feature = "polars-time"))]
270            {
271                panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
272            }
273        } else {
274            DataType::String
275        }
276    }
277    // match regex in a particular order
278    else if BOOLEAN_RE.is_match(string) {
279        DataType::Boolean
280    } else if !decimal_comma && FLOAT_RE.is_match(string)
281        || decimal_comma && FLOAT_RE_DECIMAL.is_match(string)
282    {
283        DataType::Float64
284    } else if INTEGER_RE.is_match(string) {
285        DataType::Int64
286    } else if try_parse_dates {
287        #[cfg(feature = "polars-time")]
288        {
289            match date_infer::infer_pattern_single(string) {
290                Some(pattern_with_offset) => match pattern_with_offset {
291                    Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
292                        DataType::Datetime(TimeUnit::Microseconds, None)
293                    },
294                    Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
295                    Pattern::DatetimeYMDZ => {
296                        DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
297                    },
298                    Pattern::Time => DataType::Time,
299                },
300                None => DataType::String,
301            }
302        }
303        #[cfg(not(feature = "polars-time"))]
304        {
305            panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
306        }
307    } else {
308        DataType::String
309    }
310}
311
312fn column_name(i: usize) -> PlSmallStr {
313    format_pl_smallstr!("column_{}", i + 1)
314}