use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::Semaphore;
use tracing;
use crate::embedding::config::ParallelConfig;
use crate::embedding::error::{ApiError, EmbeddingError};
use crate::embedding::provider::{EmbeddingProvider, Vector};
#[derive(Serialize)]
struct OllamaRequest {
model: String,
input: Vec<String>,
}
#[derive(Deserialize)]
struct OllamaResponse {
embeddings: Vec<Vec<f32>>,
}
#[derive(Clone)]
pub struct OllamaProvider {
client: Client,
endpoint: String,
model: String,
dimension: usize,
parallel_config: ParallelConfig,
semaphore: Arc<Semaphore>,
}
impl OllamaProvider {
pub const DEFAULT_ENDPOINT: &'static str = "http://localhost:11434/api/embed";
pub const DEFAULT_MODEL: &'static str = "mxbai-embed-large";
const REQUEST_TIMEOUT_SECS: u64 = 60;
#[allow(clippy::manual_string_new)]
fn sanitize_for_nomic(text: &str) -> String {
text.replace('|', " ") .replace('[', "(") .replace(']', ")") .replace('→', "->") .replace('←', "<-")
.replace('↔', "<->")
.replace(['├', '└'], "+")
.replace('│', " ")
.replace('─', "-")
.replace(['┌', '┐', '┘', '┤', '┬', '┴', '┼'], "+")
}
pub fn new(endpoint: String, model: String, dimension: usize) -> Result<Self, EmbeddingError> {
Self::new_with_config(endpoint, model, dimension, ParallelConfig::default())
}
pub fn new_with_config(
endpoint: String,
model: String,
dimension: usize,
config: ParallelConfig,
) -> Result<Self, EmbeddingError> {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(Self::REQUEST_TIMEOUT_SECS))
.build()?;
let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
Ok(Self {
client,
endpoint,
model,
dimension,
parallel_config: config,
semaphore,
})
}
pub fn default_config() -> Result<Self, EmbeddingError> {
Self::new(
Self::DEFAULT_ENDPOINT.to_string(),
Self::DEFAULT_MODEL.to_string(),
1024, )
}
async fn embed_batch_parallel(
&self,
texts: Vec<String>,
) -> Result<Vec<Vector>, EmbeddingError> {
let total_texts = texts.len();
let sub_batch_size = self.parallel_config.sub_batch_size;
let sub_batches: Vec<Vec<String>> = texts
.chunks(sub_batch_size)
.map(|chunk| chunk.to_vec())
.collect();
let num_batches = sub_batches.len();
tracing::info!(
"Parallel batch embedding: {} texts in {} sub-batches (size: {}, concurrency: {})",
total_texts,
num_batches,
sub_batch_size,
self.parallel_config.max_concurrency
);
let start = std::time::Instant::now();
let handles: Vec<_> = sub_batches
.into_iter()
.enumerate()
.map(|(idx, batch)| {
let semaphore = self.semaphore.clone();
let this = self.clone();
let batch_size = batch.len();
tokio::spawn(async move {
let _permit = semaphore.acquire().await.unwrap();
let batch_start = std::time::Instant::now();
tracing::debug!("Starting sub-batch {} ({} texts)", idx, batch_size);
let result = this.embed_batch_raw(batch).await;
let elapsed = batch_start.elapsed();
tracing::debug!(
"Sub-batch {} completed in {:.2}s ({} texts)",
idx,
elapsed.as_secs_f64(),
batch_size
);
(idx, result)
})
})
.collect();
let mut results: Vec<(usize, Result<Vec<Vector>, EmbeddingError>)> = Vec::new();
for handle in handles {
let (idx, result) = handle.await.map_err(|e| {
EmbeddingError::Api(ApiError::InvalidResponse(format!("Task join error: {}", e)))
})?;
results.push((idx, result));
}
results.sort_by_key(|(idx, _)| *idx);
let mut embeddings = Vec::with_capacity(total_texts);
for (idx, result) in results {
let batch_embeddings = result.map_err(|e| {
EmbeddingError::Api(ApiError::InvalidResponse(format!(
"Sub-batch {} failed: {}",
idx, e
)))
})?;
embeddings.extend(batch_embeddings);
}
let elapsed = start.elapsed();
let throughput = total_texts as f64 / elapsed.as_secs_f64();
tracing::info!(
"Parallel batch complete: {} texts in {:.2}s ({:.0} texts/sec)",
total_texts,
elapsed.as_secs_f64(),
throughput
);
Ok(embeddings)
}
async fn embed_batch_raw(&self, texts: Vec<String>) -> Result<Vec<Vector>, EmbeddingError> {
if texts.is_empty() {
return Ok(Vec::new());
}
let batch_size = texts.len();
const MAX_CHARS: usize = 6000;
let processed_texts: Vec<String> = if self.model == "nomic-embed-text" {
texts
.into_iter()
.map(|t| {
let sanitized = Self::sanitize_for_nomic(&t);
if sanitized.len() > MAX_CHARS {
sanitized
.char_indices()
.take_while(|(i, _)| *i < MAX_CHARS)
.map(|(_, c)| c)
.collect()
} else {
sanitized
}
})
.collect()
} else {
texts
.into_iter()
.map(|t| {
if t.len() > MAX_CHARS {
t.char_indices()
.take_while(|(i, _)| *i < MAX_CHARS)
.map(|(_, c)| c)
.collect()
} else {
t
}
})
.collect()
};
if !processed_texts.is_empty() {
let first = &processed_texts[0];
let non_ascii: Vec<char> = first.chars().filter(|c| !c.is_ascii()).collect();
if !non_ascii.is_empty() {
tracing::debug!(
"Batch has {} non-ASCII chars after processing: {:?}",
non_ascii.len(),
non_ascii.iter().take(10).collect::<Vec<_>>()
);
}
tracing::debug!(
"First text preview ({} chars): {:?}",
first.len(),
first.chars().take(80).collect::<String>()
);
}
let request_body = OllamaRequest {
model: self.model.clone(),
input: processed_texts,
};
const MAX_RETRIES: u32 = 3;
const INITIAL_BACKOFF_MS: u64 = 500;
let mut last_error: Option<EmbeddingError> = None;
for attempt in 0..=MAX_RETRIES {
if attempt > 0 {
let backoff_ms = INITIAL_BACKOFF_MS * (1 << (attempt - 1)); tracing::warn!(
"Retry {}/{} for batch of {} texts after {}ms backoff",
attempt,
MAX_RETRIES,
batch_size,
backoff_ms
);
tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
}
let response = match self
.client
.post(&self.endpoint)
.json(&request_body)
.send()
.await
{
Ok(r) => r,
Err(e) => {
tracing::error!(
"Failed to send batch of {} texts (attempt {}): {}",
batch_size,
attempt + 1,
e
);
last_error = Some(EmbeddingError::Network(e));
continue;
}
};
let status = response.status();
if status.is_success() {
let body: OllamaResponse = match response.json().await {
Ok(b) => b,
Err(e) => {
return Err(EmbeddingError::Api(ApiError::InvalidResponse(format!(
"Failed to parse batch response for {} texts: {}",
batch_size, e
))));
}
};
if body.embeddings.len() != batch_size {
return Err(EmbeddingError::Api(ApiError::InvalidResponse(format!(
"Batch size mismatch: sent {} texts but got {} embeddings",
batch_size,
body.embeddings.len()
))));
}
let expected_dim = self.dimension();
for embedding in body.embeddings.iter() {
if embedding.len() != expected_dim {
use crate::embedding::error::DimensionMismatchError;
return Err(EmbeddingError::DimensionMismatch(
DimensionMismatchError::new(
expected_dim,
embedding.len(),
"Ollama".to_string(),
self.model.clone(),
self.dimension,
),
));
}
}
return Ok(body.embeddings);
}
let error_msg = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
match status.as_u16() {
500..=599 => {
tracing::warn!(
"Server error {} for batch of {} texts: {} (attempt {}/{})",
status.as_u16(),
batch_size,
error_msg,
attempt + 1,
MAX_RETRIES + 1
);
last_error = Some(EmbeddingError::Api(ApiError::ServerError {
status: status.as_u16(),
message: format!("Batch of {} texts failed: {}", batch_size, error_msg),
}));
continue; }
429 => {
return Err(EmbeddingError::Api(ApiError::RateLimit {
retry_after_ms: 1000,
}));
}
401 => {
return Err(EmbeddingError::Api(ApiError::Authentication(error_msg)));
}
400 => {
return Err(EmbeddingError::Api(ApiError::BadRequest(format!(
"Batch of {} texts rejected: {}",
batch_size, error_msg
))));
}
_ => {
return Err(EmbeddingError::Api(ApiError::InvalidResponse(format!(
"HTTP {} for batch of {} texts: {}",
status, batch_size, error_msg
))));
}
}
}
Err(last_error.unwrap_or_else(|| {
EmbeddingError::Api(ApiError::ServerError {
status: 500,
message: format!(
"Batch of {} texts failed after {} retries",
batch_size,
MAX_RETRIES + 1
),
})
}))
}
}
#[async_trait]
impl EmbeddingProvider for OllamaProvider {
async fn embed(&self, text: String) -> Result<Vector, EmbeddingError> {
let embeddings = self.embed_batch_raw(vec![text]).await?;
Ok(embeddings.into_iter().next().unwrap())
}
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vector>, EmbeddingError> {
if self.parallel_config.enabled && texts.len() > self.parallel_config.sub_batch_size {
self.embed_batch_parallel(texts).await
} else {
self.embed_batch_raw(texts).await
}
}
fn dimension(&self) -> usize {
self.dimension
}
fn provider_name(&self) -> &'static str {
"ollama"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ollama_provider_creation() {
let provider = OllamaProvider::new(
"http://localhost:11434/api/embed".to_string(),
"nomic-embed-text".to_string(),
768,
);
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.dimension(), 768);
assert_eq!(provider.provider_name(), "ollama");
}
#[test]
fn test_ollama_provider_default_config() {
let provider = OllamaProvider::default_config();
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.endpoint, OllamaProvider::DEFAULT_ENDPOINT);
assert_eq!(provider.model, OllamaProvider::DEFAULT_MODEL);
assert_eq!(provider.model, "mxbai-embed-large");
assert_eq!(provider.dimension(), 1024);
}
#[test]
fn test_ollama_provider_clone() {
let provider = OllamaProvider::default_config().unwrap();
let cloned = provider.clone();
assert_eq!(provider.dimension(), cloned.dimension());
assert_eq!(provider.provider_name(), cloned.provider_name());
}
#[test]
fn test_ollama_request_serialization_single() {
let request = OllamaRequest {
model: "nomic-embed-text".to_string(),
input: vec!["test text".to_string()],
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("nomic-embed-text"));
assert!(json.contains("test text"));
assert!(json.contains("\"input\":["));
}
#[test]
fn test_ollama_request_serialization_batch() {
let request = OllamaRequest {
model: "nomic-embed-text".to_string(),
input: vec!["text1".to_string(), "text2".to_string()],
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"model\":\"nomic-embed-text\""));
assert!(json.contains("\"input\":[\"text1\",\"text2\"]"));
}
#[test]
fn test_ollama_response_deserialization_single() {
let json = r#"{"embeddings":[[0.1,0.2,0.3]]}"#;
let response: OllamaResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.embeddings.len(), 1);
assert_eq!(response.embeddings[0].len(), 3);
assert_eq!(response.embeddings[0][0], 0.1);
}
#[test]
fn test_ollama_response_deserialization_batch() {
let json = r#"{"embeddings":[[0.1,0.2],[0.3,0.4]]}"#;
let response: OllamaResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.embeddings.len(), 2);
assert_eq!(response.embeddings[0].len(), 2);
assert_eq!(response.embeddings[1].len(), 2);
assert_eq!(response.embeddings[0][0], 0.1);
assert_eq!(response.embeddings[0][1], 0.2);
assert_eq!(response.embeddings[1][0], 0.3);
assert_eq!(response.embeddings[1][1], 0.4);
}
#[tokio::test]
async fn test_embed_batch_empty_input() {
let provider = OllamaProvider::default_config().unwrap();
let result = provider.embed_batch(vec![]).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 0);
}
#[tokio::test]
async fn test_embed_batch_raw_empty_returns_empty() {
let provider = OllamaProvider::default_config().unwrap();
let result = provider.embed_batch_raw(vec![]).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 0);
}
#[tokio::test]
#[ignore] async fn test_ollama_batch_api_integration() {
let provider = OllamaProvider::default_config().unwrap();
let texts = vec!["hello".to_string(), "world".to_string()];
let embeddings = provider.embed_batch(texts).await.unwrap();
assert_eq!(embeddings.len(), 2);
assert_eq!(embeddings[0].len(), 1024);
assert_eq!(embeddings[1].len(), 1024);
}
#[tokio::test]
#[ignore] async fn test_ollama_single_embed_uses_batch_api() {
let provider = OllamaProvider::default_config().unwrap();
let embedding = provider.embed("test".to_string()).await.unwrap();
assert_eq!(embedding.len(), 1024);
}
#[test]
fn test_sub_batch_splitting() {
let texts: Vec<String> = (0..105).map(|i| i.to_string()).collect();
let batches: Vec<Vec<String>> = texts.chunks(50).map(|c| c.to_vec()).collect();
assert_eq!(batches.len(), 3);
assert_eq!(batches[0].len(), 50);
assert_eq!(batches[1].len(), 50);
assert_eq!(batches[2].len(), 5);
}
#[test]
fn test_result_merge_ordering() {
let mut results = vec![
(2, vec!["c1".to_string(), "c2".to_string()]),
(0, vec!["a1".to_string(), "a2".to_string()]),
(1, vec!["b1".to_string(), "b2".to_string()]),
];
results.sort_by_key(|(idx, _)| *idx);
let merged: Vec<String> = results.into_iter().flat_map(|(_, v)| v).collect();
assert_eq!(merged, vec!["a1", "a2", "b1", "b2", "c1", "c2"]);
}
#[test]
fn test_parallel_config_construction() {
let config = ParallelConfig {
enabled: true,
sub_batch_size: 50,
max_concurrency: 8,
};
let provider = OllamaProvider::new_with_config(
"http://localhost:11434/api/embed".to_string(),
"nomic-embed-text".to_string(),
768,
config.clone(),
)
.unwrap();
assert_eq!(provider.parallel_config.enabled, true);
assert_eq!(provider.parallel_config.sub_batch_size, 50);
assert_eq!(provider.parallel_config.max_concurrency, 8);
}
#[test]
fn test_parallel_config_defaults() {
let provider = OllamaProvider::default_config().unwrap();
assert_eq!(provider.parallel_config.enabled, true);
assert_eq!(provider.parallel_config.sub_batch_size, 50);
assert_eq!(provider.parallel_config.max_concurrency, 8);
}
#[tokio::test]
async fn test_small_batch_uses_raw_not_parallel() {
let config = ParallelConfig {
enabled: true,
sub_batch_size: 50,
max_concurrency: 8,
};
let provider = OllamaProvider::new_with_config(
"http://localhost:11434/api/embed".to_string(),
"nomic-embed-text".to_string(),
768,
config,
)
.unwrap();
let texts: Vec<String> = (0..10).map(|i| format!("text_{}", i)).collect();
assert!(texts.len() <= provider.parallel_config.sub_batch_size);
}
#[tokio::test]
async fn test_large_batch_triggers_parallel() {
let config = ParallelConfig {
enabled: true,
sub_batch_size: 50,
max_concurrency: 8,
};
let provider = OllamaProvider::new_with_config(
"http://localhost:11434/api/embed".to_string(),
"nomic-embed-text".to_string(),
768,
config,
)
.unwrap();
let texts: Vec<String> = (0..100).map(|i| format!("text_{}", i)).collect();
assert!(texts.len() > provider.parallel_config.sub_batch_size);
}
#[tokio::test]
#[ignore] async fn test_parallel_preserves_order() {
let config = ParallelConfig {
enabled: true,
sub_batch_size: 10,
max_concurrency: 4,
};
let provider = OllamaProvider::new_with_config(
OllamaProvider::DEFAULT_ENDPOINT.to_string(),
OllamaProvider::DEFAULT_MODEL.to_string(),
1024,
config,
)
.unwrap();
let texts: Vec<String> = (0..50).map(|i| format!("text_{}", i)).collect();
let embeddings = provider.embed_batch(texts.clone()).await.unwrap();
assert_eq!(embeddings.len(), 50);
for embedding in &embeddings {
assert_eq!(embedding.len(), 1024);
}
}
#[test]
fn test_parallel_disabled_config() {
let config = ParallelConfig {
enabled: false,
sub_batch_size: 50,
max_concurrency: 8,
};
let provider = OllamaProvider::new_with_config(
"http://localhost:11434/api/embed".to_string(),
"nomic-embed-text".to_string(),
768,
config,
)
.unwrap();
assert_eq!(provider.parallel_config.enabled, false);
let texts: Vec<String> = (0..100).map(|i| format!("text_{}", i)).collect();
assert!(texts.len() > provider.parallel_config.sub_batch_size);
assert!(!provider.parallel_config.enabled);
}
#[test]
fn test_ollama_accepts_dimension_1024() {
let provider = OllamaProvider::new(
"http://localhost:11434/api/embed".to_string(),
"mxbai-embed-large".to_string(),
1024,
);
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.dimension(), 1024);
assert_eq!(provider.provider_name(), "ollama");
}
#[test]
fn test_dimension_returns_configured_value() {
let provider_768 = OllamaProvider::new(
"http://localhost:11434/api/embed".to_string(),
"nomic-embed-text".to_string(),
768,
)
.unwrap();
assert_eq!(provider_768.dimension(), 768);
let provider_1024 = OllamaProvider::new(
"http://localhost:11434/api/embed".to_string(),
"mxbai-embed-large".to_string(),
1024,
)
.unwrap();
assert_eq!(provider_1024.dimension(), 1024);
let provider_512 = OllamaProvider::new(
"http://localhost:11434/api/embed".to_string(),
"custom-model".to_string(),
512,
)
.unwrap();
assert_eq!(provider_512.dimension(), 512);
}
#[test]
fn test_backward_compatibility_dimension_768() {
let provider = OllamaProvider::new(
"http://localhost:11434/api/embed".to_string(),
"nomic-embed-text".to_string(),
768,
);
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.dimension(), 768);
}
#[test]
fn test_new_with_config_accepts_dimension() {
let config = ParallelConfig::default();
let provider = OllamaProvider::new_with_config(
"http://localhost:11434/api/embed".to_string(),
"mxbai-embed-large".to_string(),
1024,
config,
);
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.dimension(), 1024);
}
#[test]
fn test_sanitize_for_nomic_replaces_pipes() {
let input = "function | table | data";
let output = OllamaProvider::sanitize_for_nomic(input);
assert_eq!(output, "function table data");
assert!(!output.contains('|'));
}
#[test]
fn test_sanitize_for_nomic_replaces_brackets() {
let input = "[x] checkbox [link](url)";
let output = OllamaProvider::sanitize_for_nomic(input);
assert_eq!(output, "(x) checkbox (link)(url)");
assert!(!output.contains('['));
assert!(!output.contains(']'));
}
#[test]
fn test_sanitize_for_nomic_replaces_unicode_arrows() {
let input = "a → b ← c ↔ d";
let output = OllamaProvider::sanitize_for_nomic(input);
assert_eq!(output, "a -> b <- c <-> d");
assert!(!output.contains('→'));
assert!(!output.contains('←'));
assert!(!output.contains('↔'));
}
#[test]
fn test_sanitize_for_nomic_replaces_box_drawing() {
let input = "├── file\n└── dir\n│ ├── nested";
let output = OllamaProvider::sanitize_for_nomic(input);
assert!(!output.contains('├'));
assert!(!output.contains('└'));
assert!(!output.contains('│'));
assert!(!output.contains('─'));
assert!(output.contains('+'));
assert!(output.contains('-'));
}
#[test]
fn test_sanitize_for_nomic_all_problematic_chars() {
let input = "| [ ] → ← ↔ ├ └ │ ─ ┌ ┐ ┘ ┤ ┬ ┴ ┼";
let output = OllamaProvider::sanitize_for_nomic(input);
let problematic_chars = [
'|', '[', ']', '→', '←', '↔', '├', '└', '│', '─', '┌', '┐', '┘', '┤', '┬', '┴', '┼',
];
for ch in &problematic_chars {
assert!(!output.contains(*ch), "Output still contains: {}", ch);
}
}
#[test]
fn test_sanitize_for_nomic_preserves_normal_text() {
let input = "function calculateTotal(a, b) { return a + b; }";
let output = OllamaProvider::sanitize_for_nomic(input);
assert_eq!(output, input);
}
#[tokio::test]
async fn test_conditional_sanitization_nomic_embed_text() {
let provider = OllamaProvider::new(
"http://localhost:11434/api/embed".to_string(),
"nomic-embed-text".to_string(),
768,
)
.unwrap();
assert_eq!(provider.model, "nomic-embed-text");
let test_text = "| table | [link] → symbol";
let sanitized = OllamaProvider::sanitize_for_nomic(test_text);
assert!(!sanitized.contains('|'));
assert!(!sanitized.contains('['));
assert!(!sanitized.contains('→'));
}
#[tokio::test]
async fn test_conditional_sanitization_mxbai_embed_large() {
let provider = OllamaProvider::new(
"http://localhost:11434/api/embed".to_string(),
"mxbai-embed-large".to_string(),
1024,
)
.unwrap();
assert_eq!(provider.model, "mxbai-embed-large");
}
#[test]
fn test_sanitize_for_nomic_idempotent() {
let input = "| [x] → ├ test";
let once = OllamaProvider::sanitize_for_nomic(input);
let twice = OllamaProvider::sanitize_for_nomic(&once);
assert_eq!(once, twice);
}
#[test]
fn test_sanitize_for_nomic_empty_string() {
let input = "";
let output = OllamaProvider::sanitize_for_nomic(input);
assert_eq!(output, "");
}
#[test]
fn test_sanitize_for_nomic_unicode_preserved() {
let input = "Hello 世界 مرحبا שלום";
let output = OllamaProvider::sanitize_for_nomic(input);
assert_eq!(output, input);
}
#[test]
fn test_model_comparison_exact_match() {
let nomic_provider = OllamaProvider::new(
"http://localhost:11434/api/embed".to_string(),
"nomic-embed-text".to_string(),
768,
)
.unwrap();
assert_eq!(nomic_provider.model, "nomic-embed-text");
let mxbai_provider = OllamaProvider::new(
"http://localhost:11434/api/embed".to_string(),
"mxbai-embed-large".to_string(),
1024,
)
.unwrap();
assert_eq!(mxbai_provider.model, "mxbai-embed-large");
assert_ne!(nomic_provider.model, mxbai_provider.model);
}
}