use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::{Error, Result};
const DEFAULT_GEMMA_QUERY_TASK: &str = "search documents";
const DEFAULT_QWEN3_RETRIEVAL_INSTRUCTION: &str =
"Given a web search query, retrieve relevant passages that answer the query";
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(try_from = "Vec<u32>", into = "Vec<u32>")]
pub struct PreparedEmbeddingInput {
token_ids: Vec<u32>,
}
impl PreparedEmbeddingInput {
pub fn new(token_ids: Vec<u32>) -> Result<Self> {
if token_ids.is_empty() {
return Err(Error::EmptyPreparedEmbeddingInput);
}
Ok(Self { token_ids })
}
#[must_use]
pub fn token_ids(&self) -> &[u32] {
&self.token_ids
}
#[must_use]
pub fn token_count(&self) -> usize {
self.token_ids.len()
}
#[must_use]
pub fn into_token_ids(self) -> Vec<u32> {
self.token_ids
}
}
impl AsRef<[u32]> for PreparedEmbeddingInput {
fn as_ref(&self) -> &[u32] {
self.token_ids()
}
}
impl TryFrom<Vec<u32>> for PreparedEmbeddingInput {
type Error = Error;
fn try_from(token_ids: Vec<u32>) -> Result<Self> {
Self::new(token_ids)
}
}
impl From<PreparedEmbeddingInput> for Vec<u32> {
fn from(input: PreparedEmbeddingInput) -> Self {
input.into_token_ids()
}
}
pub struct BatchItem<M> {
pub meta: M,
pub input: PreparedEmbeddingInput,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Dialect {
#[default]
OpenAI,
DeepInfra,
#[serde(
rename = "llamacpp",
alias = "llama-cpp",
alias = "llama_cpp",
alias = "llama.cpp"
)]
LlamaCpp,
}
impl Dialect {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::OpenAI => "openai",
Self::DeepInfra => "deepinfra",
Self::LlamaCpp => "llamacpp",
}
}
}
impl std::fmt::Display for Dialect {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
pub type ProviderDialect = Dialect;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ModelFamily {
Gemma,
#[default]
Qwen3,
}
impl ModelFamily {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Gemma => "gemma",
Self::Qwen3 => "qwen3",
}
}
#[must_use]
pub fn default_query_instruction(self) -> &'static str {
match self {
Self::Gemma => DEFAULT_GEMMA_QUERY_TASK,
Self::Qwen3 => DEFAULT_QWEN3_RETRIEVAL_INSTRUCTION,
}
}
#[must_use]
pub fn format_embedding_input(
self,
input: &EmbeddingInput,
query_instruction: Option<&str>,
) -> String {
match (self, input.role) {
(Self::Gemma, EmbeddingRole::Query) => {
let instruction = normalize_optional_text(query_instruction)
.unwrap_or_else(|| self.default_query_instruction().to_string());
format!("task: {instruction} | query: {}", input.text)
}
(Self::Gemma, EmbeddingRole::Document) => {
let title = normalize_optional_text(input.title.as_deref())
.unwrap_or_else(|| "none".to_string());
format!("title: {title} | text: {}", input.text)
}
(Self::Qwen3, EmbeddingRole::Query) => {
let instruction = normalize_optional_text(query_instruction)
.unwrap_or_else(|| self.default_query_instruction().to_string());
format!("Instruct: {instruction}\nQuery: {}", input.text)
}
(Self::Qwen3, EmbeddingRole::Document) => {
match normalize_optional_text(input.title.as_deref()) {
Some(title) => format!("{title}\n{}", input.text),
None => input.text.clone(),
}
}
}
}
#[must_use]
pub fn format_reranker_input(
self,
query: &RerankQuery,
document: &RerankDocument,
instruction: Option<&str>,
) -> String {
match self {
Self::Qwen3 => {
let instruction = normalize_optional_text(instruction)
.unwrap_or_else(|| self.default_query_instruction().to_string());
format!(
"Instruct: {instruction}\nQuery: {}\nDocument: {}",
query.text, document.text
)
}
Self::Gemma => format!("Query: {}\nDocument: {}", query.text, document.text),
}
}
}
impl std::fmt::Display for ModelFamily {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingRole {
Query,
#[default]
Document,
}
impl std::fmt::Display for EmbeddingRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Query => f.write_str("query"),
Self::Document => f.write_str("document"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EmbeddingInput {
#[serde(default)]
pub role: EmbeddingRole,
pub text: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EmbedOutput {
pub embeddings: Vec<Vec<f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RerankQuery {
pub text: String,
pub token_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RerankDocument {
pub text: String,
pub token_count: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddDecision {
Continue,
Flush,
}
pub trait BatchingStrategy: Send {
fn add(&mut self, token_count: usize) -> AddDecision;
fn flush(&mut self);
fn max_items_per_batch(&self) -> usize;
fn max_tokens_per_batch(&self) -> usize;
}
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed(&self, input: &[PreparedEmbeddingInput]) -> Result<EmbedOutput>;
}
#[async_trait]
pub trait RerankingProvider: Send + Sync {
async fn rerank(&self, query: &RerankQuery, documents: &[RerankDocument]) -> Result<Vec<f64>>;
}
fn normalize_optional_text(value: Option<&str>) -> Option<String> {
let normalized = value?.trim();
if normalized.is_empty() {
None
} else {
Some(normalized.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prepared_embedding_input_rejects_empty_tokens() {
let err = PreparedEmbeddingInput::new(Vec::new()).unwrap_err();
assert!(matches!(err, Error::EmptyPreparedEmbeddingInput));
}
#[test]
fn prepared_embedding_input_reports_token_count() {
let input = PreparedEmbeddingInput::new(vec![1, 2, 3]).unwrap();
assert_eq!(input.token_count(), 3);
assert_eq!(input.token_ids(), &[1, 2, 3]);
}
#[test]
fn gemma_query_formatting_uses_custom_task() {
let input = EmbeddingInput {
role: EmbeddingRole::Query,
text: "rust async runtime".to_string(),
title: None,
};
let formatted = ModelFamily::Gemma.format_embedding_input(&input, Some("custom task"));
assert_eq!(formatted, "task: custom task | query: rust async runtime");
}
#[test]
fn gemma_query_formatting_uses_default_task_for_missing_or_blank_instruction() {
let input = EmbeddingInput {
role: EmbeddingRole::Query,
text: "rust async runtime".to_string(),
title: None,
};
let expected = format!(
"task: {} | query: rust async runtime",
ModelFamily::Gemma.default_query_instruction()
);
assert_eq!(
ModelFamily::Gemma.format_embedding_input(&input, None),
expected
);
assert_eq!(
ModelFamily::Gemma.format_embedding_input(&input, Some(" ")),
expected
);
}
#[test]
fn gemma_document_formatting_uses_title_or_none() {
let with_title = EmbeddingInput {
role: EmbeddingRole::Document,
text: "Rust enables fearless concurrency".to_string(),
title: Some("Rust".to_string()),
};
let without_title = EmbeddingInput {
role: EmbeddingRole::Document,
text: "Rust enables fearless concurrency".to_string(),
title: None,
};
assert_eq!(
ModelFamily::Gemma.format_embedding_input(&with_title, Some("ignored")),
"title: Rust | text: Rust enables fearless concurrency"
);
assert_eq!(
ModelFamily::Gemma.format_embedding_input(&without_title, Some("ignored")),
"title: none | text: Rust enables fearless concurrency"
);
}
#[test]
fn qwen3_query_formatting_uses_default_and_override() {
let input = EmbeddingInput {
role: EmbeddingRole::Query,
text: "rust ownership".to_string(),
title: None,
};
assert_eq!(
ModelFamily::Qwen3.format_embedding_input(&input, None),
format!(
"Instruct: {}\nQuery: rust ownership",
ModelFamily::Qwen3.default_query_instruction()
)
);
assert_eq!(
ModelFamily::Qwen3.format_embedding_input(&input, Some("custom instruction")),
"Instruct: custom instruction\nQuery: rust ownership"
);
}
#[test]
fn qwen3_query_formatting_trims_custom_instruction() {
let input = EmbeddingInput {
role: EmbeddingRole::Query,
text: "rust ownership".to_string(),
title: None,
};
assert_eq!(
ModelFamily::Qwen3.format_embedding_input(&input, Some(" custom instruction ")),
"Instruct: custom instruction\nQuery: rust ownership"
);
}
#[test]
fn qwen3_document_formatting_ignores_query_instruction() {
let titled = EmbeddingInput {
role: EmbeddingRole::Document,
text: "Borrow checking catches aliasing bugs".to_string(),
title: Some("Borrow Checker".to_string()),
};
let untitled = EmbeddingInput {
role: EmbeddingRole::Document,
text: "Borrow checking catches aliasing bugs".to_string(),
title: None,
};
assert_eq!(
ModelFamily::Qwen3.format_embedding_input(&titled, Some("ignored")),
"Borrow Checker\nBorrow checking catches aliasing bugs"
);
assert_eq!(
ModelFamily::Qwen3.format_embedding_input(&untitled, Some("ignored")),
"Borrow checking catches aliasing bugs"
);
}
#[test]
fn qwen3_reranker_formatting_uses_default_and_override() {
let query = RerankQuery {
text: "memory safety".to_string(),
token_count: 2,
};
let document = RerankDocument {
text: "Rust prevents data races".to_string(),
token_count: 4,
};
assert_eq!(
ModelFamily::Qwen3.format_reranker_input(&query, &document, None),
format!(
"Instruct: {}\nQuery: memory safety\nDocument: Rust prevents data races",
ModelFamily::Qwen3.default_query_instruction()
)
);
assert_eq!(
ModelFamily::Qwen3.format_reranker_input(&query, &document, Some("rank docs")),
"Instruct: rank docs\nQuery: memory safety\nDocument: Rust prevents data races"
);
}
}