use candle_core::{DType, Device, Tensor};
use crate::error::{MIError, Result};
use crate::hooks::{HookCache, HookSpec};
use crate::tokenizer::MITokenizer;
pub trait MIBackend: Send + Sync {
fn num_layers(&self) -> usize;
fn hidden_size(&self) -> usize;
fn vocab_size(&self) -> usize;
fn num_heads(&self) -> usize;
fn forward(&self, input_ids: &Tensor, hooks: &HookSpec) -> Result<HookCache>;
fn project_to_vocab(&self, hidden: &Tensor) -> Result<Tensor>;
fn chat_template(&self, _prompt: &str, _system_prompt: Option<&str>) -> Option<String> {
None
}
fn embedding_vector(&self, _token_id: u32) -> Result<Tensor> {
Err(MIError::Hook(
"embedding_vector not supported for this backend".into(),
))
}
}
pub struct MIModel {
backend: Box<dyn MIBackend>,
device: Device,
tokenizer: Option<MITokenizer>,
}
impl MIModel {
#[cfg(any(feature = "transformer", feature = "rwkv"))]
pub fn from_pretrained(model_id: &str) -> Result<Self> {
let device = Self::select_device()?;
let dtype = DType::F32;
let fetch_config = crate::download::fetch_config_builder()
.build()
.map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
let files =
hf_fetch_model::download_files_with_config_blocking(model_id.to_owned(), &fetch_config)
.map(hf_fetch_model::DownloadOutcome::into_inner)
.map_err(|e| MIError::Download(e.to_string()))?;
let config_path = files
.get("config.json")
.ok_or_else(|| MIError::Config("config.json not found in downloaded files".into()))?;
let config_str = std::fs::read_to_string(config_path)
.map_err(|e| MIError::Config(format!("read config.json: {e}")))?;
let json: serde_json::Value = serde_json::from_str(&config_str)
.map_err(|e| MIError::Config(format!("parse config.json: {e}")))?;
let model_type = json
.get("model_type")
.and_then(serde_json::Value::as_str)
.ok_or_else(|| MIError::Config("missing 'model_type' field".into()))?;
let tokenizer = files
.get("tokenizer.json")
.and_then(|p| MITokenizer::from_hf_path(p).ok());
let weights_paths = resolve_safetensors_paths(&files)?;
let vb = create_var_builder(&weights_paths, dtype, &device)?;
match model_type {
#[cfg(feature = "transformer")]
mt if crate::config::SUPPORTED_MODEL_TYPES.contains(&mt) => {
use crate::config::TransformerConfig;
use crate::transformer::GenericTransformer;
let config = TransformerConfig::from_hf_config(&json)?;
let transformer = GenericTransformer::load(config, &device, dtype, vb)?;
Ok(Self::with_tokenizer(
Box::new(transformer),
device,
tokenizer,
))
}
#[cfg(feature = "rwkv")]
mt if crate::rwkv::SUPPORTED_RWKV_MODEL_TYPES.contains(&mt) => {
use crate::rwkv::{GenericRwkv, RwkvConfig};
let config = RwkvConfig::from_hf_config(&json)?;
let rwkv = GenericRwkv::load(config, &device, dtype, vb)?;
Ok(Self::with_tokenizer(Box::new(rwkv), device, tokenizer))
}
#[cfg(feature = "transformer")]
_unknown => {
use crate::config::TransformerConfig;
use crate::transformer::GenericTransformer;
let tensor_names = extract_tensor_names(&files)?;
TransformerConfig::check_auto_compatibility(&json, &tensor_names).into_result()?;
let config = TransformerConfig::from_hf_config_auto(&json, &tensor_names)?;
let transformer = GenericTransformer::load(config, &device, dtype, vb)?;
Ok(Self::with_tokenizer(
Box::new(transformer),
device,
tokenizer,
))
}
#[cfg(not(feature = "transformer"))]
other => Err(MIError::Config(format!(
"unsupported model_type: '{other}' (enable the `transformer` feature for auto-config)"
))),
}
}
#[cfg(any(feature = "transformer", feature = "rwkv"))]
fn select_device() -> Result<Device> {
match Device::cuda_if_available(0) {
Ok(dev) => Ok(dev),
Err(e) => Err(MIError::Model(e)),
}
}
#[must_use]
pub fn new(backend: Box<dyn MIBackend>, device: Device) -> Self {
Self {
backend,
device,
tokenizer: None,
}
}
#[must_use]
pub fn with_tokenizer(
backend: Box<dyn MIBackend>,
device: Device,
tokenizer: Option<MITokenizer>,
) -> Self {
Self {
backend,
device,
tokenizer,
}
}
#[must_use]
pub const fn device(&self) -> &Device {
&self.device
}
#[must_use]
pub const fn tokenizer(&self) -> Option<&MITokenizer> {
self.tokenizer.as_ref()
}
#[must_use]
pub fn num_layers(&self) -> usize {
self.backend.num_layers()
}
#[must_use]
pub fn hidden_size(&self) -> usize {
self.backend.hidden_size()
}
#[must_use]
pub fn vocab_size(&self) -> usize {
self.backend.vocab_size()
}
#[must_use]
pub fn num_heads(&self) -> usize {
self.backend.num_heads()
}
pub fn forward(&self, input_ids: &Tensor, hooks: &HookSpec) -> Result<HookCache> {
self.backend.forward(input_ids, hooks)
}
pub fn project_to_vocab(&self, hidden: &Tensor) -> Result<Tensor> {
self.backend.project_to_vocab(hidden)
}
#[must_use]
pub fn backend(&self) -> &dyn MIBackend {
&*self.backend
}
pub fn forward_text(&self, text: &str, hooks: &HookSpec) -> Result<TextForwardResult> {
let tokenizer = self
.tokenizer()
.ok_or_else(|| MIError::Config("forward_text requires a tokenizer".into()))?;
let encoding = tokenizer.encode_with_offsets(text)?;
let input = Tensor::new(&encoding.ids[..], &self.device)?.unsqueeze(0)?;
let cache = self.forward(&input, hooks)?;
Ok(TextForwardResult { cache, encoding })
}
}
#[derive(Debug)]
pub struct TextForwardResult {
cache: HookCache,
encoding: crate::util::positioning::EncodingWithOffsets,
}
impl TextForwardResult {
#[must_use]
pub const fn cache(&self) -> &HookCache {
&self.cache
}
#[must_use]
pub fn into_cache(self) -> HookCache {
self.cache
}
#[must_use]
pub const fn encoding(&self) -> &crate::util::positioning::EncodingWithOffsets {
&self.encoding
}
#[must_use]
pub const fn output(&self) -> &Tensor {
self.cache.output()
}
pub fn require(&self, hook: &crate::hooks::HookPoint) -> Result<&Tensor> {
self.cache.require(hook)
}
#[must_use]
pub fn get(&self, hook: &crate::hooks::HookPoint) -> Option<&Tensor> {
self.cache.get(hook)
}
#[must_use]
pub fn tokens(&self) -> &[String] {
&self.encoding.tokens
}
#[must_use]
pub const fn seq_len(&self) -> usize {
self.encoding.len()
}
}
pub fn sample_token(logits: &Tensor, temperature: f32) -> Result<u32> {
if temperature <= 0.0 {
argmax(logits)
} else {
sample_with_temperature(logits, temperature)
}
}
fn argmax(logits: &Tensor) -> Result<u32> {
let logits_f32 = logits.to_dtype(DType::F32)?;
let logits_vec: Vec<f32> = logits_f32.flatten_all()?.to_vec1()?;
let (max_idx, _) = logits_vec
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.ok_or_else(|| MIError::Model(candle_core::Error::Msg("empty logits".into())))?;
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
Ok(max_idx as u32)
}
fn sample_with_temperature(logits: &Tensor, temperature: f32) -> Result<u32> {
use rand::Rng;
let logits_f32 = logits.to_dtype(DType::F32)?;
let logits_vec: Vec<f32> = logits_f32.flatten_all()?.to_vec1()?;
if logits_vec.is_empty() {
return Err(MIError::Model(candle_core::Error::Msg(
"empty logits".into(),
)));
}
let scaled: Vec<f32> = logits_vec.iter().map(|x| x / temperature).collect();
let max_val = scaled.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_vals: Vec<f32> = scaled.iter().map(|x| (x - max_val).exp()).collect();
let sum: f32 = exp_vals.iter().sum();
let probs: Vec<f32> = exp_vals.iter().map(|x| x / sum).collect();
let mut rng = rand::thread_rng();
let r: f32 = rng.r#gen();
let mut cumsum = 0.0;
for (idx, &p) in probs.iter().enumerate() {
cumsum += p;
if r < cumsum {
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
return Ok(idx as u32);
}
}
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
Ok((probs.len() - 1) as u32)
}
pub fn extract_token_prob(logits: &Tensor, token_id: u32) -> Result<f32> {
use candle_core::IndexOp;
let logits_f32 = logits.to_dtype(DType::F32)?;
let last_logits = match logits_f32.dims().len() {
1 => logits_f32,
2 => {
let seq_len = logits_f32.dim(0)?;
logits_f32.i(seq_len - 1)?
}
3 => {
let seq_len = logits_f32.dim(1)?;
logits_f32.i((0, seq_len - 1))?
}
n => {
return Err(MIError::Model(candle_core::Error::Msg(format!(
"extract_token_prob: expected 1-3 dims, got {n}"
))));
}
};
let probs = candle_nn::ops::softmax_last_dim(&last_logits)?;
#[allow(clippy::as_conversions)]
let prob = probs.i(token_id as usize)?.to_scalar::<f32>()?;
Ok(prob)
}
#[derive(Debug, Clone)]
pub struct GenerationResult {
pub prompt: String,
pub full_text: String,
pub generated_text: String,
pub prompt_tokens: Vec<u32>,
pub generated_tokens: Vec<u32>,
pub total_tokens: usize,
}
#[cfg(any(feature = "transformer", feature = "rwkv"))]
#[derive(serde::Deserialize)]
struct SafetensorsIndex {
weight_map: std::collections::HashMap<String, String>,
}
#[cfg(feature = "transformer")]
fn extract_tensor_names(
files: &std::collections::HashMap<String, std::path::PathBuf>,
) -> Result<Vec<String>> {
if let Some(index_path) = files.get("model.safetensors.index.json") {
return crate::config::tensor_names_from_index(index_path);
}
if let Some(st_path) = files.get("model.safetensors") {
return crate::config::tensor_names_from_safetensors(st_path);
}
Err(MIError::Config(
"no safetensors files found for tensor name extraction".into(),
))
}
#[cfg(any(feature = "transformer", feature = "rwkv"))]
fn resolve_safetensors_paths(
files: &std::collections::HashMap<String, std::path::PathBuf>,
) -> Result<Vec<std::path::PathBuf>> {
if let Some(index_path) = files.get("model.safetensors.index.json") {
let index_str = std::fs::read_to_string(index_path)
.map_err(|e| MIError::Model(candle_core::Error::Msg(format!("read index: {e}"))))?;
let index: SafetensorsIndex = serde_json::from_str(&index_str)
.map_err(|e| MIError::Config(format!("parse index: {e}")))?;
let mut shard_names: Vec<String> = index.weight_map.values().cloned().collect();
shard_names.sort();
shard_names.dedup();
let mut paths = Vec::with_capacity(shard_names.len());
for shard_name in &shard_names {
let path = files.get(shard_name.as_str()).ok_or_else(|| {
MIError::Model(candle_core::Error::Msg(format!(
"shard {shard_name} not found in downloaded files"
)))
})?;
paths.push(path.clone());
}
return Ok(paths);
}
let path = files.get("model.safetensors").ok_or_else(|| {
MIError::Model(candle_core::Error::Msg(
"model.safetensors not found in downloaded files".into(),
))
})?;
Ok(vec![path.clone()])
}
#[cfg(any(feature = "transformer", feature = "rwkv"))]
fn create_var_builder(
paths: &[std::path::PathBuf],
dtype: DType,
device: &Device,
) -> Result<candle_nn::VarBuilder<'static>> {
#[cfg(feature = "mmap")]
{
mmap_var_builder(paths, dtype, device)
}
#[cfg(not(feature = "mmap"))]
{
buffered_var_builder(paths, dtype, device)
}
}
#[cfg(all(any(feature = "transformer", feature = "rwkv"), not(feature = "mmap")))]
fn buffered_var_builder(
paths: &[std::path::PathBuf],
dtype: DType,
device: &Device,
) -> Result<candle_nn::VarBuilder<'static>> {
if paths.len() > 1 {
return Err(MIError::Config(format!(
"this model is sharded across {} files and requires the `mmap` feature.\n \
Library: candle-mi = {{ features = [\"mmap\"] }}\n \
Example: cargo run --features mmap --example <name>",
paths.len()
)));
}
let path = paths
.first()
.ok_or_else(|| MIError::Model(candle_core::Error::Msg("no safetensors files".into())))?;
let data = std::fs::read(path).map_err(|e| {
MIError::Model(candle_core::Error::Msg(format!(
"read {}: {e}",
path.display()
)))
})?;
let vb = candle_nn::VarBuilder::from_buffered_safetensors(data, dtype, device)?;
Ok(vb)
}
#[cfg(all(any(feature = "transformer", feature = "rwkv"), feature = "mmap"))]
#[allow(unsafe_code)]
fn mmap_var_builder(
paths: &[std::path::PathBuf],
dtype: DType,
device: &Device,
) -> Result<candle_nn::VarBuilder<'static>> {
let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(paths, dtype, device)? };
Ok(vb)
}