use super::{
generators::{ModelWeights, SyntheticWeightGenerator},
ConstructorInput, Device, ModelConfig, ModelFormat, QuantType,
};
use crate::error::RealizarError;
use crate::Result;
pub trait ModelFixture: Send + Sync {
fn config(&self) -> &ModelConfig;
fn format(&self) -> ModelFormat;
fn quant_type(&self) -> QuantType;
fn forward(&self, device: Device, tokens: &[u32]) -> Result<Vec<f32>>;
fn embed(&self, device: Device, token: u32) -> Result<Vec<f32>>;
fn to_bytes(&self) -> Result<Vec<u8>>;
fn convert_to(&self, target: ModelFormat) -> Result<Box<dyn ModelFixture>>;
fn memory_bytes(&self) -> usize;
}
pub struct GgufFixture {
config: ModelConfig,
weights: ModelWeights,
}
impl GgufFixture {
pub fn new(config: ModelConfig, quant: QuantType, seed: u64) -> Self {
let gen = SyntheticWeightGenerator::new(seed);
let weights = gen.generate_model_weights(&config, quant);
Self { config, weights }
}
pub fn tiny_gqa() -> Self {
Self::new(ModelConfig::tiny(), QuantType::Q4_0, 42)
}
pub fn tiny_mha() -> Self {
let mut config = ModelConfig::tiny();
config.num_kv_heads = config.num_heads; Self::new(config, QuantType::Q4_0, 42)
}
pub fn small() -> Self {
Self::new(ModelConfig::small(), QuantType::Q4_K, 42)
}
pub fn from_constructor(input: &ConstructorInput) -> Self {
Self::new(
input.config.clone(),
input.quantization.unwrap_or(QuantType::F32),
input.weights_seed,
)
}
}
impl ModelFixture for GgufFixture {
fn config(&self) -> &ModelConfig {
&self.config
}
fn format(&self) -> ModelFormat {
ModelFormat::GGUF
}
fn quant_type(&self) -> QuantType {
self.weights.quant_type
}
fn forward(&self, _device: Device, _tokens: &[u32]) -> Result<Vec<f32>> {
let output_size = self.config.vocab_size;
let mut logits = vec![0.0f32; output_size];
for (i, logit) in logits.iter_mut().enumerate() {
*logit = ((i as f32) / (output_size as f32) - 0.5) * 2.0;
}
Ok(logits)
}
fn embed(&self, _device: Device, token: u32) -> Result<Vec<f32>> {
let mut embedding = vec![0.0f32; self.config.hidden_dim];
let scale = 1.0 / (self.config.hidden_dim as f32).sqrt();
for (i, val) in embedding.iter_mut().enumerate() {
*val = ((token as usize * 17 + i * 13) % 1000) as f32 / 1000.0 * scale;
}
Ok(embedding)
}
fn to_bytes(&self) -> Result<Vec<u8>> {
let mut bytes = Vec::new();
bytes.extend_from_slice(b"GGUF");
bytes.extend_from_slice(&3u32.to_le_bytes());
let tensor_count = 1 + self.config.num_layers * 9 + 2; bytes.extend_from_slice(&(tensor_count as u64).to_le_bytes());
bytes.extend_from_slice(&10u64.to_le_bytes());
write_gguf_kv(
&mut bytes,
"llama.attention.head_count",
self.config.num_heads as u32,
);
write_gguf_kv(
&mut bytes,
"llama.attention.head_count_kv",
self.config.num_kv_heads as u32,
);
write_gguf_kv(
&mut bytes,
"llama.embedding_length",
self.config.hidden_dim as u32,
);
write_gguf_kv(
&mut bytes,
"llama.block_count",
self.config.num_layers as u32,
);
write_gguf_kv(
&mut bytes,
"llama.vocab_size",
self.config.vocab_size as u32,
);
Ok(bytes)
}
fn convert_to(&self, target: ModelFormat) -> Result<Box<dyn ModelFixture>> {
match target {
ModelFormat::APR => Ok(Box::new(AprFixture::from_gguf(self)?)),
ModelFormat::Safetensors => Ok(Box::new(SafetensorsFixture::from_gguf(self)?)),
ModelFormat::GGUF => {
Ok(Box::new(GgufFixture {
config: self.config.clone(),
weights: self.weights.clone(),
}))
},
ModelFormat::PyTorch => Err(RealizarError::UnsupportedOperation {
operation: "convert_to".to_string(),
reason: "GGUF to PyTorch conversion not supported".to_string(),
}),
}
}
fn memory_bytes(&self) -> usize {
self.weights.total_bytes()
}
}
pub struct AprFixture {
config: ModelConfig,
weights: ModelWeights,
}
impl AprFixture {
pub fn new(config: ModelConfig, quant: QuantType, seed: u64) -> Self {
let gen = SyntheticWeightGenerator::new(seed);
let weights = gen.generate_model_weights(&config, quant);
Self { config, weights }
}
pub fn tiny_gqa() -> Self {
Self::new(ModelConfig::tiny(), QuantType::Q4_0, 42)
}
pub fn from_gguf(gguf: &GgufFixture) -> Result<Self> {
Ok(Self {
config: gguf.config.clone(),
weights: gguf.weights.clone(),
})
}
pub fn from_constructor(input: &ConstructorInput) -> Self {
Self::new(
input.config.clone(),
input.quantization.unwrap_or(QuantType::F32),
input.weights_seed,
)
}
}
impl ModelFixture for AprFixture {
fn config(&self) -> &ModelConfig {
&self.config
}
fn format(&self) -> ModelFormat {
ModelFormat::APR
}
fn quant_type(&self) -> QuantType {
self.weights.quant_type
}
fn forward(&self, _device: Device, _tokens: &[u32]) -> Result<Vec<f32>> {
let output_size = self.config.vocab_size;
let mut logits = vec![0.0f32; output_size];
for (i, logit) in logits.iter_mut().enumerate() {
*logit = ((i as f32) / (output_size as f32) - 0.5) * 2.0;
}
Ok(logits)
}
fn embed(&self, _device: Device, token: u32) -> Result<Vec<f32>> {
let mut embedding = vec![0.0f32; self.config.hidden_dim];
let scale = 1.0 / (self.config.hidden_dim as f32).sqrt();
for (i, val) in embedding.iter_mut().enumerate() {
*val = ((token as usize * 17 + i * 13) % 1000) as f32 / 1000.0 * scale;
}
Ok(embedding)
}
fn to_bytes(&self) -> Result<Vec<u8>> {
let mut bytes = Vec::new();
bytes.extend_from_slice(b"APR\x02");
bytes.extend_from_slice(&64u32.to_le_bytes());
let metadata = serde_json::json!({
"num_heads": self.config.num_heads,
"num_kv_heads": self.config.num_kv_heads,
"hidden_size": self.config.hidden_dim,
"num_layers": self.config.num_layers,
"vocab_size": self.config.vocab_size,
"intermediate_size": self.config.intermediate_dim,
"rope_theta": self.config.rope_theta,
});
let metadata_bytes =
serde_json::to_vec(&metadata).map_err(|e| RealizarError::FormatError {
reason: format!("APR metadata serialization failed: {}", e),
})?;
bytes.extend_from_slice(&64u64.to_le_bytes()); bytes.extend_from_slice(&(metadata_bytes.len() as u32).to_le_bytes());
while bytes.len() < 64 {
bytes.push(0);
}
bytes.extend_from_slice(&metadata_bytes);
Ok(bytes)
}
fn convert_to(&self, target: ModelFormat) -> Result<Box<dyn ModelFixture>> {
match target {
ModelFormat::GGUF => Ok(Box::new(GgufFixture {
config: self.config.clone(),
weights: self.weights.clone(),
})),
ModelFormat::Safetensors => Ok(Box::new(SafetensorsFixture {
config: self.config.clone(),
weights: self.weights.clone(),
})),
ModelFormat::APR => Ok(Box::new(AprFixture {
config: self.config.clone(),
weights: self.weights.clone(),
})),
ModelFormat::PyTorch => Err(RealizarError::UnsupportedOperation {
operation: "convert_to".to_string(),
reason: "APR to PyTorch conversion not supported".to_string(),
}),
}
}
fn memory_bytes(&self) -> usize {
self.weights.total_bytes()
}
}
pub struct SafetensorsFixture {
config: ModelConfig,
weights: ModelWeights,
}
impl SafetensorsFixture {
pub fn new(config: ModelConfig, quant: QuantType, seed: u64) -> Self {
let actual_quant = if quant.supported_by(ModelFormat::Safetensors) {
quant
} else {
QuantType::F32
};
let gen = SyntheticWeightGenerator::new(seed);
let weights = gen.generate_model_weights(&config, actual_quant);
Self { config, weights }
}
pub fn tiny() -> Self {
Self::new(ModelConfig::tiny(), QuantType::F32, 42)
}
pub fn from_gguf(gguf: &GgufFixture) -> Result<Self> {
let quant = if gguf.quant_type().supported_by(ModelFormat::Safetensors) {
gguf.quant_type()
} else {
QuantType::F32
};
let gen = SyntheticWeightGenerator::new(42);
let weights = gen.generate_model_weights(&gguf.config, quant);
Ok(Self {
config: gguf.config.clone(),
weights,
})
}
pub fn from_constructor(input: &ConstructorInput) -> Self {
Self::new(
input.config.clone(),
input.quantization.unwrap_or(QuantType::F32),
input.weights_seed,
)
}
}
include!("safetensors_fixture.rs");