use crate::models::with_tracing::{linear_b, Linear, RmsNorm};
use candle::{DType, Device, Module, Result, Tensor};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct TextEncoderConfig {
#[serde(default = "default_vocab_size")]
pub vocab_size: usize,
#[serde(default = "default_hidden_size")]
pub hidden_size: usize,
#[serde(default = "default_intermediate_size")]
pub intermediate_size: usize,
#[serde(default = "default_num_hidden_layers")]
pub num_hidden_layers: usize,
#[serde(default = "default_num_attention_heads")]
pub num_attention_heads: usize,
#[serde(default = "default_num_key_value_heads")]
pub num_key_value_heads: usize,
#[serde(default = "default_head_dim")]
pub head_dim: usize,
#[serde(default = "default_rms_norm_eps")]
pub rms_norm_eps: f64,
#[serde(default = "default_rope_theta")]
pub rope_theta: f64,
#[serde(default = "default_attention_bias")]
pub attention_bias: bool,
#[serde(default = "default_hidden_act")]
pub hidden_act: Activation,
#[serde(default = "default_max_position_embeddings")]
pub max_position_embeddings: usize,
}
fn default_vocab_size() -> usize {
151936
}
fn default_hidden_size() -> usize {
2560
}
fn default_intermediate_size() -> usize {
9728
}
fn default_num_hidden_layers() -> usize {
36
}
fn default_num_attention_heads() -> usize {
32
}
fn default_num_key_value_heads() -> usize {
8
}
fn default_head_dim() -> usize {
128
}
fn default_rms_norm_eps() -> f64 {
1e-6
}
fn default_rope_theta() -> f64 {
1_000_000.0
}
fn default_attention_bias() -> bool {
false
}
fn default_hidden_act() -> Activation {
Activation::Silu
}
fn default_max_position_embeddings() -> usize {
40960
}
impl Default for TextEncoderConfig {
fn default() -> Self {
Self::z_image()
}
}
impl TextEncoderConfig {
pub fn z_image() -> Self {
Self {
vocab_size: 151936,
hidden_size: 2560,
intermediate_size: 9728,
num_hidden_layers: 36,
num_attention_heads: 32,
num_key_value_heads: 8,
head_dim: 128,
rms_norm_eps: 1e-6,
rope_theta: 1_000_000.0,
attention_bias: false,
hidden_act: Activation::Silu,
max_position_embeddings: 40960,
}
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &TextEncoderConfig, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?.to_dtype(dtype)?,
cos: freqs.cos()?.to_dtype(dtype)?,
})
}
fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> {
let (_, _, seq_len, _) = q.dims4()?;
let cos = self.cos.narrow(0, offset, seq_len)?;
let sin = self.sin.narrow(0, offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
struct Mlp {
gate_proj: candle_nn::Linear,
up_proj: candle_nn::Linear,
down_proj: candle_nn::Linear,
act_fn: Activation,
}
impl Mlp {
fn new(cfg: &TextEncoderConfig, vb: VarBuilder) -> Result<Self> {
Ok(Self {
gate_proj: candle_nn::linear_no_bias(
cfg.hidden_size,
cfg.intermediate_size,
vb.pp("gate_proj"),
)?,
up_proj: candle_nn::linear_no_bias(
cfg.hidden_size,
cfg.intermediate_size,
vb.pp("up_proj"),
)?,
down_proj: candle_nn::linear_no_bias(
cfg.intermediate_size,
cfg.hidden_size,
vb.pp("down_proj"),
)?,
act_fn: cfg.hidden_act,
})
}
}
impl Module for Mlp {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = x.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
x.unsqueeze(2)?
.broadcast_as((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
q_norm: RmsNorm,
k_norm: RmsNorm,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
}
impl Attention {
fn new(
cfg: &TextEncoderConfig,
rotary_emb: Arc<RotaryEmbedding>,
vb: VarBuilder,
) -> Result<Self> {
let head_dim = cfg.head_dim;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let q_proj = linear_b(
cfg.hidden_size,
num_heads * head_dim,
cfg.attention_bias,
vb.pp("q_proj"),
)?;
let k_proj = linear_b(
cfg.hidden_size,
num_kv_heads * head_dim,
cfg.attention_bias,
vb.pp("k_proj"),
)?;
let v_proj = linear_b(
cfg.hidden_size,
num_kv_heads * head_dim,
cfg.attention_bias,
vb.pp("v_proj"),
)?;
let o_proj = linear_b(
num_heads * head_dim,
cfg.hidden_size,
cfg.attention_bias,
vb.pp("o_proj"),
)?;
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
let hidden_size = head_dim * cfg.num_attention_heads;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size,
rotary_emb,
})
}
fn forward(&self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
let (b, l, _) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let q = q
.reshape((b, l, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let q_flat = q.flatten(0, 2)?;
let k_flat = k.flatten(0, 2)?;
let q_flat = self.q_norm.forward(&q_flat)?;
let k_flat = self.k_norm.forward(&k_flat)?;
let q = q_flat.reshape((b, self.num_heads, l, self.head_dim))?;
let k = k_flat.reshape((b, self.num_kv_heads, l, self.head_dim))?;
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
let scale = 1.0 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(m) = attn_mask {
scores = scores.broadcast_add(m)?;
}
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
let ctx = probs.matmul(&v)?;
ctx.transpose(1, 2)?
.reshape((b, l, self.hidden_size))?
.apply(&self.o_proj)
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: Mlp,
ln1: RmsNorm,
ln2: RmsNorm,
}
impl DecoderLayer {
fn new(cfg: &TextEncoderConfig, rotary: Arc<RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(cfg, rotary, vb.pp("self_attn"))?;
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let ln2 = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
ln1,
ln2,
})
}
fn forward(&self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
let h = self.ln1.forward(x)?;
let h = self.self_attn.forward(&h, mask, offset)?;
let x = (x + h)?;
let h2 = self.ln2.forward(&x)?;
let h2 = h2.apply(&self.mlp)?;
x + h2
}
}
#[derive(Debug, Clone)]
pub struct ZImageTextEncoder {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
num_hidden_layers: usize,
device: Device,
dtype: DType,
}
impl ZImageTextEncoder {
pub fn new(cfg: &TextEncoderConfig, vb: VarBuilder) -> Result<Self> {
let vb_model = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_model.pp("embed_tokens"))?;
let rotary = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_layers = vb_model.pp("layers");
for i in 0..cfg.num_hidden_layers {
layers.push(DecoderLayer::new(cfg, rotary.clone(), vb_layers.pp(i))?);
}
Ok(Self {
embed_tokens,
layers,
num_hidden_layers: cfg.num_hidden_layers,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| {
(0..(tgt + offset)).map(move |j| if j <= i + offset { 0.0 } else { minf })
})
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let (b, l) = input_ids.dims2()?;
let mut hidden_states = self.embed_tokens.forward(input_ids)?;
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, 0)?)
};
let target_layer = self.num_hidden_layers - 2;
for (i, layer) in self.layers.iter().enumerate() {
hidden_states = layer.forward(&hidden_states, causal.as_ref(), 0)?;
if i == target_layer {
return Ok(hidden_states);
}
}
candle::bail!("Layer index out of bounds")
}
pub fn hidden_size(&self) -> usize {
self.embed_tokens.embeddings().dim(1).unwrap_or(2560)
}
}