use std::{
collections::HashMap,
io::{Cursor, Read},
};
use serde::{Deserialize, Serialize};
use crate::error::{RealizarError, Result};
use crate::inference::simd_bf16_to_f32;
pub fn find_sibling_file(
model_path: &std::path::Path,
companion_name: &str,
) -> Option<std::path::PathBuf> {
let parent = model_path.parent()?;
let filename = model_path.file_name()?.to_str()?;
if !filename.ends_with(".index.json") {
let stem = model_path.file_stem()?.to_str()?;
let prefixed = parent.join(format!("{stem}.{companion_name}"));
if prefixed.exists() {
return Some(prefixed);
}
}
let plain = parent.join(companion_name);
if plain.exists() {
return Some(plain);
}
None
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub enum SafetensorsDtype {
F32,
F16,
BF16,
I32,
I64,
U8,
Bool,
}
#[derive(Debug, Deserialize)]
struct TensorMetadata {
dtype: SafetensorsDtype,
shape: Vec<usize>,
data_offsets: [usize; 2],
}
#[derive(Debug, Clone, PartialEq)]
pub struct SafetensorsTensorInfo {
pub name: String,
pub dtype: SafetensorsDtype,
pub shape: Vec<usize>,
pub data_offsets: [usize; 2],
}
#[derive(Debug, Clone)]
pub struct SafetensorsModel {
pub tensors: HashMap<String, SafetensorsTensorInfo>,
pub data: Vec<u8>,
}
include!("safetensors_parser.rs");
include!("mapped_model.rs");
include!("shard.rs");