use anyhow::Result;
use candle_core::{DType, Device, Module, Tensor};
use candle_transformers::models::clip;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokenizers::Tokenizer;
use super::park;
pub(crate) const CLIP_CHUNK_LEN: usize = 77;
pub(crate) const CLIP_MAX_TOKENS_PER_CHUNK: usize = CLIP_CHUNK_LEN - 2;
pub(crate) const CLIP_BOS_TOKEN: &str = "<|startoftext|>";
pub(crate) const CLIP_EOS_TOKEN: &str = "<|endoftext|>";
pub(crate) fn long_prompts_enabled() -> bool {
std::env::var("MOLD_LONG_PROMPTS")
.map(|v| v == "1")
.unwrap_or(false)
}
fn clip_special_ids(tokenizer: &Tokenizer) -> (u32, u32) {
let bos = tokenizer.token_to_id(CLIP_BOS_TOKEN).unwrap_or(49406);
let eos = tokenizer.token_to_id(CLIP_EOS_TOKEN).unwrap_or(49407);
(bos, eos)
}
fn strip_specials(mut ids: Vec<u32>, bos_id: u32, eos_id: u32) -> Vec<u32> {
if ids.first() == Some(&bos_id) {
ids.remove(0);
}
if ids.last() == Some(&eos_id) {
ids.pop();
}
ids
}
pub(crate) fn chunk_token_ids(
raw_ids: &[u32],
max_per_chunk: usize,
bos_id: u32,
eos_id: u32,
) -> Vec<Vec<u32>> {
let pad_id = eos_id;
if raw_ids.is_empty() {
let mut chunk = Vec::with_capacity(CLIP_CHUNK_LEN);
chunk.push(bos_id);
chunk.push(eos_id);
chunk.resize(CLIP_CHUNK_LEN, pad_id);
return vec![chunk];
}
raw_ids
.chunks(max_per_chunk)
.map(|window| {
let mut chunk = Vec::with_capacity(CLIP_CHUNK_LEN);
chunk.push(bos_id);
chunk.extend_from_slice(window);
chunk.push(eos_id);
chunk.resize(CLIP_CHUNK_LEN, pad_id);
chunk
})
.collect()
}
pub(crate) fn tokenize_chunks(
prompt: &str,
tokenizer: &Tokenizer,
device: &Device,
) -> Result<Vec<Tensor>> {
let (bos_id, eos_id) = clip_special_ids(tokenizer);
let raw = tokenizer
.encode(prompt, false)
.map_err(|e| anyhow::anyhow!("CLIP tokenization failed: {e}"))?
.get_ids()
.to_vec();
let raw = strip_specials(raw, bos_id, eos_id);
let chunks = chunk_token_ids(&raw, CLIP_MAX_TOKENS_PER_CHUNK, bos_id, eos_id);
chunks
.into_iter()
.map(|ids| {
Tensor::new(ids.as_slice(), device)
.and_then(|t| t.unsqueeze(0))
.map_err(Into::into)
})
.collect()
}
pub fn config() -> clip::text_model::ClipTextConfig {
clip::text_model::ClipTextConfig {
vocab_size: 49408,
projection_dim: 768,
activation: clip::text_model::Activation::QuickGelu,
intermediate_size: 3072,
embed_dim: 768,
max_position_embeddings: 77,
pad_with: None,
num_hidden_layers: 12,
num_attention_heads: 12,
}
}
pub(crate) struct ClipEncoder {
pub model: Option<clip::text_model::ClipTextTransformer>,
pub tokenizer: Arc<Tokenizer>,
pub device: Device,
pub on_gpu: bool,
encoder_path: PathBuf,
parked_tensors: Option<HashMap<String, Tensor>>,
}
impl ClipEncoder {
#[allow(dead_code)]
pub fn load(
encoder_path: &PathBuf,
tokenizer_path: &PathBuf,
device: &Device,
dtype: DType,
progress: &crate::progress::ProgressReporter,
) -> Result<Self> {
Self::load_with_tokenizer(encoder_path, tokenizer_path, device, dtype, progress, None)
}
pub fn load_with_tokenizer(
encoder_path: &PathBuf,
tokenizer_path: &PathBuf,
device: &Device,
dtype: DType,
progress: &crate::progress::ProgressReporter,
cached_tokenizer: Option<Arc<Tokenizer>>,
) -> Result<Self> {
Self::load_with_tokenizer_and_tensors(
encoder_path,
tokenizer_path,
device,
dtype,
progress,
cached_tokenizer,
None,
)
}
pub fn load_with_tokenizer_and_tensors(
encoder_path: &PathBuf,
tokenizer_path: &PathBuf,
device: &Device,
dtype: DType,
progress: &crate::progress::ProgressReporter,
cached_tokenizer: Option<Arc<Tokenizer>>,
cached_tensors: Option<Arc<HashMap<String, Tensor>>>,
) -> Result<Self> {
let vb = if let Some(tensors) = cached_tensors {
park::varbuilder_from_parked(tensors.as_ref(), dtype, device)
} else {
crate::weight_loader::load_safetensors_with_progress(
std::slice::from_ref(encoder_path),
dtype,
device,
"CLIP-L",
progress,
)?
};
let model = clip::text_model::ClipTextTransformer::new(vb.pp("text_model"), &config())?;
let tokenizer = match cached_tokenizer {
Some(tok) => tok,
None => Arc::new(
Tokenizer::from_file(tokenizer_path)
.map_err(|e| anyhow::anyhow!("failed to load CLIP tokenizer: {e}"))?,
),
};
let on_gpu = crate::device::is_gpu(device);
Ok(Self {
model: Some(model),
tokenizer,
device: device.clone(),
on_gpu,
encoder_path: encoder_path.clone(),
parked_tensors: None,
})
}
pub fn tokenizer_arc(&self) -> Arc<Tokenizer> {
self.tokenizer.clone()
}
pub fn encode(
&mut self,
prompt: &str,
target_device: &Device,
target_dtype: DType,
) -> Result<Tensor> {
let emb = if long_prompts_enabled() {
self.encode_chunked(prompt)?
} else {
self.encode_truncated(prompt)?
};
Ok(emb.to_device(target_device)?.to_dtype(target_dtype)?)
}
fn encode_truncated(&self, prompt: &str) -> Result<Tensor> {
let clip = self
.model
.as_ref()
.ok_or_else(|| anyhow::anyhow!("CLIP model unavailable"))?;
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(|e| anyhow::anyhow!("CLIP tokenization failed: {e}"))?
.get_ids()
.to_vec();
tokens.truncate(CLIP_CHUNK_LEN);
let input_ids = Tensor::new(&tokens[..], &self.device)?.unsqueeze(0)?;
Ok(clip.forward(&input_ids)?)
}
fn encode_chunked(&self, prompt: &str) -> Result<Tensor> {
let clip = self
.model
.as_ref()
.ok_or_else(|| anyhow::anyhow!("CLIP model unavailable"))?;
let chunks = tokenize_chunks(prompt, &self.tokenizer, &self.device)?;
debug_assert!(!chunks.is_empty(), "tokenize_chunks always emits ≥1 chunk");
let mut pooled = Vec::with_capacity(chunks.len());
for chunk in &chunks {
pooled.push(clip.forward(chunk)?);
}
if pooled.len() == 1 {
return Ok(pooled.into_iter().next().expect("len==1"));
}
let stacked = Tensor::cat(&pooled, 0)?;
let mean = stacked.mean_keepdim(0)?;
Ok(mean)
}
pub fn drop_weights(&mut self) {
self.model = None;
self.parked_tensors = None;
}
pub fn reload(
&mut self,
encoder_path: &PathBuf,
dtype: DType,
progress: &crate::progress::ProgressReporter,
) -> Result<()> {
let vb = crate::weight_loader::load_safetensors_with_progress(
std::slice::from_ref(encoder_path),
dtype,
&self.device,
"CLIP-L",
progress,
)?;
self.model = Some(clip::text_model::ClipTextTransformer::new(
vb.pp("text_model"),
&config(),
)?);
Ok(())
}
pub fn park_to_cpu(&mut self) -> Result<()> {
if self.is_parked() {
self.model = None;
return Ok(());
}
let parked = park::load_tensors_to_cpu(std::slice::from_ref(&self.encoder_path))?;
self.parked_tensors = Some(parked);
self.model = None;
Ok(())
}
pub fn unpark_to_gpu(
&mut self,
dtype: DType,
progress: &crate::progress::ProgressReporter,
) -> Result<()> {
if self.model.is_some() {
return Ok(());
}
if let Some(parked) = self.parked_tensors.as_ref() {
let vb = park::varbuilder_from_parked(parked, dtype, &self.device);
self.model = Some(clip::text_model::ClipTextTransformer::new(
vb.pp("text_model"),
&config(),
)?);
return Ok(());
}
let path = self.encoder_path.clone();
self.reload(&path, dtype, progress)
}
pub fn is_parked(&self) -> bool {
self.model.is_none() && self.parked_tensors.is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_encoder() -> ClipEncoder {
let dummy_path = std::env::temp_dir().join("nonexistent-clip-tokenizer.json");
let tokenizer = Arc::new(tokenizers::Tokenizer::new(
tokenizers::models::wordpiece::WordPiece::default(),
));
ClipEncoder {
model: None,
tokenizer,
device: Device::Cpu,
on_gpu: false,
encoder_path: dummy_path,
parked_tensors: None,
}
}
#[test]
fn test_is_parked_state_machine() {
let mut e = make_test_encoder();
assert!(!e.is_parked());
e.parked_tensors = Some(HashMap::new());
assert!(e.is_parked());
e.drop_weights();
assert!(!e.is_parked());
assert!(e.parked_tensors.is_none());
}
#[test]
fn test_park_when_already_parked_is_noop() {
let mut e = make_test_encoder();
let mut map = HashMap::new();
map.insert(
"canary".to_string(),
Tensor::zeros((1,), DType::F32, &Device::Cpu).unwrap(),
);
e.parked_tensors = Some(map);
e.model = None;
assert!(e.is_parked());
e.park_to_cpu().expect("re-park is noop");
assert!(e.is_parked());
assert!(
e.parked_tensors.as_ref().unwrap().contains_key("canary"),
"re-park preserved the existing parked map"
);
}
}
#[cfg(test)]
mod chunking_tests {
use super::*;
const BOS_ID: u32 = 49406;
const EOS_ID: u32 = 49407;
#[test]
fn tokenize_chunks_short_prompt_one_chunk() {
let raw: Vec<u32> = (1..=10).collect();
let chunks = chunk_token_ids(&raw, CLIP_MAX_TOKENS_PER_CHUNK, BOS_ID, EOS_ID);
assert_eq!(chunks.len(), 1, "≤75 content tokens fit in one chunk");
let chunk = &chunks[0];
assert_eq!(chunk.len(), CLIP_CHUNK_LEN);
assert_eq!(chunk[0], BOS_ID);
assert_eq!(&chunk[1..=10], raw.as_slice());
assert_eq!(chunk[11], EOS_ID);
assert!(chunk[12..].iter().all(|&t| t == EOS_ID));
}
#[test]
fn tokenize_chunks_two_chunks() {
let raw: Vec<u32> = (1..=100).collect();
let chunks = chunk_token_ids(&raw, CLIP_MAX_TOKENS_PER_CHUNK, BOS_ID, EOS_ID);
assert_eq!(chunks.len(), 2, "100 tokens straddle the 75-token boundary");
assert_eq!(chunks[0][0], BOS_ID);
assert_eq!(&chunks[0][1..=75], &raw[..75]);
assert_eq!(chunks[0][76], EOS_ID);
assert_eq!(chunks[1][0], BOS_ID);
assert_eq!(&chunks[1][1..=25], &raw[75..]);
assert_eq!(chunks[1][26], EOS_ID);
assert!(chunks[1][27..].iter().all(|&t| t == EOS_ID));
}
#[test]
fn tokenize_chunks_exact_75_boundary() {
let raw: Vec<u32> = (1..=75).collect();
let chunks = chunk_token_ids(&raw, CLIP_MAX_TOKENS_PER_CHUNK, BOS_ID, EOS_ID);
assert_eq!(
chunks.len(),
1,
"exactly 75 content tokens fit in one chunk"
);
assert_eq!(chunks[0].len(), CLIP_CHUNK_LEN);
assert_eq!(chunks[0][0], BOS_ID);
assert_eq!(chunks[0][76], EOS_ID, "EOS lands in the last slot");
}
#[test]
fn tokenize_chunks_empty_prompt() {
let raw: Vec<u32> = Vec::new();
let chunks = chunk_token_ids(&raw, CLIP_MAX_TOKENS_PER_CHUNK, BOS_ID, EOS_ID);
assert_eq!(chunks.len(), 1);
let chunk = &chunks[0];
assert_eq!(chunk.len(), CLIP_CHUNK_LEN);
assert_eq!(chunk[0], BOS_ID);
assert_eq!(chunk[1], EOS_ID);
assert!(
chunk[2..].iter().all(|&t| t == EOS_ID),
"remaining slots are EOS-padded",
);
}
#[test]
fn tokenize_chunks_padding_uses_eos() {
const CUSTOM_EOS: u32 = 12345;
let raw: Vec<u32> = vec![10, 20, 30];
let chunks = chunk_token_ids(&raw, CLIP_MAX_TOKENS_PER_CHUNK, BOS_ID, CUSTOM_EOS);
assert_eq!(chunks.len(), 1);
let chunk = &chunks[0];
assert_eq!(chunk[0], BOS_ID);
assert_eq!(chunk[1..=3], [10, 20, 30]);
assert_eq!(chunk[4], CUSTOM_EOS, "EOS marker after content");
assert!(
chunk[5..].iter().all(|&t| t == CUSTOM_EOS),
"pad_id == eos_id: padding fills with EOS",
);
}
#[test]
fn tokenize_chunks_three_chunks_at_150_tokens() {
let raw: Vec<u32> = (1..=150).collect();
let chunks = chunk_token_ids(&raw, CLIP_MAX_TOKENS_PER_CHUNK, BOS_ID, EOS_ID);
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[1][0], BOS_ID);
assert_eq!(&chunks[1][1..=75], &raw[75..]);
assert_eq!(chunks[1][76], EOS_ID);
let raw: Vec<u32> = (1..=151).collect();
let chunks = chunk_token_ids(&raw, CLIP_MAX_TOKENS_PER_CHUNK, BOS_ID, EOS_ID);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[2][1], 151, "last chunk holds the trailing token");
assert_eq!(chunks[2][2], EOS_ID, "EOS immediately after the lone token");
}
#[test]
fn strip_specials_removes_leading_bos_and_trailing_eos() {
let stripped = strip_specials(vec![BOS_ID, 10, 20, EOS_ID], BOS_ID, EOS_ID);
assert_eq!(stripped, vec![10, 20]);
}
#[test]
fn strip_specials_leaves_clean_input_alone() {
let stripped = strip_specials(vec![10, 20, 30], BOS_ID, EOS_ID);
assert_eq!(stripped, vec![10, 20, 30]);
}
#[test]
fn strip_specials_handles_only_specials() {
let stripped = strip_specials(vec![BOS_ID, EOS_ID], BOS_ID, EOS_ID);
assert_eq!(stripped, vec![] as Vec<u32>);
}
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
use std::sync::{Mutex, OnceLock};
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(|p| p.into_inner())
}
#[test]
fn long_prompts_enabled_env_default_off() {
let _guard = env_lock();
unsafe { std::env::remove_var("MOLD_LONG_PROMPTS") };
assert!(!long_prompts_enabled(), "default must be off");
}
#[test]
fn long_prompts_enabled_env_set_to_1() {
let _guard = env_lock();
unsafe { std::env::set_var("MOLD_LONG_PROMPTS", "1") };
let on = long_prompts_enabled();
unsafe { std::env::remove_var("MOLD_LONG_PROMPTS") };
assert!(on, "MOLD_LONG_PROMPTS=1 must enable chunking");
}
#[test]
fn long_prompts_enabled_env_other_value_off() {
let _guard = env_lock();
unsafe { std::env::set_var("MOLD_LONG_PROMPTS", "true") };
let on = long_prompts_enabled();
unsafe { std::env::remove_var("MOLD_LONG_PROMPTS") };
assert!(!on, "only the literal '1' should enable chunking");
}
fn try_load_clip_tokenizer() -> Option<Tokenizer> {
let candidates = [
std::env::var("MOLD_TEST_CLIP_TOKENIZER").ok(),
std::env::var("HOME")
.ok()
.map(|h| format!("{h}/.mold/models/shared/clip-vit-large-patch14/tokenizer.json")),
];
for path in candidates.into_iter().flatten() {
if std::path::Path::new(&path).exists() {
if let Ok(tok) = Tokenizer::from_file(&path) {
return Some(tok);
}
}
}
None
}
#[test]
fn tokenize_chunks_short_prompt_with_real_tokenizer() {
let Some(tokenizer) = try_load_clip_tokenizer() else {
eprintln!(
"skipping: no CLIP tokenizer fixture (set MOLD_TEST_CLIP_TOKENIZER \
or place tokenizer.json under ~/.mold/models/shared/clip-vit-large-patch14/)",
);
return;
};
let chunks = tokenize_chunks("a cat", &tokenizer, &Device::Cpu)
.expect("real tokenizer must accept a short prompt");
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].dims(), &[1, CLIP_CHUNK_LEN]);
}
#[test]
fn tokenize_chunks_long_prompt_with_real_tokenizer_grows_chunks() {
let Some(tokenizer) = try_load_clip_tokenizer() else {
eprintln!(
"skipping: no CLIP tokenizer fixture (set MOLD_TEST_CLIP_TOKENIZER \
or place tokenizer.json under ~/.mold/models/shared/clip-vit-large-patch14/)",
);
return;
};
let prompt = "alpha ".repeat(200);
let short =
tokenize_chunks("a cat", &tokenizer, &Device::Cpu).expect("short prompt tokenizes");
let long = tokenize_chunks(prompt.trim(), &tokenizer, &Device::Cpu)
.expect("long prompt tokenizes");
assert!(
long.len() > short.len(),
"long prompt produces more chunks than a short one ({} vs {})",
long.len(),
short.len(),
);
for chunk in &long {
assert_eq!(chunk.dims(), &[1, CLIP_CHUNK_LEN]);
}
}
}