use super::*;
use crate::train::Batch;
use crate::Tensor;
#[cfg(not(target_arch = "wasm32"))]
use std::io::Write;
#[cfg(not(target_arch = "wasm32"))]
use tempfile::NamedTempFile;
fn make_batch(input: Vec<f32>, target: Vec<f32>) -> Batch {
Batch::new(Tensor::from_vec(input, false), Tensor::from_vec(target, false))
}
#[test]
fn test_rebatch_empty() {
let batches: Vec<Batch> = Vec::new();
let result = rebatch(batches, 4);
assert!(result.is_empty());
}
#[test]
fn test_rebatch_single_batch() {
let batches = vec![make_batch(vec![1.0, 2.0, 3.0, 4.0], vec![0.1, 0.2])];
let result = rebatch(batches, 2);
assert_eq!(result.len(), 1);
}
#[test]
fn test_rebatch_multiple_batches() {
let batches = vec![
make_batch(vec![1.0, 2.0], vec![0.1]),
make_batch(vec![3.0, 4.0], vec![0.2]),
make_batch(vec![5.0, 6.0], vec![0.3]),
];
let result = rebatch(batches, 2);
assert_eq!(result.len(), 2); }
#[test]
fn test_rebatch_batch_size_one() {
let batches = vec![make_batch(vec![1.0, 2.0, 3.0], vec![0.1, 0.2, 0.3])];
let result = rebatch(batches, 1);
assert_eq!(result.len(), 1);
}
#[test]
fn test_rebatch_large_batch_size() {
let batches = vec![make_batch(vec![1.0], vec![0.1]), make_batch(vec![2.0], vec![0.2])];
let result = rebatch(batches, 100);
assert_eq!(result.len(), 1); }
#[test]
fn test_load_training_batches_nonexistent_file() {
use crate::config::schema::TrainSpec;
use std::path::PathBuf;
let spec = TrainSpec {
model: crate::config::ModelRef {
path: PathBuf::from("model.bin"),
layers: vec![],
..Default::default()
},
data: crate::config::DataConfig {
train: PathBuf::from("/nonexistent/path/data.parquet"),
val: None,
batch_size: 8,
auto_infer_types: false,
seq_len: None,
..Default::default()
},
optimizer: crate::config::OptimSpec {
name: "adam".to_string(),
lr: 0.001,
params: Default::default(),
},
lora: None,
quantize: None,
merge: None,
training: Default::default(),
publish: None,
};
let result = load_training_batches(&spec);
assert!(result.is_ok());
let batches = result.expect("operation should succeed");
assert!(!batches.is_empty()); }
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_load_training_batches_unsupported_format() {
use crate::config::schema::TrainSpec;
use std::path::PathBuf;
let mut temp_file =
NamedTempFile::with_suffix(".xyz").expect("temp file creation should succeed");
writeln!(temp_file, "test data").expect("file write should succeed");
let spec = TrainSpec {
model: crate::config::ModelRef {
path: PathBuf::from("model.bin"),
layers: vec![],
..Default::default()
},
data: crate::config::DataConfig {
train: temp_file.path().to_path_buf(),
val: None,
batch_size: 8,
auto_infer_types: false,
seq_len: None,
..Default::default()
},
optimizer: crate::config::OptimSpec {
name: "adam".to_string(),
lr: 0.001,
params: Default::default(),
},
lora: None,
quantize: None,
merge: None,
training: Default::default(),
publish: None,
};
let result = load_training_batches(&spec);
assert!(result.is_ok());
let batches = result.expect("operation should succeed");
assert!(!batches.is_empty()); }
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_load_json_batches_structured_format() {
let mut temp_file =
NamedTempFile::with_suffix(".json").expect("temp file creation should succeed");
writeln!(
temp_file,
r#"{{
"examples": [
{{"input": [1.0, 2.0], "target": [0.1]}},
{{"input": [3.0, 4.0], "target": [0.2]}}
]
}}"#
)
.expect("operation should succeed");
let result = load_json_batches(temp_file.path(), 2);
assert!(result.is_ok());
let batches = result.expect("operation should succeed");
assert!(!batches.is_empty());
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_load_json_batches_array_format() {
let mut temp_file =
NamedTempFile::with_suffix(".json").expect("temp file creation should succeed");
writeln!(
temp_file,
r#"[
{{"input": [1.0, 2.0], "target": [0.1]}},
{{"input": [3.0, 4.0], "target": [0.2]}},
{{"input": [5.0, 6.0], "target": [0.3]}}
]"#
)
.expect("operation should succeed");
let result = load_json_batches(temp_file.path(), 2);
assert!(result.is_ok());
let batches = result.expect("operation should succeed");
assert!(!batches.is_empty());
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_load_json_batches_invalid_format() {
let mut temp_file =
NamedTempFile::with_suffix(".json").expect("temp file creation should succeed");
writeln!(temp_file, r#"{{"some": "other", "format": true}}"#)
.expect("file write should succeed");
let result = load_json_batches(temp_file.path(), 2);
assert!(result.is_ok());
let batches = result.expect("operation should succeed");
assert!(!batches.is_empty());
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_load_json_batches_file_not_found() {
use std::path::Path;
let result = load_json_batches(Path::new("/nonexistent/file.json"), 2);
assert!(result.is_err());
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_load_json_batches_batch_size_one() {
let mut temp_file =
NamedTempFile::with_suffix(".json").expect("temp file creation should succeed");
writeln!(
temp_file,
r#"[
{{"input": [1.0], "target": [0.1]}},
{{"input": [2.0], "target": [0.2]}}
]"#
)
.expect("operation should succeed");
let result = load_json_batches(temp_file.path(), 1);
assert!(result.is_ok());
let batches = result.expect("operation should succeed");
assert_eq!(batches.len(), 2);
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_load_json_batches_batch_size_zero() {
let mut temp_file =
NamedTempFile::with_suffix(".json").expect("temp file creation should succeed");
writeln!(temp_file, r#"[{{"input": [1.0], "target": [0.1]}}]"#)
.expect("file write should succeed");
let result = load_json_batches(temp_file.path(), 0);
assert!(result.is_ok());
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_load_training_batches_json_file() {
use crate::config::schema::TrainSpec;
use std::path::PathBuf;
let mut temp_file =
NamedTempFile::with_suffix(".json").expect("temp file creation should succeed");
writeln!(temp_file, r#"[{{"input": [1.0, 2.0], "target": [0.5]}}]"#)
.expect("file write should succeed");
let spec = TrainSpec {
model: crate::config::ModelRef {
path: PathBuf::from("model.bin"),
layers: vec![],
..Default::default()
},
data: crate::config::DataConfig {
train: temp_file.path().to_path_buf(),
val: None,
batch_size: 4,
auto_infer_types: false,
seq_len: None,
..Default::default()
},
optimizer: crate::config::OptimSpec {
name: "adam".to_string(),
lr: 0.001,
params: Default::default(),
},
lora: None,
quantize: None,
merge: None,
training: Default::default(),
publish: None,
};
let result = load_training_batches(&spec);
assert!(result.is_ok());
}
#[test]
fn test_rebatch_preserves_data() {
let batches =
vec![make_batch(vec![1.0, 2.0], vec![10.0]), make_batch(vec![3.0, 4.0], vec![20.0])];
let result = rebatch(batches, 1);
assert_eq!(result.len(), 2);
let all_inputs: Vec<f32> =
result.iter().flat_map(|b| b.inputs.data().iter().copied()).collect();
assert_eq!(all_inputs, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_rebatch_exact_batch_size() {
let batches = vec![
make_batch(vec![1.0], vec![0.1]),
make_batch(vec![2.0], vec![0.2]),
make_batch(vec![3.0], vec![0.3]),
make_batch(vec![4.0], vec![0.4]),
];
let result = rebatch(batches, 2);
assert_eq!(result.len(), 2);
}
#[test]
fn test_rebatch_remainder() {
let batches = vec![
make_batch(vec![1.0], vec![0.1]),
make_batch(vec![2.0], vec![0.2]),
make_batch(vec![3.0], vec![0.3]),
make_batch(vec![4.0], vec![0.4]),
make_batch(vec![5.0], vec![0.5]),
];
let result = rebatch(batches, 2);
assert_eq!(result.len(), 3);
}