use aprender::linear_model::LinearRegression;
use aprender::primitives::{Matrix, Vector};
use aprender::traits::Estimator;
use std::fs;
use std::path::Path;
#[test]
fn test_linear_regression_save_safetensors_creates_file() {
let x = Matrix::from_vec(4, 2, vec![1.0, 2.0, 2.0, 1.0, 3.0, 4.0, 4.0, 3.0])
.expect("Test data should be valid");
let y = Vector::from_vec(vec![5.0, 4.0, 11.0, 10.0]);
let mut model = LinearRegression::new();
model.fit(&x, &y).expect("Test data should be valid");
let path = "test_model.safetensors";
model
.save_safetensors(path)
.expect("Test data should be valid");
assert!(
Path::new(path).exists(),
"SafeTensors file should be created"
);
fs::remove_file(path).ok();
}
#[test]
fn test_safetensors_header_format() {
let x = Matrix::from_vec(2, 1, vec![1.0, 2.0]).expect("Test data should be valid");
let y = Vector::from_vec(vec![3.0, 4.0]);
let mut model = LinearRegression::new();
model.fit(&x, &y).expect("Test data should be valid");
let path = "test_header.safetensors";
model
.save_safetensors(path)
.expect("Test data should be valid");
let bytes = fs::read(path).expect("Test data should be valid");
assert!(bytes.len() >= 8, "File must be at least 8 bytes");
let header_bytes: [u8; 8] = bytes[0..8].try_into().expect("Test data should be valid");
let metadata_len = u64::from_le_bytes(header_bytes);
assert!(metadata_len > 0, "Metadata length must be > 0");
assert!(metadata_len < 10000, "Metadata length should be reasonable");
fs::remove_file(path).ok();
}
#[test]
#[allow(clippy::disallowed_methods)] fn test_safetensors_json_metadata_structure() {
let x = Matrix::from_vec(3, 2, vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
.expect("Test data should be valid");
let y = Vector::from_vec(vec![1.0, 2.0, 3.0]);
let mut model = LinearRegression::new();
model.fit(&x, &y).expect("Test data should be valid");
let path = "test_metadata.safetensors";
model
.save_safetensors(path)
.expect("Test data should be valid");
let bytes = fs::read(path).expect("Test data should be valid");
let header_bytes: [u8; 8] = bytes[0..8].try_into().expect("Test data should be valid");
let metadata_len = u64::from_le_bytes(header_bytes) as usize;
let metadata_json = &bytes[8..8 + metadata_len];
let metadata_str = std::str::from_utf8(metadata_json).expect("Test data should be valid");
let metadata: serde_json::Value =
serde_json::from_str(metadata_str).expect("Test data should be valid");
assert!(
metadata.get("coefficients").is_some(),
"Must have 'coefficients' tensor"
);
let coeff_meta = &metadata["coefficients"];
assert_eq!(coeff_meta["dtype"], "F32", "Coefficients must be F32");
assert!(coeff_meta.get("shape").is_some(), "Must have shape");
assert!(
coeff_meta.get("data_offsets").is_some(),
"Must have data_offsets"
);
assert!(
metadata.get("intercept").is_some(),
"Must have 'intercept' tensor"
);
let intercept_meta = &metadata["intercept"];
assert_eq!(intercept_meta["dtype"], "F32", "Intercept must be F32");
assert_eq!(
intercept_meta["shape"],
serde_json::json!([1]),
"Intercept shape must be [1]"
);
fs::remove_file(path).ok();
}
#[test]
fn test_safetensors_coefficients_serialization() {
let x = Matrix::from_vec(3, 2, vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
.expect("Test data should be valid");
let y = Vector::from_vec(vec![2.0, 3.0, 5.0]);
let mut model = LinearRegression::new();
model.fit(&x, &y).expect("Test data should be valid");
let path = "test_coeffs.safetensors";
model
.save_safetensors(path)
.expect("Test data should be valid");
let bytes = fs::read(path).expect("Test data should be valid");
let header_bytes: [u8; 8] = bytes[0..8].try_into().expect("Test data should be valid");
let metadata_len = u64::from_le_bytes(header_bytes) as usize;
let metadata_json = &bytes[8..8 + metadata_len];
let metadata: serde_json::Value = serde_json::from_str(
std::str::from_utf8(metadata_json).expect("Test data should be valid"),
)
.expect("Test data should be valid");
let offsets = metadata["coefficients"]["data_offsets"]
.as_array()
.expect("Test data should be valid");
let start = offsets[0].as_u64().expect("Test data should be valid") as usize + 8 + metadata_len;
let end = offsets[1].as_u64().expect("Test data should be valid") as usize + 8 + metadata_len;
let coeff_bytes = &bytes[start..end];
assert_eq!(
coeff_bytes.len() % 4,
0,
"Coefficient data must be multiple of 4 bytes"
);
let n_coeffs = coeff_bytes.len() / 4;
for i in 0..n_coeffs {
let f32_bytes: [u8; 4] = coeff_bytes[i * 4..(i + 1) * 4]
.try_into()
.expect("Test data should be valid");
let _value = f32::from_le_bytes(f32_bytes); }
fs::remove_file(path).ok();
}
#[test]
fn test_safetensors_roundtrip() {
let x = Matrix::from_vec(
5,
3,
vec![
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0,
],
)
.expect("Test data should be valid");
let y = Vector::from_vec(vec![2.0, 3.0, 4.0, 5.0, 6.0]);
let mut model_original = LinearRegression::new();
model_original
.fit(&x, &y)
.expect("Test data should be valid");
let original_coeffs = model_original.coefficients();
let original_intercept = model_original.intercept();
let path = "test_roundtrip.safetensors";
model_original
.save_safetensors(path)
.expect("Test data should be valid");
let model_loaded = LinearRegression::load_safetensors(path).expect("Test data should be valid");
let loaded_coeffs = model_loaded.coefficients();
assert_eq!(
loaded_coeffs.len(),
original_coeffs.len(),
"Coefficient count must match"
);
for i in 0..original_coeffs.len() {
let diff = (loaded_coeffs[i] - original_coeffs[i]).abs();
assert!(
diff < 1e-6,
"Coefficient {} must match: {} vs {}",
i,
original_coeffs[i],
loaded_coeffs[i]
);
}
let diff = (model_loaded.intercept() - original_intercept).abs();
assert!(
diff < 1e-6,
"Intercept must match: {} vs {}",
original_intercept,
model_loaded.intercept()
);
let pred_original = model_original.predict(&x);
let pred_loaded = model_loaded.predict(&x);
for i in 0..pred_original.len() {
let diff = (pred_loaded[i] - pred_original[i]).abs();
assert!(diff < 1e-5, "Prediction {i} must match");
}
fs::remove_file(path).ok();
}
#[test]
fn test_safetensors_file_does_not_exist_error() {
let result = LinearRegression::load_safetensors("nonexistent.safetensors");
assert!(
result.is_err(),
"Loading nonexistent file should return error"
);
let error_msg = result.expect_err("Expected error in test");
assert!(
error_msg.contains("No such file") || error_msg.contains("not found"),
"Error should mention file not found, got: {error_msg}"
);
}
#[test]
fn test_safetensors_corrupted_header_error() {
let path = "test_corrupted.safetensors";
fs::write(path, [1, 2, 3]).expect("Test data should be valid");
let result = LinearRegression::load_safetensors(path);
assert!(
result.is_err(),
"Loading corrupted file should return error"
);
fs::remove_file(path).ok();
}