use std::future::Future;
use std::time::Duration;
use reqwest::Client;
pub(crate) fn run_async<F, T>(f: F) -> T
where
F: Future<Output = T> + Send + 'static,
T: Send,
{
match tokio::runtime::Handle::try_current() {
Ok(handle) => {
if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread {
tokio::task::block_in_place(|| handle.block_on(f))
} else {
std::thread::scope(|s| {
s.spawn(|| {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to build tokio runtime")
.block_on(f)
})
.join()
.expect("async worker thread panicked")
})
}
}
Err(_) => tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to build tokio runtime")
.block_on(f),
}
}
use serde::{Deserialize, Serialize};
use super::SemanticSearchError;
#[derive(Serialize)]
struct EmbedRequest<'a> {
model: &'a str,
input: &'a [String],
}
#[derive(Deserialize)]
struct EmbedResponse {
data: Vec<EmbedData>,
}
#[derive(Deserialize)]
struct EmbedData {
index: usize,
embedding: Vec<f32>,
}
#[derive(Clone)]
pub struct RemoteEmbedder {
client: Client,
url: String,
model: String,
dim: usize,
api_key: Option<String>,
}
const BATCH_SIZE: usize = 64;
const REQUEST_TIMEOUT_SECS: u64 = 30;
const MAX_TEXT_CHARS: usize = 2000;
impl RemoteEmbedder {
pub async fn new(
base_url: &str,
model: &str,
expected_dim: Option<usize>,
api_key: Option<String>,
) -> Result<Self, SemanticSearchError> {
let client = Client::builder()
.timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
.build()
.map_err(|e| {
SemanticSearchError::ModelInitError(format!("HTTP client build failed: {e}"))
})?;
let url = format!("{}/v1/embeddings", base_url.trim_end_matches('/'));
let probe = Self::request(
&client,
&url,
model,
&["probe".to_string()],
api_key.as_deref(),
)
.await?;
let actual_dim = probe.first().map(|v| v.len()).ok_or_else(|| {
SemanticSearchError::ModelInitError(
"Remote server returned empty embedding on probe".into(),
)
})?;
if let Some(expected) = expected_dim {
if actual_dim != expected {
return Err(SemanticSearchError::ModelInitError(format!(
"Remote embedding dim mismatch: expected {expected}, server returned {actual_dim}"
)));
}
}
tracing::info!(
target: "semantic",
"Remote embedding backend ready: url={url} model={model} dim={actual_dim}"
);
Ok(Self {
client,
url,
model: model.to_string(),
dim: actual_dim,
api_key,
})
}
pub fn dim(&self) -> usize {
self.dim
}
pub async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, SemanticSearchError> {
let mut results: Vec<(usize, Vec<f32>)> = Vec::with_capacity(texts.len());
let truncated: Vec<String> = texts
.iter()
.map(|t| {
if t.chars().count() > MAX_TEXT_CHARS {
t.chars().take(MAX_TEXT_CHARS).collect()
} else {
t.clone()
}
})
.collect();
for (chunk_start, chunk) in truncated.chunks(BATCH_SIZE).enumerate() {
let embeddings = Self::request(
&self.client,
&self.url,
&self.model,
chunk,
self.api_key.as_deref(),
)
.await?;
if embeddings.len() != chunk.len() {
return Err(SemanticSearchError::EmbeddingError(format!(
"Remote server returned {} embeddings for {} inputs",
embeddings.len(),
chunk.len()
)));
}
for (i, emb) in embeddings.into_iter().enumerate() {
if emb.len() != self.dim {
return Err(SemanticSearchError::EmbeddingError(format!(
"Remote embedding at index {} has dim {}, expected {}",
chunk_start * BATCH_SIZE + i,
emb.len(),
self.dim
)));
}
results.push((chunk_start * BATCH_SIZE + i, emb));
}
}
results.sort_by_key(|(i, _)| *i);
Ok(results.into_iter().map(|(_, emb)| emb).collect())
}
async fn request(
client: &Client,
url: &str,
model: &str,
texts: &[String],
api_key: Option<&str>,
) -> Result<Vec<Vec<f32>>, SemanticSearchError> {
let body = EmbedRequest {
model,
input: texts,
};
let mut req = client.post(url).json(&body);
if let Some(key) = api_key {
req = req.bearer_auth(key);
}
let resp = req.send().await.map_err(|e| {
SemanticSearchError::EmbeddingError(format!("Remote embed request failed: {e}"))
})?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
return Err(SemanticSearchError::EmbeddingError(format!(
"Remote embed server returned {status}: {text}"
)));
}
let parsed: EmbedResponse = resp.json().await.map_err(|e| {
SemanticSearchError::EmbeddingError(format!("Failed to parse embed response: {e}"))
})?;
let mut data = parsed.data;
data.sort_by_key(|d| d.index);
for (expected, d) in data.iter().enumerate() {
if d.index != expected {
return Err(SemanticSearchError::EmbeddingError(format!(
"Remote embed response has non-contiguous index: expected {expected}, got {}",
d.index
)));
}
}
Ok(data.into_iter().map(|d| d.embedding).collect())
}
}