yscv_model/dataset/
csv.rs1use std::path::Path;
2
3use crate::ModelError;
4
5use super::helpers::{
6 adapter_sample_len, build_supervised_dataset_from_flat_values, load_dataset_text_file,
7 validate_adapter_sample_shape, validate_csv_delimiter, validate_finite_values,
8};
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct SupervisedCsvConfig {
13 input_shape: Vec<usize>,
14 target_shape: Vec<usize>,
15 delimiter: char,
16 has_header: bool,
17}
18
19impl SupervisedCsvConfig {
20 pub fn new(input_shape: Vec<usize>, target_shape: Vec<usize>) -> Result<Self, ModelError> {
21 validate_adapter_sample_shape("input_shape", &input_shape)?;
22 validate_adapter_sample_shape("target_shape", &target_shape)?;
23 Ok(Self {
24 input_shape,
25 target_shape,
26 delimiter: ',',
27 has_header: false,
28 })
29 }
30
31 pub fn with_delimiter(mut self, delimiter: char) -> Result<Self, ModelError> {
32 validate_csv_delimiter(delimiter)?;
33 self.delimiter = delimiter;
34 Ok(self)
35 }
36
37 pub fn with_header(mut self, has_header: bool) -> Self {
38 self.has_header = has_header;
39 self
40 }
41
42 pub fn input_shape(&self) -> &[usize] {
43 &self.input_shape
44 }
45
46 pub fn target_shape(&self) -> &[usize] {
47 &self.target_shape
48 }
49
50 pub fn delimiter(&self) -> char {
51 self.delimiter
52 }
53
54 pub fn has_header(&self) -> bool {
55 self.has_header
56 }
57}
58
59pub fn parse_supervised_dataset_csv(
65 content: &str,
66 config: &SupervisedCsvConfig,
67) -> Result<super::types::SupervisedDataset, ModelError> {
68 let input_row_len = adapter_sample_len("input_shape", config.input_shape())?;
69 let target_row_len = adapter_sample_len("target_shape", config.target_shape())?;
70 let expected_columns = input_row_len.checked_add(target_row_len).ok_or_else(|| {
71 ModelError::InvalidDatasetAdapterShape {
72 field: "row_columns",
73 shape: vec![input_row_len, target_row_len],
74 message: "column count overflow".to_string(),
75 }
76 })?;
77
78 let mut input_values = Vec::new();
79 let mut target_values = Vec::new();
80 let mut sample_count = 0usize;
81 let mut header_skipped = false;
82 let mut row_values = Vec::with_capacity(expected_columns);
83
84 for (line_idx, raw_line) in content.lines().enumerate() {
85 let line_number = line_idx + 1;
86 let line = raw_line.trim();
87 if line.is_empty() || line.starts_with('#') {
88 continue;
89 }
90 if config.has_header() && !header_skipped {
91 header_skipped = true;
92 continue;
93 }
94
95 let columns = line
96 .split(config.delimiter())
97 .map(str::trim)
98 .collect::<Vec<_>>();
99 if columns.len() != expected_columns {
100 return Err(ModelError::InvalidDatasetRecordColumns {
101 line: line_number,
102 expected: expected_columns,
103 got: columns.len(),
104 });
105 }
106
107 row_values.clear();
108 for (column_idx, value_str) in columns.iter().enumerate() {
109 let value = value_str
110 .parse::<f32>()
111 .map_err(|error| ModelError::DatasetCsvParse {
112 line: line_number,
113 column: column_idx + 1,
114 message: error.to_string(),
115 })?;
116 row_values.push(value);
117 }
118
119 let (input_row, target_row) = row_values.split_at(input_row_len);
120 validate_finite_values(line_number, "input", input_row)?;
121 validate_finite_values(line_number, "target", target_row)?;
122 input_values.extend_from_slice(input_row);
123 target_values.extend_from_slice(target_row);
124
125 sample_count =
126 sample_count
127 .checked_add(1)
128 .ok_or_else(|| ModelError::InvalidDatasetAdapterShape {
129 field: "sample_count",
130 shape: vec![sample_count],
131 message: "sample count overflow".to_string(),
132 })?;
133 }
134
135 if sample_count == 0 {
136 return Err(ModelError::EmptyDataset);
137 }
138
139 build_supervised_dataset_from_flat_values(
140 config.input_shape(),
141 config.target_shape(),
142 sample_count,
143 input_values,
144 target_values,
145 )
146}
147
148pub fn load_supervised_dataset_csv_file<P: AsRef<Path>>(
150 path: P,
151 config: &SupervisedCsvConfig,
152) -> Result<super::types::SupervisedDataset, ModelError> {
153 let content = load_dataset_text_file(path)?;
154 parse_supervised_dataset_csv(&content, config)
155}