use std::collections::HashMap;
use ferrotorch_core::grad_fns::arithmetic::{add, mul};
use ferrotorch_core::{
FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage, numeric_cast,
};
use ferrotorch_nn::module::{Module, StateDict};
use ferrotorch_nn::parameter::Parameter;
use ferrotorch_nn::{
Embedding, GELU, GeluApproximate, LayerNorm, Linear, reshape_to_heads, standard_attention,
transpose_heads_to_2d,
};
#[derive(Debug, Clone)]
pub struct ClipTextConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
pub max_position_embeddings: usize,
pub vocab_size: usize,
pub layer_norm_eps: f64,
}
impl Default for ClipTextConfig {
fn default() -> Self {
Self::sd_v1_5()
}
}
impl ClipTextConfig {
pub fn sd_v1_5() -> Self {
Self {
hidden_size: 768,
intermediate_size: 3072,
num_attention_heads: 12,
num_hidden_layers: 12,
max_position_embeddings: 77,
vocab_size: 49408,
layer_norm_eps: 1e-5,
}
}
#[inline]
#[must_use]
pub fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
pub fn validate(&self) -> FerrotorchResult<()> {
if self.hidden_size == 0
|| self.intermediate_size == 0
|| self.num_attention_heads == 0
|| self.num_hidden_layers == 0
|| self.max_position_embeddings == 0
|| self.vocab_size == 0
{
return Err(FerrotorchError::InvalidArgument {
message: "ClipTextConfig: all size fields must be > 0".into(),
});
}
if self.hidden_size % self.num_attention_heads != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ClipTextConfig: hidden_size {} not divisible by num_attention_heads {}",
self.hidden_size, self.num_attention_heads,
),
});
}
if !self.layer_norm_eps.is_finite() || self.layer_norm_eps <= 0.0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ClipTextConfig: layer_norm_eps must be finite and > 0, got {}",
self.layer_norm_eps,
),
});
}
Ok(())
}
pub fn from_json_str(s: &str) -> FerrotorchResult<Self> {
let v: serde_json::Value =
serde_json::from_str(s).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("ClipTextConfig::from_json_str: bad JSON: {e}"),
})?;
let mut cfg = Self::default();
if let Some(x) = v.get("hidden_size").and_then(serde_json::Value::as_u64) {
cfg.hidden_size = x as usize;
}
if let Some(x) = v
.get("intermediate_size")
.and_then(serde_json::Value::as_u64)
{
cfg.intermediate_size = x as usize;
}
if let Some(x) = v
.get("num_attention_heads")
.and_then(serde_json::Value::as_u64)
{
cfg.num_attention_heads = x as usize;
}
if let Some(x) = v
.get("num_hidden_layers")
.and_then(serde_json::Value::as_u64)
{
cfg.num_hidden_layers = x as usize;
}
if let Some(x) = v
.get("max_position_embeddings")
.and_then(serde_json::Value::as_u64)
{
cfg.max_position_embeddings = x as usize;
}
if let Some(x) = v.get("vocab_size").and_then(serde_json::Value::as_u64) {
cfg.vocab_size = x as usize;
}
if let Some(x) = v.get("layer_norm_eps").and_then(serde_json::Value::as_f64) {
cfg.layer_norm_eps = x;
}
cfg.validate()?;
Ok(cfg)
}
pub fn from_file(path: &std::path::Path) -> FerrotorchResult<Self> {
let s = std::fs::read_to_string(path).map_err(|e| FerrotorchError::InvalidArgument {
message: format!(
"ClipTextConfig::from_file: failed to read {}: {e}",
path.display(),
),
})?;
Self::from_json_str(&s)
}
}
fn reshape_owned<T: Float>(t: &Tensor<T>, shape: Vec<usize>) -> FerrotorchResult<Tensor<T>> {
let prod: usize = shape.iter().product();
if prod != t.numel() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"ClipTextEncoder reshape: target {shape:?} (= {prod} elements) does not \
match source numel {}",
t.numel()
),
});
}
let data = t.data_vec()?;
Tensor::from_storage(TensorStorage::cpu(data), shape, t.requires_grad())
}
fn float_index_tensor<T: Float>(ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
let data: Vec<T> = ids
.iter()
.map(|&i| numeric_cast::cast::<u32, T>(i))
.collect::<FerrotorchResult<Vec<T>>>()?;
let n = data.len();
Tensor::from_storage(TensorStorage::cpu(data), vec![n], false)
}
#[derive(Debug)]
pub struct ClipTextEmbeddings<T: Float> {
pub token_embedding: Embedding<T>,
pub position_embedding: Embedding<T>,
hidden_size: usize,
max_position_embeddings: usize,
training: bool,
}
impl<T: Float> ClipTextEmbeddings<T> {
pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
cfg.validate()?;
Ok(Self {
token_embedding: Embedding::new(cfg.vocab_size, cfg.hidden_size, None)?,
position_embedding: Embedding::new(cfg.max_position_embeddings, cfg.hidden_size, None)?,
hidden_size: cfg.hidden_size,
max_position_embeddings: cfg.max_position_embeddings,
training: false,
})
}
pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
if input_ids.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "ClipTextEmbeddings::forward_from_ids needs at least one token".into(),
});
}
let seq_len = input_ids.len();
if seq_len > self.max_position_embeddings {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ClipTextEmbeddings: sequence length {seq_len} exceeds \
max_position_embeddings {}",
self.max_position_embeddings,
),
});
}
let word_idx = float_index_tensor::<T>(input_ids)?;
let word_2d = self.token_embedding.forward(&word_idx)?;
let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
let pos_idx = float_index_tensor::<T>(&pos_ids)?;
let pos_2d = self.position_embedding.forward(&pos_idx)?;
let summed = add(&word_2d, &pos_2d)?;
reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
}
}
impl<T: Float> Module<T> for ClipTextEmbeddings<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let word_2d = self.token_embedding.forward(input)?;
let seq_len = input.numel();
let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
let pos_idx = float_index_tensor::<T>(&pos_ids)?;
let pos_2d = self.position_embedding.forward(&pos_idx)?;
let summed = add(&word_2d, &pos_2d)?;
reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut out = Vec::new();
out.extend(self.token_embedding.parameters());
out.extend(self.position_embedding.parameters());
out
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut out = Vec::new();
out.extend(self.token_embedding.parameters_mut());
out.extend(self.position_embedding.parameters_mut());
out
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut out = Vec::new();
for (n, p) in self.token_embedding.named_parameters() {
out.push((format!("token_embedding.{n}"), p));
}
for (n, p) in self.position_embedding.named_parameters() {
out.push((format!("position_embedding.{n}"), p));
}
out
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn state_dict(&self) -> StateDict<T> {
self.named_parameters()
.into_iter()
.map(|(n, p)| (n, p.tensor().clone()))
.collect()
}
fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
let extract = |prefix: &str| -> StateDict<T> {
let p = format!("{prefix}.");
state
.iter()
.filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
.collect()
};
if strict {
let prefixes = ["token_embedding", "position_embedding"];
for k in state.keys() {
if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
return Err(FerrotorchError::InvalidArgument {
message: format!("unexpected key in ClipTextEmbeddings state_dict: {k:?}"),
});
}
}
}
self.token_embedding
.load_state_dict(&extract("token_embedding"), strict)?;
self.position_embedding
.load_state_dict(&extract("position_embedding"), strict)?;
Ok(())
}
}
#[derive(Debug)]
pub struct ClipSelfAttention<T: Float> {
pub q_proj: Linear<T>,
pub k_proj: Linear<T>,
pub v_proj: Linear<T>,
pub out_proj: Linear<T>,
num_heads: usize,
head_dim: usize,
hidden: usize,
training: bool,
}
impl<T: Float> ClipSelfAttention<T> {
pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
cfg.validate()?;
Ok(Self {
q_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
k_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
v_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
out_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
num_heads: cfg.num_attention_heads,
head_dim: cfg.head_dim(),
hidden: cfg.hidden_size,
training: false,
})
}
}
impl<T: Float> Module<T> for ClipSelfAttention<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
if shape.len() != 3 || shape[0] != 1 || shape[2] != self.hidden {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"ClipSelfAttention expects [1, S, {}], got {:?}",
self.hidden, shape,
),
});
}
let seq_len = shape[1];
let q = self.q_proj.forward(input)?;
let k = self.k_proj.forward(input)?;
let v = self.v_proj.forward(input)?;
let q2 = reshape_owned(&q, vec![seq_len, self.hidden])?;
let k2 = reshape_owned(&k, vec![seq_len, self.hidden])?;
let v2 = reshape_owned(&v, vec![seq_len, self.hidden])?;
let q_h = reshape_to_heads(&q2, self.num_heads, seq_len, self.head_dim)?;
let k_h = reshape_to_heads(&k2, self.num_heads, seq_len, self.head_dim)?;
let v_h = reshape_to_heads(&v2, self.num_heads, seq_len, self.head_dim)?;
let ctx = standard_attention(&q_h, &k_h, &v_h, true)?;
let ctx2 = transpose_heads_to_2d(&ctx, self.num_heads, seq_len, self.head_dim)?;
let ctx3 = reshape_owned(&ctx2, vec![1, seq_len, self.hidden])?;
self.out_proj.forward(&ctx3)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut out = Vec::new();
out.extend(self.q_proj.parameters());
out.extend(self.k_proj.parameters());
out.extend(self.v_proj.parameters());
out.extend(self.out_proj.parameters());
out
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut out = Vec::new();
out.extend(self.q_proj.parameters_mut());
out.extend(self.k_proj.parameters_mut());
out.extend(self.v_proj.parameters_mut());
out.extend(self.out_proj.parameters_mut());
out
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut out = Vec::new();
for (n, p) in self.q_proj.named_parameters() {
out.push((format!("q_proj.{n}"), p));
}
for (n, p) in self.k_proj.named_parameters() {
out.push((format!("k_proj.{n}"), p));
}
for (n, p) in self.v_proj.named_parameters() {
out.push((format!("v_proj.{n}"), p));
}
for (n, p) in self.out_proj.named_parameters() {
out.push((format!("out_proj.{n}"), p));
}
out
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn state_dict(&self) -> StateDict<T> {
self.named_parameters()
.into_iter()
.map(|(n, p)| (n, p.tensor().clone()))
.collect()
}
fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
let extract = |prefix: &str| -> StateDict<T> {
let p = format!("{prefix}.");
state
.iter()
.filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
.collect()
};
if strict {
let prefixes = ["q_proj", "k_proj", "v_proj", "out_proj"];
for k in state.keys() {
if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
return Err(FerrotorchError::InvalidArgument {
message: format!("unexpected key in ClipSelfAttention state_dict: {k:?}"),
});
}
}
}
self.q_proj.load_state_dict(&extract("q_proj"), strict)?;
self.k_proj.load_state_dict(&extract("k_proj"), strict)?;
self.v_proj.load_state_dict(&extract("v_proj"), strict)?;
self.out_proj
.load_state_dict(&extract("out_proj"), strict)?;
Ok(())
}
}
#[derive(Debug)]
pub struct ClipMlp<T: Float> {
pub fc1: Linear<T>,
pub fc2: Linear<T>,
activation: GELU,
training: bool,
}
impl<T: Float> ClipMlp<T> {
pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
cfg.validate()?;
Ok(Self {
fc1: Linear::new(cfg.hidden_size, cfg.intermediate_size, true)?,
fc2: Linear::new(cfg.intermediate_size, cfg.hidden_size, true)?,
activation: GELU::with_approximate(GeluApproximate::Sigmoid),
training: false,
})
}
}
impl<T: Float> Module<T> for ClipMlp<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let h = self.fc1.forward(input)?;
let h = self.activation.forward(&h)?;
self.fc2.forward(&h)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut out = Vec::new();
out.extend(self.fc1.parameters());
out.extend(self.fc2.parameters());
out
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut out = Vec::new();
out.extend(self.fc1.parameters_mut());
out.extend(self.fc2.parameters_mut());
out
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut out = Vec::new();
for (n, p) in self.fc1.named_parameters() {
out.push((format!("fc1.{n}"), p));
}
for (n, p) in self.fc2.named_parameters() {
out.push((format!("fc2.{n}"), p));
}
out
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn state_dict(&self) -> StateDict<T> {
self.named_parameters()
.into_iter()
.map(|(n, p)| (n, p.tensor().clone()))
.collect()
}
fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
let extract = |prefix: &str| -> StateDict<T> {
let p = format!("{prefix}.");
state
.iter()
.filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
.collect()
};
if strict {
for k in state.keys() {
if !(k.starts_with("fc1.") || k.starts_with("fc2.")) {
return Err(FerrotorchError::InvalidArgument {
message: format!("unexpected key in ClipMlp state_dict: {k:?}"),
});
}
}
}
self.fc1.load_state_dict(&extract("fc1"), strict)?;
self.fc2.load_state_dict(&extract("fc2"), strict)?;
Ok(())
}
}
#[derive(Debug)]
pub struct ClipEncoderLayer<T: Float> {
pub layer_norm1: LayerNorm<T>,
pub self_attn: ClipSelfAttention<T>,
pub layer_norm2: LayerNorm<T>,
pub mlp: ClipMlp<T>,
training: bool,
}
impl<T: Float> ClipEncoderLayer<T> {
pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
Ok(Self {
layer_norm1: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
self_attn: ClipSelfAttention::new(cfg)?,
layer_norm2: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
mlp: ClipMlp::new(cfg)?,
training: false,
})
}
}
impl<T: Float> Module<T> for ClipEncoderLayer<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let normed = self.layer_norm1.forward(input)?;
let attn_out = self.self_attn.forward(&normed)?;
let after_attn = add(input, &attn_out)?;
let normed_ffn = self.layer_norm2.forward(&after_attn)?;
let mlp_out = self.mlp.forward(&normed_ffn)?;
add(&after_attn, &mlp_out)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut out = Vec::new();
out.extend(self.layer_norm1.parameters());
out.extend(self.self_attn.parameters());
out.extend(self.layer_norm2.parameters());
out.extend(self.mlp.parameters());
out
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut out = Vec::new();
out.extend(self.layer_norm1.parameters_mut());
out.extend(self.self_attn.parameters_mut());
out.extend(self.layer_norm2.parameters_mut());
out.extend(self.mlp.parameters_mut());
out
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut out = Vec::new();
for (n, p) in self.layer_norm1.named_parameters() {
out.push((format!("layer_norm1.{n}"), p));
}
for (n, p) in self.self_attn.named_parameters() {
out.push((format!("self_attn.{n}"), p));
}
for (n, p) in self.layer_norm2.named_parameters() {
out.push((format!("layer_norm2.{n}"), p));
}
for (n, p) in self.mlp.named_parameters() {
out.push((format!("mlp.{n}"), p));
}
out
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn state_dict(&self) -> StateDict<T> {
self.named_parameters()
.into_iter()
.map(|(n, p)| (n, p.tensor().clone()))
.collect()
}
fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
let extract = |prefix: &str| -> StateDict<T> {
let p = format!("{prefix}.");
state
.iter()
.filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
.collect()
};
if strict {
let prefixes = ["layer_norm1", "self_attn", "layer_norm2", "mlp"];
for k in state.keys() {
if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
return Err(FerrotorchError::InvalidArgument {
message: format!("unexpected key in ClipEncoderLayer state_dict: {k:?}"),
});
}
}
}
self.layer_norm1
.load_state_dict(&extract("layer_norm1"), strict)?;
self.self_attn
.load_state_dict(&extract("self_attn"), strict)?;
self.layer_norm2
.load_state_dict(&extract("layer_norm2"), strict)?;
self.mlp.load_state_dict(&extract("mlp"), strict)?;
Ok(())
}
}
#[derive(Debug)]
pub struct ClipEncoder<T: Float> {
pub layers: Vec<ClipEncoderLayer<T>>,
training: bool,
}
impl<T: Float> ClipEncoder<T> {
pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
cfg.validate()?;
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for _ in 0..cfg.num_hidden_layers {
layers.push(ClipEncoderLayer::new(cfg)?);
}
Ok(Self {
layers,
training: false,
})
}
}
impl<T: Float> Module<T> for ClipEncoder<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let mut h = input.clone();
for l in &self.layers {
h = l.forward(&h)?;
}
Ok(h)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut out = Vec::new();
for l in &self.layers {
out.extend(l.parameters());
}
out
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut out = Vec::new();
for l in &mut self.layers {
out.extend(l.parameters_mut());
}
out
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut out = Vec::new();
for (i, l) in self.layers.iter().enumerate() {
for (n, p) in l.named_parameters() {
out.push((format!("layers.{i}.{n}"), p));
}
}
out
}
fn train(&mut self) {
self.training = true;
for l in &mut self.layers {
l.train();
}
}
fn eval(&mut self) {
self.training = false;
for l in &mut self.layers {
l.eval();
}
}
fn is_training(&self) -> bool {
self.training
}
fn state_dict(&self) -> StateDict<T> {
self.named_parameters()
.into_iter()
.map(|(n, p)| (n, p.tensor().clone()))
.collect()
}
fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
let extract = |prefix: &str| -> StateDict<T> {
let p = format!("{prefix}.");
state
.iter()
.filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
.collect()
};
if strict {
for k in state.keys() {
if !k.starts_with("layers.") {
return Err(FerrotorchError::InvalidArgument {
message: format!("unexpected key in ClipEncoder state_dict: {k:?}"),
});
}
}
}
for (i, l) in self.layers.iter_mut().enumerate() {
l.load_state_dict(&extract(&format!("layers.{i}")), strict)?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct ClipTextEncoder<T: Float> {
pub embeddings: ClipTextEmbeddings<T>,
pub encoder: ClipEncoder<T>,
pub final_layer_norm: LayerNorm<T>,
pub config: ClipTextConfig,
training: bool,
}
impl<T: Float> ClipTextEncoder<T> {
pub fn new(cfg: ClipTextConfig) -> FerrotorchResult<Self> {
cfg.validate()?;
let embeddings = ClipTextEmbeddings::new(&cfg)?;
let encoder = ClipEncoder::new(&cfg)?;
let final_layer_norm = LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?;
Ok(Self {
embeddings,
encoder,
final_layer_norm,
config: cfg,
training: false,
})
}
pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
let h = self.embeddings.forward_from_ids(input_ids)?;
let h = self.encoder.forward(&h)?;
self.final_layer_norm.forward(&h)
}
pub fn forward_from_id_tensor(&self, ids: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if ids.ndim() != 1 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"ClipTextEncoder::forward_from_id_tensor expects 1-D ids, got {:?}",
ids.shape()
),
});
}
let data = ids.data_vec()?;
let mut u32_ids: Vec<u32> = Vec::with_capacity(data.len());
for (i, v) in data.iter().enumerate() {
let f = num_traits::ToPrimitive::to_f64(v).ok_or_else(|| {
FerrotorchError::InvalidArgument {
message: format!(
"ClipTextEncoder::forward_from_id_tensor: id at {i} \
not representable as f64"
),
}
})?;
if !f.is_finite() || f < 0.0 || f > u32::MAX as f64 || f.fract() != 0.0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ClipTextEncoder::forward_from_id_tensor: id at {i} ({f}) \
is not a non-negative integer"
),
});
}
u32_ids.push(f as u32);
}
self.forward_from_ids(&u32_ids)
}
pub fn load_hf_state_dict(
&mut self,
hf_state: &StateDict<T>,
strict: bool,
) -> FerrotorchResult<crate::safetensors_loader::DropReport> {
let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
let mut dropped: Vec<String> = Vec::new();
for (k, v) in hf_state {
let after = k
.strip_prefix("text_model.")
.map_or_else(|| k.clone(), str::to_owned);
if after == "embeddings.position_ids" {
dropped.push(k.clone());
continue;
}
let is_known = after.starts_with("embeddings.token_embedding.")
|| after.starts_with("embeddings.position_embedding.")
|| after.starts_with("encoder.")
|| after.starts_with("final_layer_norm.");
if is_known {
remapped.insert(after, v.clone());
continue;
}
if strict {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ClipTextEncoder::load_hf_state_dict: key {k:?} is not a \
known CLIP text-tower parameter and strict mode is on. \
Pass strict=false to drop unknown keys."
),
});
}
dropped.push(k.clone());
}
dropped.sort();
self.load_state_dict(&remapped, strict)?;
Ok(crate::safetensors_loader::DropReport { dropped })
}
}
impl<T: Float> Module<T> for ClipTextEncoder<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let h = self.encoder.forward(input)?;
self.final_layer_norm.forward(&h)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut out = Vec::new();
out.extend(self.embeddings.parameters());
out.extend(self.encoder.parameters());
out.extend(self.final_layer_norm.parameters());
out
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut out = Vec::new();
out.extend(self.embeddings.parameters_mut());
out.extend(self.encoder.parameters_mut());
out.extend(self.final_layer_norm.parameters_mut());
out
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut out = Vec::new();
for (n, p) in self.embeddings.named_parameters() {
out.push((format!("embeddings.{n}"), p));
}
for (n, p) in self.encoder.named_parameters() {
out.push((format!("encoder.{n}"), p));
}
for (n, p) in self.final_layer_norm.named_parameters() {
out.push((format!("final_layer_norm.{n}"), p));
}
out
}
fn train(&mut self) {
self.training = true;
self.embeddings.train();
self.encoder.train();
self.final_layer_norm.train();
}
fn eval(&mut self) {
self.training = false;
self.embeddings.eval();
self.encoder.eval();
self.final_layer_norm.eval();
}
fn is_training(&self) -> bool {
self.training
}
fn state_dict(&self) -> StateDict<T> {
self.named_parameters()
.into_iter()
.map(|(n, p)| (n, p.tensor().clone()))
.collect()
}
fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
let extract = |prefix: &str| -> StateDict<T> {
let p = format!("{prefix}.");
state
.iter()
.filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
.collect()
};
if strict {
for k in state.keys() {
if !(k.starts_with("embeddings.")
|| k.starts_with("encoder.")
|| k.starts_with("final_layer_norm."))
{
return Err(FerrotorchError::InvalidArgument {
message: format!("unexpected key in ClipTextEncoder state_dict: {k:?}"),
});
}
}
}
self.embeddings
.load_state_dict(&extract("embeddings"), strict)?;
self.encoder.load_state_dict(&extract("encoder"), strict)?;
self.final_layer_norm
.load_state_dict(&extract("final_layer_norm"), strict)?;
Ok(())
}
}
#[allow(dead_code)]
fn _unused_mul_ref<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
mul(a, b)
}
#[cfg(test)]
mod tests {
use super::*;
fn tiny_cfg() -> ClipTextConfig {
ClipTextConfig {
hidden_size: 8,
intermediate_size: 16,
num_attention_heads: 2,
num_hidden_layers: 1,
max_position_embeddings: 6,
vocab_size: 32,
layer_norm_eps: 1e-5,
}
}
#[test]
fn sd_v1_5_config_is_canonical() {
let c = ClipTextConfig::sd_v1_5();
assert_eq!(c.hidden_size, 768);
assert_eq!(c.intermediate_size, 3072);
assert_eq!(c.num_attention_heads, 12);
assert_eq!(c.num_hidden_layers, 12);
assert_eq!(c.max_position_embeddings, 77);
assert_eq!(c.vocab_size, 49408);
assert_eq!(c.head_dim(), 64);
c.validate().unwrap();
}
#[test]
fn validate_catches_bad_head_count() {
let mut c = tiny_cfg();
c.num_attention_heads = 3; assert!(c.validate().is_err());
}
#[test]
fn from_json_str_round_trip() {
let json = r#"{
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"max_position_embeddings": 77,
"vocab_size": 49408,
"layer_norm_eps": 1e-5,
"hidden_act": "quick_gelu"
}"#;
let c = ClipTextConfig::from_json_str(json).unwrap();
assert_eq!(c.hidden_size, 768);
assert_eq!(c.intermediate_size, 3072);
assert_eq!(c.num_attention_heads, 12);
assert_eq!(c.num_hidden_layers, 12);
assert_eq!(c.max_position_embeddings, 77);
}
#[test]
fn embeddings_forward_shape() {
let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
let ids = [1u32, 5, 7, 9];
let out = emb.forward_from_ids(&ids).unwrap();
assert_eq!(out.shape(), &[1, 4, 8]);
for &v in out.data().unwrap() {
assert!(v.is_finite(), "embedding non-finite: {v}");
}
}
#[test]
fn embeddings_reject_too_long_sequence() {
let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
let ids: Vec<u32> = (0..7).collect(); assert!(emb.forward_from_ids(&ids).is_err());
}
#[test]
fn self_attention_forward_shape() {
let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
let x = Tensor::from_storage(
TensorStorage::cpu(vec![0.1f32; 5 * 8]),
vec![1, 5, 8],
false,
)
.unwrap();
let out = attn.forward(&x).unwrap();
assert_eq!(out.shape(), &[1, 5, 8]);
for &v in out.data().unwrap() {
assert!(v.is_finite());
}
}
#[test]
fn self_attention_is_actually_causal() {
let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
let mut a = vec![0.1f32; 4 * 8];
for i in 0..2 * 8 {
a[i] = ((i + 1) as f32).sin();
}
let mut b = a.clone();
for i in (2 * 8)..(4 * 8) {
b[i] = ((i + 11) as f32).sin();
}
let xa = Tensor::from_storage(TensorStorage::cpu(a), vec![1, 4, 8], false).unwrap();
let xb = Tensor::from_storage(TensorStorage::cpu(b), vec![1, 4, 8], false).unwrap();
let oa = attn.forward(&xa).unwrap();
let ob = attn.forward(&xb).unwrap();
let da = oa.data().unwrap();
let db = ob.data().unwrap();
for i in 0..2 * 8 {
assert!(
(da[i] - db[i]).abs() < 1e-5,
"row {} ({}) differs between runs: {} vs {}",
i / 8,
i % 8,
da[i],
db[i]
);
}
}
#[test]
fn mlp_uses_quick_gelu() {
let mlp = ClipMlp::<f32>::new(&tiny_cfg()).unwrap();
let x = Tensor::from_storage(
TensorStorage::cpu(vec![0.0f32; 3 * 8]),
vec![1, 3, 8],
false,
)
.unwrap();
let out = mlp.forward(&x).unwrap();
assert_eq!(out.shape(), &[1, 3, 8]);
for &v in out.data().unwrap() {
assert!(v.is_finite());
}
}
#[test]
fn encoder_layer_forward_shape() {
let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
let x = Tensor::from_storage(
TensorStorage::cpu(vec![0.1f32; 5 * 8]),
vec![1, 5, 8],
false,
)
.unwrap();
let out = layer.forward(&x).unwrap();
assert_eq!(out.shape(), &[1, 5, 8]);
for &v in out.data().unwrap() {
assert!(v.is_finite());
}
}
#[test]
fn encoder_layer_named_parameters_use_hf_layout() {
let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
let names: Vec<String> = layer
.named_parameters()
.into_iter()
.map(|(n, _)| n)
.collect();
for k in [
"layer_norm1.weight",
"layer_norm1.bias",
"self_attn.q_proj.weight",
"self_attn.q_proj.bias",
"self_attn.k_proj.weight",
"self_attn.v_proj.weight",
"self_attn.out_proj.weight",
"self_attn.out_proj.bias",
"layer_norm2.weight",
"mlp.fc1.weight",
"mlp.fc1.bias",
"mlp.fc2.weight",
"mlp.fc2.bias",
] {
assert!(
names.iter().any(|n| n == k),
"missing parameter key {k:?} in {names:?}"
);
}
}
#[test]
fn tiny_encoder_forward_from_ids_shape() {
let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
let ids = vec![1u32, 5, 7];
let out = enc.forward_from_ids(&ids).unwrap();
assert_eq!(out.shape(), &[1, 3, 8]);
for &v in out.data().unwrap() {
assert!(v.is_finite());
}
}
#[test]
fn tiny_named_parameters_use_hf_layout() {
let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
let names: Vec<String> = enc.named_parameters().into_iter().map(|(n, _)| n).collect();
for k in [
"embeddings.token_embedding.weight",
"embeddings.position_embedding.weight",
"encoder.layers.0.layer_norm1.weight",
"encoder.layers.0.self_attn.q_proj.weight",
"encoder.layers.0.self_attn.out_proj.bias",
"encoder.layers.0.layer_norm2.bias",
"encoder.layers.0.mlp.fc1.weight",
"encoder.layers.0.mlp.fc2.bias",
"final_layer_norm.weight",
"final_layer_norm.bias",
] {
assert!(
names.iter().any(|n| n == k),
"missing parameter key {k:?} in {names:?}"
);
}
}
#[test]
fn round_trip_state_dict() {
let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
let sd = src.state_dict();
let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
dst.load_state_dict(&sd, true).unwrap();
let ids = vec![2u32, 4, 6];
let a = src.forward_from_ids(&ids).unwrap();
let b = dst.forward_from_ids(&ids).unwrap();
for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
}
}
#[test]
fn load_hf_state_dict_strips_text_model_prefix() {
let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
let bare = src.state_dict();
let mut prefixed: StateDict<f32> = HashMap::new();
for (k, v) in bare {
prefixed.insert(format!("text_model.{k}"), v);
}
prefixed.insert(
"text_model.embeddings.position_ids".into(),
ferrotorch_core::zeros::<f32>(&[1, 6]).unwrap(),
);
let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
assert_eq!(
rep.dropped,
vec!["text_model.embeddings.position_ids".to_string()]
);
let ids = vec![1u32, 2, 3];
let a = src.forward_from_ids(&ids).unwrap();
let b = dst.forward_from_ids(&ids).unwrap();
for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
assert!((x - y).abs() < 1e-5);
}
}
#[test]
fn load_hf_state_dict_strict_rejects_unknown_key() {
let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
let mut sd: StateDict<f32> = HashMap::new();
sd.insert(
"mystery.key".into(),
ferrotorch_core::zeros::<f32>(&[1]).unwrap(),
);
assert!(dst.load_hf_state_dict(&sd, true).is_err());
}
#[test]
fn forward_from_id_tensor_matches_forward_from_ids() {
let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
let ids = vec![1u32, 5, 7];
let id_tensor = float_index_tensor::<f32>(&ids).unwrap();
let a = enc.forward_from_ids(&ids).unwrap();
let b = enc.forward_from_id_tensor(&id_tensor).unwrap();
for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
assert!((x - y).abs() < 1e-5);
}
}
}