use anyhow::{anyhow, Result};
use async_trait::async_trait;
use safetensors::SafeTensors;
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use tracing::{info, debug, warn};
use super::{GenOptions, LoadedModel, ModelSpec, InferenceEngine};
#[derive(Debug)]
pub struct SafeTensorsEngine {
}
impl Default for SafeTensorsEngine {
fn default() -> Self {
Self::new()
}
}
impl SafeTensorsEngine {
pub fn new() -> Self {
Self {}
}
pub fn is_safetensors_model(path: &Path) -> bool {
if let Some(ext) = path.extension().and_then(|s| s.to_str()) {
return ext == "safetensors";
}
if let Ok(data) = fs::read(path) {
if data.len() >= 8 {
return SafeTensors::deserialize(&data).is_ok();
}
}
false
}
pub fn discover_safetensors_models(dir: &Path) -> Result<Vec<PathBuf>> {
let mut models = Vec::new();
if !dir.exists() || !dir.is_dir() {
return Ok(models);
}
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.is_file() && Self::is_safetensors_model(&path) {
models.push(path);
}
}
Ok(models)
}
}
#[async_trait]
impl InferenceEngine for SafeTensorsEngine {
async fn load(&self, spec: &ModelSpec) -> Result<Box<dyn LoadedModel>> {
info!("Loading SafeTensors model: {}", spec.base_path.display());
if !Self::is_safetensors_model(&spec.base_path) {
return Err(anyhow!("File is not in SafeTensors format: {}", spec.base_path.display()));
}
let model = SafeTensorsModel::load(spec).await?;
Ok(Box::new(model))
}
}
#[derive(Debug)]
struct SafeTensorsModel {
name: String,
model_data: Vec<u8>, config: ModelConfig,
tokenizer: SimpleTokenizer,
}
#[derive(Debug, Clone)]
struct ModelConfig {
vocab_size: usize,
hidden_size: usize,
num_layers: usize,
max_sequence_length: usize,
}
#[derive(Debug)]
struct SimpleTokenizer {
vocab: HashMap<String, u32>,
reverse_vocab: HashMap<u32, String>,
bos_token: u32,
eos_token: u32,
}
impl SafeTensorsModel {
async fn load(spec: &ModelSpec) -> Result<Self> {
info!("Reading SafeTensors file: {}", spec.base_path.display());
let model_data = fs::read(&spec.base_path)?;
let tensors = SafeTensors::deserialize(&model_data)?;
debug!("SafeTensors loaded with {} tensors", tensors.len());
for name in tensors.names() {
if let Ok(tensor) = tensors.tensor(name) {
debug!("Tensor '{}': shape {:?}, dtype {:?}", name, tensor.shape(), tensor.dtype());
}
}
let config = Self::load_config(&spec.base_path, &tensors).await?;
let tokenizer = Self::load_tokenizer(&spec.base_path).await?;
Ok(SafeTensorsModel {
name: spec.name.clone(),
model_data, config,
tokenizer,
})
}
async fn load_config(model_path: &Path, tensors: &SafeTensors<'_>) -> Result<ModelConfig> {
let config_path = model_path.with_file_name("config.json");
if config_path.exists() {
let config_data = fs::read_to_string(&config_path)?;
let json: serde_json::Value = serde_json::from_str(&config_data)?;
let vocab_size = json["vocab_size"].as_u64().unwrap_or(32000) as usize;
let hidden_size = json["hidden_size"].as_u64().unwrap_or(4096) as usize;
let num_layers = json["num_hidden_layers"].as_u64().unwrap_or(32) as usize;
let max_length = json["max_position_embeddings"].as_u64().unwrap_or(2048) as usize;
return Ok(ModelConfig {
vocab_size,
hidden_size,
num_layers,
max_sequence_length: max_length,
});
}
info!("No config.json found, inferring from tensor shapes");
let vocab_size = if let Ok(tensor) = tensors.tensor("lm_head.weight") {
tensor.shape()[0]
} else if let Ok(tensor) = tensors.tensor("embed_tokens.weight") {
tensor.shape()[0]
} else {
32000 };
let hidden_size = if let Ok(tensor) = tensors.tensor("embed_tokens.weight") {
tensor.shape()[1]
} else {
4096 };
let mut num_layers = 0;
for name in tensors.names() {
if name.contains("layers.") {
if let Some(layer_num_str) = name.split("layers.").nth(1).and_then(|s| s.split('.').next()) {
if let Ok(layer_num) = layer_num_str.parse::<usize>() {
num_layers = num_layers.max(layer_num + 1);
}
}
}
}
if num_layers == 0 {
num_layers = 32; }
Ok(ModelConfig {
vocab_size,
hidden_size,
num_layers,
max_sequence_length: 2048,
})
}
async fn load_tokenizer(model_path: &Path) -> Result<SimpleTokenizer> {
let tokenizer_path = model_path.with_file_name("tokenizer.json");
if tokenizer_path.exists() {
let tokenizer_data = fs::read_to_string(&tokenizer_path)?;
let json: serde_json::Value = serde_json::from_str(&tokenizer_data)?;
if let Some(model) = json.get("model") {
if let Some(vocab) = model.get("vocab") {
let mut vocab_map = HashMap::new();
let mut reverse_vocab = HashMap::new();
if let Some(vocab_obj) = vocab.as_object() {
for (token, id) in vocab_obj {
if let Some(id_num) = id.as_u64() {
vocab_map.insert(token.clone(), id_num as u32);
reverse_vocab.insert(id_num as u32, token.clone());
}
}
}
let bos_token = vocab_map.get("<s>").copied()
.or_else(|| vocab_map.get("<|startoftext|>").copied())
.unwrap_or(0);
let eos_token = vocab_map.get("</s>").copied()
.or_else(|| vocab_map.get("<|endoftext|>").copied())
.unwrap_or(1);
return Ok(SimpleTokenizer {
vocab: vocab_map,
reverse_vocab,
bos_token,
eos_token,
});
}
}
}
warn!("No tokenizer.json found, using simple character tokenizer");
Self::create_simple_char_tokenizer()
}
fn create_simple_char_tokenizer() -> Result<SimpleTokenizer> {
let mut vocab = HashMap::new();
let mut reverse_vocab = HashMap::new();
vocab.insert("<s>".to_string(), 0);
reverse_vocab.insert(0, "<s>".to_string());
vocab.insert("</s>".to_string(), 1);
reverse_vocab.insert(1, "</s>".to_string());
vocab.insert("<unk>".to_string(), 2);
reverse_vocab.insert(2, "<unk>".to_string());
let mut id = 3;
for c in 32u8..=126 { let char_str = (c as char).to_string();
vocab.insert(char_str.clone(), id);
reverse_vocab.insert(id, char_str);
id += 1;
}
Ok(SimpleTokenizer {
vocab,
reverse_vocab,
bos_token: 0,
eos_token: 1,
})
}
}
impl SimpleTokenizer {
fn encode(&self, text: &str) -> Vec<u32> {
let mut tokens = vec![self.bos_token];
for char in text.chars() {
let token = self.vocab.get(&char.to_string()).copied().unwrap_or(2); tokens.push(token);
}
tokens
}
fn decode(&self, tokens: &[u32]) -> String {
let mut text = String::new();
for &token in tokens {
if token == self.bos_token || token == self.eos_token {
continue; }
if let Some(token_str) = self.reverse_vocab.get(&token) {
text.push_str(token_str);
}
}
text
}
}
#[async_trait]
impl LoadedModel for SafeTensorsModel {
async fn generate(
&self,
prompt: &str,
opts: GenOptions,
on_token: Option<Box<dyn FnMut(String) + Send>>,
) -> Result<String> {
info!("Generating with SafeTensors model: prompt length = {}", prompt.len());
let input_tokens = self.tokenizer.encode(prompt);
debug!("Input tokens: {} tokens", input_tokens.len());
let response = self.simple_generate(prompt, &opts).await?;
if let Some(mut callback) = on_token {
for word in response.split_whitespace() {
callback(format!("{} ", word));
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
}
Ok(response)
}
}
impl SafeTensorsModel {
async fn simple_generate(&self, prompt: &str, opts: &GenOptions) -> Result<String> {
warn!("Using simplified SafeTensors generation - full inference not yet implemented");
let response = format!(
"SafeTensors model '{}' loaded successfully with {} layers and vocab size {}. \
Input prompt: '{}' (length: {}). \
This is a demonstration that SafeTensors models can be loaded natively in Rust without Python. \
Full transformer inference coming soon!",
self.name,
self.config.num_layers,
self.config.vocab_size,
prompt,
prompt.len()
);
let words: Vec<&str> = response.split_whitespace().collect();
let limited_response = if words.len() > opts.max_tokens {
words[..opts.max_tokens].join(" ")
} else {
response
};
Ok(limited_response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::TempDir;
#[test]
fn test_safetensors_engine_creation() {
let engine = SafeTensorsEngine::new();
assert!(format!("{:?}", engine).contains("SafeTensorsEngine"));
}
#[test]
fn test_is_safetensors_model() {
assert!(SafeTensorsEngine::is_safetensors_model(Path::new("model.safetensors")));
assert!(!SafeTensorsEngine::is_safetensors_model(Path::new("model.gguf")));
assert!(!SafeTensorsEngine::is_safetensors_model(Path::new("config.json")));
}
#[test]
fn test_discover_safetensors_models_empty_dir() {
let temp_dir = TempDir::new().unwrap();
let models = SafeTensorsEngine::discover_safetensors_models(temp_dir.path()).unwrap();
assert!(models.is_empty());
}
#[test]
fn test_discover_safetensors_models_with_files() {
let temp_dir = TempDir::new().unwrap();
let safetensors_path = temp_dir.path().join("model.safetensors");
let mut file = fs::File::create(&safetensors_path).unwrap();
file.write_all(&create_minimal_safetensors()).unwrap();
let other_path = temp_dir.path().join("config.json");
fs::write(&other_path, r#"{"test": true}"#).unwrap();
let models = SafeTensorsEngine::discover_safetensors_models(temp_dir.path()).unwrap();
assert_eq!(models.len(), 1);
assert!(models[0].file_name().unwrap().to_str().unwrap().contains("model.safetensors"));
}
#[tokio::test]
async fn test_safetensors_engine_load_invalid_file() {
let engine = SafeTensorsEngine::new();
let spec = ModelSpec {
name: "test".to_string(),
base_path: PathBuf::from("nonexistent.safetensors"),
lora_path: None,
template: Some("chatml".to_string()),
ctx_len: 2048,
n_threads: None,
};
let result = engine.load(&spec).await;
assert!(result.is_err());
}
#[test]
fn test_simple_tokenizer_creation() {
let tokenizer = SafeTensorsModel::create_simple_char_tokenizer().unwrap();
assert!(tokenizer.vocab.contains_key("<s>"));
assert!(tokenizer.vocab.contains_key("</s>"));
assert!(tokenizer.vocab.contains_key("A"));
assert_eq!(tokenizer.bos_token, 0);
assert_eq!(tokenizer.eos_token, 1);
}
#[test]
fn test_simple_tokenizer_encode_decode() {
let tokenizer = SafeTensorsModel::create_simple_char_tokenizer().unwrap();
let text = "Hello";
let tokens = tokenizer.encode(text);
let decoded = tokenizer.decode(&tokens[1..tokens.len()]);
assert_eq!(decoded, text);
}
#[test]
fn test_model_config_creation() {
let config = ModelConfig {
vocab_size: 32000,
hidden_size: 4096,
num_layers: 32,
max_sequence_length: 2048,
};
assert_eq!(config.vocab_size, 32000);
assert_eq!(config.hidden_size, 4096);
assert_eq!(config.num_layers, 32);
assert_eq!(config.max_sequence_length, 2048);
}
fn create_minimal_safetensors() -> Vec<u8> {
let metadata = r#"{"test_tensor":{"dtype":"F32","shape":[1,1],"data_offsets":[0,4]}}"#;
let metadata_bytes = metadata.as_bytes();
let metadata_len = metadata_bytes.len() as u64;
let mut data = Vec::new();
data.extend_from_slice(&metadata_len.to_le_bytes());
data.extend_from_slice(metadata_bytes);
data.extend_from_slice(&[0u8, 0u8, 0u8, 0u8]);
data
}
}