Skip to main content

entrenar/config/train/batches/
json.rs

1//! JSON batch loading
2
3use super::super::demo::create_demo_batches;
4use crate::error::{Error, Result};
5use crate::train::Batch;
6use crate::Tensor;
7use std::path::Path;
8
9/// Load batches from JSON file
10pub fn load_json_batches(path: &Path, batch_size: usize) -> Result<Vec<Batch>> {
11    println!("  Loading JSON: {}", path.display());
12
13    // Try to load as JSON array of {input, target} objects
14    let content = std::fs::read_to_string(path).map_err(|e| {
15        Error::ConfigError(format!("Failed to read JSON {}: {}", path.display(), e))
16    })?;
17
18    #[derive(serde::Deserialize)]
19    struct Example {
20        input: Vec<f32>,
21        target: Vec<f32>,
22    }
23
24    #[derive(serde::Deserialize)]
25    struct DataFile {
26        examples: Vec<Example>,
27    }
28
29    // Try structured format first
30    if let Ok(data) = serde_json::from_str::<DataFile>(&content) {
31        println!("  Loaded {} examples from JSON", data.examples.len());
32        let batches: Vec<Batch> = data
33            .examples
34            .chunks(batch_size.max(1))
35            .map(|chunk| {
36                let input_data: Vec<f32> = chunk.iter().flat_map(|ex| ex.input.clone()).collect();
37                let target_data: Vec<f32> = chunk.iter().flat_map(|ex| ex.target.clone()).collect();
38                Batch::new(
39                    Tensor::from_vec(input_data, false),
40                    Tensor::from_vec(target_data, false),
41                )
42            })
43            .collect();
44        return Ok(batches);
45    }
46
47    // Try array of examples
48    if let Ok(examples) = serde_json::from_str::<Vec<Example>>(&content) {
49        println!("  Loaded {} examples from JSON array", examples.len());
50        let batches: Vec<Batch> = examples
51            .chunks(batch_size.max(1))
52            .map(|chunk| {
53                let input_data: Vec<f32> = chunk.iter().flat_map(|ex| ex.input.clone()).collect();
54                let target_data: Vec<f32> = chunk.iter().flat_map(|ex| ex.target.clone()).collect();
55                Batch::new(
56                    Tensor::from_vec(input_data, false),
57                    Tensor::from_vec(target_data, false),
58                )
59            })
60            .collect();
61        return Ok(batches);
62    }
63
64    eprintln!("Warning: Could not parse JSON data format, using demo data");
65    Ok(create_demo_batches(batch_size))
66}