use std::fmt;
use std::io::{Read, Write};
use std::process::{Child, Command, Stdio};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
use leindex_embed::protocol::{
self, BatchId, EmbedRequest, EmbedResponse, Frame, MsgType, RerankDocument, RerankRequest,
RerankResponse, Response, WorkerError,
};
fn read_frame_with_timeout(stdout: &mut std::process::ChildStdout) -> Result<Vec<u8>, ClientError> {
let mut len_buf = [0u8; 4];
match stdout.read_exact(&mut len_buf) {
Ok(()) => {}
Err(e) => {
return Err(ClientError::Ipc(format!(
"failed to read frame length: {}",
e
)));
}
}
let payload_len = u32::from_le_bytes(len_buf);
if payload_len > MAX_RESPONSE_FRAME_SIZE {
return Err(ClientError::Ipc(format!(
"response frame too large: {} bytes (max: {} bytes)",
payload_len, MAX_RESPONSE_FRAME_SIZE
)));
}
let payload_len = payload_len as usize;
let mut frame_buf = vec![0u8; payload_len];
match stdout.read_exact(&mut frame_buf) {
Ok(()) => Ok(frame_buf),
Err(e) => Err(ClientError::Ipc(format!(
"failed to read frame payload: {}",
e
))),
}
}
static BATCH_COUNTER: AtomicU64 = AtomicU64::new(1);
const MAX_RESPONSE_FRAME_SIZE: u32 = 64 * 1024 * 1024;
const IPC_TIMEOUT_SECS: u64 = 30;
fn platform_binary_name(binary_name: &str) -> String {
if cfg!(windows) {
format!("{}.exe", binary_name)
} else {
binary_name.to_string()
}
}
fn resolve_worker_binary() -> Result<std::path::PathBuf, std::io::Error> {
let binary_name = platform_binary_name("leindex-embed");
if let Ok(exe) = std::env::current_exe() {
if let Some(exe_dir) = exe.parent() {
let sibling = exe_dir.join(&binary_name);
if sibling.exists() {
return Ok(sibling);
}
}
}
which::which(&binary_name).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("worker binary '{}' not found in PATH: {}", binary_name, e),
)
})
}
#[derive(Debug, thiserror::Error)]
pub enum ClientError {
#[error("failed to spawn worker: {0}")]
SpawnFailed(String),
#[error("IPC error: {0}")]
Ipc(String),
#[error("worker error: {0}")]
Worker(WorkerError),
#[error("protocol error: {0}")]
Protocol(String),
#[error(
"IPC timeout: worker did not respond within {} seconds",
IPC_TIMEOUT_SECS
)]
Timeout,
}
#[derive(Debug)]
pub enum EmbedResult {
Success(EmbedResponse),
Fallback {
batch_id: BatchId,
error: ClientError,
},
}
impl EmbedResult {
pub fn is_success(&self) -> bool {
matches!(self, EmbedResult::Success(_))
}
pub fn is_fallback(&self) -> bool {
matches!(self, EmbedResult::Fallback { .. })
}
pub fn into_success(self) -> Option<EmbedResponse> {
match self {
EmbedResult::Success(resp) => Some(resp),
EmbedResult::Fallback { .. } => None,
}
}
}
pub struct EmbeddingClient {
worker: Arc<Mutex<Option<WorkerHandle>>>,
}
impl fmt::Debug for EmbeddingClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EmbeddingClient")
.field("worker", &self.worker.lock().map(|g| g.is_some()))
.finish()
}
}
impl Clone for EmbeddingClient {
fn clone(&self) -> Self {
Self {
worker: Arc::clone(&self.worker),
}
}
}
struct WorkerHandle {
child: Child,
stdin: Option<std::process::ChildStdin>,
read_thread: thread::JoinHandle<()>,
read_request_tx: std::sync::mpsc::Sender<ReadRequest>,
}
enum ReadRequest {
Read {
tx: mpsc::Sender<Result<Vec<u8>, ClientError>>,
},
Shutdown,
}
impl EmbeddingClient {
pub fn new() -> Self {
Self {
worker: Arc::new(Mutex::new(None)),
}
}
fn next_batch_id() -> BatchId {
BatchId::new(BATCH_COUNTER.fetch_add(1, Ordering::Relaxed))
}
fn ensure_worker(&self) -> Result<(), ClientError> {
let mut guard = self
.worker
.lock()
.map_err(|e| ClientError::Ipc(format!("failed to lock worker handle: {}", e)))?;
if guard.is_some() {
return Ok(());
}
self.spawn_worker(&mut guard)
}
fn spawn_worker(
&self,
guard: &mut std::sync::MutexGuard<'_, Option<WorkerHandle>>,
) -> Result<(), ClientError> {
let worker_path = resolve_worker_binary().map_err(|e| {
ClientError::SpawnFailed(format!("failed to resolve worker binary: {}", e))
})?;
let mut child = Command::new(&worker_path)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.spawn()
.map_err(|e| ClientError::SpawnFailed(e.to_string()))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| ClientError::SpawnFailed("failed to open worker stdin".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| ClientError::SpawnFailed("failed to open worker stdout".to_string()))?;
let (read_request_tx, read_request_rx) = mpsc::channel::<ReadRequest>();
let read_thread = thread::spawn(move || {
let mut stdout = stdout;
while let Ok(request) = read_request_rx.recv() {
match request {
ReadRequest::Read { tx } => {
let result = read_frame_with_timeout(&mut stdout);
let _ = tx.send(result);
}
ReadRequest::Shutdown => {
break;
}
}
}
});
**guard = Some(WorkerHandle {
child,
stdin: Some(stdin),
read_thread,
read_request_tx,
});
Ok(())
}
pub fn kill_worker(&self) {
if let Ok(mut guard) = self.worker.lock() {
if let Some(handle) = guard.as_mut() {
let _ = handle.read_request_tx.send(ReadRequest::Shutdown);
#[cfg(unix)]
{
let pid = handle.child.id() as libc::pid_t;
if pid > 0 {
unsafe {
libc::kill(pid, libc::SIGTERM);
}
}
}
#[cfg(not(unix))]
{
drop(handle.stdin.take());
}
}
if let Some(mut handle) = guard.take() {
#[cfg(unix)]
{
drop(handle.stdin.take());
match handle.child.try_wait() {
Ok(Some(_status)) => {
let _ = handle.read_thread.join();
let _ = handle.child.wait();
return;
}
Ok(None) => {
let deadline =
std::time::Instant::now() + std::time::Duration::from_secs(2);
loop {
match handle.child.try_wait() {
Ok(Some(_)) => break,
Ok(None) if std::time::Instant::now() < deadline => {
std::thread::sleep(std::time::Duration::from_millis(50));
}
_ => break,
}
}
}
Err(_) => {}
}
}
#[cfg(not(unix))]
{
let deadline =
std::time::Instant::now() + std::time::Duration::from_secs(2);
loop {
match handle.child.try_wait() {
Ok(Some(_)) => break,
Ok(None) if std::time::Instant::now() < deadline => {
std::thread::sleep(std::time::Duration::from_millis(50));
}
_ => break,
}
}
}
let _ = handle.child.kill();
let _ = handle.read_thread.join();
let _ = handle.child.wait();
}
}
}
pub fn embed_with_fallback(&self, texts: &[String], expected_dim: usize) -> EmbedResult {
let batch_id = Self::next_batch_id();
match self.embed_attempt(batch_id, texts, expected_dim) {
Ok(response) => return EmbedResult::Success(response),
Err(first_error) => {
tracing::warn!(
batch_id = %batch_id,
error = %first_error,
"ONNX worker failed on first attempt, retrying once"
);
self.kill_worker();
let retry_batch_id = Self::next_batch_id();
match self.embed_attempt(retry_batch_id, texts, expected_dim) {
Ok(response) => {
tracing::info!(
original_batch = %batch_id,
retry_batch = %retry_batch_id,
"ONNX worker retry succeeded"
);
return EmbedResult::Success(response);
}
Err(retry_error) => {
tracing::warn!(
batch_id = %batch_id,
retry_batch_id = %retry_batch_id,
first_error = %first_error,
retry_error = %retry_error,
"ONNX worker fallback for batch {}: {} (retry exhausted, degrading to TF-IDF)",
batch_id,
retry_error
);
self.kill_worker();
return EmbedResult::Fallback {
batch_id,
error: retry_error,
};
}
}
}
}
}
fn embed_attempt(
&self,
batch_id: BatchId,
texts: &[String],
expected_dim: usize,
) -> Result<EmbedResponse, ClientError> {
self.ensure_worker()?;
let request = EmbedRequest {
texts: texts.to_vec(),
expected_dim,
};
let frame = protocol::embed_request_frame(batch_id, request)
.map_err(|e| ClientError::Ipc(e.to_string()))?;
let response_frame = self.send_and_receive(frame)?;
match response_frame.header.msg_type {
MsgType::EmbedResponse => {
let response: Response = response_frame
.decode_payload()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
match response {
Response::Embed(embed_resp) => Ok(embed_resp),
_ => Err(ClientError::Protocol("expected Embed response".to_string())),
}
}
MsgType::Error => {
let response: Response = response_frame
.decode_payload()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
match response {
Response::Error(err) => Err(ClientError::Worker(err)),
_ => Err(ClientError::Protocol("expected Error response".to_string())),
}
}
other => Err(ClientError::Protocol(format!(
"unexpected response type: {:?}",
other
))),
}
}
pub fn embed(
&self,
texts: &[String],
expected_dim: usize,
) -> Result<EmbedResponse, ClientError> {
self.ensure_worker()?;
let batch_id = Self::next_batch_id();
let request = EmbedRequest {
texts: texts.to_vec(),
expected_dim,
};
let frame = protocol::embed_request_frame(batch_id, request)
.map_err(|e| ClientError::Ipc(e.to_string()))?;
let response_frame = self.send_and_receive(frame)?;
match response_frame.header.msg_type {
MsgType::EmbedResponse => {
let response: Response = response_frame
.decode_payload()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
match response {
Response::Embed(embed_resp) => Ok(embed_resp),
_ => Err(ClientError::Protocol("expected Embed response".to_string())),
}
}
MsgType::Error => {
let response: Response = response_frame
.decode_payload()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
match response {
Response::Error(err) => Err(ClientError::Worker(err)),
_ => Err(ClientError::Protocol("expected Error response".to_string())),
}
}
other => Err(ClientError::Protocol(format!(
"unexpected response type: {:?}",
other
))),
}
}
pub fn rerank(
&self,
query: &str,
documents: Vec<RerankDocument>,
) -> Result<RerankResponse, ClientError> {
self.ensure_worker()?;
let batch_id = Self::next_batch_id();
let request = RerankRequest {
query: query.to_string(),
documents,
};
let frame = protocol::rerank_request_frame(batch_id, request)
.map_err(|e| ClientError::Ipc(e.to_string()))?;
let response_frame = self.send_and_receive(frame)?;
match response_frame.header.msg_type {
MsgType::RerankResponse => {
let response: Response = response_frame
.decode_payload()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
match response {
Response::Rerank(rerank_resp) => Ok(rerank_resp),
_ => Err(ClientError::Protocol(
"expected Rerank response".to_string(),
)),
}
}
MsgType::Error => {
let response: Response = response_frame
.decode_payload()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
match response {
Response::Error(err) => Err(ClientError::Worker(err)),
_ => Err(ClientError::Protocol("expected Error response".to_string())),
}
}
other => Err(ClientError::Protocol(format!(
"unexpected response type: {:?}",
other
))),
}
}
fn send_and_receive(&self, frame: Frame) -> Result<Frame, ClientError> {
let mut guard = self
.worker
.lock()
.map_err(|e| ClientError::Ipc(format!("failed to lock worker handle: {}", e)))?;
let handle = guard
.as_mut()
.ok_or_else(|| ClientError::Ipc("worker not running".to_string()))?;
let wire = frame
.encode_wire()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
let request_batch_id = frame.header.batch_id;
if let Err(e) = handle
.stdin
.as_mut()
.ok_or_else(|| ClientError::Ipc("worker stdin not available".into()))?
.write_all(&wire)
{
drop(guard);
self.kill_worker();
return Err(ClientError::Ipc(format!(
"failed to write to worker: {}",
e
)));
}
if let Err(e) = handle
.stdin
.as_mut()
.ok_or_else(|| ClientError::Ipc("worker stdin not available".into()))?
.flush()
{
drop(guard);
self.kill_worker();
return Err(ClientError::Ipc(format!(
"failed to flush worker stdin: {}",
e
)));
}
let (tx, rx) = mpsc::channel();
handle
.read_request_tx
.send(ReadRequest::Read { tx })
.map_err(|_e| ClientError::Ipc("reader thread channel closed".to_string()))?;
match rx.recv_timeout(Duration::from_secs(IPC_TIMEOUT_SECS)) {
Ok(Ok(frame_buf)) => {
let response = match Frame::from_wire_bytes(&frame_buf) {
Ok(response) => response,
Err(e) => {
drop(guard);
self.kill_worker();
return Err(ClientError::Ipc(e.to_string()));
}
};
if response.header.batch_id != request_batch_id {
drop(guard);
self.kill_worker();
return Err(ClientError::Ipc(format!(
"response batch_id mismatch: expected {}, got {}",
request_batch_id, response.header.batch_id
)));
}
Ok(response)
}
Ok(Err(e)) => {
drop(guard);
self.kill_worker();
if e.to_string().contains("too large") {
Err(ClientError::Ipc(e.to_string()))
} else {
Err(ClientError::Ipc(format!(
"failed to read from worker: {}",
e
)))
}
}
Err(mpsc::RecvTimeoutError::Timeout) => {
drop(guard);
self.kill_worker();
Err(ClientError::Timeout)
}
Err(mpsc::RecvTimeoutError::Disconnected) => {
drop(guard);
self.kill_worker();
Err(ClientError::Ipc(
"reader thread disconnected unexpectedly".to_string(),
))
}
}
}
}
impl Drop for EmbeddingClient {
fn drop(&mut self) {
let worker = match Arc::try_unwrap(std::mem::take(&mut self.worker)) {
Ok(worker) => worker,
Err(_) => return,
};
let mut guard = worker.into_inner().unwrap_or_else(|e| e.into_inner());
{
if let Some(handle) = guard.as_mut() {
let _ = handle.read_request_tx.send(ReadRequest::Shutdown);
#[cfg(unix)]
{
let pid = handle.child.id() as libc::pid_t;
if pid > 0 {
unsafe {
libc::kill(pid, libc::SIGTERM);
}
}
}
#[cfg(not(unix))]
{
drop(handle.stdin.take());
}
}
if let Some(mut handle) = guard.take() {
drop(handle.stdin.take());
#[cfg(unix)]
{
let deadline = std::time::Duration::from_secs(2);
let start = std::time::Instant::now();
loop {
match handle.child.try_wait() {
Ok(Some(_)) => break,
Ok(None) if start.elapsed() < deadline => {
std::thread::sleep(std::time::Duration::from_millis(50));
}
_ => break,
}
}
}
#[cfg(not(unix))]
{
let deadline = std::time::Duration::from_secs(2);
let start = std::time::Instant::now();
loop {
match handle.child.try_wait() {
Ok(Some(_)) => break,
Ok(None) if start.elapsed() < deadline => {
std::thread::sleep(std::time::Duration::from_millis(50));
}
_ => break,
}
}
}
let _ = handle.child.kill();
let _ = handle.read_thread.join();
let _ = handle.child.wait();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use leindex_embed::protocol::ErrorKind;
#[test]
fn test_client_creation() {
let _client = EmbeddingClient::new();
}
#[test]
fn test_client_debug_impl() {
let client = EmbeddingClient::new();
let debug_str = format!("{:?}", client);
assert!(debug_str.contains("EmbeddingClient"));
}
#[test]
fn test_client_clone_shares_worker() {
let client = EmbeddingClient::new();
let cloned = client.clone();
let _ = format!("{:?}", cloned);
}
#[test]
fn test_client_error_display() {
let err = ClientError::SpawnFailed("not found".to_string());
assert!(err.to_string().contains("not found"));
let worker_err = WorkerError {
kind: ErrorKind::ModelNotFound,
message: "missing model".to_string(),
};
let err = ClientError::Worker(worker_err);
assert!(err.to_string().contains("missing model"));
}
#[test]
fn test_embed_result_success() {
let response = EmbedResponse::new(vec![1.0, 2.0, 3.0, 4.0], 1, 4);
let result = EmbedResult::Success(response);
assert!(result.is_success());
assert!(!result.is_fallback());
assert!(result.into_success().is_some());
}
#[test]
fn test_embed_result_fallback() {
let error = ClientError::Worker(WorkerError {
kind: ErrorKind::Inference,
message: "worker crashed".to_string(),
});
let result = EmbedResult::Fallback {
batch_id: BatchId::new(42),
error,
};
assert!(!result.is_success());
assert!(result.is_fallback());
assert!(result.into_success().is_none());
}
#[test]
fn test_batch_id_monotonic() {
let id1 = EmbeddingClient::next_batch_id();
let id2 = EmbeddingClient::next_batch_id();
assert!(
id2.0 > id1.0,
"batch IDs should be monotonically increasing"
);
}
}