use anyhow::Result;
use candle_core::{DType, IndexOp, Module, Tensor, D};
use candle_transformers::models::stable_diffusion::clip::{
self, ClipTextTransformer, Config as ClipConfig,
};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokenizers::Tokenizer;
use super::park;
use super::t5::T5Encoder;
struct ClipWithTokenizer {
model: Option<ClipTextTransformer>,
config: ClipConfig,
tokenizer: Arc<Tokenizer>,
max_position_embeddings: usize,
device: candle_core::Device,
parked_tensors: Option<HashMap<String, Tensor>>,
}
impl ClipWithTokenizer {
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
fn load(
encoder_path: &PathBuf,
tokenizer_path: &PathBuf,
config: ClipConfig,
max_position_embeddings: usize,
device: &candle_core::Device,
dtype: DType,
component: &str,
progress: &crate::progress::ProgressReporter,
) -> Result<Self> {
Self::load_with_tokenizer(
encoder_path,
tokenizer_path,
None,
config,
max_position_embeddings,
device,
dtype,
component,
progress,
)
}
#[allow(clippy::too_many_arguments)]
fn load_with_tokenizer(
encoder_path: &PathBuf,
tokenizer_path: &PathBuf,
cached_tokenizer: Option<Arc<Tokenizer>>,
config: ClipConfig,
max_position_embeddings: usize,
device: &candle_core::Device,
dtype: DType,
component: &str,
progress: &crate::progress::ProgressReporter,
) -> Result<Self> {
let vb = crate::weight_loader::load_safetensors_with_progress(
std::slice::from_ref(encoder_path),
dtype,
device,
component,
progress,
)?;
let model = ClipTextTransformer::new(vb, &config)?;
let tokenizer = cached_tokenizer.map(Ok).unwrap_or_else(|| {
Tokenizer::from_file(tokenizer_path)
.map(Arc::new)
.map_err(|e| anyhow::anyhow!("failed to load CLIP tokenizer: {e}"))
})?;
Ok(Self {
model: Some(model),
config,
tokenizer,
max_position_embeddings,
device: device.clone(),
parked_tensors: None,
})
}
fn encode_text_to_embedding(&self, prompt: &str) -> Result<(Tensor, Tensor)> {
let clip = self
.model
.as_ref()
.ok_or_else(|| anyhow::anyhow!("CLIP model unavailable (weights dropped)"))?;
let pad_id = match &self.config.pad_with {
Some(padding) => *self
.tokenizer
.get_vocab(true)
.get(padding.as_str())
.ok_or_else(|| anyhow::anyhow!("Failed to tokenize CLIP padding"))?,
None => *self
.tokenizer
.get_vocab(true)
.get("<|endoftext|>")
.ok_or_else(|| anyhow::anyhow!("Failed to tokenize CLIP end-of-text"))?,
};
let raw_tokens = self
.tokenizer
.encode(prompt, true)
.map_err(|e| anyhow::anyhow!("CLIP tokenization failed: {e}"))?
.get_ids()
.to_vec();
let (tokens, eos_position) =
prepare_clip_tokens(raw_tokens, self.max_position_embeddings, pad_id);
let tokens = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
let (_text_embeddings, text_embeddings_penultimate) =
clip.forward_until_encoder_layer(&tokens, usize::MAX, -2)?;
let text_embeddings_pooled = {
let (last_hidden, _) = clip.forward_until_encoder_layer(&tokens, usize::MAX, 0)?;
last_hidden.i((0, eos_position, ..))?
};
Ok((text_embeddings_penultimate, text_embeddings_pooled))
}
fn drop_weights(&mut self) {
self.model = None;
self.parked_tensors = None;
}
fn reload(
&mut self,
encoder_path: &PathBuf,
dtype: DType,
component: &str,
progress: &crate::progress::ProgressReporter,
) -> Result<()> {
let vb = crate::weight_loader::load_safetensors_with_progress(
std::slice::from_ref(encoder_path),
dtype,
&self.device,
component,
progress,
)?;
self.model = Some(ClipTextTransformer::new(vb, &self.config)?);
Ok(())
}
fn park_to_cpu(&mut self, encoder_path: &PathBuf) -> Result<()> {
if self.is_parked() {
self.model = None;
return Ok(());
}
let parked = park::load_tensors_to_cpu(std::slice::from_ref(encoder_path))?;
self.parked_tensors = Some(parked);
self.model = None;
Ok(())
}
fn unpark_to_gpu(
&mut self,
encoder_path: &PathBuf,
dtype: DType,
component: &str,
progress: &crate::progress::ProgressReporter,
) -> Result<()> {
if self.model.is_some() {
return Ok(());
}
if let Some(parked) = self.parked_tensors.as_ref() {
let vb = park::varbuilder_from_parked(parked, dtype, &self.device);
self.model = Some(ClipTextTransformer::new(vb, &self.config)?);
return Ok(());
}
self.reload(encoder_path, dtype, component, progress)
}
fn is_parked(&self) -> bool {
self.model.is_none() && self.parked_tensors.is_some()
}
}
pub(crate) struct SD3TripleEncoder {
clip_l: ClipWithTokenizer,
clip_g: ClipWithTokenizer,
clip_g_text_projection: candle_nn::Linear,
t5: T5Encoder,
clip_l_path: PathBuf,
clip_g_path: PathBuf,
t5_path: PathBuf,
pub on_gpu: bool,
}
impl SD3TripleEncoder {
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
pub fn load(
clip_l_path: &PathBuf,
clip_l_tokenizer_path: &PathBuf,
clip_g_path: &PathBuf,
clip_g_tokenizer_path: &PathBuf,
t5_path: &PathBuf,
t5_tokenizer_path: &PathBuf,
device: &candle_core::Device,
dtype: DType,
progress: &crate::progress::ProgressReporter,
) -> Result<Self> {
Self::load_with_tokenizers(
clip_l_path,
clip_l_tokenizer_path,
None,
clip_g_path,
clip_g_tokenizer_path,
None,
t5_path,
t5_tokenizer_path,
None,
device,
dtype,
progress,
)
}
#[allow(clippy::too_many_arguments)]
pub fn load_with_tokenizers(
clip_l_path: &PathBuf,
clip_l_tokenizer_path: &PathBuf,
clip_l_tokenizer: Option<Arc<Tokenizer>>,
clip_g_path: &PathBuf,
clip_g_tokenizer_path: &PathBuf,
clip_g_tokenizer: Option<Arc<Tokenizer>>,
t5_path: &PathBuf,
t5_tokenizer_path: &PathBuf,
t5_tokenizer: Option<Arc<Tokenizer>>,
device: &candle_core::Device,
dtype: DType,
progress: &crate::progress::ProgressReporter,
) -> Result<Self> {
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::load_with_tokenizer(
clip_l_path,
clip_l_tokenizer_path,
clip_l_tokenizer,
clip::Config::sdxl(),
max_position_embeddings,
device,
dtype,
"SD3 CLIP-L",
progress,
)?;
let clip_g_vb = crate::weight_loader::load_safetensors_with_progress(
std::slice::from_ref(clip_g_path),
dtype,
device,
"SD3 CLIP-G projection",
progress,
)?;
let text_projection =
candle_nn::linear_no_bias(1280, 1280, clip_g_vb.pp("text_projection"))?;
let clip_g = ClipWithTokenizer::load_with_tokenizer(
clip_g_path,
clip_g_tokenizer_path,
clip_g_tokenizer,
clip::Config::sdxl2(),
max_position_embeddings,
device,
dtype,
"SD3 CLIP-G",
progress,
)?;
let on_gpu = crate::device::is_gpu(device);
let t5 = T5Encoder::load_with_tokenizer(
t5_path,
t5_tokenizer_path,
device,
dtype,
progress,
t5_tokenizer,
)?;
Ok(Self {
clip_l,
clip_g,
clip_g_text_projection: text_projection,
t5,
clip_l_path: clip_l_path.clone(),
clip_g_path: clip_g_path.clone(),
t5_path: t5_path.clone(),
on_gpu,
})
}
pub fn encode(
&mut self,
prompt: &str,
target_device: &candle_core::Device,
target_dtype: DType,
) -> Result<(Tensor, Tensor)> {
let (clip_l_embeddings, clip_l_pooled) = self.clip_l.encode_text_to_embedding(prompt)?;
let (clip_g_embeddings, clip_g_pooled) = self.clip_g.encode_text_to_embedding(prompt)?;
let clip_g_pooled_projected = self
.clip_g_text_projection
.forward(&clip_g_pooled.unsqueeze(0)?)?
.squeeze(0)?;
let y = Tensor::cat(&[&clip_l_pooled, &clip_g_pooled_projected], 0)?.unsqueeze(0)?;
let clip_embeddings_concat = Tensor::cat(
&[&clip_l_embeddings, &clip_g_embeddings],
D::Minus1,
)?
.pad_with_zeros(D::Minus1, 0, 2048)?;
let t5_embeddings = self
.t5
.encode(prompt, target_device, target_dtype)?
.to_dtype(DType::F16)?;
let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?;
let context = context.to_device(target_device)?.to_dtype(target_dtype)?;
let y = y.to_device(target_device)?.to_dtype(target_dtype)?;
Ok((context, y))
}
pub fn drop_weights(&mut self) {
self.clip_l.drop_weights();
self.clip_g.drop_weights();
self.t5.drop_weights();
}
pub fn reload(
&mut self,
dtype: DType,
progress: &crate::progress::ProgressReporter,
) -> Result<()> {
self.clip_l
.reload(&self.clip_l_path, dtype, "SD3 CLIP-L", progress)?;
self.clip_g
.reload(&self.clip_g_path, dtype, "SD3 CLIP-G", progress)?;
self.t5.reload(&self.t5_path, dtype, progress)?;
Ok(())
}
pub fn park_to_cpu(&mut self) -> Result<()> {
self.clip_l.park_to_cpu(&self.clip_l_path)?;
self.clip_g.park_to_cpu(&self.clip_g_path)?;
self.t5.park_to_cpu()?;
Ok(())
}
pub fn unpark_to_gpu(
&mut self,
dtype: DType,
progress: &crate::progress::ProgressReporter,
) -> Result<()> {
self.clip_l
.unpark_to_gpu(&self.clip_l_path, dtype, "SD3 CLIP-L", progress)?;
self.clip_g
.unpark_to_gpu(&self.clip_g_path, dtype, "SD3 CLIP-G", progress)?;
self.t5.unpark_to_gpu(dtype, progress)?;
Ok(())
}
pub fn is_parked(&self) -> bool {
self.clip_l.is_parked() && self.clip_g.is_parked() && self.t5.is_parked()
}
pub fn is_loaded(&self) -> bool {
self.clip_l.model.is_some() && self.clip_g.model.is_some() && self.t5.model.is_some()
}
}
fn prepare_clip_tokens(mut raw_tokens: Vec<u32>, max_len: usize, pad_id: u32) -> (Vec<u32>, usize) {
let original_len = raw_tokens.len();
if original_len > max_len {
let eos_id = *raw_tokens
.last()
.expect("original_len > max_len implies non-empty");
raw_tokens.truncate(max_len);
if let Some(last) = raw_tokens.last_mut() {
*last = eos_id;
}
tracing::debug!(
"SD3 CLIP prompt exceeded {} tokens ({} raw); truncated with EOS preserved",
max_len,
original_len,
);
}
let eos_position = raw_tokens.len().saturating_sub(1);
while raw_tokens.len() < max_len {
raw_tokens.push(pad_id);
}
(raw_tokens, eos_position)
}
#[cfg(test)]
mod tests {
use super::prepare_clip_tokens;
const MAX_LEN: usize = 77;
const PAD_ID: u32 = 0;
const EOS_ID: u32 = 49407;
#[test]
fn pads_short_prompt_to_max_len() {
let raw = vec![49406, 10, 20, 30, EOS_ID]; let (tokens, eos) = prepare_clip_tokens(raw, MAX_LEN, PAD_ID);
assert_eq!(tokens.len(), MAX_LEN, "must pad up to max_len");
assert_eq!(eos, 4, "eos_position tracks the raw EOS slot");
assert_eq!(tokens[4], EOS_ID, "EOS preserved at original position");
assert_eq!(tokens[5], PAD_ID, "pads follow the real tokens");
assert_eq!(*tokens.last().unwrap(), PAD_ID);
}
#[test]
fn leaves_exact_length_untouched() {
let mut raw: Vec<u32> = (1..MAX_LEN as u32).collect();
raw.push(EOS_ID);
assert_eq!(raw.len(), MAX_LEN);
let (tokens, eos) = prepare_clip_tokens(raw.clone(), MAX_LEN, PAD_ID);
assert_eq!(tokens.len(), MAX_LEN);
assert_eq!(eos, MAX_LEN - 1);
assert_eq!(tokens, raw);
}
#[test]
fn truncates_overlong_prompt_preserving_eos() {
let mut raw: Vec<u32> = (1..=131).collect();
raw.push(EOS_ID);
assert_eq!(raw.len(), 132);
let (tokens, eos) = prepare_clip_tokens(raw, MAX_LEN, PAD_ID);
assert_eq!(tokens.len(), MAX_LEN, "overlong sequence must be truncated");
assert_eq!(eos, MAX_LEN - 1, "eos_position must land on the last slot");
assert_eq!(
tokens[MAX_LEN - 1],
EOS_ID,
"EOS must be preserved in the final slot so pooled output reads EOS hidden state",
);
}
#[test]
fn handles_empty_input() {
let (tokens, eos) = prepare_clip_tokens(Vec::new(), MAX_LEN, PAD_ID);
assert_eq!(tokens.len(), MAX_LEN);
assert_eq!(eos, 0);
assert!(tokens.iter().all(|t| *t == PAD_ID));
}
}