use anyhow::Result;
use candle_core::{DType, Device, Tensor};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokenizers::Tokenizer;
use super::park;
use super::qwen3_bf16::{Bf16Qwen3Encoder, Qwen3BF16Config};
use super::qwen3_gguf::GgufQwen3Encoder;
pub(crate) enum Qwen3Model {
BF16(Bf16Qwen3Encoder),
Quantized(GgufQwen3Encoder),
}
impl Qwen3Model {
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
match self {
Self::BF16(m) => m.forward(input_ids),
Self::Quantized(m) => m.forward(input_ids),
}
}
pub fn forward_with_layers(
&mut self,
input_ids: &Tensor,
layer_indices: &[usize],
) -> Result<Tensor> {
match self {
Self::BF16(m) => m.forward_with_layers(input_ids, layer_indices),
Self::Quantized(m) => m.forward_with_layers(input_ids, layer_indices),
}
}
}
pub(crate) struct Qwen3Encoder {
pub model: Option<Qwen3Model>,
pub tokenizer: Arc<Tokenizer>,
pub device: Device,
pub on_gpu: bool,
pub is_quantized: bool,
encoder_paths: Vec<PathBuf>,
dtype: DType,
bf16_config: Qwen3BF16Config,
parked_tensors: Option<HashMap<String, Tensor>>,
}
fn format_prompt_for_qwen3(prompt: &str) -> String {
format!(
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
prompt
)
}
fn format_prompt_for_flux2(prompt: &str) -> String {
format!("{}<think>\n\n</think>\n\n", format_prompt_for_qwen3(prompt))
}
impl Qwen3Encoder {
#[allow(dead_code)]
pub fn load_bf16(
encoder_paths: &[PathBuf],
tokenizer_path: &PathBuf,
device: &Device,
dtype: DType,
bf16_config: &Qwen3BF16Config,
progress: &crate::progress::ProgressReporter,
) -> Result<Self> {
Self::load_bf16_with_tokenizer(
encoder_paths,
tokenizer_path,
None,
device,
dtype,
bf16_config,
progress,
)
}
#[allow(clippy::too_many_arguments)]
pub fn load_bf16_with_tokenizer(
encoder_paths: &[PathBuf],
tokenizer_path: &PathBuf,
tokenizer: Option<Arc<Tokenizer>>,
device: &Device,
dtype: DType,
bf16_config: &Qwen3BF16Config,
progress: &crate::progress::ProgressReporter,
) -> Result<Self> {
let vb = crate::weight_loader::load_safetensors_with_progress(
encoder_paths,
dtype,
device,
"Qwen3 encoder",
progress,
)?;
let model = Qwen3Model::BF16(Bf16Qwen3Encoder::load(bf16_config, vb)?);
let tokenizer = tokenizer.map(Ok).unwrap_or_else(|| {
Tokenizer::from_file(tokenizer_path)
.map(Arc::new)
.map_err(|e| anyhow::anyhow!("failed to load Qwen3 tokenizer: {e}"))
})?;
let on_gpu = crate::device::is_gpu(device);
Ok(Self {
model: Some(model),
tokenizer,
device: device.clone(),
on_gpu,
is_quantized: false,
encoder_paths: encoder_paths.to_vec(),
dtype,
bf16_config: *bf16_config,
parked_tensors: None,
})
}
#[allow(dead_code)]
pub fn load_gguf(
gguf_path: &Path,
tokenizer_path: &PathBuf,
device: &Device,
bf16_config: &Qwen3BF16Config,
) -> Result<Self> {
Self::load_gguf_with_tokenizer(gguf_path, tokenizer_path, None, device, bf16_config)
}
pub fn load_gguf_with_tokenizer(
gguf_path: &Path,
tokenizer_path: &PathBuf,
tokenizer: Option<Arc<Tokenizer>>,
device: &Device,
bf16_config: &Qwen3BF16Config,
) -> Result<Self> {
let model = Qwen3Model::Quantized(GgufQwen3Encoder::load(gguf_path, device)?);
let tokenizer = tokenizer.map(Ok).unwrap_or_else(|| {
Tokenizer::from_file(tokenizer_path)
.map(Arc::new)
.map_err(|e| anyhow::anyhow!("failed to load Qwen3 tokenizer: {e}"))
})?;
let on_gpu = crate::device::is_gpu(device);
Ok(Self {
model: Some(model),
tokenizer,
device: device.clone(),
on_gpu,
is_quantized: true,
encoder_paths: vec![gguf_path.to_path_buf()],
dtype: DType::F32, bf16_config: *bf16_config,
parked_tensors: None,
})
}
pub fn encode(
&mut self,
prompt: &str,
target_device: &Device,
target_dtype: DType,
) -> Result<(Tensor, usize)> {
let model = self
.model
.as_mut()
.ok_or_else(|| anyhow::anyhow!("Qwen3 model unavailable (weights dropped)"))?;
let formatted = format_prompt_for_qwen3(prompt);
let tokens = self
.tokenizer
.encode(formatted.as_str(), true)
.map_err(|e| anyhow::anyhow!("Qwen3 tokenization failed: {e}"))?
.get_ids()
.to_vec();
let token_count = tokens.len();
let input_ids = Tensor::from_vec(tokens, (1, token_count), &self.device)?;
let emb = model.forward(&input_ids)?;
let emb = emb.to_device(target_device)?.to_dtype(target_dtype)?;
Ok((emb, token_count))
}
pub fn encode_with_layers(
&mut self,
prompt: &str,
target_device: &Device,
target_dtype: DType,
layer_indices: &[usize],
) -> Result<(Tensor, usize)> {
let model = self
.model
.as_mut()
.ok_or_else(|| anyhow::anyhow!("Qwen3 model unavailable (weights dropped)"))?;
let formatted = format_prompt_for_flux2(prompt);
let tokens = self
.tokenizer
.encode(formatted.as_str(), true)
.map_err(|e| anyhow::anyhow!("Qwen3 tokenization failed: {e}"))?
.get_ids()
.to_vec();
let token_count = tokens.len();
let input_ids = Tensor::from_vec(tokens, (1, token_count), &self.device)?;
let emb = model.forward_with_layers(&input_ids, layer_indices)?;
let emb = emb.to_device(target_device)?.to_dtype(target_dtype)?;
Ok((emb, token_count))
}
pub fn drop_weights(&mut self) {
self.model = None;
self.parked_tensors = None;
}
pub fn reload(&mut self, progress: &crate::progress::ProgressReporter) -> Result<()> {
if self.is_quantized {
self.model = Some(Qwen3Model::Quantized(GgufQwen3Encoder::load(
&self.encoder_paths[0],
&self.device,
)?));
} else {
let vb = crate::weight_loader::load_safetensors_with_progress(
&self.encoder_paths,
self.dtype,
&self.device,
"Qwen3 encoder",
progress,
)?;
self.model = Some(Qwen3Model::BF16(Bf16Qwen3Encoder::load(
&self.bf16_config,
vb,
)?));
}
Ok(())
}
pub fn park_to_cpu(&mut self) -> Result<()> {
if self.is_parked() {
self.model = None;
return Ok(());
}
if self.is_quantized {
self.drop_weights();
return Ok(());
}
let parked = park::load_tensors_to_cpu(&self.encoder_paths)?;
self.parked_tensors = Some(parked);
self.model = None;
Ok(())
}
pub fn unpark_to_gpu(&mut self, progress: &crate::progress::ProgressReporter) -> Result<()> {
if self.model.is_some() {
return Ok(());
}
if let Some(parked) = self.parked_tensors.as_ref() {
let vb = park::varbuilder_from_parked(parked, self.dtype, &self.device);
self.model = Some(Qwen3Model::BF16(Bf16Qwen3Encoder::load(
&self.bf16_config,
vb,
)?));
return Ok(());
}
self.reload(progress)
}
pub fn is_parked(&self) -> bool {
self.model.is_none() && self.parked_tensors.is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn z_image_chat_template() {
let result = format_prompt_for_qwen3("a cat");
assert!(result.starts_with("<|im_start|>user\n"));
assert!(result.contains("a cat"));
assert!(result.ends_with("<|im_start|>assistant\n"));
assert!(!result.contains("<think>"));
}
#[test]
fn flux2_chat_template_includes_thinking() {
let result = format_prompt_for_flux2("a sunset");
assert!(result.starts_with("<|im_start|>user\n"));
assert!(result.contains("a sunset"));
assert!(result.contains("<|im_start|>assistant\n"));
assert!(result.contains("<think>\n\n</think>\n\n"));
assert!(result.ends_with("<think>\n\n</think>\n\n"));
}
#[test]
fn templates_differ_only_in_thinking_block() {
let z = format_prompt_for_qwen3("test");
let f = format_prompt_for_flux2("test");
assert_eq!(f, format!("{z}<think>\n\n</think>\n\n"));
}
#[test]
fn test_qwen3_template_empty_prompt() {
let result = format_prompt_for_qwen3("");
assert_eq!(
result,
"<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n"
);
let flux_result = format_prompt_for_flux2("");
assert!(flux_result.contains("<|im_end|>"));
assert!(flux_result.ends_with("<think>\n\n</think>\n\n"));
}
#[test]
fn test_flux2_template_preserves_special_chars() {
let prompt = "a <robot> in {brackets} & symbols <>";
let result = format_prompt_for_flux2(prompt);
assert!(result.contains("<robot>"));
assert!(result.contains("{brackets}"));
assert!(result.contains("& symbols <>"));
assert!(result.starts_with("<|im_start|>user\n"));
assert!(result.contains("<|im_end|>"));
}
#[test]
fn test_templates_exact_structure() {
let prompt = "hello";
let qwen3 = format_prompt_for_qwen3(prompt);
assert_eq!(
qwen3,
"<|im_start|>user\nhello<|im_end|>\n<|im_start|>assistant\n"
);
let flux2 = format_prompt_for_flux2(prompt);
assert_eq!(
flux2,
"<|im_start|>user\nhello<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
);
}
}