use anyhow::{Result, anyhow};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Serialize)]
struct EmbeddingRequest {
input: Vec<String>,
model: String,
}
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
}
#[derive(Debug, Serialize)]
struct RerankRequest {
query: String,
documents: Vec<String>,
model: String,
}
#[derive(Debug, Deserialize)]
struct RerankResponse {
results: Vec<RerankResult>,
}
#[derive(Debug, Deserialize)]
struct RerankResult {
index: usize,
score: f32,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct ProviderConfig {
#[serde(default)]
pub name: String,
#[serde(default)]
pub base_url: String,
#[serde(default)]
pub model: String,
#[serde(default = "default_priority")]
pub priority: u8,
#[serde(default = "default_embeddings_endpoint")]
pub endpoint: String,
}
fn default_priority() -> u8 {
10
}
fn default_embeddings_endpoint() -> String {
"/v1/embeddings".to_string()
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct RerankerConfig {
pub base_url: Option<String>,
pub model: Option<String>,
#[serde(default = "default_rerank_endpoint")]
pub endpoint: String,
}
fn default_rerank_endpoint() -> String {
"/v1/rerank".to_string()
}
fn default_dimension() -> usize {
4096
}
fn default_max_batch_chars() -> usize {
128000 }
fn default_max_batch_items() -> usize {
64 }
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct EmbeddingConfig {
#[serde(default = "default_dimension")]
pub required_dimension: usize,
#[serde(default = "default_max_batch_chars")]
pub max_batch_chars: usize,
#[serde(default = "default_max_batch_items")]
pub max_batch_items: usize,
#[serde(default)]
pub providers: Vec<ProviderConfig>,
#[serde(default)]
pub reranker: RerankerConfig,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
required_dimension: 4096,
max_batch_chars: default_max_batch_chars(),
max_batch_items: default_max_batch_items(),
providers: vec![
ProviderConfig {
name: "ollama-local".to_string(),
base_url: "http://localhost:11434".to_string(),
model: "qwen3-embedding:8b".to_string(),
priority: 1,
endpoint: default_embeddings_endpoint(),
},
ProviderConfig {
name: "dragon".to_string(),
base_url: "http://dragon:12345".to_string(),
model: "Qwen/Qwen3-Embedding-4B".to_string(),
priority: 2,
endpoint: default_embeddings_endpoint(),
},
],
reranker: RerankerConfig::default(),
}
}
}
impl EmbeddingConfig {
pub fn provider_name(&self) -> String {
self.providers
.first()
.map(|p| p.name.clone())
.unwrap_or_else(|| "none".to_string())
}
pub fn model_name(&self) -> String {
self.providers
.first()
.map(|p| p.model.clone())
.unwrap_or_else(|| "none".to_string())
}
pub fn dimension(&self) -> usize {
self.required_dimension
}
}
#[derive(Debug, Clone)]
pub struct MlxConfig {
pub disabled: bool,
pub local_port: u16,
pub dragon_url: String,
pub dragon_port: u16,
pub embedder_model: String,
pub reranker_model: String,
pub reranker_port_offset: u16,
pub max_batch_chars: usize,
pub max_batch_items: usize,
}
impl Default for MlxConfig {
fn default() -> Self {
Self {
disabled: false,
local_port: 12345,
dragon_url: "http://dragon".to_string(),
dragon_port: 12345,
embedder_model: "Qwen/Qwen3-Embedding-4B".to_string(),
reranker_model: "Qwen/Qwen3-Reranker-4B".to_string(),
reranker_port_offset: 1,
max_batch_chars: default_max_batch_chars(),
max_batch_items: default_max_batch_items(),
}
}
}
impl MlxConfig {
pub fn from_env() -> Self {
let disabled = std::env::var("DISABLE_MLX")
.map(|v| v == "1" || v.to_lowercase() == "true")
.unwrap_or(false);
let local_port = std::env::var("EMBEDDER_PORT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(12345);
let dragon_url =
std::env::var("DRAGON_BASE_URL").unwrap_or_else(|_| "http://dragon".to_string());
let dragon_port = std::env::var("DRAGON_EMBEDDER_PORT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(local_port);
let reranker_port_offset = std::env::var("RERANKER_PORT")
.ok()
.and_then(|s| s.parse::<u16>().ok())
.map(|rp| rp.saturating_sub(local_port))
.unwrap_or(1);
let embedder_model = std::env::var("EMBEDDER_MODEL")
.unwrap_or_else(|_| "Qwen/Qwen3-Embedding-4B".to_string());
let reranker_model = std::env::var("RERANKER_MODEL")
.unwrap_or_else(|_| "Qwen/Qwen3-Reranker-4B".to_string());
let max_batch_chars = std::env::var("MLX_MAX_BATCH_CHARS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(32000);
let max_batch_items = std::env::var("MLX_MAX_BATCH_ITEMS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(16);
Self {
disabled,
local_port,
dragon_url,
dragon_port,
embedder_model,
reranker_model,
reranker_port_offset,
max_batch_chars,
max_batch_items,
}
}
#[allow(clippy::too_many_arguments)]
pub fn merge_file_config(
&mut self,
disabled: Option<bool>,
local_port: Option<u16>,
dragon_url: Option<String>,
dragon_port: Option<u16>,
embedder_model: Option<String>,
reranker_model: Option<String>,
reranker_port_offset: Option<u16>,
) {
if let Some(v) = disabled {
self.disabled = v;
}
if let Some(v) = local_port {
self.local_port = v;
}
if let Some(v) = dragon_url {
self.dragon_url = v;
}
if let Some(v) = dragon_port {
self.dragon_port = v;
}
if let Some(v) = embedder_model {
self.embedder_model = v;
}
if let Some(v) = reranker_model {
self.reranker_model = v;
}
if let Some(v) = reranker_port_offset {
self.reranker_port_offset = v;
}
}
pub fn to_embedding_config(&self) -> EmbeddingConfig {
let reranker_port = self.local_port + self.reranker_port_offset;
EmbeddingConfig {
required_dimension: 4096,
max_batch_chars: self.max_batch_chars,
max_batch_items: self.max_batch_items,
providers: vec![
ProviderConfig {
name: "local".to_string(),
base_url: format!("http://localhost:{}", self.local_port),
model: self.embedder_model.clone(),
priority: 1,
endpoint: default_embeddings_endpoint(),
},
ProviderConfig {
name: "dragon".to_string(),
base_url: format!("{}:{}", self.dragon_url, self.dragon_port),
model: self.embedder_model.clone(),
priority: 2,
endpoint: default_embeddings_endpoint(),
},
],
reranker: RerankerConfig {
base_url: Some(format!("{}:{}", self.dragon_url, reranker_port)),
model: Some(self.reranker_model.clone()),
endpoint: default_rerank_endpoint(),
},
}
}
pub fn with_batch_limits(mut self, max_chars: usize, max_items: usize) -> Self {
self.max_batch_chars = max_chars;
self.max_batch_items = max_items;
self
}
}
pub struct EmbeddingClient {
client: Client,
embedder_url: String,
embedder_model: String,
reranker_url: Option<String>,
reranker_model: Option<String>,
connected_to: String,
required_dimension: usize,
max_batch_chars: usize,
max_batch_items: usize,
}
pub type MLXBridge = EmbeddingClient;
impl EmbeddingClient {
pub async fn new(config: &EmbeddingConfig) -> Result<Self> {
if config.providers.is_empty() {
return Err(anyhow!(
"No embedding providers configured! Add providers to [embeddings.providers]"
));
}
let client = Client::builder()
.timeout(Duration::from_secs(300))
.connect_timeout(Duration::from_secs(10))
.build()?;
let mut providers = config.providers.clone();
providers.sort_by_key(|p| p.priority);
let mut tried = Vec::new();
for provider in &providers {
let base_url = provider.base_url.trim_end_matches('/');
match Self::health_check(&client, base_url).await {
Ok(()) => {
tracing::info!("Embedding: Connected to {} ({})", provider.name, base_url);
let embedder_url = format!("{}{}", base_url, provider.endpoint);
let (reranker_url, reranker_model) =
if let Some(ref rr_base) = config.reranker.base_url {
(
Some(format!(
"{}{}",
rr_base.trim_end_matches('/'),
config.reranker.endpoint
)),
config.reranker.model.clone(),
)
} else {
(None, None)
};
return Ok(Self {
client,
embedder_url,
embedder_model: provider.model.clone(),
reranker_url,
reranker_model,
connected_to: provider.name.clone(),
required_dimension: config.required_dimension,
max_batch_chars: config.max_batch_chars,
max_batch_items: config.max_batch_items,
});
}
Err(e) => {
tracing::warn!(
"Embedding: {} ({}) unavailable: {}",
provider.name,
base_url,
e
);
tried.push(format!("- {} ({}): {}", provider.name, base_url, e));
}
}
}
Err(anyhow!(
"All embedding providers unavailable!\nTried:\n{}",
tried.join("\n")
))
}
pub async fn from_legacy(config: &MlxConfig) -> Result<Self> {
if config.disabled {
return Err(anyhow!(
"Embedding disabled via config. No fallback available!"
));
}
tracing::warn!("Using legacy [mlx] config - please migrate to [embeddings.providers]");
let embedding_config = config.to_embedding_config();
Self::new(&embedding_config).await
}
pub async fn from_env() -> Result<Self> {
let config = MlxConfig::from_env();
Self::from_legacy(&config).await
}
async fn health_check(client: &Client, base_url: &str) -> Result<()> {
let url = format!("{}/v1/models", base_url);
let response = client
.get(&url)
.timeout(Duration::from_secs(5))
.send()
.await;
match response {
Ok(resp) if resp.status().is_success() => Ok(()),
Ok(resp) if resp.status().as_u16() == 404 => {
let ollama_url = format!("{}/api/tags", base_url);
let ollama_resp = client
.get(&ollama_url)
.timeout(Duration::from_secs(5))
.send()
.await?;
if ollama_resp.status().is_success() {
Ok(())
} else {
Err(anyhow!("Neither /v1/models nor /api/tags available"))
}
}
Ok(resp) => Err(anyhow!("Health check failed: {}", resp.status())),
Err(e) => Err(anyhow!("Connection failed: {}", e)),
}
}
pub fn connected_to(&self) -> &str {
&self.connected_to
}
pub fn required_dimension(&self) -> usize {
self.required_dimension
}
pub async fn embed(&mut self, text: &str) -> Result<Vec<f32>> {
let text_preview: String = text.chars().take(100).collect();
tracing::debug!(
"Embedding single text ({} chars): {}{}",
text.chars().count(),
text_preview,
if text.chars().count() > 100 {
"..."
} else {
""
}
);
let request = EmbeddingRequest {
input: vec![text.to_string()],
model: self.embedder_model.clone(),
};
let response = match self
.client
.post(&self.embedder_url)
.json(&request)
.send()
.await
{
Ok(resp) => resp,
Err(e) => {
tracing::error!(
"Embedding request failed: {:?}\n URL: {}\n Model: {}",
e,
self.embedder_url,
self.embedder_model
);
return Err(anyhow!("Embedding request failed: {}", e));
}
};
let status = response.status();
let response_text = response.text().await.unwrap_or_else(|e| {
tracing::warn!("Failed to read response body: {:?}", e);
"<failed to read body>".to_string()
});
if !status.is_success() {
tracing::error!(
"Embedding API error (HTTP {}):\n URL: {}\n Model: {}\n Response: {}",
status,
self.embedder_url,
self.embedder_model,
response_text
);
return Err(anyhow!(
"Embedding API error (HTTP {}): {}",
status,
response_text
));
}
let parsed: EmbeddingResponse = match serde_json::from_str(&response_text) {
Ok(r) => r,
Err(e) => {
tracing::error!(
"Failed to parse embedding response: {:?}\n Response body: {}",
e,
response_text
);
return Err(anyhow!("Failed to parse embedding response: {}", e));
}
};
let embedding = parsed
.data
.into_iter()
.next()
.map(|d| d.embedding)
.ok_or_else(|| {
tracing::error!("No embedding returned in response: {}", response_text);
anyhow!("No embedding returned")
})?;
if embedding.len() != self.required_dimension {
tracing::error!(
"Dimension mismatch! Expected {}, got {}. Model: {}",
self.required_dimension,
embedding.len(),
self.embedder_model
);
return Err(anyhow!(
"Dimension mismatch! Expected {}, got {}. This would corrupt the database!",
self.required_dimension,
embedding.len()
));
}
tracing::debug!("Successfully embedded text ({} dims)", embedding.len());
Ok(embedding)
}
pub async fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let mut all_embeddings = Vec::with_capacity(texts.len());
let mut current_batch: Vec<String> = Vec::new();
let mut current_batch_indices: Vec<usize> = Vec::new();
let mut current_chars = 0;
let max_text_chars = self.max_batch_chars / 2;
let prepared_texts: Vec<String> = texts
.iter()
.map(|text| {
let char_count = text.chars().count();
if char_count > max_text_chars {
tracing::debug!(
"Text too large ({} chars), truncating to {} chars",
char_count,
max_text_chars
);
truncate_at_boundary(text, max_text_chars)
} else {
text.clone()
}
})
.collect();
let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
let mut failed_indices: Vec<usize> = Vec::new();
for (idx, text_to_embed) in prepared_texts.iter().enumerate() {
let text_len = text_to_embed.chars().count();
if !current_batch.is_empty()
&& (current_chars + text_len > self.max_batch_chars
|| current_batch.len() >= self.max_batch_items)
{
match self.embed_batch_internal(¤t_batch).await {
Ok(batch_embeddings) => {
for (i, emb) in batch_embeddings.into_iter().enumerate() {
if let Some(orig_idx) = current_batch_indices.get(i) {
results[*orig_idx] = Some(emb);
}
}
}
Err(e) => {
tracing::warn!(
"Batch embedding failed for {} texts, will retry individually: {}",
current_batch.len(),
e
);
failed_indices.extend(current_batch_indices.iter().copied());
}
}
current_batch.clear();
current_batch_indices.clear();
current_chars = 0;
}
current_batch.push(text_to_embed.clone());
current_batch_indices.push(idx);
current_chars += text_len;
}
if !current_batch.is_empty() {
match self.embed_batch_internal(¤t_batch).await {
Ok(batch_embeddings) => {
for (i, emb) in batch_embeddings.into_iter().enumerate() {
if let Some(orig_idx) = current_batch_indices.get(i) {
results[*orig_idx] = Some(emb);
}
}
}
Err(e) => {
tracing::warn!(
"Batch embedding failed for {} texts, will retry individually: {}",
current_batch.len(),
e
);
failed_indices.extend(current_batch_indices.iter().copied());
}
}
}
const MAX_RETRIES: usize = 3;
for idx in failed_indices {
let text = &prepared_texts[idx];
let mut attempts = 0;
let mut last_error = String::new();
while attempts < MAX_RETRIES {
match self.embed(text).await {
Ok(embedding) => {
results[idx] = Some(embedding);
tracing::info!(
"Retry succeeded for chunk {} after {} attempts",
idx,
attempts + 1
);
break;
}
Err(e) => {
attempts += 1;
last_error = e.to_string();
tracing::warn!(
"Embed attempt {}/{} failed for chunk {}: {}",
attempts,
MAX_RETRIES,
idx,
e
);
if attempts < MAX_RETRIES {
let delay_ms = 100 * (1 << attempts);
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
}
}
}
}
if results[idx].is_none() {
tracing::error!(
"Chunk {} failed after {} retries: {}",
idx,
MAX_RETRIES,
last_error
);
return Err(anyhow!(
"Failed to embed chunk {} after {} retries: {}",
idx,
MAX_RETRIES,
last_error
));
}
}
for (idx, opt) in results.iter().enumerate() {
match opt {
Some(emb) => all_embeddings.push(emb.clone()),
None => {
return Err(anyhow!(
"Internal error: missing embedding for chunk {}",
idx
));
}
}
}
Ok(all_embeddings)
}
async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let total_chars: usize = texts.iter().map(|t| t.chars().count()).sum();
tracing::debug!(
"Embedding batch: {} texts, {} chars total",
texts.len(),
total_chars
);
for (i, text) in texts.iter().enumerate() {
let preview: String = text.chars().take(50).collect();
tracing::trace!(
" Batch[{}]: {} chars - {}{}",
i,
text.chars().count(),
preview,
if text.chars().count() > 50 { "..." } else { "" }
);
}
let request = EmbeddingRequest {
input: texts.to_vec(),
model: self.embedder_model.clone(),
};
const MAX_BATCH_RETRIES: usize = 10;
const MAX_BACKOFF_SECS: u64 = 30;
let mut attempt = 0;
loop {
attempt += 1;
let response = match self
.client
.post(&self.embedder_url)
.json(&request)
.send()
.await
{
Ok(resp) => resp,
Err(e) => {
if attempt >= MAX_BATCH_RETRIES {
tracing::error!(
"Batch embedding failed after {} retries: {:?}\n URL: {}\n Model: {}",
MAX_BATCH_RETRIES,
e,
self.embedder_url,
self.embedder_model
);
return Err(anyhow!(
"Embedding request failed after {} retries: {}",
MAX_BATCH_RETRIES,
e
));
}
let backoff_secs = (1u64 << attempt.min(5)).min(MAX_BACKOFF_SECS);
tracing::warn!(
"Embedding request failed (attempt {}/{}), retrying in {}s: {}",
attempt,
MAX_BATCH_RETRIES,
backoff_secs,
e
);
tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
continue;
}
};
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
if attempt >= MAX_BATCH_RETRIES {
tracing::error!(
"Embedding API error after {} retries: {} - {}",
MAX_BATCH_RETRIES,
status,
body
);
return Err(anyhow!("Embedding API error: {} - {}", status, body));
}
let backoff_secs = (1u64 << attempt.min(5)).min(MAX_BACKOFF_SECS);
tracing::warn!(
"Embedding API error (attempt {}/{}), retrying in {}s: {} - {}",
attempt,
MAX_BATCH_RETRIES,
backoff_secs,
status,
body
);
tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
continue;
}
let embedding_response: EmbeddingResponse = match response.json().await {
Ok(r) => r,
Err(e) => {
if attempt >= MAX_BATCH_RETRIES {
return Err(anyhow!("Failed to parse embedding response: {}", e));
}
let backoff_secs = (1u64 << attempt.min(5)).min(MAX_BACKOFF_SECS);
tracing::warn!(
"Failed to parse response (attempt {}/{}), retrying in {}s: {}",
attempt,
MAX_BATCH_RETRIES,
backoff_secs,
e
);
tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
continue;
}
};
let embeddings: Vec<Vec<f32>> = embedding_response
.data
.into_iter()
.map(|d| d.embedding)
.collect();
if embeddings.len() != texts.len() {
return Err(anyhow!(
"Embedding count mismatch: got {} embeddings for {} texts",
embeddings.len(),
texts.len()
));
}
if let Some(first) = embeddings.first()
&& first.len() != self.required_dimension
{
return Err(anyhow!(
"Dimension mismatch: expected {}, got {}",
self.required_dimension,
first.len()
));
}
return Ok(embeddings);
}
}
pub async fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<(usize, f32)>> {
let reranker_url = self.reranker_url.as_ref().ok_or_else(|| {
anyhow!("Reranker not configured. Add [embeddings.reranker] to config.")
})?;
let reranker_model = self
.reranker_model
.as_ref()
.ok_or_else(|| anyhow!("Reranker model not configured."))?;
let query_preview: String = query.chars().take(100).collect();
tracing::debug!(
"Reranking {} documents for query: {}{}",
documents.len(),
query_preview,
if query.chars().count() > 100 {
"..."
} else {
""
}
);
let request = RerankRequest {
query: query.to_string(),
documents: documents.to_vec(),
model: reranker_model.clone(),
};
let response = match self.client.post(reranker_url).json(&request).send().await {
Ok(resp) => resp,
Err(e) => {
tracing::error!(
"Rerank request failed: {:?}\n URL: {}\n Model: {}\n Query: {}\n Documents: {}",
e,
reranker_url,
reranker_model,
query_preview,
documents.len()
);
return Err(anyhow!("Rerank request failed: {}", e));
}
};
let status = response.status();
let response_text = response.text().await.unwrap_or_else(|e| {
tracing::warn!("Failed to read rerank response body: {:?}", e);
"<failed to read body>".to_string()
});
if !status.is_success() {
tracing::error!(
"Rerank API error (HTTP {}):\n URL: {}\n Model: {}\n Response: {}",
status,
reranker_url,
reranker_model,
response_text
);
return Err(anyhow!(
"Rerank API error (HTTP {}): {}",
status,
response_text
));
}
let parsed: RerankResponse = match serde_json::from_str(&response_text) {
Ok(r) => r,
Err(e) => {
tracing::error!(
"Failed to parse rerank response: {:?}\n Response body: {}",
e,
response_text
);
return Err(anyhow!("Failed to parse rerank response: {}", e));
}
};
tracing::debug!("Rerank complete: {} documents scored", parsed.results.len());
Ok(parsed
.results
.into_iter()
.map(|r| (r.index, r.score))
.collect())
}
}
fn truncate_at_boundary(text: &str, max_chars: usize) -> String {
let char_count = text.chars().count();
if char_count <= max_chars {
return text.to_string();
}
let byte_idx = text
.char_indices()
.nth(max_chars)
.map(|(idx, _)| idx)
.unwrap_or(text.len());
let truncated = &text[..byte_idx];
let half_byte_idx = text
.char_indices()
.nth(max_chars / 2)
.map(|(idx, _)| idx)
.unwrap_or(0);
if let Some(pos) = truncated.rfind(['.', '!', '?', '\n'])
&& pos > half_byte_idx
{
return text[..=pos].to_string();
}
if let Some(pos) = truncated.rfind([' ', '\t', '\n']) {
return text[..pos].to_string();
}
truncated.to_string()
}
#[derive(Debug, Clone)]
pub struct TokenConfig {
pub max_tokens: usize,
pub chars_per_token: f32,
}
impl Default for TokenConfig {
fn default() -> Self {
Self {
max_tokens: 8192, chars_per_token: 3.0, }
}
}
impl TokenConfig {
pub fn english() -> Self {
Self {
max_tokens: 8192,
chars_per_token: 4.0,
}
}
pub fn multilingual() -> Self {
Self {
max_tokens: 8192,
chars_per_token: 2.5,
}
}
pub fn with_max_tokens(mut self, max: usize) -> Self {
self.max_tokens = max;
self
}
}
pub fn estimate_tokens(text: &str, config: &TokenConfig) -> usize {
let char_count = text.chars().count();
(char_count as f32 / config.chars_per_token).ceil() as usize
}
pub fn validate_chunk_tokens(chunk: &str, config: &TokenConfig) -> Result<()> {
let estimated = estimate_tokens(chunk, config);
if estimated > config.max_tokens {
return Err(anyhow!(
"Chunk exceeds token limit: ~{} tokens > {} max (text: {} chars). \
Consider reducing chunk_size or enabling truncation.",
estimated,
config.max_tokens,
chunk.chars().count()
));
}
Ok(())
}
pub fn safe_chunk_size(config: &TokenConfig) -> usize {
let safe_tokens = (config.max_tokens as f32 * 0.8) as usize;
(safe_tokens as f32 * config.chars_per_token) as usize
}
pub fn truncate_to_token_limit(text: &str, config: &TokenConfig) -> String {
let safe_chars = safe_chunk_size(config);
if text.chars().count() <= safe_chars {
return text.to_string();
}
truncate_at_boundary(text, safe_chars)
}
pub fn validate_batch_tokens(texts: &[String], config: &TokenConfig) -> Vec<(usize, usize)> {
texts
.iter()
.enumerate()
.filter_map(|(idx, text)| {
let estimated = estimate_tokens(text, config);
if estimated > config.max_tokens {
Some((idx, estimated))
} else {
None
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_sorting() {
let mut providers = [
ProviderConfig {
name: "low".into(),
base_url: "http://a".into(),
model: "m".into(),
priority: 10,
endpoint: "/v1/embeddings".into(),
},
ProviderConfig {
name: "high".into(),
base_url: "http://b".into(),
model: "m".into(),
priority: 1,
endpoint: "/v1/embeddings".into(),
},
];
providers.sort_by_key(|p| p.priority);
assert_eq!(providers[0].name, "high");
assert_eq!(providers[1].name, "low");
}
#[test]
fn test_legacy_conversion() {
let legacy = MlxConfig {
disabled: false,
local_port: 12345,
dragon_url: "http://dragon".into(),
dragon_port: 12345,
embedder_model: "test-model".into(),
reranker_model: "rerank-model".into(),
reranker_port_offset: 1,
max_batch_chars: 32000,
max_batch_items: 16,
};
let config = legacy.to_embedding_config();
assert_eq!(config.providers.len(), 2);
assert_eq!(config.providers[0].base_url, "http://localhost:12345");
assert!(config.reranker.base_url.is_some());
assert_eq!(config.max_batch_chars, 32000);
assert_eq!(config.max_batch_items, 16);
}
#[test]
fn test_default_config() {
let config = EmbeddingConfig::default();
assert_eq!(config.required_dimension, 4096);
assert_eq!(config.max_batch_chars, 128000); assert_eq!(config.max_batch_items, 64); assert!(!config.providers.is_empty());
}
#[test]
fn test_truncate_at_boundary() {
let text = "Hello world. This is a test.";
let truncated = truncate_at_boundary(text, 15);
assert_eq!(truncated, "Hello world.");
let text = "Hello world this is a test";
let truncated = truncate_at_boundary(text, 15);
assert_eq!(truncated, "Hello world");
let text = "Short text";
let truncated = truncate_at_boundary(text, 100);
assert_eq!(truncated, "Short text");
}
#[test]
fn test_token_estimation() {
let config = TokenConfig::default();
let text = "Hello world"; let tokens = estimate_tokens(text, &config);
assert!((3..=5).contains(&tokens));
let english_config = TokenConfig::english();
let tokens = estimate_tokens(text, &english_config);
assert!((2..=4).contains(&tokens));
}
#[test]
fn test_chunk_validation() {
let config = TokenConfig::default().with_max_tokens(100);
let short = "Hello world";
assert!(validate_chunk_tokens(short, &config).is_ok());
let long = "a".repeat(1000); assert!(validate_chunk_tokens(&long, &config).is_err());
}
#[test]
fn test_safe_chunk_size() {
let config = TokenConfig::default();
let safe = safe_chunk_size(&config);
assert!(safe > 15000 && safe < 25000);
}
#[test]
fn test_batch_validation() {
let config = TokenConfig::default().with_max_tokens(10);
let texts = vec![
"short".to_string(), "a".repeat(100), "also short".to_string(), "b".repeat(200), ];
let failures = validate_batch_tokens(&texts, &config);
assert_eq!(failures.len(), 2);
assert_eq!(failures[0].0, 1); assert_eq!(failures[1].0, 3); }
}