use candle_core::{Device, Result, Tensor};
use std::path::Path;
use tracing::info;
use super::detector::{ModelDetector, QuantizationType};
use super::gguf_model::GgufModel;
use super::llama_4bit::{Llama4Bit, Llama4BitConfig};
#[cfg(feature = "tokenizers")]
use super::llama::Llama;
#[allow(clippy::large_enum_variant)]
pub enum UnifiedModel {
#[cfg(feature = "tokenizers")]
BitLlama(Llama),
FourBit(Llama4Bit),
Gguf(GgufModel),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelType {
BitLlama,
FourBit,
Gguf,
}
impl UnifiedModel {
pub fn detect_type<P: AsRef<Path>>(path: P) -> Result<ModelType> {
let path = path.as_ref();
if ModelDetector::is_gguf(path) {
info!("🔍 Detected GGUF model");
return Ok(ModelType::Gguf);
}
let dir = if path.is_file() {
path.parent().unwrap_or(path)
} else {
path
};
if let Ok(info) = ModelDetector::detect(dir) {
match info.quantization {
QuantizationType::Int4 => return Ok(ModelType::FourBit),
QuantizationType::Ternary => return Ok(ModelType::BitLlama),
QuantizationType::FP16 | QuantizationType::BF16 | QuantizationType::FP32 => {
return Ok(ModelType::BitLlama)
}
_ => {}
}
}
let config_path = dir.join("config.json");
if config_path.exists() {
if let Ok(config_str) = std::fs::read_to_string(&config_path) {
if let Ok(config) = serde_json::from_str::<serde_json::Value>(&config_str) {
if let Some(quant) = config.get("quantization") {
if let Some(qtype) = quant.get("type").and_then(|v| v.as_str()) {
if qtype == "int4" || qtype == "4bit" || qtype == "nf4" {
info!("🔍 Detected 4-bit model from config.json");
return Ok(ModelType::FourBit);
}
}
}
}
}
}
let model_path = dir.join("model.safetensors");
if model_path.exists() {
if let Ok(info) = detect_from_safetensors(&model_path) {
return Ok(info);
}
}
info!("🔍 Defaulting to BitLlama (FP16/TTT) model type");
Ok(ModelType::BitLlama)
}
#[cfg(feature = "tokenizers")]
pub fn load_auto<P: AsRef<Path>>(path: P) -> Result<Self> {
let model_type = Self::detect_type(&path)?;
match model_type {
ModelType::FourBit => {
info!("📦 Loading as 4-bit quantized model");
let model = load_4bit_model(&path)?;
Ok(UnifiedModel::FourBit(model))
}
ModelType::BitLlama => {
info!("📦 Loading as BitLlama (FP16/TTT) model");
let model = Llama::load_auto(&path)?;
Ok(UnifiedModel::BitLlama(model))
}
ModelType::Gguf => {
info!("📦 Loading as GGUF model");
let device = Device::cuda_if_available(0)?;
let model = GgufModel::load(&path, &device)?;
Ok(UnifiedModel::Gguf(model))
}
}
}
pub fn is_loaded(&self) -> bool {
true }
pub fn model_type(&self) -> ModelType {
match self {
#[cfg(feature = "tokenizers")]
UnifiedModel::BitLlama(_) => ModelType::BitLlama,
UnifiedModel::FourBit(_) => ModelType::FourBit,
UnifiedModel::Gguf(_) => ModelType::Gguf,
}
}
#[cfg(feature = "tokenizers")]
pub fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result<String> {
match self {
UnifiedModel::BitLlama(llama) => llama.generate(prompt, max_tokens),
UnifiedModel::FourBit(_) | UnifiedModel::Gguf(_) => Err(candle_core::Error::Msg(
"This model requires external tokenizer for generate()".to_string(),
)),
}
}
pub fn forward_4bit(&mut self, input_ids: &Tensor, start_pos: usize) -> Result<Tensor> {
match self {
#[cfg(feature = "tokenizers")]
UnifiedModel::BitLlama(_) => Err(candle_core::Error::Msg(
"Use generate() for BitLlama models".to_string(),
)),
UnifiedModel::FourBit(model) => model.forward(input_ids, start_pos),
UnifiedModel::Gguf(model) => model.forward(input_ids, start_pos),
}
}
pub fn reset_cache(&mut self) {
match self {
#[cfg(feature = "tokenizers")]
UnifiedModel::BitLlama(llama) => {
llama.model.reset_kv_cache();
}
UnifiedModel::FourBit(model) => {
model.clear_kv_cache();
}
UnifiedModel::Gguf(model) => {
model.reset_cache();
}
}
}
#[cfg(feature = "tokenizers")]
pub fn stream_completion<F>(
&mut self,
prompt: &str,
max_tokens: usize,
temp: f64,
callback: F,
) -> Result<String>
where
F: FnMut(&str) -> anyhow::Result<bool>,
{
match self {
UnifiedModel::BitLlama(llama) => {
llama.stream_completion(prompt, max_tokens, temp, callback)
}
UnifiedModel::FourBit(_) | UnifiedModel::Gguf(_) => Err(candle_core::Error::Msg(
"Streaming not supported for this model. Use CLI instead.".to_string(),
)),
}
}
#[cfg(feature = "tokenizers")]
pub fn learn(&mut self, text: &str) -> Result<()> {
match self {
UnifiedModel::BitLlama(llama) => llama.learn(text),
UnifiedModel::FourBit(_) | UnifiedModel::Gguf(_) => Err(candle_core::Error::Msg(
"Learning not supported for this model".to_string(),
)),
}
}
#[cfg(feature = "tokenizers")]
pub fn soul_level(&self) -> u64 {
match self {
UnifiedModel::BitLlama(llama) => llama.soul_level,
UnifiedModel::FourBit(_) | UnifiedModel::Gguf(_) => 0,
}
}
#[cfg(feature = "tokenizers")]
pub fn load_memory<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
match self {
UnifiedModel::BitLlama(llama) => llama.load_memory(path),
UnifiedModel::FourBit(_) | UnifiedModel::Gguf(_) => Err(candle_core::Error::Msg(
"Soul loading not supported for this model".to_string(),
)),
}
}
}
fn detect_from_safetensors<P: AsRef<Path>>(path: P) -> Result<ModelType> {
use candle_core::safetensors::MmapedSafetensors;
let tensors = unsafe { MmapedSafetensors::new(path)? };
let tensor_list = tensors.tensors();
let names: Vec<_> = tensor_list.iter().map(|(n, _)| n.clone()).collect();
let has_4bit = names.iter().any(|n| {
n.contains("weight_4bit")
|| n.contains("scales_4bit")
|| (n.contains("scales") && n.contains("qweight"))
});
if has_4bit {
return Ok(ModelType::FourBit);
}
let has_ternary = names.iter().any(|n| n.contains("weight_packed"));
if has_ternary {
return Ok(ModelType::BitLlama);
}
Ok(ModelType::BitLlama)
}
#[allow(dead_code)]
fn load_4bit_model<P: AsRef<Path>>(path: P) -> Result<Llama4Bit> {
let path = path.as_ref();
let dir = if path.is_file() {
path.parent().unwrap_or(path)
} else {
path
};
let config_path = dir.join("config.json");
let mut model_path = dir.join("model.safetensors");
if !model_path.exists() {
model_path = dir.join("weight.safetensors");
}
let config_str = std::fs::read_to_string(&config_path)
.map_err(|e| candle_core::Error::Msg(format!("Failed to read config: {}", e)))?;
let config: Llama4BitConfig = serde_json::from_str(&config_str)
.map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
let tensors = candle_core::safetensors::load(&model_path, &device)?;
Llama4Bit::load(&tensors, config, &device)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_type() {
let _ = UnifiedModel::detect_type("nonexistent_path");
}
}