use anyhow::{Context, Result};
use ndarray::Array2;
pub struct EmbeddingClient {
port: u16,
}
impl EmbeddingClient {
pub fn new(port: u16) -> Self {
Self { port }
}
pub fn compute_text_embeddings(&self, chunks: &[String]) -> Result<Array2<f32>> {
let rt = tokio::runtime::Handle::try_current()
.unwrap_or_else(|_| tokio::runtime::Runtime::new().unwrap().handle().clone());
rt.block_on(async { self.compute_text_embeddings_async(chunks).await })
}
async fn compute_text_embeddings_async(&self, chunks: &[String]) -> Result<Array2<f32>> {
use zeromq::{Socket, SocketRecv, SocketSend, ZmqMessage};
if chunks.is_empty() {
anyhow::bail!("Empty input to embedding server");
}
let batch_size = 128;
let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(chunks.len());
for batch in chunks.chunks(batch_size) {
let mut socket = zeromq::ReqSocket::new();
socket
.connect(&format!("tcp://localhost:{}", self.port))
.await
.context("connecting to embedding server")?;
let request = rmp_serde::to_vec(batch)?;
let msg = ZmqMessage::from(request);
socket
.send(msg)
.await
.context("sending embedding request")?;
let response = socket
.recv()
.await
.context("receiving embedding response")?;
let response_bytes = response.get(0).map(|f| f.as_ref()).unwrap_or(&[]);
let embeddings: Vec<Vec<f32>> =
rmp_serde::from_slice(response_bytes).context("decoding embedding response")?;
if embeddings.is_empty() {
anyhow::bail!("Empty response from embedding server");
}
all_embeddings.extend(embeddings);
}
let n = all_embeddings.len();
let d = all_embeddings[0].len();
let flat: Vec<f32> = all_embeddings.into_iter().flatten().collect();
Array2::from_shape_vec((n, d), flat).context("reshaping embeddings")
}
pub fn compute_distances(&self, node_ids: &[usize], query: &[f32]) -> Result<Vec<f32>> {
let rt = tokio::runtime::Handle::try_current()
.unwrap_or_else(|_| tokio::runtime::Runtime::new().unwrap().handle().clone());
rt.block_on(async { self.compute_distances_async(node_ids, query).await })
}
async fn compute_distances_async(&self, node_ids: &[usize], query: &[f32]) -> Result<Vec<f32>> {
use zeromq::{Socket, SocketRecv, SocketSend, ZmqMessage};
let mut socket = zeromq::ReqSocket::new();
socket
.connect(&format!("tcp://localhost:{}", self.port))
.await
.context("connecting to embedding server")?;
let request: Vec<serde_json::Value> =
vec![serde_json::json!(node_ids), serde_json::json!(query)];
let request_bytes = rmp_serde::to_vec(&request)?;
let msg = ZmqMessage::from(request_bytes);
socket.send(msg).await?;
let response = socket.recv().await?;
let response_bytes = response.get(0).map(|f| f.as_ref()).unwrap_or(&[]);
let result: Vec<Vec<f32>> = rmp_serde::from_slice(response_bytes)?;
Ok(result.into_iter().next().unwrap_or_default())
}
pub fn get_embeddings_by_id(&self, node_ids: &[usize]) -> Result<Array2<f32>> {
let rt = tokio::runtime::Handle::try_current()
.unwrap_or_else(|_| tokio::runtime::Runtime::new().unwrap().handle().clone());
rt.block_on(async { self.get_embeddings_by_id_async(node_ids).await })
}
async fn get_embeddings_by_id_async(&self, node_ids: &[usize]) -> Result<Array2<f32>> {
use zeromq::{Socket, SocketRecv, SocketSend, ZmqMessage};
let mut socket = zeromq::ReqSocket::new();
socket
.connect(&format!("tcp://localhost:{}", self.port))
.await?;
let request = vec![node_ids.to_vec()];
let request_bytes = rmp_serde::to_vec(&request)?;
let msg = ZmqMessage::from(request_bytes);
socket.send(msg).await?;
let response = socket.recv().await?;
let response_bytes = response.get(0).map(|f| f.as_ref()).unwrap_or(&[]);
let result: Vec<Vec<f32>> = rmp_serde::from_slice(response_bytes)?;
if result.len() < 2 {
anyhow::bail!("Invalid embedding-by-id response");
}
let dims = &result[0];
if dims.len() < 2 {
anyhow::bail!("Invalid dimensions in response");
}
let n = dims[0] as usize;
let d = dims[1] as usize;
let flat_data = &result[1];
if flat_data.len() != n * d {
anyhow::bail!(
"Data length mismatch: expected {}, got {}",
n * d,
flat_data.len()
);
}
Array2::from_shape_vec((n, d), flat_data.clone()).context("reshaping embeddings")
}
}