#![cfg(feature = "cuda")]
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Tensor, TensorStorage};
use ferrotorch_gpu::{
CudaBuffer, GpuDevice, GpuError, gpu_bmm_f32, gpu_layernorm, gpu_matmul_f32, gpu_softmax,
kernels::{gpu_add, gpu_broadcast_add, gpu_embed_lookup_batch, gpu_gelu, gpu_scale},
transfer::{cpu_to_gpu, gpu_to_cpu},
};
use ferrotorch_nn::module::{Module, StateDict};
use crate::clip_text_encoder::{ClipTextConfig, ClipTextEncoder};
use crate::safetensors_loader::DropReport;
#[derive(Debug)]
struct GpuLayerNorm {
weight: CudaBuffer<f32>,
bias: CudaBuffer<f32>,
eps: f32,
normalized_shape: usize,
}
#[derive(Debug)]
struct GpuLinearT {
weight_t: CudaBuffer<f32>,
bias: CudaBuffer<f32>,
in_features: usize,
out_features: usize,
}
#[derive(Debug)]
#[allow(clippy::struct_field_names)]
struct GpuClipAttn {
q_proj: GpuLinearT,
k_proj: GpuLinearT,
v_proj: GpuLinearT,
out_proj: GpuLinearT,
}
#[derive(Debug)]
struct GpuClipMlp {
fc1: GpuLinearT,
fc2: GpuLinearT,
}
#[derive(Debug)]
struct GpuClipLayer {
layer_norm1: GpuLayerNorm,
self_attn: GpuClipAttn,
layer_norm2: GpuLayerNorm,
mlp: GpuClipMlp,
}
#[derive(Debug)]
pub struct GpuClipTextEncoder {
token_embedding: CudaBuffer<f32>,
position_embedding: CudaBuffer<f32>,
layers: Vec<GpuClipLayer>,
final_layer_norm: GpuLayerNorm,
causal_mask_full: CudaBuffer<f32>,
config: ClipTextConfig,
device: GpuDevice,
}
impl GpuClipTextEncoder {
pub fn new(
config: ClipTextConfig,
mut state: StateDict<f32>,
device: GpuDevice,
) -> FerrotorchResult<(Self, DropReport)> {
config.validate()?;
let hidden = config.hidden_size;
let inter = config.intermediate_size;
let vocab = config.vocab_size;
let max_pos = config.max_position_embeddings;
let eps = config.layer_norm_eps as f32;
let token_embedding = pop_tensor(
&mut state,
"embeddings.token_embedding.weight",
vocab * hidden,
&device,
)?;
let position_embedding = pop_tensor(
&mut state,
"embeddings.position_embedding.weight",
max_pos * hidden,
&device,
)?;
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for li in 0..config.num_hidden_layers {
let prefix = format!("encoder.layers.{li}");
let layer_norm1 = pop_layernorm(
&mut state,
&format!("{prefix}.layer_norm1"),
hidden,
eps,
&device,
)?;
let q_proj = pop_linear_t(
&mut state,
&format!("{prefix}.self_attn.q_proj"),
hidden,
hidden,
&device,
)?;
let k_proj = pop_linear_t(
&mut state,
&format!("{prefix}.self_attn.k_proj"),
hidden,
hidden,
&device,
)?;
let v_proj = pop_linear_t(
&mut state,
&format!("{prefix}.self_attn.v_proj"),
hidden,
hidden,
&device,
)?;
let out_proj = pop_linear_t(
&mut state,
&format!("{prefix}.self_attn.out_proj"),
hidden,
hidden,
&device,
)?;
let layer_norm2 = pop_layernorm(
&mut state,
&format!("{prefix}.layer_norm2"),
hidden,
eps,
&device,
)?;
let fc1 = pop_linear_t(
&mut state,
&format!("{prefix}.mlp.fc1"),
hidden,
inter,
&device,
)?;
let fc2 = pop_linear_t(
&mut state,
&format!("{prefix}.mlp.fc2"),
inter,
hidden,
&device,
)?;
layers.push(GpuClipLayer {
layer_norm1,
self_attn: GpuClipAttn {
q_proj,
k_proj,
v_proj,
out_proj,
},
layer_norm2,
mlp: GpuClipMlp { fc1, fc2 },
});
}
let final_layer_norm = pop_layernorm(&mut state, "final_layer_norm", hidden, eps, &device)?;
let mut mask = vec![0.0_f32; max_pos * max_pos];
for i in 0..max_pos {
for j in 0..max_pos {
if j > i {
mask[i * max_pos + j] = f32::NEG_INFINITY;
}
}
}
let causal_mask_full = cpu_to_gpu(&mask, &device).map_err(gpu_err)?;
let mut dropped: Vec<String> = state.keys().cloned().collect();
dropped.sort();
let report = DropReport { dropped };
Ok((
Self {
token_embedding,
position_embedding,
layers,
final_layer_norm,
causal_mask_full,
config,
device,
},
report,
))
}
pub fn from_module(
cpu: &ClipTextEncoder<f32>,
device: &GpuDevice,
) -> FerrotorchResult<(Self, DropReport)> {
let state: StateDict<f32> = cpu.state_dict();
Self::new(cpu.config.clone(), state, device.clone())
}
pub fn encode(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<f32>> {
let cfg = &self.config;
let s = input_ids.len();
if s == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "GpuClipTextEncoder::encode: input_ids is empty".into(),
});
}
if s > cfg.max_position_embeddings {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"GpuClipTextEncoder::encode: seq_len {s} exceeds \
max_position_embeddings {}",
cfg.max_position_embeddings
),
});
}
for &id in input_ids {
if (id as usize) >= cfg.vocab_size {
return Err(FerrotorchError::IndexOutOfBounds {
index: id as usize,
axis: 0,
size: cfg.vocab_size,
});
}
}
let hidden = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let head_dim = cfg.head_dim();
let max_pos = cfg.max_position_embeddings;
let token_ids_f32: Vec<f32> = input_ids.iter().map(|&i| i as f32).collect();
let token_ids_gpu = cpu_to_gpu(&token_ids_f32, &self.device).map_err(gpu_err)?;
let tok_emb = gpu_embed_lookup_batch(
&token_ids_gpu,
&self.token_embedding,
s,
hidden,
&self.device,
)
.map_err(gpu_err)?;
let pos_ids_f32: Vec<f32> = (0..s as u32).map(|i| i as f32).collect();
let pos_ids_gpu = cpu_to_gpu(&pos_ids_f32, &self.device).map_err(gpu_err)?;
let pos_emb = gpu_embed_lookup_batch(
&pos_ids_gpu,
&self.position_embedding,
s,
hidden,
&self.device,
)
.map_err(gpu_err)?;
let mut h = gpu_add(&tok_emb, &pos_emb, &self.device).map_err(gpu_err)?;
let causal_mask = if s == max_pos {
None
} else {
let full = gpu_to_cpu(&self.causal_mask_full, &self.device).map_err(gpu_err)?;
let mut sliced = vec![0.0_f32; s * s];
for i in 0..s {
for j in 0..s {
sliced[i * s + j] = full[i * max_pos + j];
}
}
Some(cpu_to_gpu(&sliced, &self.device).map_err(gpu_err)?)
};
for layer in &self.layers {
let normed1 = layernorm_forward(&layer.layer_norm1, &h, s, hidden, &self.device)?;
let q = linear_forward(&layer.self_attn.q_proj, &normed1, s, &self.device)?;
let k = linear_forward(&layer.self_attn.k_proj, &normed1, s, &self.device)?;
let v = linear_forward(&layer.self_attn.v_proj, &normed1, s, &self.device)?;
let q_heads = reshape_seq_to_heads(&q, s, num_heads, head_dim, &self.device)?;
let k_heads = reshape_seq_to_heads(&k, s, num_heads, head_dim, &self.device)?;
let v_heads = reshape_seq_to_heads(&v, s, num_heads, head_dim, &self.device)?;
let k_heads_t = transpose_last_two(&k_heads, num_heads, s, head_dim, &self.device)?;
let scores = gpu_bmm_f32(
&q_heads,
&k_heads_t,
num_heads,
s,
head_dim,
s,
&self.device,
)
.map_err(gpu_err)?;
let scale = (head_dim as f64).sqrt().recip() as f32;
let scaled = gpu_scale(&scores, scale, &self.device).map_err(gpu_err)?;
let mask_ref = causal_mask.as_ref().unwrap_or(&self.causal_mask_full);
let masked = gpu_broadcast_add(
&scaled,
mask_ref,
&[num_heads, s, s],
&[1, s, s],
&[num_heads, s, s],
&self.device,
)
.map_err(gpu_err)?;
let probs = gpu_softmax(&masked, num_heads * s, s, &self.device).map_err(gpu_err)?;
let attended = gpu_bmm_f32(&probs, &v_heads, num_heads, s, s, head_dim, &self.device)
.map_err(gpu_err)?;
let merged = reshape_heads_to_seq(&attended, num_heads, s, head_dim, &self.device)?;
let attn_out = linear_forward(&layer.self_attn.out_proj, &merged, s, &self.device)?;
h = gpu_add(&h, &attn_out, &self.device).map_err(gpu_err)?;
let normed2 = layernorm_forward(&layer.layer_norm2, &h, s, hidden, &self.device)?;
let mlp_h = linear_forward(&layer.mlp.fc1, &normed2, s, &self.device)?;
let mlp_act = gpu_gelu(&mlp_h, &self.device).map_err(gpu_err)?;
let mlp_out = linear_forward(&layer.mlp.fc2, &mlp_act, s, &self.device)?;
h = gpu_add(&h, &mlp_out, &self.device).map_err(gpu_err)?;
}
let normed = layernorm_forward(&self.final_layer_norm, &h, s, hidden, &self.device)?;
let out_data = gpu_to_cpu(&normed, &self.device).map_err(gpu_err)?;
Tensor::from_storage(TensorStorage::cpu(out_data), vec![1, s, hidden], false)
}
}
fn gpu_err(e: GpuError) -> FerrotorchError {
FerrotorchError::InvalidArgument {
message: format!("GpuClipTextEncoder GPU op failed: {e}"),
}
}
fn pop_tensor(
state: &mut StateDict<f32>,
key: &str,
expected_len: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let t = state
.remove(key)
.ok_or_else(|| FerrotorchError::InvalidArgument {
message: format!("GpuClipTextEncoder: missing tensor {key:?}"),
})?;
let data = t.data()?;
if data.len() != expected_len {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuClipTextEncoder: tensor {key:?} length {} != expected {expected_len}",
data.len()
),
});
}
cpu_to_gpu(data, device).map_err(gpu_err)
}
fn pop_layernorm(
state: &mut StateDict<f32>,
prefix: &str,
normalized_shape: usize,
eps: f32,
device: &GpuDevice,
) -> FerrotorchResult<GpuLayerNorm> {
let weight = pop_tensor(state, &format!("{prefix}.weight"), normalized_shape, device)?;
let bias = pop_tensor(state, &format!("{prefix}.bias"), normalized_shape, device)?;
Ok(GpuLayerNorm {
weight,
bias,
eps,
normalized_shape,
})
}
fn pop_linear_t(
state: &mut StateDict<f32>,
prefix: &str,
in_f: usize,
out_f: usize,
device: &GpuDevice,
) -> FerrotorchResult<GpuLinearT> {
let w_key = format!("{prefix}.weight");
let b_key = format!("{prefix}.bias");
let w_t = state
.remove(&w_key)
.ok_or_else(|| FerrotorchError::InvalidArgument {
message: format!("GpuClipTextEncoder: missing tensor {w_key:?}"),
})?;
let w_data = w_t.data()?;
if w_data.len() != out_f * in_f {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuClipTextEncoder: tensor {w_key:?} length {} != expected {}",
w_data.len(),
out_f * in_f
),
});
}
let mut transposed = vec![0.0_f32; in_f * out_f];
for o in 0..out_f {
for i in 0..in_f {
transposed[i * out_f + o] = w_data[o * in_f + i];
}
}
let weight_t = cpu_to_gpu(&transposed, device).map_err(gpu_err)?;
let bias = pop_tensor(state, &b_key, out_f, device)?;
Ok(GpuLinearT {
weight_t,
bias,
in_features: in_f,
out_features: out_f,
})
}
fn layernorm_forward(
ln: &GpuLayerNorm,
x: &CudaBuffer<f32>,
s: usize,
hidden: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
if hidden != ln.normalized_shape {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuClipTextEncoder::layernorm: expected hidden={}, got {}",
ln.normalized_shape, hidden
),
});
}
gpu_layernorm(x, &ln.weight, &ln.bias, s, hidden, ln.eps, device).map_err(gpu_err)
}
fn linear_forward(
lin: &GpuLinearT,
x: &CudaBuffer<f32>,
s: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let y = gpu_matmul_f32(
x,
&lin.weight_t,
s,
lin.in_features,
lin.out_features,
device,
)
.map_err(gpu_err)?;
gpu_broadcast_add(
&y,
&lin.bias,
&[s, lin.out_features],
&[1, lin.out_features],
&[s, lin.out_features],
device,
)
.map_err(gpu_err)
}
fn reshape_seq_to_heads(
x: &CudaBuffer<f32>,
s: usize,
num_heads: usize,
head_dim: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let host = gpu_to_cpu(x, device).map_err(gpu_err)?;
let mut out = vec![0.0_f32; num_heads * s * head_dim];
for i in 0..s {
for h in 0..num_heads {
for d in 0..head_dim {
let src = i * (num_heads * head_dim) + h * head_dim + d;
let dst = h * s * head_dim + i * head_dim + d;
out[dst] = host[src];
}
}
}
cpu_to_gpu(&out, device).map_err(gpu_err)
}
fn reshape_heads_to_seq(
x: &CudaBuffer<f32>,
num_heads: usize,
s: usize,
head_dim: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let host = gpu_to_cpu(x, device).map_err(gpu_err)?;
let mut out = vec![0.0_f32; s * num_heads * head_dim];
for h in 0..num_heads {
for i in 0..s {
for d in 0..head_dim {
let src = h * s * head_dim + i * head_dim + d;
let dst = i * (num_heads * head_dim) + h * head_dim + d;
out[dst] = host[src];
}
}
}
cpu_to_gpu(&out, device).map_err(gpu_err)
}
fn transpose_last_two(
x: &CudaBuffer<f32>,
batch: usize,
m: usize,
n: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let host = gpu_to_cpu(x, device).map_err(gpu_err)?;
let mut out = vec![0.0_f32; batch * n * m];
for bi in 0..batch {
for mi in 0..m {
for ni in 0..n {
let src = bi * m * n + mi * n + ni;
let dst = bi * n * m + ni * m + mi;
out[dst] = host[src];
}
}
}
cpu_to_gpu(&out, device).map_err(gpu_err)
}
#[cfg(all(test, feature = "cuda"))]
mod tests {
use super::*;
use crate::clip_text_encoder::{ClipTextConfig, ClipTextEncoder};
fn tiny_cfg() -> ClipTextConfig {
ClipTextConfig {
hidden_size: 8,
intermediate_size: 16,
num_attention_heads: 2,
num_hidden_layers: 1,
max_position_embeddings: 6,
vocab_size: 32,
layer_norm_eps: 1e-5,
}
}
#[test]
fn gpu_clip_matches_cpu_tiny() {
let Ok(device) = GpuDevice::new(0) else {
return;
};
let cfg = tiny_cfg();
let cpu = ClipTextEncoder::<f32>::new(cfg.clone()).unwrap();
let (gpu, report) = GpuClipTextEncoder::from_module(&cpu, &device).unwrap();
assert!(
report.dropped.is_empty(),
"unexpected dropped keys: {:?}",
report.dropped
);
let ids = vec![1u32, 5, 7, 11, 17, 23];
let cpu_out = cpu.forward_from_ids(&ids).unwrap();
let gpu_out = gpu.encode(&ids).unwrap();
assert_eq!(cpu_out.shape(), gpu_out.shape());
let cpu_data = cpu_out.data().unwrap();
let gpu_data = gpu_out.data().unwrap();
let mut max_abs = 0.0_f32;
for (a, b) in cpu_data.iter().zip(gpu_data.iter()) {
let d = (a - b).abs();
if d > max_abs {
max_abs = d;
}
}
assert!(max_abs < 1e-3, "gpu vs cpu tiny CLIP max_abs = {max_abs}");
}
#[test]
fn gpu_clip_short_seq_is_causal() {
let Ok(device) = GpuDevice::new(0) else {
return;
};
let cfg = tiny_cfg();
let cpu = ClipTextEncoder::<f32>::new(cfg.clone()).unwrap();
let (gpu, _) = GpuClipTextEncoder::from_module(&cpu, &device).unwrap();
let ids_a = vec![1u32, 5, 7];
let mut ids_b = ids_a.clone();
ids_b[2] = 9u32;
let oa = gpu.encode(&ids_a).unwrap();
let ob = gpu.encode(&ids_b).unwrap();
let da = oa.data().unwrap();
let db = ob.data().unwrap();
for d in 0..cfg.hidden_size {
assert!(
(da[d] - db[d]).abs() < 1e-5,
"row 0 col {d} differs: {} vs {}",
da[d],
db[d]
);
}
}
}