Skip to main content

entrenar/config/train/batches/
parquet.rs

1//! Parquet batch loading using alimentar
2
3use super::super::arrow::arrow_array_to_f32;
4use super::super::demo::create_demo_batches;
5use super::rebatch::rebatch;
6use crate::error::{Error, Result};
7use crate::train::Batch;
8use crate::Tensor;
9use alimentar::{ArrowDataset, Dataset};
10use arrow::datatypes::Schema;
11use arrow::record_batch::RecordBatch;
12use std::path::Path;
13
14/// Column detection result
15struct ColumnPair<'a> {
16    input_name: &'a str,
17    target_name: &'a str,
18}
19
20/// Detect input column from schema
21fn detect_input_column<'a>(column_names: &[&'a str]) -> Option<&'a str> {
22    column_names
23        .iter()
24        .find(|&&n| n == "input" || n == "input_ids" || n == "x" || n == "features")
25        .copied()
26}
27
28/// Detect target column from schema
29fn detect_target_column<'a>(column_names: &[&'a str]) -> Option<&'a str> {
30    column_names
31        .iter()
32        .find(|&&n| n == "target" || n == "output" || n == "labels" || n == "y")
33        .copied()
34}
35
36/// Detect input/target column pair from schema
37fn detect_columns<'a>(column_names: &[&'a str]) -> Option<ColumnPair<'a>> {
38    let input_name = detect_input_column(column_names)?;
39    let target_name = detect_target_column(column_names)?;
40    Some(ColumnPair { input_name, target_name })
41}
42
43/// Log column detection warning and return demo batches
44fn handle_missing_columns(column_names: &[&str], batch_size: usize) -> Vec<Batch> {
45    eprintln!("Warning: Could not find input/target columns in parquet (found: {column_names:?})");
46    eprintln!("  Expected columns like: input/target, x/y, features/labels");
47    create_demo_batches(batch_size)
48}
49
50/// Convert a single record batch to a training batch
51fn record_batch_to_training_batch(
52    record_batch: &RecordBatch,
53    schema: &Schema,
54    input_name: &str,
55    target_name: &str,
56) -> Result<Batch> {
57    let input_idx = schema
58        .index_of(input_name)
59        .map_err(|e| Error::ConfigError(format!("Column not found: {e}")))?;
60    let target_idx = schema
61        .index_of(target_name)
62        .map_err(|e| Error::ConfigError(format!("Column not found: {e}")))?;
63
64    let input_array = record_batch.column(input_idx);
65    let target_array = record_batch.column(target_idx);
66
67    let input_data = arrow_array_to_f32(input_array)?;
68    let target_data = arrow_array_to_f32(target_array)?;
69
70    Ok(Batch::new(Tensor::from_vec(input_data, false), Tensor::from_vec(target_data, false)))
71}
72
73/// Process all record batches from dataset
74fn process_record_batches(dataset: &ArrowDataset, columns: &ColumnPair<'_>) -> Result<Vec<Batch>> {
75    let schema = dataset.schema();
76    let mut batches = Vec::new();
77
78    for record_batch in dataset.iter() {
79        let batch = record_batch_to_training_batch(
80            &record_batch,
81            &schema,
82            columns.input_name,
83            columns.target_name,
84        )?;
85        batches.push(batch);
86    }
87
88    Ok(batches)
89}
90
91/// Load batches from parquet file using alimentar
92pub fn load_parquet_batches(path: &Path, batch_size: usize) -> Result<Vec<Batch>> {
93    println!("  Loading parquet: {}", path.display());
94
95    let dataset = ArrowDataset::from_parquet(path).map_err(|e| {
96        Error::ConfigError(format!("Failed to load parquet {}: {}", path.display(), e))
97    })?;
98
99    println!("  Loaded {} rows from parquet", dataset.len());
100
101    let schema = dataset.schema();
102    let column_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
103
104    let columns = match detect_columns(&column_names) {
105        Some(cols) => cols,
106        None => return Ok(handle_missing_columns(&column_names, batch_size)),
107    };
108
109    println!("  Using columns: input='{}', target='{}'", columns.input_name, columns.target_name);
110
111    let mut batches = process_record_batches(&dataset, &columns)?;
112
113    // Re-batch to desired batch size if needed
114    if batches.len() > 1 && batch_size > 0 {
115        batches = rebatch(batches, batch_size);
116    }
117
118    Ok(batches)
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use arrow::array::{Float32Array, Float64Array, Int32Array};
125    use arrow::datatypes::{DataType, Field};
126    use std::sync::Arc;
127
128    fn make_test_schema() -> Schema {
129        Schema::new(vec![
130            Field::new("input", DataType::Float32, false),
131            Field::new("target", DataType::Float32, false),
132        ])
133    }
134
135    fn make_test_record_batch() -> RecordBatch {
136        let schema = Arc::new(make_test_schema());
137        let input = Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]);
138        let target = Float32Array::from(vec![0.0, 1.0, 0.0, 1.0]);
139        RecordBatch::try_new(schema, vec![Arc::new(input), Arc::new(target)])
140            .expect("conversion should succeed")
141    }
142
143    #[test]
144    fn test_detect_input_column_input() {
145        let cols = vec!["input", "target"];
146        assert_eq!(detect_input_column(&cols), Some("input"));
147    }
148
149    #[test]
150    fn test_detect_input_column_input_ids() {
151        let cols = vec!["input_ids", "labels"];
152        assert_eq!(detect_input_column(&cols), Some("input_ids"));
153    }
154
155    #[test]
156    fn test_detect_input_column_x() {
157        let cols = vec!["x", "y"];
158        assert_eq!(detect_input_column(&cols), Some("x"));
159    }
160
161    #[test]
162    fn test_detect_input_column_features() {
163        let cols = vec!["features", "labels"];
164        assert_eq!(detect_input_column(&cols), Some("features"));
165    }
166
167    #[test]
168    fn test_detect_input_column_none() {
169        let cols = vec!["foo", "bar"];
170        assert_eq!(detect_input_column(&cols), None);
171    }
172
173    #[test]
174    fn test_detect_target_column_target() {
175        let cols = vec!["input", "target"];
176        assert_eq!(detect_target_column(&cols), Some("target"));
177    }
178
179    #[test]
180    fn test_detect_target_column_output() {
181        let cols = vec!["input", "output"];
182        assert_eq!(detect_target_column(&cols), Some("output"));
183    }
184
185    #[test]
186    fn test_detect_target_column_labels() {
187        let cols = vec!["features", "labels"];
188        assert_eq!(detect_target_column(&cols), Some("labels"));
189    }
190
191    #[test]
192    fn test_detect_target_column_y() {
193        let cols = vec!["x", "y"];
194        assert_eq!(detect_target_column(&cols), Some("y"));
195    }
196
197    #[test]
198    fn test_detect_target_column_none() {
199        let cols = vec!["foo", "bar"];
200        assert_eq!(detect_target_column(&cols), None);
201    }
202
203    #[test]
204    fn test_detect_columns_success() {
205        let cols = vec!["input", "target"];
206        let result = detect_columns(&cols);
207        assert!(result.is_some());
208        let pair = result.expect("operation should succeed");
209        assert_eq!(pair.input_name, "input");
210        assert_eq!(pair.target_name, "target");
211    }
212
213    #[test]
214    fn test_detect_columns_missing_input() {
215        let cols = vec!["foo", "target"];
216        assert!(detect_columns(&cols).is_none());
217    }
218
219    #[test]
220    fn test_detect_columns_missing_target() {
221        let cols = vec!["input", "bar"];
222        assert!(detect_columns(&cols).is_none());
223    }
224
225    #[test]
226    fn test_handle_missing_columns_returns_demo_batches() {
227        let cols = vec!["foo", "bar"];
228        let batches = handle_missing_columns(&cols, 32);
229        assert!(!batches.is_empty());
230    }
231
232    #[test]
233    fn test_record_batch_to_training_batch_success() {
234        let record_batch = make_test_record_batch();
235        let schema = make_test_schema();
236        let result = record_batch_to_training_batch(&record_batch, &schema, "input", "target");
237        assert!(result.is_ok());
238        let batch = result.expect("operation should succeed");
239        assert_eq!(batch.inputs.data().len(), 4);
240        assert_eq!(batch.targets.data().len(), 4);
241    }
242
243    #[test]
244    fn test_record_batch_to_training_batch_invalid_input_column() {
245        let record_batch = make_test_record_batch();
246        let schema = make_test_schema();
247        let result =
248            record_batch_to_training_batch(&record_batch, &schema, "nonexistent", "target");
249        assert!(result.is_err());
250    }
251
252    #[test]
253    fn test_record_batch_to_training_batch_invalid_target_column() {
254        let record_batch = make_test_record_batch();
255        let schema = make_test_schema();
256        let result = record_batch_to_training_batch(&record_batch, &schema, "input", "nonexistent");
257        assert!(result.is_err());
258    }
259
260    #[test]
261    fn test_record_batch_with_float64() {
262        let schema = Arc::new(Schema::new(vec![
263            Field::new("x", DataType::Float64, false),
264            Field::new("y", DataType::Float64, false),
265        ]));
266        let input = Float64Array::from(vec![1.0, 2.0, 3.0]);
267        let target = Float64Array::from(vec![0.0, 1.0, 2.0]);
268        let record_batch =
269            RecordBatch::try_new(schema.clone(), vec![Arc::new(input), Arc::new(target)])
270                .expect("conversion should succeed");
271
272        let result = record_batch_to_training_batch(&record_batch, &schema, "x", "y");
273        assert!(result.is_ok());
274    }
275
276    #[test]
277    fn test_record_batch_with_int32() {
278        let schema = Arc::new(Schema::new(vec![
279            Field::new("features", DataType::Int32, false),
280            Field::new("labels", DataType::Int32, false),
281        ]));
282        let input = Int32Array::from(vec![1, 2, 3]);
283        let target = Int32Array::from(vec![0, 1, 0]);
284        let record_batch =
285            RecordBatch::try_new(schema.clone(), vec![Arc::new(input), Arc::new(target)])
286                .expect("conversion should succeed");
287
288        let result = record_batch_to_training_batch(&record_batch, &schema, "features", "labels");
289        assert!(result.is_ok());
290    }
291
292    #[test]
293    fn test_column_pair_fields() {
294        let pair = ColumnPair { input_name: "input", target_name: "target" };
295        assert_eq!(pair.input_name, "input");
296        assert_eq!(pair.target_name, "target");
297    }
298}