1use rayon::iter::{IntoParallelIterator, ParallelIterator};
2
3use crate::dataframe::DataFrame;
4use crate::series::{Series, SeriesTrait};
5use crate::AxionResult;
6use crate::AxionError;
7use crate::dtype::DataType;
8use std::fs::File;
9use std::path::Path;
10use std::collections::{HashMap, HashSet};
11use std::io::{BufReader, BufRead};
12
13#[derive(Debug, Clone)]
37pub struct ReadCsvOptions {
38 pub delimiter: u8,
40 pub has_header: bool,
43 pub infer_schema: bool,
46 pub infer_schema_length: Option<usize>,
49 pub dtypes: Option<HashMap<String, DataType>>,
52 pub skip_rows: usize,
54 pub comment_char: Option<u8>,
56 pub use_columns: Option<Vec<String>>,
59 pub na_values: Option<HashSet<String>>,
61}
62
63impl Default for ReadCsvOptions {
64 fn default() -> Self {
65 ReadCsvOptions {
66 delimiter: b',',
67 has_header: true,
68 infer_schema: true,
69 infer_schema_length: Some(100),
70 dtypes: None,
71 skip_rows: 0,
72 comment_char: None,
73 use_columns: None,
74 na_values: None,
75 }
76 }
77}
78
79impl ReadCsvOptions {
80 pub fn builder() -> ReadCsvOptionsBuilder {
82 ReadCsvOptionsBuilder::new()
83 }
84}
85
86#[derive(Debug, Clone, Default)]
100pub struct ReadCsvOptionsBuilder {
101 delimiter: Option<u8>,
102 has_header: Option<bool>,
103 infer_schema: Option<bool>,
104 infer_schema_length: Option<Option<usize>>,
105 dtypes: Option<HashMap<String, DataType>>,
106 skip_rows: Option<usize>,
107 comment_char: Option<Option<u8>>,
108 use_columns: Option<Vec<String>>,
109 na_values: Option<HashSet<String>>,
110}
111
112impl ReadCsvOptionsBuilder {
113 pub fn new() -> Self {
115 Default::default()
116 }
117
118 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
120 self.delimiter = Some(delimiter);
121 self
122 }
123
124 pub fn with_header(mut self, has_header: bool) -> Self {
126 self.has_header = Some(has_header);
127 self
128 }
129
130 pub fn infer_schema(mut self, infer: bool) -> Self {
132 self.infer_schema = Some(infer);
133 self
134 }
135
136 pub fn infer_schema_length(mut self, length: Option<usize>) -> Self {
138 self.infer_schema_length = Some(length);
139 self
140 }
141
142 pub fn with_dtypes(mut self, dtypes: HashMap<String, DataType>) -> Self {
144 self.dtypes = Some(dtypes);
145 self
146 }
147
148 pub fn add_dtype(mut self, column_name: String, dtype: DataType) -> Self {
150 self.dtypes.get_or_insert_with(HashMap::new).insert(column_name, dtype);
151 self
152 }
153
154 pub fn skip_rows(mut self, n: usize) -> Self {
156 self.skip_rows = Some(n);
157 self
158 }
159
160 pub fn comment_char(mut self, char_opt: Option<u8>) -> Self {
162 self.comment_char = Some(char_opt);
163 self
164 }
165
166 pub fn use_columns(mut self, columns: Vec<String>) -> Self {
168 self.use_columns = Some(columns);
169 self
170 }
171
172 pub fn add_use_column(mut self, column_name: String) -> Self {
174 self.use_columns.get_or_insert_with(Vec::new).push(column_name);
175 self
176 }
177
178 pub fn na_values(mut self, values: Option<HashSet<String>>) -> Self {
180 self.na_values = values;
181 self
182 }
183
184 pub fn add_na_value(mut self, value: String) -> Self {
186 self.na_values
187 .get_or_insert_with(HashSet::new)
188 .insert(value);
189 self
190 }
191
192 pub fn build(self) -> ReadCsvOptions {
194 let defaults = ReadCsvOptions::default();
195 ReadCsvOptions {
196 delimiter: self.delimiter.unwrap_or(defaults.delimiter),
197 has_header: self.has_header.unwrap_or(defaults.has_header),
198 infer_schema: self.infer_schema.unwrap_or(defaults.infer_schema),
199 infer_schema_length: self.infer_schema_length.unwrap_or(defaults.infer_schema_length),
200 dtypes: self.dtypes.or(defaults.dtypes),
201 skip_rows: self.skip_rows.unwrap_or(defaults.skip_rows),
202 comment_char: self.comment_char.unwrap_or(defaults.comment_char),
203 use_columns: self.use_columns.or(defaults.use_columns),
204 na_values: self.na_values.or(defaults.na_values),
205 }
206 }
207}
208
209fn try_parse_i64(s: &str) -> Option<i64> {
211 s.parse::<i64>().ok()
212}
213
214fn try_parse_f64(s: &str) -> Option<f64> {
216 s.parse::<f64>().ok()
217}
218
219fn try_parse_bool(s: &str) -> Option<bool> {
221 match s.to_lowercase().as_str() {
222 "true" | "t" | "yes" | "y" | "1" => Some(true),
223 "false" | "f" | "no" | "n" | "0" => Some(false),
224 _ => None,
225 }
226}
227
228fn infer_column_type(
230 column_values: &[Option<String>],
231 infer_length: Option<usize>,
232) -> DataType {
233 let non_empty_values = column_values
234 .iter()
235 .filter_map(|opt_s| opt_s.as_ref().map(|s| s.as_str()))
236 .filter(|s| !s.is_empty());
237
238 let sample: Vec<&str> = if let Some(len) = infer_length {
239 non_empty_values.take(len).collect()
240 } else {
241 non_empty_values.collect()
242 };
243
244 if sample.is_empty() {
245 return DataType::String;
246 }
247
248 if sample.iter().all(|s| try_parse_i64(s).is_some()) {
249 return DataType::Int64;
250 }
251 if sample.iter().all(|s| try_parse_f64(s).is_some()) {
252 return DataType::Float64;
253 }
254 if sample.iter().all(|s| try_parse_bool(s).is_some()) {
255 return DataType::Bool;
256 }
257 DataType::String
258}
259
260fn parse_column_as_type(
262 column_name: String,
263 string_data: Vec<Option<String>>,
264 target_type: &DataType,
265) -> AxionResult<Box<dyn SeriesTrait>> {
266 match target_type {
267 DataType::Int64 => {
268 let parsed_data: Vec<Option<i64>> = string_data
269 .into_iter()
270 .map(|opt_s| opt_s.and_then(|s| try_parse_i64(&s)))
271 .collect();
272 Ok(Box::new(Series::<i64>::new_from_options(column_name, parsed_data)))
273 }
274 DataType::Float64 => {
275 let parsed_data: Vec<Option<f64>> = string_data
276 .into_iter()
277 .map(|opt_s| opt_s.and_then(|s| try_parse_f64(&s)))
278 .collect();
279 Ok(Box::new(Series::<f64>::new_from_options(column_name, parsed_data)))
280 }
281 DataType::Bool => {
282 let parsed_data: Vec<Option<bool>> = string_data
283 .into_iter()
284 .map(|opt_s| opt_s.and_then(|s| try_parse_bool(&s)))
285 .collect();
286 Ok(Box::new(Series::<bool>::new_from_options(column_name, parsed_data)))
287 }
288 DataType::String => {
289 Ok(Box::new(Series::<String>::new_from_options(column_name, string_data)))
290 }
291 dt => Err(AxionError::UnsupportedOperation(format!(
292 "无法将 CSV 列 '{}' 解析为类型 {:?}。CSV 解析仅支持 Int64、Float64、Bool、String 类型。",
293 column_name, dt
294 ))),
295 }
296}
297
298pub fn read_csv(filepath: impl AsRef<Path>, options: Option<ReadCsvOptions>) -> AxionResult<DataFrame> {
330 let opts = options.unwrap_or_default();
331
332 let file = File::open(filepath.as_ref())
333 .map_err(|e| AxionError::IoError(format!("无法打开文件 {:?}: {}", filepath.as_ref(), e)))?;
334
335 let mut buf_reader = BufReader::new(file);
336
337 if opts.skip_rows > 0 {
339 let mut line_buffer = String::new();
340 for i in 0..opts.skip_rows {
341 match buf_reader.read_line(&mut line_buffer) {
342 Ok(0) => {
343 return Err(AxionError::CsvError(format!(
344 "CSV 文件行数少于需要跳过的行数 {},在第 {} 行到达文件末尾。",
345 opts.skip_rows, i
346 )));
347 }
348 Ok(_) => {
349 line_buffer.clear();
350 }
351 Err(e) => {
352 return Err(AxionError::IoError(format!("跳过行时出错: {}", e)));
353 }
354 }
355 }
356 }
357
358 let mut rdr_builder = csv::ReaderBuilder::new();
359 rdr_builder.delimiter(opts.delimiter);
360 rdr_builder.has_headers(false);
361 if let Some(comment) = opts.comment_char {
362 rdr_builder.comment(Some(comment));
363 }
364
365 let rdr = rdr_builder.from_reader(buf_reader);
366 let mut records_iter = rdr.into_records();
367
368 let original_file_headers: Vec<String>;
370 let mut first_data_row_buffer: Option<csv::StringRecord> = None;
371
372 if opts.has_header {
373 if let Some(header_result) = records_iter.next() {
374 original_file_headers = header_result
375 .map_err(|e| AxionError::CsvError(format!("读取 CSV 表头失败: {}", e)))?
376 .iter()
377 .map(|s| s.to_string())
378 .collect::<Vec<String>>();
379 if original_file_headers.is_empty() && !Path::new(filepath.as_ref()).metadata().map_or(true, |m| m.len() == 0) {
380 return Err(AxionError::CsvError("CSV 表头行存在但为空。".to_string()));
381 }
382 } else {
383 return Ok(DataFrame::new_empty());
384 }
385 } else if let Some(first_record_result) = records_iter.next() {
386 let record = first_record_result.map_err(|e| AxionError::CsvError(format!("读取第一条数据记录失败: {}", e)))?;
387 if record.iter().all(|field| field.is_empty()) && !record.is_empty() {
388 original_file_headers = (0..record.len()).map(|i| format!("column_{}", i)).collect();
389 } else if record.is_empty() {
390 return Ok(DataFrame::new_empty());
391 } else {
392 original_file_headers = (0..record.len()).map(|i| format!("column_{}", i)).collect();
393 }
394 first_data_row_buffer = Some(record);
395 } else {
396 return Ok(DataFrame::new_empty());
397 }
398 if original_file_headers.is_empty() {
399 return Ok(DataFrame::new_empty());
400 }
401
402 let (final_headers_to_use, column_indices_to_read): (Vec<String>, Vec<usize>) =
403 if let Some(ref wanted_columns) = opts.use_columns {
404 if wanted_columns.is_empty() {
405 (Vec::new(), Vec::new())
406 } else {
407 let mut final_h = Vec::new();
408 let mut indices = Vec::new();
409 let original_header_map: HashMap<&String, usize> = original_file_headers.iter().enumerate().map(|(i, h_name)| (h_name, i)).collect();
410
411 for col_name_to_use in wanted_columns {
412 if let Some(&original_index) = original_header_map.get(col_name_to_use) {
413 final_h.push(col_name_to_use.clone());
414 indices.push(original_index);
415 } else {
416 return Err(AxionError::CsvError(format!(
417 "use_columns 中指定的列 '{}' 在 CSV 表头中未找到: {:?}",
418 col_name_to_use, original_file_headers
419 )));
420 }
421 }
422 (final_h, indices)
423 }
424 } else {
425 (original_file_headers.clone(), (0..original_file_headers.len()).collect())
426 };
427
428 if final_headers_to_use.is_empty() {
429 return Ok(DataFrame::new_empty());
430 }
431
432 let num_selected_columns = final_headers_to_use.len();
433 let mut column_data_str: Vec<Vec<Option<String>>> = vec![Vec::new(); num_selected_columns];
434
435 let process_record_logic = |record: &csv::StringRecord,
436 col_data_storage: &mut Vec<Vec<Option<String>>>| -> AxionResult<()> {
437
438 if opts.comment_char.is_some() && record.iter().all(|field| field.is_empty()) {
439 return Ok(());
440 }
441
442 if record.len() != original_file_headers.len() {
443 return Err(AxionError::CsvError(format!(
444 "CSV 记录有 {} 个字段,但表头有 {} 个字段。记录: {:?}",
445 record.len(),
446 original_file_headers.len(),
447 record
448 )));
449 }
450
451 for (target_idx, &original_field_idx) in column_indices_to_read.iter().enumerate() {
452 if let Some(field_str_val) = record.get(original_field_idx) {
453 let is_user_defined_na = opts.na_values
454 .as_ref()
455 .is_some_and(|na_set| na_set.contains(field_str_val));
456
457 if is_user_defined_na || field_str_val.is_empty() {
458 col_data_storage[target_idx].push(None);
459 } else {
460 col_data_storage[target_idx].push(Some(field_str_val.to_string()));
461 }
462 } else {
463 return Err(AxionError::CsvError(format!(
464 "内部错误或记录长度不一致: 尝试访问索引 {} 的字段,但记录只有 {} 个字段。",
465 original_field_idx, record.len()
466 )));
467 }
468 }
469 Ok(())
470 };
471
472 if let Some(ref record) = first_data_row_buffer {
473 process_record_logic(record, &mut column_data_str)?
474 }
475
476 for result in records_iter {
477 match result {
478 Ok(record) => {
479 process_record_logic(&record, &mut column_data_str)?
480 }
481 Err(e) => {
482 return Err(AxionError::CsvError(format!("读取 CSV 记录失败: {}", e)));
483 }
484 }
485 }
486
487 let mut data_to_process: Vec<(String, Vec<Option<String>>, DataType)> = Vec::with_capacity(num_selected_columns);
488
489 for i in 0..num_selected_columns {
490 let column_name = final_headers_to_use[i].clone();
491 let current_column_str_data = std::mem::take(&mut column_data_str[i]);
492
493 let final_dtype = if let Some(ref manual_dtypes) = opts.dtypes {
494 manual_dtypes.get(&column_name).cloned().unwrap_or_else(|| {
495 if opts.infer_schema {
496 infer_column_type(¤t_column_str_data, opts.infer_schema_length)
497 } else {
498 DataType::String
499 }
500 })
501 } else if opts.infer_schema {
502 infer_column_type(¤t_column_str_data, opts.infer_schema_length)
503 } else {
504 DataType::String
505 };
506 data_to_process.push((column_name, current_column_str_data, final_dtype));
507 }
508
509 let series_results: Vec<AxionResult<Box<dyn SeriesTrait>>> = data_to_process
510 .into_par_iter()
511 .map(|(col_name, str_data, dtype)| {
512 parse_column_as_type(col_name, str_data, &dtype)
513 })
514 .collect();
515
516 let mut series_vec: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(num_selected_columns);
517 for result in series_results {
518 match result {
519 Ok(series) => series_vec.push(series),
520 Err(e) => return Err(e),
521 }
522 }
523
524 DataFrame::new(series_vec)
525}
526
527#[derive(Default, Debug, Clone, PartialEq, Eq)]
531pub enum QuoteStyle {
532 Always,
534 #[default]
536 Necessary,
537 Never,
539 NonNumeric,
541}
542
543#[derive(Debug, Clone)]
559pub struct WriteCsvOptions {
560 pub has_header: bool,
562 pub delimiter: u8,
564 pub na_rep: String,
566 pub quote_style: QuoteStyle,
568 pub line_terminator: String,
570}
571
572impl Default for WriteCsvOptions {
573 fn default() -> Self {
574 WriteCsvOptions {
575 has_header: true,
576 delimiter: b',',
577 na_rep: "".to_string(),
578 quote_style: QuoteStyle::default(),
579 line_terminator: "\n".to_string(),
580 }
581 }
582}
583
584impl WriteCsvOptions {
585 pub fn builder() -> WriteCsvOptionsBuilder {
587 WriteCsvOptionsBuilder::new()
588 }
589}
590
591#[derive(Debug, Clone, Default)]
595pub struct WriteCsvOptionsBuilder {
596 has_header: Option<bool>,
597 delimiter: Option<u8>,
598 na_rep: Option<String>,
599 quote_style: Option<QuoteStyle>,
600 line_terminator: Option<String>,
601}
602
603impl WriteCsvOptionsBuilder {
604 pub fn new() -> Self {
606 Default::default()
607 }
608
609 pub fn with_header(mut self, has_header: bool) -> Self {
611 self.has_header = Some(has_header);
612 self
613 }
614
615 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
617 self.delimiter = Some(delimiter);
618 self
619 }
620
621 pub fn na_representation(mut self, na_rep: String) -> Self {
623 self.na_rep = Some(na_rep);
624 self
625 }
626
627 pub fn quote_style(mut self, style: QuoteStyle) -> Self {
629 self.quote_style = Some(style);
630 self
631 }
632
633 pub fn line_terminator(mut self, terminator: String) -> Self {
637 self.line_terminator = Some(terminator);
638 self
639 }
640
641 pub fn build(self) -> WriteCsvOptions {
645 let defaults = WriteCsvOptions::default();
646 WriteCsvOptions {
647 has_header: self.has_header.unwrap_or(defaults.has_header),
648 delimiter: self.delimiter.unwrap_or(defaults.delimiter),
649 na_rep: self.na_rep.unwrap_or(defaults.na_rep),
650 quote_style: self.quote_style.unwrap_or(defaults.quote_style),
651 line_terminator: self.line_terminator.unwrap_or(defaults.line_terminator),
652 }
653 }
654}