pub mod model;
pub use model::{BitMambaStudent, BitLinear, RMSNorm, BitMambaBlock};
use anyhow::{Error, Result};
use candle_core::{DType, Device};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
pub const DEFAULT_MODEL_REPO: &str = "rileyseaburg/bitmamba-student";
pub fn load() -> Result<(BitMambaStudent, Tokenizer)> {
load_from_repo(DEFAULT_MODEL_REPO)
}
pub fn load_from_repo(repo_id: &str) -> Result<(BitMambaStudent, Tokenizer)> {
let device = Device::Cpu;
let api = Api::new()?;
let repo = api.repo(Repo::new(repo_id.to_string(), RepoType::Model));
let model_path = repo.get("model.safetensors")?;
let tokenizer_path = repo.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(Error::msg)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)? };
let model = BitMambaStudent::load(vb, device)?;
Ok((model, tokenizer))
}
pub fn load_model_from_file(path: &str) -> Result<BitMambaStudent> {
let device = Device::Cpu;
let path_buf = std::path::PathBuf::from(path);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[path_buf], DType::F32, &device)? };
BitMambaStudent::load(vb, device)
}
pub fn load_tokenizer_from_file(path: &str) -> Result<Tokenizer> {
Tokenizer::from_file(path).map_err(Error::msg)
}