use crate::embedding::types::{EmbeddingProvider, EmbeddingResult};
pub struct Qwen3Provider {
inner: fastembed::Qwen3TextEmbedding,
model_id: String,
dim: usize,
}
pub const QWEN3_EMBEDDING_0_6B: &str = "Qwen/Qwen3-Embedding-0.6B";
pub const QWEN3_EMBEDDING_8B: &str = "Qwen/Qwen3-Embedding-8B";
pub const QWEN3_VL_EMBEDDING_2B: &str = "Qwen/Qwen3-VL-Embedding-2B";
const DEFAULT_MAX_LENGTH: usize = 512;
impl Qwen3Provider {
pub fn new(model_id: &str) -> anyhow::Result<Self> {
Self::with_options(
model_id,
candle_core::Device::Cpu,
candle_core::DType::F32,
DEFAULT_MAX_LENGTH,
)
}
pub fn with_options(
model_id: &str,
device: candle_core::Device,
dtype: candle_core::DType,
max_length: usize,
) -> anyhow::Result<Self> {
let te = fastembed::Qwen3TextEmbedding::from_hf(model_id, &device, dtype, max_length)?;
let dim = te.config().hidden_size;
Ok(Self {
inner: te,
model_id: model_id.to_string(),
dim,
})
}
pub fn model_id(&self) -> &str {
&self.model_id
}
}
impl EmbeddingProvider for Qwen3Provider {
fn dim(&self) -> usize {
self.dim
}
fn name(&self) -> &str {
&self.model_id
}
fn embed(&self, text: &str) -> anyhow::Result<EmbeddingResult> {
let embeddings = self.inner.embed(&[text])?;
let vector = embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("empty embedding output"))?;
let preview = if text.len() > 64 {
format!("{}…", &text[..64])
} else {
text.to_string()
};
Ok(EmbeddingResult {
vector,
text_preview: preview,
})
}
fn embed_batch(&self, texts: &[&str]) -> anyhow::Result<Vec<EmbeddingResult>> {
if texts.is_empty() {
return Ok(vec![]);
}
let embeddings = self.inner.embed(texts)?;
Ok(embeddings
.into_iter()
.zip(texts.iter())
.map(|(vector, &text)| {
let preview = if text.len() > 64 {
format!("{}…", &text[..64])
} else {
text.to_string()
};
EmbeddingResult {
vector,
text_preview: preview,
}
})
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn model_id_constants() {
assert_eq!(QWEN3_EMBEDDING_0_6B, "Qwen/Qwen3-Embedding-0.6B");
assert_eq!(QWEN3_EMBEDDING_8B, "Qwen/Qwen3-Embedding-8B");
assert_eq!(QWEN3_VL_EMBEDDING_2B, "Qwen/Qwen3-VL-Embedding-2B");
}
#[test]
#[ignore = "requires model download"]
fn embed_with_qwen3_0_6b() {
let provider = Qwen3Provider::new(QWEN3_EMBEDDING_0_6B).unwrap();
let result = provider.embed("hello world").unwrap();
assert!(!result.vector.is_empty());
assert_eq!(result.vector.len(), provider.dim());
}
#[test]
#[ignore = "requires model download"]
fn embed_batch_with_qwen3() {
let provider = Qwen3Provider::new(QWEN3_EMBEDDING_0_6B).unwrap();
let results = provider
.embed_batch(&["hello", "world", "foo bar"])
.unwrap();
assert_eq!(results.len(), 3);
for r in &results {
assert_eq!(r.vector.len(), provider.dim());
}
}
}