use super::Embedder;
use crate::{HippoError, Result, EMBEDDING_DIM};
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ExternalEmbeddingConfig {
pub url: String,
pub model: String,
pub dim: usize,
pub api_key: String,
pub timeout: Duration,
pub batch_size: usize,
pub max_retries: u32,
}
impl ExternalEmbeddingConfig {
pub fn validate(&self) -> Result<()> {
if self.url.is_empty() {
return Err(HippoError::Config("external embedding url is empty".into()));
}
if self.model.is_empty() {
return Err(HippoError::Config(
"external embedding model name is empty".into(),
));
}
if self.dim != EMBEDDING_DIM {
return Err(HippoError::Config(format!(
"external embedding dim {} != schema-required {} (DB swap compat — \
mcp-memory-service-rs uses FLOAT[384])",
self.dim, EMBEDDING_DIM
)));
}
if self.batch_size == 0 {
return Err(HippoError::Config(
"external embedding batch_size must be ≥ 1".into(),
));
}
if !(self.url.starts_with("http://") || self.url.starts_with("https://")) {
return Err(HippoError::Config(format!(
"external embedding url must start with http:// or https://: got {:?}",
self.url
)));
}
Ok(())
}
}
pub struct ExternalEmbedder {
cfg: ExternalEmbeddingConfig,
client: reqwest::Client,
headers: HeaderMap,
}
impl ExternalEmbedder {
pub fn new(cfg: ExternalEmbeddingConfig) -> Result<Self> {
cfg.validate()?;
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
if !cfg.api_key.is_empty() {
let bearer = format!("Bearer {}", cfg.api_key);
let mut v = HeaderValue::from_str(&bearer)
.map_err(|e| HippoError::Config(format!("invalid api_key for header: {e}")))?;
v.set_sensitive(true);
headers.insert(AUTHORIZATION, v);
}
headers.insert(
HeaderName::from_static("user-agent"),
HeaderValue::from_static(concat!("claude-hippo/", env!("CARGO_PKG_VERSION"))),
);
let client = reqwest::Client::builder()
.timeout(cfg.timeout)
.build()
.map_err(|e| HippoError::Config(format!("reqwest client build: {e}")))?;
Ok(Self {
cfg,
client,
headers,
})
}
pub fn config(&self) -> &ExternalEmbeddingConfig {
&self.cfg
}
async fn send_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let body = EmbeddingsRequest {
model: &self.cfg.model,
input: texts,
encoding_format: "float",
dimensions: self.cfg.dim as u32,
};
let mut attempt: u32 = 0;
loop {
let resp_result = self
.client
.post(&self.cfg.url)
.headers(self.headers.clone())
.json(&body)
.send()
.await;
let resp = match resp_result {
Ok(r) => r,
Err(e) => {
if attempt >= self.cfg.max_retries {
return Err(HippoError::Embedding(format!(
"external embeddings: network error after {} retries: {e}",
attempt
)));
}
tokio::time::sleep(backoff_delay(attempt)).await;
attempt += 1;
continue;
}
};
let status = resp.status();
if status.is_success() {
let parsed: EmbeddingsResponse = resp.json().await.map_err(|e| {
HippoError::Embedding(format!("external embeddings: bad JSON body: {e}"))
})?;
return self.normalize_response(parsed);
}
let retriable = status.as_u16() == 429 || (500..600).contains(&status.as_u16());
let body_text = resp.text().await.unwrap_or_default();
if !retriable || attempt >= self.cfg.max_retries {
return Err(classify_http_error(status, body_text, &self.cfg));
}
tokio::time::sleep(backoff_delay(attempt)).await;
attempt += 1;
}
}
fn normalize_response(&self, parsed: EmbeddingsResponse) -> Result<Vec<Vec<f32>>> {
let mut data = parsed.data;
data.sort_by_key(|d| d.index);
let mut out = Vec::with_capacity(data.len());
for d in data {
if d.embedding.len() != self.cfg.dim {
return Err(HippoError::Embedding(format!(
"external embeddings: model {:?} returned dim {} (expected {} for DB \
schema FLOAT[384] — reject rather than silently truncate)",
self.cfg.model,
d.embedding.len(),
self.cfg.dim
)));
}
let mut v = d.embedding;
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
for x in v.iter_mut() {
*x /= norm;
}
out.push(v);
}
Ok(out)
}
}
impl Embedder for ExternalEmbedder {
fn embed_one(&self, text: &str) -> Result<Vec<f32>> {
let v = self.embed_batch(&[text])?;
v.into_iter()
.next()
.ok_or_else(|| HippoError::Embedding("external embeddings: empty response".into()))
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let owned: Vec<String> = texts.iter().map(|s| (*s).to_string()).collect();
let mut out = Vec::with_capacity(owned.len());
for chunk in owned.chunks(self.cfg.batch_size) {
let chunk_owned = chunk.to_vec();
let part = run_async_in_sync(self.send_batch(&chunk_owned))?;
out.extend(part);
}
Ok(out)
}
}
#[derive(Debug, Serialize)]
struct EmbeddingsRequest<'a> {
model: &'a str,
input: &'a [String],
encoding_format: &'static str,
dimensions: u32,
}
#[derive(Debug, Deserialize)]
struct EmbeddingsResponse {
data: Vec<EmbeddingDatum>,
#[allow(dead_code)]
model: Option<String>,
#[allow(dead_code)]
object: Option<String>,
#[allow(dead_code)]
usage: Option<serde_json::Value>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingDatum {
embedding: Vec<f32>,
#[serde(default)]
index: usize,
#[allow(dead_code)]
object: Option<String>,
}
fn backoff_delay(attempt: u32) -> Duration {
let base_ms: u64 = 200_u64.saturating_mul(1_u64 << attempt.min(5));
Duration::from_millis(base_ms.min(5_000))
}
fn classify_http_error(
status: reqwest::StatusCode,
body: String,
cfg: &ExternalEmbeddingConfig,
) -> HippoError {
let body_preview = body.chars().take(400).collect::<String>();
let kind = match status.as_u16() {
401 => "auth: API key invalid or missing",
403 => "auth: API key rejected for this model",
404 => "endpoint not found (URL or model name wrong)",
429 => "rate limited (gave up after retries)",
s if (500..600).contains(&s) => "upstream 5xx (gave up after retries)",
_ => "unexpected HTTP error",
};
HippoError::Embedding(format!(
"external embeddings: {kind} — status={} url={} model={} body={:?}",
status, cfg.url, cfg.model, body_preview
))
}
fn run_async_in_sync<F, T>(fut: F) -> Result<T>
where
F: std::future::Future<Output = Result<T>> + Send,
T: Send,
{
if let Ok(handle) = tokio::runtime::Handle::try_current() {
match handle.runtime_flavor() {
tokio::runtime::RuntimeFlavor::MultiThread => {
tokio::task::block_in_place(|| handle.block_on(fut))
}
_ => std::thread::scope(|s| {
s.spawn(|| {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| HippoError::Embedding(format!("tokio runtime: {e}")))?;
rt.block_on(fut)
})
.join()
.map_err(|_| HippoError::Embedding("embedding worker panicked".into()))?
}),
}
} else {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| HippoError::Embedding(format!("tokio runtime: {e}")))?;
rt.block_on(fut)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg_with(url: String) -> ExternalEmbeddingConfig {
ExternalEmbeddingConfig {
url,
model: "text-embedding-3-small".into(),
dim: EMBEDDING_DIM,
api_key: "sk-test".into(),
timeout: Duration::from_secs(2),
batch_size: 4,
max_retries: 2,
}
}
#[test]
fn validate_rejects_wrong_dim() {
let mut cfg = cfg_with("https://example.com/v1/embeddings".into());
cfg.dim = 768;
let err = cfg.validate().unwrap_err();
let msg = err.to_string();
assert!(msg.contains("768"), "got {msg}");
assert!(msg.contains("384"), "got {msg}");
}
#[test]
fn validate_rejects_empty_url() {
let cfg = cfg_with(String::new());
assert!(cfg.validate().is_err());
}
#[test]
fn validate_rejects_non_http_url() {
let cfg = cfg_with("file:///etc/passwd".into());
assert!(cfg.validate().is_err());
}
#[test]
fn validate_rejects_zero_batch() {
let mut cfg = cfg_with("https://example.com".into());
cfg.batch_size = 0;
assert!(cfg.validate().is_err());
}
#[test]
fn new_builds_when_valid() {
let cfg = cfg_with("https://example.com/v1/embeddings".into());
let e = ExternalEmbedder::new(cfg).expect("build");
assert_eq!(e.config().model, "text-embedding-3-small");
}
#[test]
fn backoff_grows_then_caps() {
let d0 = backoff_delay(0);
let d3 = backoff_delay(3);
let d8 = backoff_delay(8);
assert!(d3 > d0);
assert!(d8.as_millis() <= 5_000);
}
}