entrenar/config/train/batches/
json.rs1use super::super::demo::create_demo_batches;
4use crate::error::{Error, Result};
5use crate::train::Batch;
6use crate::Tensor;
7use std::path::Path;
8
9pub fn load_json_batches(path: &Path, batch_size: usize) -> Result<Vec<Batch>> {
11 println!(" Loading JSON: {}", path.display());
12
13 let content = std::fs::read_to_string(path).map_err(|e| {
15 Error::ConfigError(format!("Failed to read JSON {}: {}", path.display(), e))
16 })?;
17
18 #[derive(serde::Deserialize)]
19 struct Example {
20 input: Vec<f32>,
21 target: Vec<f32>,
22 }
23
24 #[derive(serde::Deserialize)]
25 struct DataFile {
26 examples: Vec<Example>,
27 }
28
29 if let Ok(data) = serde_json::from_str::<DataFile>(&content) {
31 println!(" Loaded {} examples from JSON", data.examples.len());
32 let batches: Vec<Batch> = data
33 .examples
34 .chunks(batch_size.max(1))
35 .map(|chunk| {
36 let input_data: Vec<f32> = chunk.iter().flat_map(|ex| ex.input.clone()).collect();
37 let target_data: Vec<f32> = chunk.iter().flat_map(|ex| ex.target.clone()).collect();
38 Batch::new(
39 Tensor::from_vec(input_data, false),
40 Tensor::from_vec(target_data, false),
41 )
42 })
43 .collect();
44 return Ok(batches);
45 }
46
47 if let Ok(examples) = serde_json::from_str::<Vec<Example>>(&content) {
49 println!(" Loaded {} examples from JSON array", examples.len());
50 let batches: Vec<Batch> = examples
51 .chunks(batch_size.max(1))
52 .map(|chunk| {
53 let input_data: Vec<f32> = chunk.iter().flat_map(|ex| ex.input.clone()).collect();
54 let target_data: Vec<f32> = chunk.iter().flat_map(|ex| ex.target.clone()).collect();
55 Batch::new(
56 Tensor::from_vec(input_data, false),
57 Tensor::from_vec(target_data, false),
58 )
59 })
60 .collect();
61 return Ok(batches);
62 }
63
64 eprintln!("Warning: Could not parse JSON data format, using demo data");
65 Ok(create_demo_batches(batch_size))
66}