use async_trait::async_trait;
use super::{EmbedRequest, EmbedResponse, EmbedderClient, EmbedderError};
#[derive(Clone, Debug)]
pub struct RemoteEmbedderClient {
client: reqwest::Client,
embed_url: String,
}
impl RemoteEmbedderClient {
pub fn new(base_url: impl Into<String>) -> Self {
let base = base_url.into();
let base = base.trim_end_matches('/').to_owned();
let embed_url = format!("{base}/embed");
let client = reqwest::Client::builder()
.build()
.expect("reqwest client construction is infallible on supported platforms");
tracing::debug!(embed_url = %embed_url, "RemoteEmbedderClient constructed");
Self { client, embed_url }
}
pub fn base_url(&self) -> &str {
self.embed_url
.strip_suffix("/embed")
.unwrap_or(&self.embed_url)
}
}
#[async_trait]
impl EmbedderClient for RemoteEmbedderClient {
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbedderError> {
if texts.is_empty() {
return Ok(vec![]);
}
let sent = texts.len();
let req = EmbedRequest { texts };
tracing::debug!(url = %self.embed_url, n = sent, "RemoteEmbedderClient: sending batch");
let response = self.client.post(&self.embed_url).json(&req).send().await?;
let status = response.status();
if !status.is_success() {
let body = response
.text()
.await
.unwrap_or_else(|_| "(unreadable body)".to_owned());
return Err(EmbedderError::RemoteError {
status: status.as_u16(),
body,
});
}
let resp: EmbedResponse = response.json().await?;
if resp.vectors.len() != sent {
return Err(EmbedderError::DimensionMismatch {
sent,
got: resp.vectors.len(),
});
}
tracing::debug!(url = %self.embed_url, n = sent, "RemoteEmbedderClient: batch complete");
Ok(resp.vectors)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn remote_client_construction() {
let c = RemoteEmbedderClient::new("http://127.0.0.1:7890");
assert_eq!(c.embed_url, "http://127.0.0.1:7890/embed");
assert_eq!(c.base_url(), "http://127.0.0.1:7890");
}
#[test]
fn remote_client_strips_trailing_slash() {
let c = RemoteEmbedderClient::new("http://127.0.0.1:7890/");
assert_eq!(c.embed_url, "http://127.0.0.1:7890/embed");
}
#[tokio::test]
async fn empty_batch_short_circuits() {
let c = RemoteEmbedderClient::new("http://127.0.0.1:1"); let result = c
.embed_batch(vec![])
.await
.expect("empty batch short-circuits");
assert!(result.is_empty());
}
}