use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::{Error, Result};
const DEFAULT_GEMMA_QUERY_TASK: &str = "search documents";
const DEFAULT_QWEN3_RETRIEVAL_INSTRUCTION: &str =
"Given a web search query, retrieve relevant passages that answer the query";
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct PreparedEmbeddingInput {
pub(crate) token_ids: Vec<u32>,
pub(crate) text: String,
}
impl PreparedEmbeddingInput {
pub(crate) fn new(token_ids: Vec<u32>, text: String) -> Result<Self> {
if token_ids.is_empty() {
return Err(Error::EmptyPreparedEmbeddingInput);
}
Ok(Self { token_ids, text })
}
#[must_use]
pub(crate) fn token_count(&self) -> usize {
self.token_ids.len()
}
}
#[derive(Debug, Default)]
pub struct BatchItem<M> {
pub meta: M,
pub role: EmbeddingRole,
pub text: String,
pub title: Option<String>,
pub token_count: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Dialect {
#[default]
OpenAI,
DeepInfra,
#[serde(
rename = "llamacpp",
alias = "llama-cpp",
alias = "llama_cpp",
alias = "llama.cpp"
)]
LlamaCpp,
}
impl Dialect {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::OpenAI => "openai",
Self::DeepInfra => "deepinfra",
Self::LlamaCpp => "llamacpp",
}
}
}
impl std::fmt::Display for Dialect {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
pub type ProviderDialect = Dialect;
#[derive(Clone)]
pub enum Tokenizer {
Characters,
Tiktoken {
encoding: String,
tokenizer: Arc<tiktoken_rs::CoreBPE>,
},
HuggingFace {
model_id: String,
tokenizer: Arc<tokenizers::Tokenizer>,
},
}
impl Tokenizer {
pub(crate) fn prepare(&self, text: String) -> Result<PreparedEmbeddingInput> {
let token_ids = match self {
Self::Characters => {
return Err(Error::InvalidConfiguration {
message: "embedding preparation requires a tokenizer that yields model token ids; the characters tokenizer only counts characters".to_string(),
});
}
Self::Tiktoken { tokenizer, .. } => tokenizer.encode_ordinary(&text),
Self::HuggingFace { tokenizer, .. } => tokenizer
.encode(text.as_str(), false)
.map(|encoding| encoding.get_ids().to_vec())
.map_err(|error| Error::InvalidConfiguration {
message: format!("failed to encode with HF tokenizer: {error}"),
})?,
};
PreparedEmbeddingInput::new(token_ids, text)
}
}
impl std::fmt::Debug for Tokenizer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Characters => f.write_str("Characters"),
Self::Tiktoken { encoding, .. } => write!(f, "Tiktoken({encoding})"),
Self::HuggingFace { model_id, .. } => write!(f, "HuggingFace({model_id})"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ModelFamily {
Gemma,
#[default]
Qwen3,
}
impl ModelFamily {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Gemma => "gemma",
Self::Qwen3 => "qwen3",
}
}
#[must_use]
pub fn default_query_instruction(self) -> &'static str {
match self {
Self::Gemma => DEFAULT_GEMMA_QUERY_TASK,
Self::Qwen3 => DEFAULT_QWEN3_RETRIEVAL_INSTRUCTION,
}
}
#[must_use]
pub fn format_embedding_input(
self,
input: &EmbeddingInput,
query_instruction: Option<&str>,
) -> String {
match (self, input.role) {
(Self::Gemma, EmbeddingRole::Query) => {
let instruction = normalize_optional_text(query_instruction)
.unwrap_or_else(|| self.default_query_instruction().to_string());
format!("task: {instruction} | query: {}", input.text)
}
(Self::Gemma, EmbeddingRole::Document) => {
let title = normalize_optional_text(input.title.as_deref())
.unwrap_or_else(|| "none".to_string());
format!("title: {title} | text: {}", input.text)
}
(Self::Qwen3, EmbeddingRole::Query) => {
let instruction = normalize_optional_text(query_instruction)
.unwrap_or_else(|| self.default_query_instruction().to_string());
format!("Instruct: {instruction}\nQuery: {}", input.text)
}
(Self::Qwen3, EmbeddingRole::Document) => {
match normalize_optional_text(input.title.as_deref()) {
Some(title) => format!("{title}\n{}", input.text),
None => input.text.clone(),
}
}
}
}
#[must_use]
pub fn format_reranker_input(
self,
query: &RerankQuery,
document: &RerankDocument,
instruction: Option<&str>,
) -> String {
match self {
Self::Qwen3 => {
let instruction = normalize_optional_text(instruction)
.unwrap_or_else(|| self.default_query_instruction().to_string());
format!(
"Instruct: {instruction}\nQuery: {}\nDocument: {}",
query.text, document.text
)
}
Self::Gemma => format!("Query: {}\nDocument: {}", query.text, document.text),
}
}
}
impl std::fmt::Display for ModelFamily {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingRole {
Query,
#[default]
Document,
}
impl std::fmt::Display for EmbeddingRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Query => f.write_str("query"),
Self::Document => f.write_str("document"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EmbeddingInput {
#[serde(default)]
pub role: EmbeddingRole,
pub text: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
pub token_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EmbedOutput {
pub embeddings: Vec<Vec<f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RerankQuery {
pub text: String,
pub token_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RerankDocument {
pub text: String,
pub token_count: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddDecision {
Continue,
Flush,
}
pub trait BatchingStrategy: Send {
fn add(&mut self, token_count: usize) -> AddDecision;
fn flush(&mut self);
fn max_items_per_batch(&self) -> usize;
fn max_tokens_per_batch(&self) -> usize;
}
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed(&self, input: &[EmbeddingInput]) -> Result<EmbedOutput>;
}
#[async_trait]
pub trait RerankingProvider: Send + Sync {
async fn rerank(&self, query: &RerankQuery, documents: &[RerankDocument]) -> Result<Vec<f64>>;
}
fn normalize_optional_text(value: Option<&str>) -> Option<String> {
let normalized = value?.trim();
if normalized.is_empty() {
None
} else {
Some(normalized.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prepared_embedding_input_rejects_empty_tokens() {
let err = PreparedEmbeddingInput::new(Vec::new(), String::new()).unwrap_err();
assert!(matches!(err, Error::EmptyPreparedEmbeddingInput));
}
#[test]
fn prepared_embedding_input_reports_token_count() {
let input = PreparedEmbeddingInput::new(vec![1, 2, 3], "test".to_string()).unwrap();
assert_eq!(input.token_count(), 3);
assert_eq!(input.token_ids, &[1, 2, 3]);
}
#[test]
fn gemma_query_formatting_uses_custom_task() {
let input = EmbeddingInput {
role: EmbeddingRole::Query,
text: "rust async runtime".to_string(),
title: None,
token_count: 3,
};
let formatted = ModelFamily::Gemma.format_embedding_input(&input, Some("custom task"));
assert_eq!(formatted, "task: custom task | query: rust async runtime");
}
#[test]
fn gemma_query_formatting_uses_default_task_for_missing_or_blank_instruction() {
let input = EmbeddingInput {
role: EmbeddingRole::Query,
text: "rust async runtime".to_string(),
title: None,
token_count: 3,
};
let expected = format!(
"task: {} | query: rust async runtime",
ModelFamily::Gemma.default_query_instruction()
);
assert_eq!(
ModelFamily::Gemma.format_embedding_input(&input, None),
expected
);
assert_eq!(
ModelFamily::Gemma.format_embedding_input(&input, Some(" ")),
expected
);
}
#[test]
fn gemma_document_formatting_uses_title_or_none() {
let with_title = EmbeddingInput {
role: EmbeddingRole::Document,
text: "Rust enables fearless concurrency".to_string(),
title: Some("Rust".to_string()),
token_count: 4,
};
let without_title = EmbeddingInput {
role: EmbeddingRole::Document,
text: "Rust enables fearless concurrency".to_string(),
title: None,
token_count: 4,
};
assert_eq!(
ModelFamily::Gemma.format_embedding_input(&with_title, Some("ignored")),
"title: Rust | text: Rust enables fearless concurrency"
);
assert_eq!(
ModelFamily::Gemma.format_embedding_input(&without_title, Some("ignored")),
"title: none | text: Rust enables fearless concurrency"
);
}
#[test]
fn qwen3_query_formatting_uses_default_and_override() {
let input = EmbeddingInput {
role: EmbeddingRole::Query,
text: "rust ownership".to_string(),
title: None,
token_count: 2,
};
assert_eq!(
ModelFamily::Qwen3.format_embedding_input(&input, None),
format!(
"Instruct: {}\nQuery: rust ownership",
ModelFamily::Qwen3.default_query_instruction()
)
);
assert_eq!(
ModelFamily::Qwen3.format_embedding_input(&input, Some("custom instruction")),
"Instruct: custom instruction\nQuery: rust ownership"
);
}
#[test]
fn qwen3_query_formatting_trims_custom_instruction() {
let input = EmbeddingInput {
role: EmbeddingRole::Query,
text: "rust ownership".to_string(),
title: None,
token_count: 2,
};
assert_eq!(
ModelFamily::Qwen3.format_embedding_input(&input, Some(" custom instruction ")),
"Instruct: custom instruction\nQuery: rust ownership"
);
}
#[test]
fn qwen3_document_formatting_ignores_query_instruction() {
let titled = EmbeddingInput {
role: EmbeddingRole::Document,
text: "Borrow checking catches aliasing bugs".to_string(),
title: Some("Borrow Checker".to_string()),
token_count: 4,
};
let untitled = EmbeddingInput {
role: EmbeddingRole::Document,
text: "Borrow checking catches aliasing bugs".to_string(),
title: None,
token_count: 4,
};
assert_eq!(
ModelFamily::Qwen3.format_embedding_input(&titled, Some("ignored")),
"Borrow Checker\nBorrow checking catches aliasing bugs"
);
assert_eq!(
ModelFamily::Qwen3.format_embedding_input(&untitled, Some("ignored")),
"Borrow checking catches aliasing bugs"
);
}
#[test]
fn qwen3_reranker_formatting_uses_default_and_override() {
let query = RerankQuery {
text: "memory safety".to_string(),
token_count: 2,
};
let document = RerankDocument {
text: "Rust prevents data races".to_string(),
token_count: 4,
};
assert_eq!(
ModelFamily::Qwen3.format_reranker_input(&query, &document, None),
format!(
"Instruct: {}\nQuery: memory safety\nDocument: Rust prevents data races",
ModelFamily::Qwen3.default_query_instruction()
)
);
assert_eq!(
ModelFamily::Qwen3.format_reranker_input(&query, &document, Some("rank docs")),
"Instruct: rank docs\nQuery: memory safety\nDocument: Rust prevents data races"
);
}
}