#![allow(clippy::disallowed_methods)]
use aprender::format::v2::{AprV2Metadata, AprV2Writer, TensorDType};
fn main() {
let mut metadata = AprV2Metadata::new("whisper");
metadata.name = Some("whisper-tiny-test".to_string());
metadata.description = Some("Test Whisper model for APR v2 serving".to_string());
let mut writer = AprV2Writer::new(metadata);
fn f32_bytes(vals: &[f32]) -> Vec<u8> {
vals.iter().flat_map(|f| f.to_le_bytes()).collect()
}
for i in 0..2 {
let prefix = format!("encoder.layers.{}", i);
writer.add_tensor(
&format!("{}.self_attn.q_proj.weight", prefix),
TensorDType::F32,
vec![384, 384],
f32_bytes(&vec![0.01; 384 * 384]),
);
writer.add_tensor(
&format!("{}.self_attn.k_proj.weight", prefix),
TensorDType::F32,
vec![384, 384],
f32_bytes(&vec![0.01; 384 * 384]),
);
writer.add_tensor(
&format!("{}.self_attn.v_proj.weight", prefix),
TensorDType::F32,
vec![384, 384],
f32_bytes(&vec![0.01; 384 * 384]),
);
writer.add_tensor(
&format!("{}.self_attn.out_proj.weight", prefix),
TensorDType::F32,
vec![384, 384],
f32_bytes(&vec![0.01; 384 * 384]),
);
writer.add_tensor(
&format!("{}.mlp.fc1.weight", prefix),
TensorDType::F32,
vec![1536, 384],
f32_bytes(&vec![0.01; 1536 * 384]),
);
writer.add_tensor(
&format!("{}.mlp.fc2.weight", prefix),
TensorDType::F32,
vec![384, 1536],
f32_bytes(&vec![0.01; 384 * 1536]),
);
}
for i in 0..2 {
let prefix = format!("decoder.layers.{}", i);
writer.add_tensor(
&format!("{}.self_attn.q_proj.weight", prefix),
TensorDType::F32,
vec![384, 384],
f32_bytes(&vec![0.02; 384 * 384]),
);
writer.add_tensor(
&format!("{}.self_attn.k_proj.weight", prefix),
TensorDType::F32,
vec![384, 384],
f32_bytes(&vec![0.02; 384 * 384]),
);
writer.add_tensor(
&format!("{}.self_attn.v_proj.weight", prefix),
TensorDType::F32,
vec![384, 384],
f32_bytes(&vec![0.02; 384 * 384]),
);
writer.add_tensor(
&format!("{}.self_attn.out_proj.weight", prefix),
TensorDType::F32,
vec![384, 384],
f32_bytes(&vec![0.02; 384 * 384]),
);
writer.add_tensor(
&format!("{}.cross_attn.q_proj.weight", prefix),
TensorDType::F32,
vec![384, 384],
f32_bytes(&vec![0.03; 384 * 384]),
);
writer.add_tensor(
&format!("{}.cross_attn.k_proj.weight", prefix),
TensorDType::F32,
vec![384, 384],
f32_bytes(&vec![0.03; 384 * 384]),
);
writer.add_tensor(
&format!("{}.cross_attn.v_proj.weight", prefix),
TensorDType::F32,
vec![384, 384],
f32_bytes(&vec![0.03; 384 * 384]),
);
writer.add_tensor(
&format!("{}.cross_attn.out_proj.weight", prefix),
TensorDType::F32,
vec![384, 384],
f32_bytes(&vec![0.03; 384 * 384]),
);
writer.add_tensor(
&format!("{}.mlp.fc1.weight", prefix),
TensorDType::F32,
vec![1536, 384],
f32_bytes(&vec![0.02; 1536 * 384]),
);
writer.add_tensor(
&format!("{}.mlp.fc2.weight", prefix),
TensorDType::F32,
vec![384, 1536],
f32_bytes(&vec![0.02; 384 * 1536]),
);
}
let apr_bytes = writer.write().expect("Failed to write APR");
std::fs::write("/tmp/test-whisper-v2.apr", &apr_bytes).expect("Failed to save");
println!(
"Created /tmp/test-whisper-v2.apr ({} bytes)",
apr_bytes.len()
);
}