use std::collections::HashMap;
use std::fs;
use std::path::Path;
fn main() {
let fixtures_dir = Path::new("tests/fixtures");
if !fixtures_dir.exists() {
fs::create_dir_all(fixtures_dir).unwrap();
}
generate_safetensors_fixtures(fixtures_dir);
println!("cargo:rerun-if-changed=build.rs");
}
fn generate_safetensors_fixtures(dir: &Path) {
let model_v1_path = dir.join("model_v1.safetensors");
if !model_v1_path.exists() {
let tensors_v1 = create_model_v1_tensors();
let metadata_v1: HashMap<String, String> = HashMap::from([
("format".to_string(), "pt".to_string()),
("version".to_string(), "1.0".to_string()),
]);
safetensors::serialize_to_file(&tensors_v1, &Some(metadata_v1), &model_v1_path).unwrap();
}
let model_v2_path = dir.join("model_v2.safetensors");
if !model_v2_path.exists() {
let tensors_v2 = create_model_v2_tensors();
let metadata_v2: HashMap<String, String> = HashMap::from([
("format".to_string(), "pt".to_string()),
("version".to_string(), "2.0".to_string()),
]);
safetensors::serialize_to_file(&tensors_v2, &Some(metadata_v2), &model_v2_path).unwrap();
}
}
fn create_model_v1_tensors() -> HashMap<String, safetensors::tensor::TensorView<'static>> {
let fc1_weight: Vec<f32> = (0..512 * 256)
.map(|i| {
let x = (i as f32 / (512.0 * 256.0)) * 2.0 - 1.0; x * 0.1 })
.collect();
let fc2_weight: Vec<f32> = (0..256 * 128)
.map(|i| {
let x = (i as f32 / (256.0 * 128.0)) * 2.0 - 1.0;
x * 0.15 })
.collect();
let fc1_data: &'static [u8] = Box::leak(
fc1_weight
.iter()
.flat_map(|f| f.to_le_bytes())
.collect::<Vec<u8>>()
.into_boxed_slice(),
);
let fc2_data: &'static [u8] = Box::leak(
fc2_weight
.iter()
.flat_map(|f| f.to_le_bytes())
.collect::<Vec<u8>>()
.into_boxed_slice(),
);
HashMap::from([
(
"fc1.weight".to_string(),
safetensors::tensor::TensorView::new(safetensors::Dtype::F32, vec![512, 256], fc1_data)
.unwrap(),
),
(
"fc2.weight".to_string(),
safetensors::tensor::TensorView::new(safetensors::Dtype::F32, vec![256, 128], fc2_data)
.unwrap(),
),
])
}
fn create_model_v2_tensors() -> HashMap<String, safetensors::tensor::TensorView<'static>> {
let fc1_weight: Vec<f32> = (0..512 * 256)
.map(|i| {
let x = (i as f32 / (512.0 * 256.0)) * 2.0 - 1.0;
x * 0.15 + 0.05 })
.collect();
let fc2_weight: Vec<f32> = (0..256 * 128)
.map(|i| {
let x = (i as f32 / (256.0 * 128.0)) * 2.0 - 1.0;
x * 0.2 - 0.03 })
.collect();
let fc1_data: &'static [u8] = Box::leak(
fc1_weight
.iter()
.flat_map(|f| f.to_le_bytes())
.collect::<Vec<u8>>()
.into_boxed_slice(),
);
let fc2_data: &'static [u8] = Box::leak(
fc2_weight
.iter()
.flat_map(|f| f.to_le_bytes())
.collect::<Vec<u8>>()
.into_boxed_slice(),
);
HashMap::from([
(
"fc1.weight".to_string(),
safetensors::tensor::TensorView::new(safetensors::Dtype::F32, vec![512, 256], fc1_data)
.unwrap(),
),
(
"fc2.weight".to_string(),
safetensors::tensor::TensorView::new(safetensors::Dtype::F32, vec![256, 128], fc2_data)
.unwrap(),
),
])
}