entrenar/config/train/batches/
parquet.rs1use super::super::arrow::arrow_array_to_f32;
4use super::super::demo::create_demo_batches;
5use super::rebatch::rebatch;
6use crate::error::{Error, Result};
7use crate::train::Batch;
8use crate::Tensor;
9use alimentar::{ArrowDataset, Dataset};
10use arrow::datatypes::Schema;
11use arrow::record_batch::RecordBatch;
12use std::path::Path;
13
14struct ColumnPair<'a> {
16 input_name: &'a str,
17 target_name: &'a str,
18}
19
20fn detect_input_column<'a>(column_names: &[&'a str]) -> Option<&'a str> {
22 column_names
23 .iter()
24 .find(|&&n| n == "input" || n == "input_ids" || n == "x" || n == "features")
25 .copied()
26}
27
28fn detect_target_column<'a>(column_names: &[&'a str]) -> Option<&'a str> {
30 column_names
31 .iter()
32 .find(|&&n| n == "target" || n == "output" || n == "labels" || n == "y")
33 .copied()
34}
35
36fn detect_columns<'a>(column_names: &[&'a str]) -> Option<ColumnPair<'a>> {
38 let input_name = detect_input_column(column_names)?;
39 let target_name = detect_target_column(column_names)?;
40 Some(ColumnPair { input_name, target_name })
41}
42
43fn handle_missing_columns(column_names: &[&str], batch_size: usize) -> Vec<Batch> {
45 eprintln!("Warning: Could not find input/target columns in parquet (found: {column_names:?})");
46 eprintln!(" Expected columns like: input/target, x/y, features/labels");
47 create_demo_batches(batch_size)
48}
49
50fn record_batch_to_training_batch(
52 record_batch: &RecordBatch,
53 schema: &Schema,
54 input_name: &str,
55 target_name: &str,
56) -> Result<Batch> {
57 let input_idx = schema
58 .index_of(input_name)
59 .map_err(|e| Error::ConfigError(format!("Column not found: {e}")))?;
60 let target_idx = schema
61 .index_of(target_name)
62 .map_err(|e| Error::ConfigError(format!("Column not found: {e}")))?;
63
64 let input_array = record_batch.column(input_idx);
65 let target_array = record_batch.column(target_idx);
66
67 let input_data = arrow_array_to_f32(input_array)?;
68 let target_data = arrow_array_to_f32(target_array)?;
69
70 Ok(Batch::new(Tensor::from_vec(input_data, false), Tensor::from_vec(target_data, false)))
71}
72
73fn process_record_batches(dataset: &ArrowDataset, columns: &ColumnPair<'_>) -> Result<Vec<Batch>> {
75 let schema = dataset.schema();
76 let mut batches = Vec::new();
77
78 for record_batch in dataset.iter() {
79 let batch = record_batch_to_training_batch(
80 &record_batch,
81 &schema,
82 columns.input_name,
83 columns.target_name,
84 )?;
85 batches.push(batch);
86 }
87
88 Ok(batches)
89}
90
91pub fn load_parquet_batches(path: &Path, batch_size: usize) -> Result<Vec<Batch>> {
93 println!(" Loading parquet: {}", path.display());
94
95 let dataset = ArrowDataset::from_parquet(path).map_err(|e| {
96 Error::ConfigError(format!("Failed to load parquet {}: {}", path.display(), e))
97 })?;
98
99 println!(" Loaded {} rows from parquet", dataset.len());
100
101 let schema = dataset.schema();
102 let column_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
103
104 let columns = match detect_columns(&column_names) {
105 Some(cols) => cols,
106 None => return Ok(handle_missing_columns(&column_names, batch_size)),
107 };
108
109 println!(" Using columns: input='{}', target='{}'", columns.input_name, columns.target_name);
110
111 let mut batches = process_record_batches(&dataset, &columns)?;
112
113 if batches.len() > 1 && batch_size > 0 {
115 batches = rebatch(batches, batch_size);
116 }
117
118 Ok(batches)
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use arrow::array::{Float32Array, Float64Array, Int32Array};
125 use arrow::datatypes::{DataType, Field};
126 use std::sync::Arc;
127
128 fn make_test_schema() -> Schema {
129 Schema::new(vec![
130 Field::new("input", DataType::Float32, false),
131 Field::new("target", DataType::Float32, false),
132 ])
133 }
134
135 fn make_test_record_batch() -> RecordBatch {
136 let schema = Arc::new(make_test_schema());
137 let input = Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]);
138 let target = Float32Array::from(vec![0.0, 1.0, 0.0, 1.0]);
139 RecordBatch::try_new(schema, vec![Arc::new(input), Arc::new(target)])
140 .expect("conversion should succeed")
141 }
142
143 #[test]
144 fn test_detect_input_column_input() {
145 let cols = vec!["input", "target"];
146 assert_eq!(detect_input_column(&cols), Some("input"));
147 }
148
149 #[test]
150 fn test_detect_input_column_input_ids() {
151 let cols = vec!["input_ids", "labels"];
152 assert_eq!(detect_input_column(&cols), Some("input_ids"));
153 }
154
155 #[test]
156 fn test_detect_input_column_x() {
157 let cols = vec!["x", "y"];
158 assert_eq!(detect_input_column(&cols), Some("x"));
159 }
160
161 #[test]
162 fn test_detect_input_column_features() {
163 let cols = vec!["features", "labels"];
164 assert_eq!(detect_input_column(&cols), Some("features"));
165 }
166
167 #[test]
168 fn test_detect_input_column_none() {
169 let cols = vec!["foo", "bar"];
170 assert_eq!(detect_input_column(&cols), None);
171 }
172
173 #[test]
174 fn test_detect_target_column_target() {
175 let cols = vec!["input", "target"];
176 assert_eq!(detect_target_column(&cols), Some("target"));
177 }
178
179 #[test]
180 fn test_detect_target_column_output() {
181 let cols = vec!["input", "output"];
182 assert_eq!(detect_target_column(&cols), Some("output"));
183 }
184
185 #[test]
186 fn test_detect_target_column_labels() {
187 let cols = vec!["features", "labels"];
188 assert_eq!(detect_target_column(&cols), Some("labels"));
189 }
190
191 #[test]
192 fn test_detect_target_column_y() {
193 let cols = vec!["x", "y"];
194 assert_eq!(detect_target_column(&cols), Some("y"));
195 }
196
197 #[test]
198 fn test_detect_target_column_none() {
199 let cols = vec!["foo", "bar"];
200 assert_eq!(detect_target_column(&cols), None);
201 }
202
203 #[test]
204 fn test_detect_columns_success() {
205 let cols = vec!["input", "target"];
206 let result = detect_columns(&cols);
207 assert!(result.is_some());
208 let pair = result.expect("operation should succeed");
209 assert_eq!(pair.input_name, "input");
210 assert_eq!(pair.target_name, "target");
211 }
212
213 #[test]
214 fn test_detect_columns_missing_input() {
215 let cols = vec!["foo", "target"];
216 assert!(detect_columns(&cols).is_none());
217 }
218
219 #[test]
220 fn test_detect_columns_missing_target() {
221 let cols = vec!["input", "bar"];
222 assert!(detect_columns(&cols).is_none());
223 }
224
225 #[test]
226 fn test_handle_missing_columns_returns_demo_batches() {
227 let cols = vec!["foo", "bar"];
228 let batches = handle_missing_columns(&cols, 32);
229 assert!(!batches.is_empty());
230 }
231
232 #[test]
233 fn test_record_batch_to_training_batch_success() {
234 let record_batch = make_test_record_batch();
235 let schema = make_test_schema();
236 let result = record_batch_to_training_batch(&record_batch, &schema, "input", "target");
237 assert!(result.is_ok());
238 let batch = result.expect("operation should succeed");
239 assert_eq!(batch.inputs.data().len(), 4);
240 assert_eq!(batch.targets.data().len(), 4);
241 }
242
243 #[test]
244 fn test_record_batch_to_training_batch_invalid_input_column() {
245 let record_batch = make_test_record_batch();
246 let schema = make_test_schema();
247 let result =
248 record_batch_to_training_batch(&record_batch, &schema, "nonexistent", "target");
249 assert!(result.is_err());
250 }
251
252 #[test]
253 fn test_record_batch_to_training_batch_invalid_target_column() {
254 let record_batch = make_test_record_batch();
255 let schema = make_test_schema();
256 let result = record_batch_to_training_batch(&record_batch, &schema, "input", "nonexistent");
257 assert!(result.is_err());
258 }
259
260 #[test]
261 fn test_record_batch_with_float64() {
262 let schema = Arc::new(Schema::new(vec![
263 Field::new("x", DataType::Float64, false),
264 Field::new("y", DataType::Float64, false),
265 ]));
266 let input = Float64Array::from(vec![1.0, 2.0, 3.0]);
267 let target = Float64Array::from(vec![0.0, 1.0, 2.0]);
268 let record_batch =
269 RecordBatch::try_new(schema.clone(), vec![Arc::new(input), Arc::new(target)])
270 .expect("conversion should succeed");
271
272 let result = record_batch_to_training_batch(&record_batch, &schema, "x", "y");
273 assert!(result.is_ok());
274 }
275
276 #[test]
277 fn test_record_batch_with_int32() {
278 let schema = Arc::new(Schema::new(vec![
279 Field::new("features", DataType::Int32, false),
280 Field::new("labels", DataType::Int32, false),
281 ]));
282 let input = Int32Array::from(vec![1, 2, 3]);
283 let target = Int32Array::from(vec![0, 1, 0]);
284 let record_batch =
285 RecordBatch::try_new(schema.clone(), vec![Arc::new(input), Arc::new(target)])
286 .expect("conversion should succeed");
287
288 let result = record_batch_to_training_batch(&record_batch, &schema, "features", "labels");
289 assert!(result.is_ok());
290 }
291
292 #[test]
293 fn test_column_pair_fields() {
294 let pair = ColumnPair { input_name: "input", target_name: "target" };
295 assert_eq!(pair.input_name, "input");
296 assert_eq!(pair.target_name, "target");
297 }
298}