Skip to main content

svod_model/
state.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use snafu::{ResultExt, Snafu};
5use svod_dtype::DType;
6use svod_tensor::Tensor;
7
8pub type StateDict = HashMap<String, Tensor>;
9
10#[derive(Debug, Snafu)]
11pub enum Error {
12    #[snafu(display("failed to read file: {source}"))]
13    Io { source: std::io::Error },
14    #[snafu(display("failed to deserialize safetensors"))]
15    Safetensors { source: safetensors::SafeTensorError },
16    #[snafu(display("unsupported dtype in safetensors: {dtype}"))]
17    UnsupportedDtype { dtype: String },
18    #[snafu(display("missing key in state dict: {key}"))]
19    MissingKey { key: String },
20    #[snafu(display("{source}"))]
21    Tensor {
22        #[snafu(source(from(svod_tensor::error::Error, Box::new)))]
23        source: Box<svod_tensor::error::Error>,
24    },
25}
26
27type Result<T> = std::result::Result<T, Error>;
28
29pub fn load_safetensors(path: &Path) -> Result<StateDict> {
30    let data = std::fs::read(path).context(IoSnafu)?;
31    let tensors = safetensors::SafeTensors::deserialize(&data).context(SafetensorsSnafu)?;
32    let mut sd = StateDict::new();
33    for (name, view) in tensors.tensors() {
34        let dtype = convert_dtype(view.dtype())?;
35        let shape: Vec<usize> = view.shape().to_vec();
36        let tensor = Tensor::from_raw_bytes(view.data(), &shape, dtype).context(TensorSnafu)?;
37        sd.insert(name.to_string(), tensor);
38    }
39    Ok(sd)
40}
41
42fn convert_dtype(dt: safetensors::Dtype) -> Result<DType> {
43    use safetensors::Dtype as ST;
44    match dt {
45        ST::F32 => Ok(DType::Float32),
46        ST::F16 => Ok(DType::Float16),
47        ST::BF16 => Ok(DType::BFloat16),
48        ST::F64 => Ok(DType::Float64),
49        ST::I32 => Ok(DType::Int32),
50        ST::I64 => Ok(DType::Int64),
51        ST::I16 => Ok(DType::Int16),
52        ST::I8 => Ok(DType::Int8),
53        ST::U8 => Ok(DType::UInt8),
54        ST::BOOL => Ok(DType::Bool),
55        other => Err(Error::UnsupportedDtype { dtype: format!("{other:?}") }),
56    }
57}
58
59pub trait HasStateDict {
60    fn state_dict(&self, prefix: &str) -> StateDict;
61    fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> Result<()>;
62}
63
64/// Helper: get a tensor from a state dict by key, returning an error if missing.
65pub fn get_tensor(sd: &StateDict, key: &str) -> Result<Tensor> {
66    sd.get(key).cloned().ok_or_else(|| Error::MissingKey { key: key.to_string() })
67}
68
69/// Helper: format a prefixed key.
70pub fn prefixed(prefix: &str, name: &str) -> String {
71    if prefix.is_empty() { name.to_string() } else { format!("{prefix}.{name}") }
72}
73
74/// Insert each named field of `$self` into the state dict under
75/// `<prefix>.<field>`. Field idents are used verbatim as keys.
76#[macro_export]
77macro_rules! state_field {
78    ($sd:expr, $prefix:expr, $self:ident, [$($field:ident),+ $(,)?]) => {
79        $(
80            $sd.insert(
81                $crate::state::prefixed($prefix, stringify!($field)),
82                $self.$field.clone(),
83            );
84        )+
85    };
86}
87
88/// Load each named field of `$self` from the state dict under
89/// `<prefix>.<field>`. Mirrors [`state_field!`].
90#[macro_export]
91macro_rules! load_state_field {
92    ($self:ident, $sd:expr, $prefix:expr, [$($field:ident),+ $(,)?]) => {
93        $(
94            $self.$field = $crate::state::get_tensor(
95                $sd,
96                &$crate::state::prefixed($prefix, stringify!($field)),
97            )?;
98        )+
99    };
100}