use std::path::Path;
use snafu::ResultExt;
use crate::state::{self, HasStateDict, StateDict};
use crate::gigaam::ctc::CTCHead;
use crate::gigaam::encoder::Encoder;
use crate::gigaam::error::{Error, HubSnafu, StateSnafu};
use crate::gigaam::rnnt::RnntHead;
use crate::gigaam::{GigaAmConfig, Result, remap};
use crate::sentencepiece;
#[derive(Clone)]
pub struct GigaAm {
pub config: GigaAmConfig,
pub encoder: Encoder,
pub head: Head,
}
#[derive(Clone)]
pub enum Head {
Ctc(CTCHead),
Rnnt { head: RnntHead, runtime: RnntRuntime },
}
#[derive(Clone)]
pub struct RnntRuntime {
pub vocabulary: Vec<String>,
pub max_symbols_per_step: usize,
pub sentencepiece: bool,
}
impl Head {
pub fn as_ctc(&self) -> Option<&CTCHead> {
if let Head::Ctc(h) = self { Some(h) } else { None }
}
pub fn as_rnnt(&self) -> Option<(&RnntHead, &RnntRuntime)> {
if let Head::Rnnt { head, runtime } = self { Some((head, runtime)) } else { None }
}
pub(crate) fn expect_ctc(&self, ctx: &str) -> Result<&CTCHead> {
self.as_ctc().ok_or_else(|| Error::DecoderConfig {
message: format!("{ctx} requires a CTC head; this model has an RN-T head"),
})
}
pub(crate) fn expect_rnnt(&self, ctx: &str) -> Result<(&RnntHead, &RnntRuntime)> {
self.as_rnnt().ok_or_else(|| Error::DecoderConfig {
message: format!("{ctx} requires an RN-T head; this model has a CTC head"),
})
}
}
impl GigaAm {
pub fn from_hub(model_id: &str) -> Result<Self> {
Self::from_hub_with_revision(model_id, "main")
}
pub fn from_hub_with_revision(model_id: &str, revision: &str) -> Result<Self> {
let api = hf_hub::api::sync::Api::new().context(HubSnafu)?;
let repo =
api.repo(hf_hub::Repo::with_revision(model_id.to_string(), hf_hub::RepoType::Model, revision.to_string()));
let config_path = repo.get("config.json").context(HubSnafu)?;
let weights_path = repo.get("model.safetensors").context(HubSnafu)?;
let config = GigaAmConfig::from_json(&config_path)?;
let tokenizer_path = if config.transducer.is_some() { repo.get("tokenizer.model").ok() } else { None };
Self::from_safetensors(&weights_path, tokenizer_path.as_deref(), config)
}
pub fn from_dir(dir: &Path) -> Result<Self> {
let config_path = dir.join("config.json");
let weights_path = dir.join("model.safetensors");
let config = GigaAmConfig::from_json(&config_path)?;
let tokenizer_path = dir.join("tokenizer.model");
let tokenizer_path =
if config.transducer.is_some() && tokenizer_path.exists() { Some(tokenizer_path) } else { None };
Self::from_safetensors(&weights_path, tokenizer_path.as_deref(), config)
}
pub fn from_safetensors(weights: &Path, tokenizer: Option<&Path>, config: GigaAmConfig) -> Result<Self> {
let sd = state::load_safetensors(weights).context(StateSnafu)?;
let vocab_override = tokenizer
.map(sentencepiece::load_vocab)
.transpose()
.map_err(|e| Error::DecoderConfig { message: e.to_string() })?;
Self::from_state_dict(&sd, config, vocab_override)
}
pub fn from_state_dict(sd: &StateDict, config: GigaAmConfig, vocab_override: Option<Vec<String>>) -> Result<Self> {
let is_pytorch = sd.keys().any(|k| {
k.starts_with("encoder.")
|| k.starts_with("model.encoder.")
|| k.starts_with("head.decoder.")
|| k.starts_with("head.joint.")
});
let sd_owned = if is_pytorch { remap::remap_pytorch(sd.clone(), &config)? } else { sd.clone() };
let sd = &sd_owned;
let encoder = Encoder::from_state_dict(sd, &config)?;
let head = match &config.transducer {
None => {
let mut h = CTCHead::empty(&config);
h.load_state_dict(sd, "head").context(StateSnafu)?;
Head::Ctc(h)
}
Some(tr) => {
let vocabulary = vocab_override.unwrap_or_else(|| tr.vocabulary.clone());
if vocabulary.len() + 1 != tr.num_classes {
return Err(Error::DecoderConfig {
message: format!(
"RN-T vocabulary length + 1 ({}) != num_classes ({}); \
convention is one blank token at the end",
vocabulary.len() + 1,
tr.num_classes
),
});
}
let mut h = RnntHead::empty(
config.d_model,
tr.pred_hidden,
tr.pred_rnn_layers,
tr.joint_hidden,
tr.num_classes,
);
h.load_state_dict(sd, "head").context(StateSnafu)?;
h.predictor.prepare_for_inference()?;
Head::Rnnt {
head: h,
runtime: RnntRuntime {
vocabulary,
max_symbols_per_step: tr.max_symbols_per_step,
sentencepiece: tr.sentencepiece,
},
}
}
};
Ok(Self { config, encoder, head })
}
pub fn with_random_weights(config: GigaAmConfig) -> Self {
let encoder = Encoder::with_random_weights(&config);
let head = match &config.transducer {
None => Head::Ctc(CTCHead::empty(&config)),
Some(tr) => Head::Rnnt {
head: RnntHead::empty(
config.d_model,
tr.pred_hidden,
tr.pred_rnn_layers,
tr.joint_hidden,
tr.num_classes,
),
runtime: RnntRuntime {
vocabulary: tr.vocabulary.clone(),
max_symbols_per_step: tr.max_symbols_per_step,
sentencepiece: tr.sentencepiece,
},
},
};
Self { config, encoder, head }
}
}