oxidized_transformers/models/hf/
from_hf.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
use candle_core::{DType, Device};
use candle_nn::var_builder::SimpleBackend;
use candle_nn::VarBuilder;
use serde::{Deserialize, Serialize};
use snafu::{ResultExt, Snafu};

use crate::architectures::BuildArchitecture;
use crate::error::BoxedError;
use crate::util::renaming_backend::RenamingBackend;

#[derive(Debug, Snafu)]
pub enum FromHFError {
    #[snafu(display("Cannot convert Hugging Face model config"))]
    ConvertConfig { source: BoxedError },

    #[snafu(display("Cannot build model"))]
    BuildModel { source: BoxedError },
}

/// Models that can be loaded from Huggingface transformers checkpoints.
pub trait FromHF {
    /// Model configuration.
    type Config: BuildArchitecture<Architecture = Self::Model>
        + TryFrom<Self::HFConfig, Error = BoxedError>;

    /// HF transformers model configuration.
    type HFConfig: Clone;

    /// The type of model that is constructed.
    ///
    /// Note that this is different from `Self`. `Self` is typically a
    /// unit struct that only implements various loading strategies.
    /// `Model` is a concrete model type such as `TransformerDecoder`.
    type Model;

    /// Construct a model from an HF model configuration and parameter backend.
    ///
    /// * `hf_config` - The Hugging Face transformers model configuration.
    /// * `backend` - The parameter store backend.
    /// * `device` - The device to place the model on.
    fn from_hf(
        hf_config: HFConfigWithDType<Self::HFConfig>,
        backend: Box<dyn SimpleBackend>,
        device: &Device,
    ) -> Result<Self::Model, FromHFError> {
        // Ideally we would not clone here, but TryFrom<&...> adds a lot of
        // pesky lifetime annotations everywhere.
        let config =
            Self::Config::try_from(hf_config.config().clone()).context(ConvertConfigSnafu)?;
        let rename_backend = RenamingBackend::new(backend, Self::rename_parameters());
        let vb =
            VarBuilder::from_backend(Box::new(rename_backend), hf_config.dtype(), device.clone());
        config.build(vb).context(BuildModelSnafu)
    }

    /// Create a parameter renaming function.
    ///
    /// This method should return a function that renames Oxidized Transformers
    /// parameter names to Hugging Face transformers parameter names.
    fn rename_parameters() -> impl Fn(&str) -> String + Send + Sync;
}

/// Torch dtype
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "lowercase")]
enum TorchDType {
    BFloat16,
    Float16,
    Float32,
}

/// Simple wrapper for a HF config that exposes the dtype.
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Deserialize, Eq, PartialEq)]
pub struct HFConfigWithDType<T> {
    #[serde(flatten)]
    config: T,
    torch_dtype: TorchDType,
}

impl<T> HFConfigWithDType<T> {
    /// Get the configuration.
    pub fn config(&self) -> &T {
        &self.config
    }

    /// Get the dtype.
    pub fn dtype(&self) -> DType {
        match self.torch_dtype {
            TorchDType::BFloat16 => DType::BF16,
            TorchDType::Float16 => DType::F16,
            TorchDType::Float32 => DType::F32,
        }
    }
}