oxidized_transformers/models/hf/
from_hf.rsuse candle_core::{DType, Device};
use candle_nn::var_builder::SimpleBackend;
use candle_nn::VarBuilder;
use serde::{Deserialize, Serialize};
use snafu::{ResultExt, Snafu};
use crate::architectures::BuildArchitecture;
use crate::error::BoxedError;
use crate::util::renaming_backend::RenamingBackend;
#[derive(Debug, Snafu)]
pub enum FromHFError {
#[snafu(display("Cannot convert Hugging Face model config"))]
ConvertConfig { source: BoxedError },
#[snafu(display("Cannot build model"))]
BuildModel { source: BoxedError },
}
pub trait FromHF {
type Config: BuildArchitecture<Architecture = Self::Model>
+ TryFrom<Self::HFConfig, Error = BoxedError>;
type HFConfig: Clone;
type Model;
fn from_hf(
hf_config: HFConfigWithDType<Self::HFConfig>,
backend: Box<dyn SimpleBackend>,
device: &Device,
) -> Result<Self::Model, FromHFError> {
let config =
Self::Config::try_from(hf_config.config().clone()).context(ConvertConfigSnafu)?;
let rename_backend = RenamingBackend::new(backend, Self::rename_parameters());
let vb =
VarBuilder::from_backend(Box::new(rename_backend), hf_config.dtype(), device.clone());
config.build(vb).context(BuildModelSnafu)
}
fn rename_parameters() -> impl Fn(&str) -> String + Send + Sync;
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "lowercase")]
enum TorchDType {
BFloat16,
Float16,
Float32,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Deserialize, Eq, PartialEq)]
pub struct HFConfigWithDType<T> {
#[serde(flatten)]
config: T,
torch_dtype: TorchDType,
}
impl<T> HFConfigWithDType<T> {
pub fn config(&self) -> &T {
&self.config
}
pub fn dtype(&self) -> DType {
match self.torch_dtype {
TorchDType::BFloat16 => DType::BF16,
TorchDType::Float16 => DType::F16,
TorchDType::Float32 => DType::F32,
}
}
}