use std::borrow::Cow;
use crate::retrieval_context::EmbeddingDocumentInput;
use crate::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum EmbeddingInputKind {
Query,
Document,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum EmbeddingInputFormat {
Raw,
EmbeddingGemmaRetrievalV1,
}
impl EmbeddingInputFormat {
pub(crate) fn for_llama_cpp_embedding_model(model: &str) -> Self {
if is_embeddinggemma_model(model) {
Self::EmbeddingGemmaRetrievalV1
} else {
Self::Raw
}
}
pub(crate) fn identity(self) -> &'static str {
match self {
Self::Raw => "input=raw",
Self::EmbeddingGemmaRetrievalV1 => "input=embeddinggemma-retrieval-v1",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct EmbeddingInputFormatter {
format: EmbeddingInputFormat,
}
impl EmbeddingInputFormatter {
pub(crate) fn new(format: EmbeddingInputFormat) -> Self {
Self { format }
}
pub(crate) fn raw() -> Self {
Self::new(EmbeddingInputFormat::Raw)
}
pub(crate) fn format_identity(self) -> &'static str {
self.format.identity()
}
pub(crate) fn format_query<'a>(self, query: &'a str) -> Cow<'a, str> {
match self.format {
EmbeddingInputFormat::Raw => Cow::Borrowed(query),
EmbeddingInputFormat::EmbeddingGemmaRetrievalV1 => {
Cow::Owned(format!("task: search result | query: {query}"))
}
}
}
pub(crate) fn format_document(self, input: &EmbeddingDocumentInput<'_>) -> String {
match self.format {
EmbeddingInputFormat::Raw => format_raw_document(input),
EmbeddingInputFormat::EmbeddingGemmaRetrievalV1 => {
let title = normalized_title(input.title).unwrap_or("none");
format!("title: {title} | text: {}", input.text)
}
}
}
}
pub(crate) const DEFAULT_DOCUMENT_BATCH_WINDOW: usize = 64;
pub(crate) trait Embedder: Send + Sync {
fn embed_batch(&self, kind: EmbeddingInputKind, texts: &[String]) -> Result<Vec<Vec<f32>>>;
fn preferred_document_batch_window(&self) -> usize {
DEFAULT_DOCUMENT_BATCH_WINDOW
}
}
pub(crate) trait EmbeddingDocumentSizer: Send + Sync {
fn count_document_tokens(&self, text: &str) -> Result<usize>;
fn fits_within_token_limit_by_byte_len(&self, _byte_len: usize, _max_tokens: usize) -> bool {
false
}
}
fn format_raw_document(input: &EmbeddingDocumentInput<'_>) -> String {
let Some(title) = normalized_title(input.title) else {
return input.text.clone();
};
if input.text.trim().is_empty() {
format!("title: {title}")
} else {
format!("title: {title}\n\n{}", input.text)
}
}
fn normalized_title(title: Option<&str>) -> Option<&str> {
title.map(str::trim).filter(|title| !title.is_empty())
}
fn is_embeddinggemma_model(model: &str) -> bool {
model
.chars()
.filter(|ch| ch.is_ascii_alphanumeric())
.collect::<String>()
.to_ascii_lowercase()
.contains("embeddinggemma")
}
#[cfg(test)]
mod tests {
use super::{EmbeddingInputFormat, EmbeddingInputFormatter};
use crate::retrieval_context::EmbeddingDocumentInput;
#[test]
fn raw_formatter_keeps_document_title_and_text() {
let formatter = EmbeddingInputFormatter::raw();
let input = EmbeddingDocumentInput {
title: Some("Guide"),
text: "heading: Setup\n\nbody text".to_string(),
};
assert_eq!(
formatter.format_document(&input),
"title: Guide\n\nheading: Setup\n\nbody text"
);
}
#[test]
fn embeddinggemma_formatter_uses_title_slot() {
let formatter =
EmbeddingInputFormatter::new(EmbeddingInputFormat::EmbeddingGemmaRetrievalV1);
let input = EmbeddingDocumentInput {
title: Some("Guide"),
text: "heading: Setup\n\nbody text".to_string(),
};
assert_eq!(
formatter.format_document(&input),
"title: Guide | text: heading: Setup\n\nbody text"
);
}
#[test]
fn embeddinggemma_formatter_uses_none_for_missing_title() {
let formatter =
EmbeddingInputFormatter::new(EmbeddingInputFormat::EmbeddingGemmaRetrievalV1);
let input = EmbeddingDocumentInput {
title: None,
text: "body text".to_string(),
};
assert_eq!(
formatter.format_document(&input),
"title: none | text: body text"
);
}
#[test]
fn embeddinggemma_formats_query_with_task_prompt() {
let formatter =
EmbeddingInputFormatter::new(EmbeddingInputFormat::EmbeddingGemmaRetrievalV1);
assert_eq!(
formatter.format_query("tax question"),
"task: search result | query: tax question"
);
}
}