use crate::dynamic_quantization::{DynamicQuantizer, QuantStrategy, QuantizedWeightStorage};
use crate::error::{ModelError, ModelResult};
use crate::huggingface::{HuggingFaceHub, ModelConfig};
use crate::loader::ModelLoader;
use scirs2_core::ndarray::{s, Array2};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConversionStrategy {
MambaHF,
RwkvHF,
Direct,
}
pub struct HuggingFaceModelLoader {
hub: HuggingFaceHub,
strategy: ConversionStrategy,
}
impl HuggingFaceModelLoader {
pub fn new() -> ModelResult<Self> {
Ok(Self {
hub: HuggingFaceHub::new()?,
strategy: ConversionStrategy::Direct,
})
}
pub fn with_strategy(mut self, strategy: ConversionStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_token(mut self, token: impl Into<String>) -> Self {
self.hub = self.hub.with_token(token);
self
}
pub async fn detect_and_load(
&mut self,
repo_id: &str,
revision: Option<&str>,
) -> ModelResult<(ModelConfig, HashMap<String, Array2<f32>>)> {
let config = self.hub.load_config(repo_id, revision).await?;
if let Some(model_type) = &config.model_type {
self.strategy = match model_type.to_lowercase().as_str() {
"mamba" | "mamba2" => ConversionStrategy::MambaHF,
"rwkv" | "rwkv6" | "rwkv7" => ConversionStrategy::RwkvHF,
_ => ConversionStrategy::Direct,
};
} else if let Some(arch) = &config.architecture {
if let Some(arch_name) = arch.first() {
self.strategy = match arch_name.to_lowercase().as_str() {
s if s.contains("mamba") => ConversionStrategy::MambaHF,
s if s.contains("rwkv") => ConversionStrategy::RwkvHF,
_ => ConversionStrategy::Direct,
};
}
}
tracing::info!(
"Detected model type: {:?}, using strategy: {:?}",
config.model_type,
self.strategy
);
let weights = self.load_and_convert_weights(repo_id, revision).await?;
Ok((config, weights))
}
pub async fn load_and_convert_weights(
&self,
repo_id: &str,
revision: Option<&str>,
) -> ModelResult<HashMap<String, Array2<f32>>> {
let loader = self.hub.load_model_loader(repo_id, revision).await?;
match self.strategy {
ConversionStrategy::MambaHF => self.convert_mamba_weights(&loader),
ConversionStrategy::RwkvHF => self.convert_rwkv_weights(&loader),
ConversionStrategy::Direct => {
let mut weights = HashMap::new();
for name in loader.list_tensors() {
if let Ok(tensor) = loader.load_array2(&name) {
weights.insert(name, tensor);
}
}
Ok(weights)
}
}
}
fn convert_mamba_weights(
&self,
loader: &ModelLoader,
) -> ModelResult<HashMap<String, Array2<f32>>> {
let mut kizzasi_weights = HashMap::new();
let tensor_names = loader.list_tensors();
tracing::info!(
"Converting {} HuggingFace Mamba tensors to Kizzasi format",
tensor_names.len()
);
for hf_name in tensor_names {
let kizzasi_name = self.convert_mamba_name(&hf_name);
if hf_name.contains("mixer.x_proj") && hf_name.ends_with(".weight") {
if let Ok(x_proj) = loader.load_array2(&hf_name) {
let (_intermediate_size, combined_dim) = x_proj.dim();
let state_size = 16; let dt_rank = combined_dim - 2 * state_size;
if dt_rank > 0 && dt_rank + 2 * state_size == combined_dim {
let delta_proj = x_proj.slice(s![.., ..dt_rank]).to_owned();
let delta_name = kizzasi_name.replace("x_proj", "ssm.delta_proj");
kizzasi_weights.insert(delta_name, delta_proj);
let b_proj = x_proj
.slice(s![.., dt_rank..dt_rank + state_size])
.to_owned();
let b_name = kizzasi_name.replace("x_proj", "ssm.b_proj");
kizzasi_weights.insert(b_name, b_proj);
let c_proj = x_proj.slice(s![.., dt_rank + state_size..]).to_owned();
let c_name = kizzasi_name.replace("x_proj", "ssm.c_proj");
kizzasi_weights.insert(c_name, c_proj);
tracing::debug!("Split x_proj {} into delta/b/c projections (dt_rank={}, state_size={})",
hf_name, dt_rank, state_size);
continue;
} else {
tracing::warn!("Could not infer dimensions for x_proj splitting: combined_dim={}, inferred dt_rank={}, state_size={}",
combined_dim, dt_rank, state_size);
}
}
}
if let Ok(tensor) = loader.load_array2(&hf_name) {
kizzasi_weights.insert(kizzasi_name, tensor);
} else if let Ok(tensor) = loader.load_array1(&hf_name) {
let len = tensor.len();
let tensor_2d = tensor
.to_shape((len, 1))
.map_err(|e| {
ModelError::simple_load_error(format!("Failed to reshape tensor: {}", e))
})?
.to_owned();
kizzasi_weights.insert(kizzasi_name, tensor_2d);
}
}
tracing::info!("Converted to {} Kizzasi tensors", kizzasi_weights.len());
Ok(kizzasi_weights)
}
fn convert_mamba_name(&self, hf_name: &str) -> String {
let mut name = hf_name.to_string();
name = name.replace("backbone.", "");
if name.starts_with("embeddings") || name == "embedding.weight" {
return "input_proj".to_string();
}
if name.starts_with("lm_head") {
return name.replace("lm_head", "output_proj");
}
name = name.replace(".mixer.", ".");
name = name.replace("conv1d.", "conv.");
name = name.replace(".A_log", ".ssm.log_a");
name = name.replace(".D.", ".ssm.d_skip.");
name = name.replace(".D", ".ssm.d_skip");
name = name.replace("dt_proj", "ssm.dt_proj");
name = name.replace("x_proj", "ssm.x_proj");
name
}
fn convert_rwkv_weights(
&self,
loader: &ModelLoader,
) -> ModelResult<HashMap<String, Array2<f32>>> {
let mut kizzasi_weights = HashMap::new();
let tensor_names = loader.list_tensors();
tracing::info!(
"Converting {} HuggingFace RWKV tensors to Kizzasi format",
tensor_names.len()
);
for hf_name in tensor_names {
let kizzasi_name = self.convert_rwkv_name(&hf_name);
if let Ok(tensor) = loader.load_array2(&hf_name) {
kizzasi_weights.insert(kizzasi_name, tensor);
} else if let Ok(tensor) = loader.load_array1(&hf_name) {
let len = tensor.len();
let tensor_2d = tensor
.to_shape((len, 1))
.map_err(|e| {
ModelError::simple_load_error(format!("Failed to reshape tensor: {}", e))
})?
.to_owned();
kizzasi_weights.insert(kizzasi_name, tensor_2d);
}
}
tracing::info!("Converted to {} Kizzasi tensors", kizzasi_weights.len());
Ok(kizzasi_weights)
}
fn convert_rwkv_name(&self, hf_name: &str) -> String {
let mut name = hf_name.to_string();
if name.starts_with("emb.weight") || name == "emb" {
return "input_proj".to_string();
}
if name.starts_with("head.weight") || name.starts_with("head.") {
return name.replace("head", "output_proj");
}
name = name.replace("blocks.", "layers.");
name = name.replace("ln1.", "norm.");
name = name.replace("ln2.", "norm2.");
name = name.replace(".att.", ".time_mix.");
name = name.replace("time_decay", "decay");
name = name.replace("time_first", "first");
name = name.replace(".ffn.", ".channel_mix.");
name
}
pub fn hub(&self) -> &HuggingFaceHub {
&self.hub
}
pub fn strategy(&self) -> ConversionStrategy {
self.strategy
}
pub async fn load_and_quantize(
&mut self,
repo_id: &str,
revision: Option<&str>,
quant_strategy: QuantStrategy,
) -> ModelResult<(
ModelConfig,
HashMap<String, QuantizedWeightStorage>,
crate::dynamic_quantization::QuantizationStats,
)> {
let (config, weights) = self.detect_and_load(repo_id, revision).await?;
let quantizer = DynamicQuantizer::new().with_strategy(quant_strategy);
let quantized_weights = quantizer.quantize_weights(&weights)?;
let stats = quantizer.calculate_memory_savings(&weights, &quantized_weights);
tracing::info!(
"Quantized {} weights using {:?}: {:.2}x compression ({} → {})",
quantized_weights.len(),
quant_strategy,
stats.compression_ratio,
crate::dynamic_quantization::QuantizationStats::format_size(stats.original_size_bytes),
crate::dynamic_quantization::QuantizationStats::format_size(stats.quantized_size_bytes)
);
Ok((config, quantized_weights, stats))
}
pub async fn load_with_auto_quantization(
&mut self,
repo_id: &str,
revision: Option<&str>,
) -> ModelResult<(
ModelConfig,
HashMap<String, QuantizedWeightStorage>,
crate::dynamic_quantization::QuantizationStats,
)> {
let (_config, _) = self.detect_and_load(repo_id, revision).await?;
self.load_and_quantize(repo_id, revision, QuantStrategy::MixedPrecision)
.await
}
pub async fn load_model(
&mut self,
repo_id: &str,
revision: Option<&str>,
quant_strategy: QuantStrategy,
) -> ModelResult<Box<dyn crate::AutoregressiveModel>> {
use crate::factory::ModelFactory;
tracing::info!("Loading model '{}' end-to-end", repo_id);
let (config, quantized_weights, stats) = self
.load_and_quantize(repo_id, revision, quant_strategy)
.await?;
tracing::info!(
"Creating model instance: compression={:.1}x, memory_saved={:.2} bytes",
stats.compression_ratio,
stats.memory_saved_bytes
);
let model = ModelFactory::create_from_config(&config, quantized_weights)?;
tracing::info!("Model loaded successfully: {}", model.model_type());
Ok(model)
}
pub async fn load_model_auto(
&mut self,
repo_id: &str,
revision: Option<&str>,
) -> ModelResult<Box<dyn crate::AutoregressiveModel>> {
self.load_model(repo_id, revision, QuantStrategy::MixedPrecision)
.await
}
}
impl Default for HuggingFaceModelLoader {
fn default() -> Self {
Self::new().expect("Failed to create default HuggingFaceModelLoader")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conversion_strategy() {
let loader = HuggingFaceModelLoader::new().unwrap();
assert_eq!(loader.strategy(), ConversionStrategy::Direct);
let loader = loader.with_strategy(ConversionStrategy::MambaHF);
assert_eq!(loader.strategy(), ConversionStrategy::MambaHF);
}
#[test]
fn test_mamba_name_conversion() {
let loader = HuggingFaceModelLoader::new().unwrap();
assert_eq!(
loader.convert_mamba_name("backbone.embeddings.weight"),
"input_proj"
);
assert_eq!(
loader.convert_mamba_name("backbone.layers.0.norm.weight"),
"layers.0.norm.weight"
);
assert_eq!(
loader.convert_mamba_name("backbone.layers.0.mixer.in_proj.weight"),
"layers.0.in_proj.weight"
);
assert_eq!(
loader.convert_mamba_name("backbone.layers.0.mixer.conv1d.weight"),
"layers.0.conv.weight"
);
assert_eq!(
loader.convert_mamba_name("backbone.layers.0.mixer.A_log"),
"layers.0.ssm.log_a"
);
assert_eq!(
loader.convert_mamba_name("backbone.layers.0.mixer.D"),
"layers.0.ssm.d_skip"
);
assert_eq!(
loader.convert_mamba_name("lm_head.weight"),
"output_proj.weight"
);
}
#[test]
fn test_rwkv_name_conversion() {
let loader = HuggingFaceModelLoader::new().unwrap();
assert_eq!(loader.convert_rwkv_name("emb.weight"), "input_proj");
assert_eq!(
loader.convert_rwkv_name("blocks.0.ln1.weight"),
"layers.0.norm.weight"
);
assert_eq!(
loader.convert_rwkv_name("blocks.0.att.time_decay"),
"layers.0.time_mix.decay"
);
assert_eq!(
loader.convert_rwkv_name("blocks.0.att.key.weight"),
"layers.0.time_mix.key.weight"
);
assert_eq!(
loader.convert_rwkv_name("blocks.0.ffn.key.weight"),
"layers.0.channel_mix.key.weight"
);
assert_eq!(
loader.convert_rwkv_name("head.weight"),
"output_proj.weight"
);
}
#[test]
fn test_with_token() {
let loader = HuggingFaceModelLoader::new()
.unwrap()
.with_token("test_token");
assert_eq!(loader.hub().token.as_deref(), Some("test_token"));
}
}