axion_data/io/
csv.rs

1use rayon::iter::{IntoParallelIterator, ParallelIterator};
2
3use crate::dataframe::DataFrame;
4use crate::series::{Series, SeriesTrait};
5use crate::AxionResult;
6use crate::AxionError;
7use crate::dtype::DataType;
8use std::fs::File;
9use std::path::Path;
10use std::collections::{HashMap, HashSet};
11use std::io::{BufReader, BufRead};
12
13/// CSV 读取配置选项
14/// 
15/// 提供了丰富的 CSV 文件读取配置,支持自定义分隔符、数据类型推断、
16/// 列选择等功能。
17/// 
18/// # 示例
19/// 
20/// ```rust
21/// use axion::io::csv::{ReadCsvOptions, read_csv};
22/// use axion::dtype::DataType;
23/// use std::collections::HashMap;
24/// 
25/// // 使用默认配置
26/// let df1 = read_csv("data.csv", None)?;
27/// 
28/// // 使用自定义配置
29/// let options = ReadCsvOptions::builder()
30///     .with_delimiter(b';')
31///     .with_header(true)
32///     .infer_schema(true)
33///     .build();
34/// let df2 = read_csv("data.csv", Some(options))?;
35/// ```
36#[derive(Debug, Clone)]
37pub struct ReadCsvOptions {
38    /// 字段分隔符,默认为 `,`
39    pub delimiter: u8,
40    /// CSV 文件是否包含表头行,默认为 `true`
41    /// 如果为 `false`,列名将自动生成为 "column_0", "column_1", ...
42    pub has_header: bool,
43    /// 尝试推断列的数据类型,默认为 `true`
44    /// 如果为 `false`,所有列将被读取为字符串
45    pub infer_schema: bool,
46    /// 用于类型推断的最大行数,默认为 `100`
47    /// 如果为 `None`,则使用所有行进行推断
48    pub infer_schema_length: Option<usize>,
49    /// 可选的 HashMap,用于手动指定某些列的数据类型
50    /// 手动指定的类型将覆盖类型推断的结果
51    pub dtypes: Option<HashMap<String, DataType>>,
52    /// 跳过文件开头的 N 行,默认为 `0`
53    pub skip_rows: usize,
54    /// 将以此字符开头的行视作注释并忽略,默认为 `None`
55    pub comment_char: Option<u8>,
56    /// 可选的列选择器,指定要读取的列名子集
57    /// 如果为 `None`,则读取所有列
58    pub use_columns: Option<Vec<String>>,
59    /// 一组应被视为空值的字符串,默认为 `None`
60    pub na_values: Option<HashSet<String>>,
61}
62
63impl Default for ReadCsvOptions {
64    fn default() -> Self {
65        ReadCsvOptions {
66            delimiter: b',',
67            has_header: true,
68            infer_schema: true,
69            infer_schema_length: Some(100),
70            dtypes: None,
71            skip_rows: 0,
72            comment_char: None,
73            use_columns: None,
74            na_values: None,
75        }
76    }
77}
78
79impl ReadCsvOptions {
80    /// 创建一个新的 ReadCsvOptions 构建器,使用默认值
81    pub fn builder() -> ReadCsvOptionsBuilder {
82        ReadCsvOptionsBuilder::new()
83    }
84}
85
86/// ReadCsvOptions 的构建器
87/// 
88/// 提供了一种链式调用的方式来配置 CSV 读取选项。
89/// 
90/// # 示例
91/// 
92/// ```rust
93/// let options = ReadCsvOptions::builder()
94///     .with_delimiter(b';')
95///     .with_header(true)
96///     .skip_rows(2)
97///     .build();
98/// ```
99#[derive(Debug, Clone, Default)]
100pub struct ReadCsvOptionsBuilder {
101    delimiter: Option<u8>,
102    has_header: Option<bool>,
103    infer_schema: Option<bool>,
104    infer_schema_length: Option<Option<usize>>,
105    dtypes: Option<HashMap<String, DataType>>,
106    skip_rows: Option<usize>,
107    comment_char: Option<Option<u8>>,
108    use_columns: Option<Vec<String>>,
109    na_values: Option<HashSet<String>>,
110}
111
112impl ReadCsvOptionsBuilder {
113    /// 创建一个新的构建器实例
114    pub fn new() -> Self {
115        Default::default()
116    }
117
118    /// 设置字段分隔符
119    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
120        self.delimiter = Some(delimiter);
121        self
122    }
123
124    /// 设置是否包含表头行
125    pub fn with_header(mut self, has_header: bool) -> Self {
126        self.has_header = Some(has_header);
127        self
128    }
129
130    /// 设置是否启用类型推断
131    pub fn infer_schema(mut self, infer: bool) -> Self {
132        self.infer_schema = Some(infer);
133        self
134    }
135
136    /// 设置类型推断使用的行数
137    pub fn infer_schema_length(mut self, length: Option<usize>) -> Self {
138        self.infer_schema_length = Some(length);
139        self
140    }
141
142    /// 设置列数据类型映射
143    pub fn with_dtypes(mut self, dtypes: HashMap<String, DataType>) -> Self {
144        self.dtypes = Some(dtypes);
145        self
146    }
147    
148    /// 添加单个列的数据类型
149    pub fn add_dtype(mut self, column_name: String, dtype: DataType) -> Self {
150        self.dtypes.get_or_insert_with(HashMap::new).insert(column_name, dtype);
151        self
152    }
153
154    /// 设置跳过的行数
155    pub fn skip_rows(mut self, n: usize) -> Self {
156        self.skip_rows = Some(n);
157        self
158    }
159
160    /// 设置注释字符
161    pub fn comment_char(mut self, char_opt: Option<u8>) -> Self {
162        self.comment_char = Some(char_opt);
163        self
164    }
165
166    /// 设置要读取的列
167    pub fn use_columns(mut self, columns: Vec<String>) -> Self {
168        self.use_columns = Some(columns);
169        self
170    }
171
172    /// 添加要读取的列
173    pub fn add_use_column(mut self, column_name: String) -> Self {
174        self.use_columns.get_or_insert_with(Vec::new).push(column_name);
175        self
176    }
177    
178    /// 设置 null 值表示
179    pub fn na_values(mut self, values: Option<HashSet<String>>) -> Self {
180        self.na_values = values;
181        self
182    }
183
184    /// 添加 null 值表示
185    pub fn add_na_value(mut self, value: String) -> Self {
186        self.na_values
187            .get_or_insert_with(HashSet::new)
188            .insert(value);
189        self
190    }
191
192    /// 构建最终的 `ReadCsvOptions` 实例
193    pub fn build(self) -> ReadCsvOptions {
194        let defaults = ReadCsvOptions::default();
195        ReadCsvOptions {
196            delimiter: self.delimiter.unwrap_or(defaults.delimiter),
197            has_header: self.has_header.unwrap_or(defaults.has_header),
198            infer_schema: self.infer_schema.unwrap_or(defaults.infer_schema),
199            infer_schema_length: self.infer_schema_length.unwrap_or(defaults.infer_schema_length),
200            dtypes: self.dtypes.or(defaults.dtypes),
201            skip_rows: self.skip_rows.unwrap_or(defaults.skip_rows),
202            comment_char: self.comment_char.unwrap_or(defaults.comment_char),
203            use_columns: self.use_columns.or(defaults.use_columns),
204            na_values: self.na_values.or(defaults.na_values),
205        }
206    }
207}
208
209/// 尝试解析字符串为 i64
210fn try_parse_i64(s: &str) -> Option<i64> {
211    s.parse::<i64>().ok()
212}
213
214/// 尝试解析字符串为 f64
215fn try_parse_f64(s: &str) -> Option<f64> {
216    s.parse::<f64>().ok()
217}
218
219/// 尝试解析字符串为布尔值
220fn try_parse_bool(s: &str) -> Option<bool> {
221    match s.to_lowercase().as_str() {
222        "true" | "t" | "yes" | "y" | "1" => Some(true),
223        "false" | "f" | "no" | "n" | "0" => Some(false),
224        _ => None,
225    }
226}
227
228/// 推断单列的数据类型
229fn infer_column_type(
230    column_values: &[Option<String>],
231    infer_length: Option<usize>,
232) -> DataType {
233    let non_empty_values = column_values
234        .iter()
235        .filter_map(|opt_s| opt_s.as_ref().map(|s| s.as_str()))
236        .filter(|s| !s.is_empty());
237
238    let sample: Vec<&str> = if let Some(len) = infer_length {
239        non_empty_values.take(len).collect()
240    } else {
241        non_empty_values.collect()
242    };
243
244    if sample.is_empty() {
245        return DataType::String;
246    }
247
248    if sample.iter().all(|s| try_parse_i64(s).is_some()) {
249        return DataType::Int64;
250    }
251    if sample.iter().all(|s| try_parse_f64(s).is_some()) {
252        return DataType::Float64;
253    }
254    if sample.iter().all(|s| try_parse_bool(s).is_some()) {
255        return DataType::Bool;
256    }
257    DataType::String
258}
259
260/// 将字符串列解析为指定类型的 Series
261fn parse_column_as_type(
262    column_name: String,
263    string_data: Vec<Option<String>>,
264    target_type: &DataType,
265) -> AxionResult<Box<dyn SeriesTrait>> {
266    match target_type {
267        DataType::Int64 => {
268            let parsed_data: Vec<Option<i64>> = string_data
269                .into_iter()
270                .map(|opt_s| opt_s.and_then(|s| try_parse_i64(&s)))
271                .collect();
272            Ok(Box::new(Series::<i64>::new_from_options(column_name, parsed_data)))
273        }
274        DataType::Float64 => {
275            let parsed_data: Vec<Option<f64>> = string_data
276                .into_iter()
277                .map(|opt_s| opt_s.and_then(|s| try_parse_f64(&s)))
278                .collect();
279            Ok(Box::new(Series::<f64>::new_from_options(column_name, parsed_data)))
280        }
281        DataType::Bool => {
282            let parsed_data: Vec<Option<bool>> = string_data
283                .into_iter()
284                .map(|opt_s| opt_s.and_then(|s| try_parse_bool(&s)))
285                .collect();
286            Ok(Box::new(Series::<bool>::new_from_options(column_name, parsed_data)))
287        }
288        DataType::String => {
289            Ok(Box::new(Series::<String>::new_from_options(column_name, string_data)))
290        }
291        dt => Err(AxionError::UnsupportedOperation(format!(
292            "无法将 CSV 列 '{}' 解析为类型 {:?}。CSV 解析仅支持 Int64、Float64、Bool、String 类型。",
293            column_name, dt
294        ))),
295    }
296}
297
298/// 从 CSV 文件读取数据到 DataFrame
299/// 
300/// 支持自动类型推断、列选择、注释行处理等高级功能。
301/// 
302/// # 参数
303/// 
304/// * `filepath` - CSV 文件路径
305/// * `options` - 可选的读取配置,如果为 None 则使用默认配置
306/// 
307/// # 返回值
308/// 
309/// 成功时返回包含 CSV 数据的 DataFrame
310/// 
311/// # 错误
312/// 
313/// * `AxionError::IoError` - 文件读取失败
314/// * `AxionError::CsvError` - CSV 格式错误或解析失败
315/// 
316/// # 示例
317/// 
318/// ```rust
319/// // 使用默认配置读取
320/// let df = read_csv("data.csv", None)?;
321/// 
322/// // 使用自定义配置读取
323/// let options = ReadCsvOptions::builder()
324///     .with_delimiter(b';')
325///     .infer_schema(true)
326///     .build();
327/// let df = read_csv("data.csv", Some(options))?;
328/// ```
329pub fn read_csv(filepath: impl AsRef<Path>, options: Option<ReadCsvOptions>) -> AxionResult<DataFrame> {
330    let opts = options.unwrap_or_default();
331
332    let file = File::open(filepath.as_ref())
333        .map_err(|e| AxionError::IoError(format!("无法打开文件 {:?}: {}", filepath.as_ref(), e)))?;
334    
335    let mut buf_reader = BufReader::new(file);
336
337    // 跳过指定行数
338    if opts.skip_rows > 0 {
339        let mut line_buffer = String::new();
340        for i in 0..opts.skip_rows {
341            match buf_reader.read_line(&mut line_buffer) {
342                Ok(0) => {
343                    return Err(AxionError::CsvError(format!(
344                        "CSV 文件行数少于需要跳过的行数 {},在第 {} 行到达文件末尾。",
345                        opts.skip_rows, i
346                    )));
347                }
348                Ok(_) => {
349                    line_buffer.clear();
350                }
351                Err(e) => {
352                    return Err(AxionError::IoError(format!("跳过行时出错: {}", e)));
353                }
354            }
355        }
356    }
357
358    let mut rdr_builder = csv::ReaderBuilder::new();
359    rdr_builder.delimiter(opts.delimiter);
360    rdr_builder.has_headers(false);
361    if let Some(comment) = opts.comment_char {
362        rdr_builder.comment(Some(comment));
363    }
364
365    let rdr = rdr_builder.from_reader(buf_reader); 
366    let mut records_iter = rdr.into_records();
367
368    // 确定文件表头和第一行数据
369    let original_file_headers: Vec<String>;
370    let mut first_data_row_buffer: Option<csv::StringRecord> = None;
371
372    if opts.has_header {
373        if let Some(header_result) = records_iter.next() {
374            original_file_headers = header_result
375                .map_err(|e| AxionError::CsvError(format!("读取 CSV 表头失败: {}", e)))?
376                .iter()
377                .map(|s| s.to_string())
378                .collect::<Vec<String>>();
379            if original_file_headers.is_empty() && !Path::new(filepath.as_ref()).metadata().map_or(true, |m| m.len() == 0) {
380                 return Err(AxionError::CsvError("CSV 表头行存在但为空。".to_string()));
381            }
382        } else {
383            return Ok(DataFrame::new_empty());
384        }
385    } else if let Some(first_record_result) = records_iter.next() {
386        let record = first_record_result.map_err(|e| AxionError::CsvError(format!("读取第一条数据记录失败: {}", e)))?;
387        if record.iter().all(|field| field.is_empty()) && !record.is_empty() { 
388             original_file_headers = (0..record.len()).map(|i| format!("column_{}", i)).collect();
389        } else if record.is_empty() { 
390             return Ok(DataFrame::new_empty());
391        } else {
392            original_file_headers = (0..record.len()).map(|i| format!("column_{}", i)).collect();
393        }
394        first_data_row_buffer = Some(record); 
395    } else {
396        return Ok(DataFrame::new_empty());
397    }
398    if original_file_headers.is_empty() {
399        return Ok(DataFrame::new_empty());
400    }
401
402    let (final_headers_to_use, column_indices_to_read): (Vec<String>, Vec<usize>) =
403        if let Some(ref wanted_columns) = opts.use_columns {
404            if wanted_columns.is_empty() { 
405                (Vec::new(), Vec::new())
406            } else {
407                let mut final_h = Vec::new();
408                let mut indices = Vec::new();
409                let original_header_map: HashMap<&String, usize> = original_file_headers.iter().enumerate().map(|(i, h_name)| (h_name, i)).collect();
410
411                for col_name_to_use in wanted_columns {
412                    if let Some(&original_index) = original_header_map.get(col_name_to_use) {
413                        final_h.push(col_name_to_use.clone());
414                        indices.push(original_index);
415                    } else {
416                        return Err(AxionError::CsvError(format!(
417                            "use_columns 中指定的列 '{}' 在 CSV 表头中未找到: {:?}",
418                            col_name_to_use, original_file_headers
419                        )));
420                    }
421                }
422                (final_h, indices)
423            }
424        } else {
425            (original_file_headers.clone(), (0..original_file_headers.len()).collect())
426        };
427
428    if final_headers_to_use.is_empty() {
429        return Ok(DataFrame::new_empty());
430    }
431
432    let num_selected_columns = final_headers_to_use.len();
433    let mut column_data_str: Vec<Vec<Option<String>>> = vec![Vec::new(); num_selected_columns];
434
435    let process_record_logic = |record: &csv::StringRecord,
436                                 col_data_storage: &mut Vec<Vec<Option<String>>>| -> AxionResult<()> {
437        
438        if opts.comment_char.is_some() && record.iter().all(|field| field.is_empty()) {
439            return Ok(()); 
440        }
441
442        if record.len() != original_file_headers.len() {
443            return Err(AxionError::CsvError(format!(
444                "CSV 记录有 {} 个字段,但表头有 {} 个字段。记录: {:?}",
445                record.len(),
446                original_file_headers.len(),
447                record
448            )));
449        }
450
451        for (target_idx, &original_field_idx) in column_indices_to_read.iter().enumerate() {
452            if let Some(field_str_val) = record.get(original_field_idx) {
453                let is_user_defined_na = opts.na_values
454                    .as_ref()
455                    .is_some_and(|na_set| na_set.contains(field_str_val));
456
457                if is_user_defined_na || field_str_val.is_empty() {
458                    col_data_storage[target_idx].push(None);
459                } else {
460                    col_data_storage[target_idx].push(Some(field_str_val.to_string()));
461                }
462            } else {
463                return Err(AxionError::CsvError(format!(
464                    "内部错误或记录长度不一致: 尝试访问索引 {} 的字段,但记录只有 {} 个字段。",
465                    original_field_idx, record.len()
466                )));
467            }
468        }
469        Ok(())
470    };
471
472    if let Some(ref record) = first_data_row_buffer {
473        process_record_logic(record, &mut column_data_str)?
474    }
475
476    for result in records_iter { 
477        match result {
478            Ok(record) => {
479                process_record_logic(&record, &mut column_data_str)?
480            }
481            Err(e) => {
482                return Err(AxionError::CsvError(format!("读取 CSV 记录失败: {}", e)));
483            }
484        }
485    }
486
487    let mut data_to_process: Vec<(String, Vec<Option<String>>, DataType)> = Vec::with_capacity(num_selected_columns);
488
489    for i in 0..num_selected_columns {
490        let column_name = final_headers_to_use[i].clone();
491        let current_column_str_data = std::mem::take(&mut column_data_str[i]); 
492
493        let final_dtype = if let Some(ref manual_dtypes) = opts.dtypes {
494            manual_dtypes.get(&column_name).cloned().unwrap_or_else(|| {
495                if opts.infer_schema {
496                    infer_column_type(&current_column_str_data, opts.infer_schema_length)
497                } else {
498                    DataType::String
499                }
500            })
501        } else if opts.infer_schema {
502            infer_column_type(&current_column_str_data, opts.infer_schema_length)
503        } else {
504            DataType::String
505        };
506        data_to_process.push((column_name, current_column_str_data, final_dtype));
507    }
508
509    let series_results: Vec<AxionResult<Box<dyn SeriesTrait>>> = data_to_process
510        .into_par_iter() 
511        .map(|(col_name, str_data, dtype)| {
512            parse_column_as_type(col_name, str_data, &dtype)
513        })
514        .collect(); 
515
516    let mut series_vec: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(num_selected_columns);
517    for result in series_results {
518        match result {
519            Ok(series) => series_vec.push(series),
520            Err(e) => return Err(e), 
521        }
522    }
523
524    DataFrame::new(series_vec)
525}
526
527/// CSV 引用样式
528/// 
529/// 控制 CSV 写入时字段的引号使用策略。
530#[derive(Default, Debug, Clone, PartialEq, Eq)]
531pub enum QuoteStyle {
532    /// 总是为所有字段加上引号
533    Always,
534    /// 仅在字段包含分隔符、引号或换行符时加上引号(默认)
535    #[default]
536    Necessary,
537    /// 从不为字段加上引号(如果字段包含特殊字符,可能导致 CSV 格式无效)
538    Never,
539    /// 仅为非数字字段加上引号
540    NonNumeric,
541}
542
543/// CSV 写入配置选项
544/// 
545/// 控制 DataFrame 导出为 CSV 文件时的格式设置。
546/// 
547/// # 示例
548/// 
549/// ```rust
550/// use axion::io::csv::{WriteCsvOptions, QuoteStyle};
551/// 
552/// let options = WriteCsvOptions::builder()
553///     .with_header(true)
554///     .with_delimiter(b';')
555///     .quote_style(QuoteStyle::Always)
556///     .build();
557/// ```
558#[derive(Debug, Clone)]
559pub struct WriteCsvOptions {
560    /// 是否写入表头行,默认为 `true`
561    pub has_header: bool,
562    /// 字段分隔符,默认为 `,`
563    pub delimiter: u8,
564    /// 用于表示 null 值的字符串,默认为空字符串 `""`
565    pub na_rep: String,
566    /// 字段的引用样式,默认为 `QuoteStyle::Necessary`
567    pub quote_style: QuoteStyle,
568    /// 行终止符,默认为 `\n`
569    pub line_terminator: String,
570}
571
572impl Default for WriteCsvOptions {
573    fn default() -> Self {
574        WriteCsvOptions {
575            has_header: true,
576            delimiter: b',',
577            na_rep: "".to_string(),
578            quote_style: QuoteStyle::default(),
579            line_terminator: "\n".to_string(),
580        }
581    }
582}
583
584impl WriteCsvOptions {
585    /// 创建一个新的 WriteCsvOptions 构建器,使用默认值
586    pub fn builder() -> WriteCsvOptionsBuilder {
587        WriteCsvOptionsBuilder::new()
588    }
589}
590
591/// WriteCsvOptions 的构建器
592/// 
593/// 提供了一种链式调用的方式来配置 CSV 写入选项。
594#[derive(Debug, Clone, Default)]
595pub struct WriteCsvOptionsBuilder {
596    has_header: Option<bool>,
597    delimiter: Option<u8>,
598    na_rep: Option<String>,
599    quote_style: Option<QuoteStyle>,
600    line_terminator: Option<String>,
601}
602
603impl WriteCsvOptionsBuilder {
604    /// 创建一个新的构建器实例
605    pub fn new() -> Self {
606        Default::default()
607    }
608
609    /// 设置是否写入表头行
610    pub fn with_header(mut self, has_header: bool) -> Self {
611        self.has_header = Some(has_header);
612        self
613    }
614
615    /// 设置字段分隔符
616    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
617        self.delimiter = Some(delimiter);
618        self
619    }
620
621    /// 设置用于表示 null 值的字符串
622    pub fn na_representation(mut self, na_rep: String) -> Self {
623        self.na_rep = Some(na_rep);
624        self
625    }
626
627    /// 设置字段的引用样式
628    pub fn quote_style(mut self, style: QuoteStyle) -> Self {
629        self.quote_style = Some(style);
630        self
631    }
632
633    /// 设置行终止符
634    /// 
635    /// 例如:`"\n"` (LF), `"\r\n"` (CRLF)
636    pub fn line_terminator(mut self, terminator: String) -> Self {
637        self.line_terminator = Some(terminator);
638        self
639    }
640
641    /// 构建最终的 WriteCsvOptions 实例
642    /// 
643    /// 未在构建器中设置的字段将使用默认值
644    pub fn build(self) -> WriteCsvOptions {
645        let defaults = WriteCsvOptions::default();
646        WriteCsvOptions {
647            has_header: self.has_header.unwrap_or(defaults.has_header),
648            delimiter: self.delimiter.unwrap_or(defaults.delimiter),
649            na_rep: self.na_rep.unwrap_or(defaults.na_rep),
650            quote_style: self.quote_style.unwrap_or(defaults.quote_style),
651            line_terminator: self.line_terminator.unwrap_or(defaults.line_terminator),
652        }
653    }
654}