use anyhow::{Result, anyhow};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
pub const DEFAULT_REQUIRED_DIMENSION: usize = 2560;
pub const DEFAULT_OLLAMA_EMBEDDING_MODEL: &str = "qwen3-embedding:4b";
const DEFAULT_MAX_BATCH_RETRIES: usize = 10;
const DEFAULT_MAX_BATCH_BACKOFF_SECS: u64 = 30;
#[derive(Debug, Serialize)]
struct EmbeddingRequest {
input: Vec<String>,
model: String,
}
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
}
#[derive(Debug, Serialize)]
struct RerankRequest {
query: String,
documents: Vec<String>,
model: String,
}
#[derive(Debug, Deserialize)]
struct RerankResponse {
results: Vec<RerankResult>,
}
#[derive(Debug, Deserialize)]
struct RerankResult {
index: usize,
score: f32,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct ProviderConfig {
#[serde(default)]
pub name: String,
#[serde(default)]
pub base_url: String,
#[serde(default)]
pub model: String,
#[serde(default = "default_priority")]
pub priority: u8,
#[serde(default = "default_embeddings_endpoint")]
pub endpoint: String,
}
fn default_priority() -> u8 {
10
}
fn default_embeddings_endpoint() -> String {
"/v1/embeddings".to_string()
}
fn env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|value| value.parse::<usize>().ok())
.filter(|value| *value > 0)
.unwrap_or(default)
}
fn env_u64(name: &str, default: u64) -> u64 {
std::env::var(name)
.ok()
.and_then(|value| value.parse::<u64>().ok())
.filter(|value| *value > 0)
.unwrap_or(default)
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct RerankerConfig {
pub base_url: Option<String>,
pub model: Option<String>,
#[serde(default = "default_rerank_endpoint")]
pub endpoint: String,
}
fn default_rerank_endpoint() -> String {
"/v1/rerank".to_string()
}
fn default_dimension() -> usize {
DEFAULT_REQUIRED_DIMENSION
}
fn default_max_batch_chars() -> usize {
128000 }
fn default_max_batch_items() -> usize {
64 }
fn build_provider_endpoint(base_url: &str, endpoint: &str) -> String {
let base_url = base_url.trim_end_matches('/');
let endpoint = endpoint.trim();
if endpoint.starts_with('/') {
format!("{}{}", base_url, endpoint)
} else {
format!("{}/{}", base_url, endpoint)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct EmbeddingConfig {
#[serde(default = "default_dimension")]
pub required_dimension: usize,
#[serde(default = "default_max_batch_chars")]
pub max_batch_chars: usize,
#[serde(default = "default_max_batch_items")]
pub max_batch_items: usize,
#[serde(default)]
pub providers: Vec<ProviderConfig>,
#[serde(default)]
pub reranker: RerankerConfig,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
required_dimension: default_dimension(),
max_batch_chars: default_max_batch_chars(),
max_batch_items: default_max_batch_items(),
providers: vec![
ProviderConfig {
name: "ollama-local".to_string(),
base_url: "http://localhost:11434".to_string(),
model: DEFAULT_OLLAMA_EMBEDDING_MODEL.to_string(),
priority: 1,
endpoint: default_embeddings_endpoint(),
},
ProviderConfig {
name: "dragon".to_string(),
base_url: "http://dragon:12345".to_string(),
model: "Qwen/Qwen3-Embedding-4B".to_string(),
priority: 2,
endpoint: default_embeddings_endpoint(),
},
],
reranker: RerankerConfig::default(),
}
}
}
impl EmbeddingConfig {
pub fn provider_name(&self) -> String {
self.providers
.first()
.map(|p| p.name.clone())
.unwrap_or_else(|| "none".to_string())
}
pub fn model_name(&self) -> String {
self.providers
.first()
.map(|p| p.model.clone())
.unwrap_or_else(|| "none".to_string())
}
pub fn dimension(&self) -> usize {
self.required_dimension
}
}
#[derive(Debug, Clone)]
pub struct MlxConfig {
pub disabled: bool,
pub local_port: u16,
pub dragon_url: String,
pub dragon_port: u16,
pub embedder_model: String,
pub reranker_model: String,
pub reranker_port_offset: u16,
pub max_batch_chars: usize,
pub max_batch_items: usize,
}
#[derive(Debug, Clone, Default)]
pub struct MlxMergeOptions {
pub disabled: Option<bool>,
pub local_port: Option<u16>,
pub dragon_url: Option<String>,
pub dragon_port: Option<u16>,
pub embedder_model: Option<String>,
pub reranker_model: Option<String>,
pub reranker_port_offset: Option<u16>,
}
impl Default for MlxConfig {
fn default() -> Self {
Self {
disabled: false,
local_port: 12345,
dragon_url: "http://dragon".to_string(),
dragon_port: 12345,
embedder_model: "Qwen/Qwen3-Embedding-4B".to_string(),
reranker_model: "Qwen/Qwen3-Reranker-4B".to_string(),
reranker_port_offset: 1,
max_batch_chars: default_max_batch_chars(),
max_batch_items: default_max_batch_items(),
}
}
}
impl MlxConfig {
pub fn from_env() -> Self {
let disabled = std::env::var("DISABLE_MLX")
.map(|v| v == "1" || v.to_lowercase() == "true")
.unwrap_or(false);
let local_port = std::env::var("EMBEDDER_PORT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(12345);
let dragon_url =
std::env::var("DRAGON_BASE_URL").unwrap_or_else(|_| "http://dragon".to_string());
let dragon_port = std::env::var("DRAGON_EMBEDDER_PORT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(local_port);
let reranker_port_offset = std::env::var("RERANKER_PORT")
.ok()
.and_then(|s| s.parse::<u16>().ok())
.map(|rp| rp.saturating_sub(local_port))
.unwrap_or(1);
let embedder_model = std::env::var("EMBEDDER_MODEL")
.unwrap_or_else(|_| "Qwen/Qwen3-Embedding-4B".to_string());
let reranker_model = std::env::var("RERANKER_MODEL")
.unwrap_or_else(|_| "Qwen/Qwen3-Reranker-4B".to_string());
let max_batch_chars = std::env::var("MLX_MAX_BATCH_CHARS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(32000);
let max_batch_items = std::env::var("MLX_MAX_BATCH_ITEMS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(16);
Self {
disabled,
local_port,
dragon_url,
dragon_port,
embedder_model,
reranker_model,
reranker_port_offset,
max_batch_chars,
max_batch_items,
}
}
pub fn merge_file_config(&mut self, opts: MlxMergeOptions) {
if let Some(v) = opts.disabled {
self.disabled = v;
}
if let Some(v) = opts.local_port {
self.local_port = v;
}
if let Some(v) = opts.dragon_url {
self.dragon_url = v;
}
if let Some(v) = opts.dragon_port {
self.dragon_port = v;
}
if let Some(v) = opts.embedder_model {
self.embedder_model = v;
}
if let Some(v) = opts.reranker_model {
self.reranker_model = v;
}
if let Some(v) = opts.reranker_port_offset {
self.reranker_port_offset = v;
}
}
pub fn to_embedding_config(&self) -> EmbeddingConfig {
let reranker_port = self.local_port + self.reranker_port_offset;
let required_dimension = DEFAULT_REQUIRED_DIMENSION;
EmbeddingConfig {
required_dimension,
max_batch_chars: self.max_batch_chars,
max_batch_items: self.max_batch_items,
providers: vec![
ProviderConfig {
name: "local".to_string(),
base_url: format!("http://localhost:{}", self.local_port),
model: self.embedder_model.clone(),
priority: 1,
endpoint: default_embeddings_endpoint(),
},
ProviderConfig {
name: "dragon".to_string(),
base_url: format!("{}:{}", self.dragon_url, self.dragon_port),
model: self.embedder_model.clone(),
priority: 2,
endpoint: default_embeddings_endpoint(),
},
],
reranker: RerankerConfig {
base_url: Some(format!("{}:{}", self.dragon_url, reranker_port)),
model: Some(self.reranker_model.clone()),
endpoint: default_rerank_endpoint(),
},
}
}
pub fn with_batch_limits(mut self, max_chars: usize, max_items: usize) -> Self {
self.max_batch_chars = max_chars;
self.max_batch_items = max_items;
self
}
}
#[derive(Clone)]
pub struct EmbeddingClient {
client: Client,
embedder_url: String,
embedder_model: String,
reranker_url: Option<String>,
reranker_model: Option<String>,
connected_to: String,
required_dimension: usize,
max_batch_chars: usize,
max_batch_items: usize,
}
pub type MLXBridge = EmbeddingClient;
impl EmbeddingClient {
pub async fn new(config: &EmbeddingConfig) -> Result<Self> {
if config.providers.is_empty() {
return Err(anyhow!(
"No embedding providers configured! Add providers to [embeddings.providers]"
));
}
let client = Client::builder()
.timeout(Duration::from_secs(300))
.connect_timeout(Duration::from_secs(10))
.build()?;
let mut providers = config.providers.clone();
providers.sort_by_key(|p| p.priority);
let mut tried = Vec::new();
for provider in &providers {
let base_url = provider.base_url.trim_end_matches('/');
let provider_name = if provider.name.trim().is_empty() {
"<unnamed-provider>"
} else {
provider.name.as_str()
};
let model = provider.model.trim();
let embedder_url = build_provider_endpoint(base_url, &provider.endpoint);
match probe_provider_dimension(&client, provider).await {
Ok(actual_dim) if actual_dim == config.required_dimension => {
tracing::info!(
"Embedding: Connected to {} ({}) with model '{}' [{} dims]",
provider_name,
embedder_url,
model,
actual_dim
);
let (reranker_url, reranker_model) =
if let Some(ref rr_base) = config.reranker.base_url {
(
Some(format!(
"{}{}",
rr_base.trim_end_matches('/'),
config.reranker.endpoint
)),
config.reranker.model.clone(),
)
} else {
(None, None)
};
return Ok(Self {
client,
embedder_url,
embedder_model: provider.model.clone(),
reranker_url,
reranker_model,
connected_to: provider.name.clone(),
required_dimension: config.required_dimension,
max_batch_chars: config.max_batch_chars,
max_batch_items: config.max_batch_items,
});
}
Ok(actual_dim) => {
let failure = format!(
"- {} ({} model='{}'): the configured embedding endpoint returned {} dims, but config.required_dimension={}.\n Action: set [embeddings].required_dimension = {} or choose a {}-dim model.",
provider_name,
embedder_url,
model,
actual_dim,
config.required_dimension,
actual_dim,
config.required_dimension
);
tracing::error!("Embedding: validation failed: {}", failure);
tried.push(failure);
}
Err(e) => {
let failure = format!(
"- {} ({} model='{}'): {}",
provider_name, embedder_url, model, e
);
tracing::warn!("Embedding: provider probe failed: {}", failure);
tried.push(failure);
}
}
}
Err(anyhow!(
"No embedding provider passed validation for required_dimension={}. \
Each provider must succeed on its configured embedding endpoint before rust-memex will start.\nTried:\n{}",
config.required_dimension,
tried.join("\n")
))
}
pub async fn from_legacy(config: &MlxConfig) -> Result<Self> {
if config.disabled {
return Err(anyhow!(
"Embedding disabled via config. No fallback available!"
));
}
tracing::warn!("Using legacy [mlx] config - please migrate to [embeddings.providers]");
let embedding_config = config.to_embedding_config();
Self::new(&embedding_config).await
}
pub async fn from_env() -> Result<Self> {
let config = MlxConfig::from_env();
Self::from_legacy(&config).await
}
pub fn connected_to(&self) -> &str {
&self.connected_to
}
pub fn required_dimension(&self) -> usize {
self.required_dimension
}
pub fn batch_limits(&self) -> (usize, usize) {
(self.max_batch_chars, self.max_batch_items)
}
pub fn clone_with_batch_limits(&self, max_chars: usize, max_items: usize) -> Self {
let mut cloned = self.clone();
cloned.max_batch_chars = max_chars.max(1);
cloned.max_batch_items = max_items.max(1);
cloned
}
#[doc(hidden)]
pub fn stub_for_tests() -> Self {
Self {
client: reqwest::Client::new(),
embedder_url: "http://stub:0/v1/embeddings".to_string(),
embedder_model: "stub".to_string(),
reranker_url: None,
reranker_model: None,
connected_to: "stub-test".to_string(),
required_dimension: 4096,
max_batch_chars: 32000,
max_batch_items: 16,
}
}
pub async fn embed(&mut self, text: &str) -> Result<Vec<f32>> {
let text_preview: String = text.chars().take(100).collect();
tracing::debug!(
"Embedding single text ({} chars): {}{}",
text.chars().count(),
text_preview,
if text.chars().count() > 100 {
"..."
} else {
""
}
);
let request = EmbeddingRequest {
input: vec![text.to_string()],
model: self.embedder_model.clone(),
};
let response = match self
.client
.post(&self.embedder_url)
.json(&request)
.send()
.await
{
Ok(resp) => resp,
Err(e) => {
tracing::error!(
"Embedding request failed: {:?}\n URL: {}\n Model: {}",
e,
self.embedder_url,
self.embedder_model
);
return Err(anyhow!("Embedding request failed: {}", e));
}
};
let status = response.status();
let response_text = response.text().await.unwrap_or_else(|e| {
tracing::warn!("Failed to read response body: {:?}", e);
"<failed to read body>".to_string()
});
if !status.is_success() {
tracing::error!(
"Embedding API error (HTTP {}):\n URL: {}\n Model: {}\n Response: {}",
status,
self.embedder_url,
self.embedder_model,
response_text
);
return Err(anyhow!(
"Embedding API error (HTTP {}): {}",
status,
response_text
));
}
let parsed: EmbeddingResponse = match serde_json::from_str(&response_text) {
Ok(r) => r,
Err(e) => {
tracing::error!(
"Failed to parse embedding response: {:?}\n Response body: {}",
e,
response_text
);
return Err(anyhow!("Failed to parse embedding response: {}", e));
}
};
let embedding = parsed
.data
.into_iter()
.next()
.map(|d| d.embedding)
.ok_or_else(|| {
tracing::error!("No embedding returned in response: {}", response_text);
anyhow!("No embedding returned")
})?;
if embedding.len() != self.required_dimension {
tracing::error!(
"Dimension mismatch! Expected {}, got {}. Model: {}",
self.required_dimension,
embedding.len(),
self.embedder_model
);
return Err(anyhow!(
"Dimension mismatch! Expected {}, got {}. This would corrupt the database!",
self.required_dimension,
embedding.len()
));
}
tracing::debug!("Successfully embedded text ({} dims)", embedding.len());
Ok(embedding)
}
pub async fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let mut all_embeddings = Vec::with_capacity(texts.len());
let mut current_batch: Vec<String> = Vec::new();
let mut current_batch_indices: Vec<usize> = Vec::new();
let mut current_chars = 0;
let max_text_chars = self.max_batch_chars / 2;
let prepared_texts: Vec<String> = texts
.iter()
.map(|text| {
let char_count = text.chars().count();
if char_count > max_text_chars {
tracing::debug!(
"Text too large ({} chars), truncating to {} chars",
char_count,
max_text_chars
);
truncate_at_boundary(text, max_text_chars)
} else {
text.clone()
}
})
.collect();
let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
let mut failed_indices: Vec<usize> = Vec::new();
for (idx, text_to_embed) in prepared_texts.iter().enumerate() {
let text_len = text_to_embed.chars().count();
if !current_batch.is_empty()
&& (current_chars + text_len > self.max_batch_chars
|| current_batch.len() >= self.max_batch_items)
{
match self.embed_batch_internal(¤t_batch).await {
Ok(batch_embeddings) => {
for (i, emb) in batch_embeddings.into_iter().enumerate() {
if let Some(orig_idx) = current_batch_indices.get(i) {
results[*orig_idx] = Some(emb);
}
}
}
Err(e) => {
tracing::warn!(
"Batch embedding failed for {} texts, will retry individually: {}",
current_batch.len(),
e
);
failed_indices.extend(current_batch_indices.iter().copied());
}
}
current_batch.clear();
current_batch_indices.clear();
current_chars = 0;
}
current_batch.push(text_to_embed.clone());
current_batch_indices.push(idx);
current_chars += text_len;
}
if !current_batch.is_empty() {
match self.embed_batch_internal(¤t_batch).await {
Ok(batch_embeddings) => {
for (i, emb) in batch_embeddings.into_iter().enumerate() {
if let Some(orig_idx) = current_batch_indices.get(i) {
results[*orig_idx] = Some(emb);
}
}
}
Err(e) => {
tracing::warn!(
"Batch embedding failed for {} texts, will retry individually: {}",
current_batch.len(),
e
);
failed_indices.extend(current_batch_indices.iter().copied());
}
}
}
const MAX_RETRIES: usize = 3;
for idx in failed_indices {
let text = &prepared_texts[idx];
let mut attempts = 0;
let mut last_error = String::new();
while attempts < MAX_RETRIES {
match self.embed(text).await {
Ok(embedding) => {
results[idx] = Some(embedding);
tracing::info!(
"Retry succeeded for chunk {} after {} attempts",
idx,
attempts + 1
);
break;
}
Err(e) => {
attempts += 1;
last_error = e.to_string();
tracing::warn!(
"Embed attempt {}/{} failed for chunk {}: {}",
attempts,
MAX_RETRIES,
idx,
e
);
if attempts < MAX_RETRIES {
let delay_ms = 100 * (1 << attempts);
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
}
}
}
}
if results[idx].is_none() {
tracing::error!(
"Chunk {} failed after {} retries: {}",
idx,
MAX_RETRIES,
last_error
);
return Err(anyhow!(
"Failed to embed chunk {} after {} retries: {}",
idx,
MAX_RETRIES,
last_error
));
}
}
for (idx, opt) in results.iter().enumerate() {
match opt {
Some(emb) => all_embeddings.push(emb.clone()),
None => {
return Err(anyhow!(
"Internal error: missing embedding for chunk {}",
idx
));
}
}
}
Ok(all_embeddings)
}
async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let total_chars: usize = texts.iter().map(|t| t.chars().count()).sum();
tracing::debug!(
"Embedding batch: {} texts, {} chars total",
texts.len(),
total_chars
);
for (i, text) in texts.iter().enumerate() {
let preview: String = text.chars().take(50).collect();
tracing::trace!(
" Batch[{}]: {} chars - {}{}",
i,
text.chars().count(),
preview,
if text.chars().count() > 50 { "..." } else { "" }
);
}
let request = EmbeddingRequest {
input: texts.to_vec(),
model: self.embedder_model.clone(),
};
let max_batch_retries = env_usize(
"RUST_MEMEX_EMBED_BATCH_MAX_RETRIES",
DEFAULT_MAX_BATCH_RETRIES,
);
let max_backoff_secs = env_u64(
"RUST_MEMEX_EMBED_BATCH_MAX_BACKOFF_SECS",
DEFAULT_MAX_BATCH_BACKOFF_SECS,
);
let mut attempt = 0;
loop {
attempt += 1;
let response = match self
.client
.post(&self.embedder_url)
.json(&request)
.send()
.await
{
Ok(resp) => resp,
Err(e) => {
if attempt >= max_batch_retries {
tracing::error!(
"Batch embedding failed after {} retries: {:?}\n URL: {}\n Model: {}",
max_batch_retries,
e,
self.embedder_url,
self.embedder_model
);
return Err(anyhow!(
"Embedding request failed after {} retries: {}",
max_batch_retries,
e
));
}
let backoff_secs = (1u64 << attempt.min(5)).min(max_backoff_secs);
tracing::warn!(
"Embedding request failed (attempt {}/{}), retrying in {}s: {}",
attempt,
max_batch_retries,
backoff_secs,
e
);
tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
continue;
}
};
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
if attempt >= max_batch_retries {
tracing::error!(
"Embedding API error after {} retries: {} - {}",
max_batch_retries,
status,
body
);
return Err(anyhow!("Embedding API error: {} - {}", status, body));
}
let backoff_secs = (1u64 << attempt.min(5)).min(max_backoff_secs);
tracing::warn!(
"Embedding API error (attempt {}/{}), retrying in {}s: {} - {}",
attempt,
max_batch_retries,
backoff_secs,
status,
body
);
tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
continue;
}
let embedding_response: EmbeddingResponse = match response.json().await {
Ok(r) => r,
Err(e) => {
if attempt >= max_batch_retries {
return Err(anyhow!("Failed to parse embedding response: {}", e));
}
let backoff_secs = (1u64 << attempt.min(5)).min(max_backoff_secs);
tracing::warn!(
"Failed to parse response (attempt {}/{}), retrying in {}s: {}",
attempt,
max_batch_retries,
backoff_secs,
e
);
tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
continue;
}
};
let embeddings: Vec<Vec<f32>> = embedding_response
.data
.into_iter()
.map(|d| d.embedding)
.collect();
if embeddings.len() != texts.len() {
return Err(anyhow!(
"Embedding count mismatch: got {} embeddings for {} texts",
embeddings.len(),
texts.len()
));
}
if let Some(first) = embeddings.first()
&& first.len() != self.required_dimension
{
return Err(anyhow!(
"Dimension mismatch: expected {}, got {}",
self.required_dimension,
first.len()
));
}
return Ok(embeddings);
}
}
pub async fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<(usize, f32)>> {
let reranker_url = self.reranker_url.as_ref().ok_or_else(|| {
anyhow!("Reranker not configured. Add [embeddings.reranker] to config.")
})?;
let reranker_model = self
.reranker_model
.as_ref()
.ok_or_else(|| anyhow!("Reranker model not configured."))?;
let query_preview: String = query.chars().take(100).collect();
tracing::debug!(
"Reranking {} documents for query: {}{}",
documents.len(),
query_preview,
if query.chars().count() > 100 {
"..."
} else {
""
}
);
let request = RerankRequest {
query: query.to_string(),
documents: documents.to_vec(),
model: reranker_model.clone(),
};
let response = match self.client.post(reranker_url).json(&request).send().await {
Ok(resp) => resp,
Err(e) => {
tracing::error!(
"Rerank request failed: {:?}\n URL: {}\n Model: {}\n Query: {}\n Documents: {}",
e,
reranker_url,
reranker_model,
query_preview,
documents.len()
);
return Err(anyhow!("Rerank request failed: {}", e));
}
};
let status = response.status();
let response_text = response.text().await.unwrap_or_else(|e| {
tracing::warn!("Failed to read rerank response body: {:?}", e);
"<failed to read body>".to_string()
});
if !status.is_success() {
tracing::error!(
"Rerank API error (HTTP {}):\n URL: {}\n Model: {}\n Response: {}",
status,
reranker_url,
reranker_model,
response_text
);
return Err(anyhow!(
"Rerank API error (HTTP {}): {}",
status,
response_text
));
}
let parsed: RerankResponse = match serde_json::from_str(&response_text) {
Ok(r) => r,
Err(e) => {
tracing::error!(
"Failed to parse rerank response: {:?}\n Response body: {}",
e,
response_text
);
return Err(anyhow!("Failed to parse rerank response: {}", e));
}
};
tracing::debug!("Rerank complete: {} documents scored", parsed.results.len());
Ok(parsed
.results
.into_iter()
.map(|r| (r.index, r.score))
.collect())
}
}
pub(crate) async fn probe_provider_dimension(
client: &Client,
provider: &ProviderConfig,
) -> Result<usize> {
let base_url = provider.base_url.trim_end_matches('/');
if base_url.is_empty() {
return Err(anyhow!("provider base_url is empty"));
}
let endpoint = provider.endpoint.trim();
if endpoint.is_empty() {
return Err(anyhow!("provider endpoint is empty"));
}
let model = provider.model.trim();
if model.is_empty() {
return Err(anyhow!("provider model is empty"));
}
let embedder_url = build_provider_endpoint(base_url, endpoint);
let request = EmbeddingRequest {
input: vec!["dimension probe".to_string()],
model: model.to_string(),
};
let response = client
.post(&embedder_url)
.json(&request)
.timeout(Duration::from_secs(30))
.send()
.await
.map_err(|e| anyhow!("POST {} failed: {}", embedder_url, e))?;
let status = response.status();
let body = response.text().await.unwrap_or_default();
if !status.is_success() {
let hint = if status.as_u16() == 404 {
" Check provider.endpoint; Ollama and OpenAI-compatible servers typically use /v1/embeddings."
} else {
""
};
return Err(anyhow!(
"POST {} returned {} for model '{}': {}{}",
embedder_url,
status,
model,
body.chars().take(300).collect::<String>(),
hint
));
}
let embed_response: EmbeddingResponse = serde_json::from_str(&body).map_err(|e| {
anyhow!(
"POST {} returned non-embedding JSON for model '{}': {} (body: {})",
embedder_url,
model,
e,
body.chars().take(200).collect::<String>()
)
})?;
embed_response
.data
.first()
.map(|d| d.embedding.len())
.ok_or_else(|| {
anyhow!(
"POST {} returned no embeddings for model '{}'",
embedder_url,
model
)
})
}
fn truncate_at_boundary(text: &str, max_chars: usize) -> String {
let char_count = text.chars().count();
if char_count <= max_chars {
return text.to_string();
}
let byte_idx = text
.char_indices()
.nth(max_chars)
.map(|(idx, _)| idx)
.unwrap_or(text.len());
let truncated = &text[..byte_idx];
let half_byte_idx = text
.char_indices()
.nth(max_chars / 2)
.map(|(idx, _)| idx)
.unwrap_or(0);
if let Some(pos) = truncated.rfind(['.', '!', '?', '\n'])
&& pos > half_byte_idx
{
return text[..=pos].to_string();
}
if let Some(pos) = truncated.rfind([' ', '\t', '\n']) {
return text[..pos].to_string();
}
truncated.to_string()
}
pub const DEFAULT_MAX_TOKENS: usize = 35_000;
#[derive(Debug, Clone)]
pub struct TokenConfig {
pub max_tokens: usize,
pub chars_per_token: f32,
}
impl Default for TokenConfig {
fn default() -> Self {
Self {
max_tokens: DEFAULT_MAX_TOKENS,
chars_per_token: 3.0,
}
}
}
impl TokenConfig {
pub fn english() -> Self {
Self {
max_tokens: DEFAULT_MAX_TOKENS,
chars_per_token: 4.0,
}
}
pub fn for_multilingual_text() -> Self {
Self {
max_tokens: DEFAULT_MAX_TOKENS,
chars_per_token: 2.5,
}
}
pub fn with_max_tokens(mut self, max: usize) -> Self {
self.max_tokens = max;
self
}
}
pub fn estimate_tokens(text: &str, config: &TokenConfig) -> usize {
let char_count = text.chars().count();
(char_count as f32 / config.chars_per_token).ceil() as usize
}
pub fn validate_chunk_tokens(chunk: &str, config: &TokenConfig) -> Result<()> {
let estimated = estimate_tokens(chunk, config);
if estimated > config.max_tokens {
return Err(anyhow!(
"Chunk exceeds token limit: ~{} tokens > {} max (text: {} chars). \
Consider reducing chunk_size or enabling truncation.",
estimated,
config.max_tokens,
chunk.chars().count()
));
}
Ok(())
}
pub fn safe_chunk_size(config: &TokenConfig) -> usize {
let safe_tokens = (config.max_tokens as f32 * 0.8) as usize;
(safe_tokens as f32 * config.chars_per_token) as usize
}
pub fn truncate_to_token_limit(text: &str, config: &TokenConfig) -> String {
let safe_chars = safe_chunk_size(config);
if text.chars().count() <= safe_chars {
return text.to_string();
}
truncate_at_boundary(text, safe_chars)
}
pub fn validate_batch_tokens(texts: &[String], config: &TokenConfig) -> Vec<(usize, usize)> {
texts
.iter()
.enumerate()
.filter_map(|(idx, text)| {
let estimated = estimate_tokens(text, config);
if estimated > config.max_tokens {
Some((idx, estimated))
} else {
None
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{Json, Router, extract::State, routing::post};
use serde_json::json;
async fn mock_embeddings(State(dim): State<usize>) -> Json<serde_json::Value> {
Json(json!({
"data": [{
"embedding": vec![0.25_f32; dim]
}]
}))
}
async fn spawn_mock_embedding_server(dim: usize) -> String {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let app = Router::new()
.route("/v1/embeddings", post(mock_embeddings))
.with_state(dim);
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
tokio::time::sleep(Duration::from_millis(10)).await;
format!("http://{}", addr)
}
#[test]
fn test_provider_sorting() {
let mut providers = [
ProviderConfig {
name: "low".into(),
base_url: "http://a".into(),
model: "m".into(),
priority: 10,
endpoint: "/v1/embeddings".into(),
},
ProviderConfig {
name: "high".into(),
base_url: "http://b".into(),
model: "m".into(),
priority: 1,
endpoint: "/v1/embeddings".into(),
},
];
providers.sort_by_key(|p| p.priority);
assert_eq!(providers[0].name, "high");
assert_eq!(providers[1].name, "low");
}
#[test]
fn test_legacy_conversion() {
let legacy = MlxConfig {
disabled: false,
local_port: 12345,
dragon_url: "http://dragon".into(),
dragon_port: 12345,
embedder_model: "test-model".into(),
reranker_model: "rerank-model".into(),
reranker_port_offset: 1,
max_batch_chars: 32000,
max_batch_items: 16,
};
let config = legacy.to_embedding_config();
assert_eq!(config.providers.len(), 2);
assert_eq!(config.providers[0].base_url, "http://localhost:12345");
assert!(config.reranker.base_url.is_some());
assert_eq!(config.max_batch_chars, 32000);
assert_eq!(config.max_batch_items, 16);
}
#[test]
fn test_default_config() {
let config = EmbeddingConfig::default();
assert_eq!(config.required_dimension, DEFAULT_REQUIRED_DIMENSION);
assert_eq!(config.max_batch_chars, 128000); assert_eq!(config.max_batch_items, 64); assert!(!config.providers.is_empty());
assert_eq!(config.providers[0].model, DEFAULT_OLLAMA_EMBEDDING_MODEL);
}
#[tokio::test]
async fn test_probe_provider_dimension_reads_actual_dimension() {
let base_url = spawn_mock_embedding_server(2560).await;
let client = Client::new();
let provider = ProviderConfig {
name: "mock".into(),
base_url,
model: "mock-embedder".into(),
priority: 1,
endpoint: "/v1/embeddings".into(),
};
let dim = probe_provider_dimension(&client, &provider).await.unwrap();
assert_eq!(dim, 2560);
}
#[tokio::test]
async fn test_embedding_client_fails_fast_on_dimension_mismatch() {
let base_url = spawn_mock_embedding_server(2560).await;
let config = EmbeddingConfig {
required_dimension: 1024,
providers: vec![ProviderConfig {
name: "mock".into(),
base_url,
model: "mock-embedder".into(),
priority: 1,
endpoint: "/v1/embeddings".into(),
}],
..EmbeddingConfig::default()
};
let err = EmbeddingClient::new(&config)
.await
.err()
.expect("dimension mismatch should fail")
.to_string();
assert!(err.contains("returned 2560 dims"));
assert!(err.contains("required_dimension=1024"));
}
#[test]
fn test_truncate_at_boundary() {
let text = "Hello world. This is a test.";
let truncated = truncate_at_boundary(text, 15);
assert_eq!(truncated, "Hello world.");
let text = "Hello world this is a test";
let truncated = truncate_at_boundary(text, 15);
assert_eq!(truncated, "Hello world");
let text = "Short text";
let truncated = truncate_at_boundary(text, 100);
assert_eq!(truncated, "Short text");
}
#[test]
fn test_token_estimation() {
let config = TokenConfig::default();
let text = "Hello world"; let tokens = estimate_tokens(text, &config);
assert!((3..=5).contains(&tokens));
let english_config = TokenConfig::english();
let tokens = estimate_tokens(text, &english_config);
assert!((2..=4).contains(&tokens));
}
#[test]
fn default_token_ceiling_stays_above_long_transcript_floor() {
let config = TokenConfig::default();
assert_eq!(DEFAULT_MAX_TOKENS, 35_000);
assert_eq!(config.max_tokens, DEFAULT_MAX_TOKENS);
assert!(config.max_tokens >= 35_000);
}
#[test]
fn test_chunk_validation() {
let config = TokenConfig::default().with_max_tokens(100);
let short = "Hello world";
assert!(validate_chunk_tokens(short, &config).is_ok());
let long = "a".repeat(1000); assert!(validate_chunk_tokens(&long, &config).is_err());
}
#[test]
fn test_safe_chunk_size() {
let config = TokenConfig::default();
let safe = safe_chunk_size(&config);
assert!(safe > 80_000 && safe < 90_000);
}
#[test]
fn test_batch_validation() {
let config = TokenConfig::default().with_max_tokens(10);
let texts = vec![
"short".to_string(), "a".repeat(100), "also short".to_string(), "b".repeat(200), ];
let failures = validate_batch_tokens(&texts, &config);
assert_eq!(failures.len(), 2);
assert_eq!(failures[0].0, 1); assert_eq!(failures[1].0, 3); }
}
#[derive(Debug, Clone)]
pub struct DimensionAdapter {
pub source_dim: usize,
pub target_dim: usize,
}
impl DimensionAdapter {
pub fn new(source_dim: usize, target_dim: usize) -> Self {
Self {
source_dim,
target_dim,
}
}
pub fn needs_adaptation(&self) -> bool {
self.source_dim != self.target_dim
}
pub fn adapt(&self, embedding: Vec<f32>) -> Vec<f32> {
if embedding.len() == self.target_dim {
return embedding;
}
if embedding.len() < self.target_dim {
self.expand(embedding)
} else {
self.contract(embedding)
}
}
pub fn expand(&self, embedding: Vec<f32>) -> Vec<f32> {
if embedding.len() >= self.target_dim {
return embedding[..self.target_dim].to_vec();
}
let mut padded = embedding;
padded.resize(self.target_dim, 0.0);
self.normalize(&mut padded);
padded
}
pub fn contract(&self, embedding: Vec<f32>) -> Vec<f32> {
if embedding.len() <= self.target_dim {
return embedding;
}
if self.is_power_of_two_reduction(embedding.len()) {
self.average_reduction(embedding)
} else {
embedding[..self.target_dim].to_vec()
}
}
fn is_power_of_two_reduction(&self, source_len: usize) -> bool {
source_len > self.target_dim
&& source_len.is_power_of_two()
&& self.target_dim.is_power_of_two()
&& source_len.is_multiple_of(self.target_dim)
}
fn average_reduction(&self, embedding: Vec<f32>) -> Vec<f32> {
let factor = embedding.len() / self.target_dim;
let mut result = Vec::with_capacity(self.target_dim);
for chunk in embedding.chunks(factor) {
let sum: f32 = chunk.iter().sum();
result.push(sum / factor as f32);
}
self.normalize(&mut result);
result
}
fn normalize(&self, vec: &mut [f32]) {
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for v in vec.iter_mut() {
*v /= norm;
}
}
}
}
pub fn cross_dimension_search_adapt(query_embedding: Vec<f32>, target_dim: usize) -> Vec<f32> {
let adapter = DimensionAdapter::new(query_embedding.len(), target_dim);
adapter.adapt(query_embedding)
}
#[cfg(test)]
mod dimension_adapter_tests {
use super::*;
#[test]
fn test_expand_1024_to_4096() {
let adapter = DimensionAdapter::new(1024, 4096);
let small = vec![0.1f32; 1024];
let expanded = adapter.expand(small);
assert_eq!(expanded.len(), 4096);
assert!(expanded[0].abs() > 1e-10);
assert!(expanded[4095].abs() < 1e-10);
}
#[test]
fn test_contract_4096_to_1024() {
let adapter = DimensionAdapter::new(4096, 1024);
let large = vec![0.1f32; 4096];
let contracted = adapter.contract(large);
assert_eq!(contracted.len(), 1024);
let norm: f32 = contracted.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
#[test]
fn test_adapt_auto_detect() {
let adapter = DimensionAdapter::new(1024, 4096);
let small = vec![0.1f32; 1024];
let result = adapter.adapt(small);
assert_eq!(result.len(), 4096);
let adapter = DimensionAdapter::new(4096, 1024);
let large = vec![0.1f32; 4096];
let result = adapter.adapt(large);
assert_eq!(result.len(), 1024);
}
#[test]
fn test_no_adaptation_needed() {
let adapter = DimensionAdapter::new(4096, 4096);
assert!(!adapter.needs_adaptation());
let embedding = vec![0.1f32; 4096];
let result = adapter.adapt(embedding.clone());
assert_eq!(result, embedding);
}
#[test]
fn test_average_reduction_preserves_info() {
let adapter = DimensionAdapter::new(4096, 2048);
let large: Vec<f32> = (0..4096).map(|i| i as f32 / 4096.0).collect();
let contracted = adapter.contract(large);
assert_eq!(contracted.len(), 2048);
let norm: f32 = contracted.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
}