use std::fmt;
use std::io::{Read, Write};
use std::process::{Child, Command, Stdio};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
use leindex_embed::protocol::{
self, BatchId, EmbedRequest, EmbedResponse, Frame, MsgType, RerankDocument, RerankRequest,
RerankResponse, Response, WorkerError,
};
fn read_frame_with_timeout(stdout: &mut std::process::ChildStdout) -> Result<Vec<u8>, ClientError> {
let mut len_buf = [0u8; 4];
match stdout.read_exact(&mut len_buf) {
Ok(()) => {}
Err(e) => {
return Err(ClientError::Ipc(format!(
"failed to read frame length: {}",
e
)));
}
}
let payload_len = u32::from_le_bytes(len_buf);
if payload_len > MAX_RESPONSE_FRAME_SIZE {
return Err(ClientError::Ipc(format!(
"response frame too large: {} bytes (max: {} bytes)",
payload_len, MAX_RESPONSE_FRAME_SIZE
)));
}
let payload_len = payload_len as usize;
let mut frame_buf = vec![0u8; payload_len];
match stdout.read_exact(&mut frame_buf) {
Ok(()) => Ok(frame_buf),
Err(e) => Err(ClientError::Ipc(format!(
"failed to read frame payload: {}",
e
))),
}
}
static BATCH_COUNTER: AtomicU64 = AtomicU64::new(1);
const MAX_RESPONSE_FRAME_SIZE: u32 = 64 * 1024 * 1024;
const IPC_TIMEOUT_SECS: u64 = 300;
fn platform_binary_name(binary_name: &str) -> String {
if cfg!(windows) {
format!("{}.exe", binary_name)
} else {
binary_name.to_string()
}
}
fn resolve_worker_binary() -> Result<std::path::PathBuf, std::io::Error> {
let binary_name = platform_binary_name("leindex-embed");
if let Ok(exe) = std::env::current_exe() {
if let Some(exe_dir) = exe.parent() {
let sibling = exe_dir.join(&binary_name);
if sibling.exists() {
return Ok(sibling);
}
}
}
which::which(&binary_name).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("worker binary '{}' not found in PATH: {}", binary_name, e),
)
})
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
struct WorkerConfigEnv {
ort_dylib_path: Option<String>,
execution_provider: Option<String>,
}
fn read_worker_config_env_from_config() -> WorkerConfigEnv {
let Some(home) = leindex_home_dir() else {
return WorkerConfigEnv::default();
};
let cfg = home.join("config").join("leindex.toml");
let Ok(contents) = std::fs::read_to_string(&cfg) else {
return WorkerConfigEnv::default();
};
let mut parsed = WorkerConfigEnv::default();
for raw in contents.lines() {
let line = raw.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some(value) = parse_config_assignment(line, "ort_dylib_path") {
parsed.ort_dylib_path = Some(value);
} else if let Some(value) = parse_config_assignment(line, "execution_provider") {
if !value.eq_ignore_ascii_case("auto") {
parsed.execution_provider = Some(value);
}
}
}
parsed
}
fn read_ort_dylib_path_from_config() -> Option<String> {
read_worker_config_env_from_config().ort_dylib_path
}
fn read_execution_provider_from_config() -> Option<String> {
read_worker_config_env_from_config().execution_provider
}
fn parse_config_assignment(line: &str, key: &str) -> Option<String> {
let rest = line.strip_prefix(key)?.trim_start();
let value_part = rest.strip_prefix('=')?.trim();
let trimmed = value_part.trim_matches(|c| c == '"' || c == '\'').trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
fn leindex_home_dir() -> Option<std::path::PathBuf> {
if let Ok(custom) = std::env::var("LEINDEX_HOME") {
let p = std::path::PathBuf::from(&custom);
if p.is_absolute() {
return Some(p);
}
}
std::env::var("HOME")
.ok()
.map(|h| std::path::PathBuf::from(h).join(".leindex"))
}
#[derive(Debug, thiserror::Error)]
pub enum ClientError {
#[error("failed to spawn worker: {0}")]
SpawnFailed(String),
#[error("IPC error: {0}")]
Ipc(String),
#[error("worker error: {0}")]
Worker(WorkerError),
#[error("protocol error: {0}")]
Protocol(String),
#[error(
"IPC timeout: worker did not respond within {} seconds",
IPC_TIMEOUT_SECS
)]
Timeout,
}
#[derive(Debug)]
pub enum EmbedResult {
Success(EmbedResponse),
Fallback {
batch_id: BatchId,
error: ClientError,
},
}
impl EmbedResult {
pub fn is_success(&self) -> bool {
matches!(self, EmbedResult::Success(_))
}
pub fn is_fallback(&self) -> bool {
matches!(self, EmbedResult::Fallback { .. })
}
pub fn into_success(self) -> Option<EmbedResponse> {
match self {
EmbedResult::Success(resp) => Some(resp),
EmbedResult::Fallback { .. } => None,
}
}
}
pub struct EmbeddingClient {
worker: Arc<Mutex<Option<WorkerHandle>>>,
}
impl fmt::Debug for EmbeddingClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EmbeddingClient")
.field("worker", &self.worker.lock().map(|g| g.is_some()))
.finish()
}
}
impl Clone for EmbeddingClient {
fn clone(&self) -> Self {
Self {
worker: Arc::clone(&self.worker),
}
}
}
struct WorkerHandle {
child: Child,
stdin: Option<std::process::ChildStdin>,
read_thread: thread::JoinHandle<()>,
read_request_tx: std::sync::mpsc::Sender<ReadRequest>,
}
enum ReadRequest {
Read {
tx: mpsc::Sender<Result<Vec<u8>, ClientError>>,
},
Shutdown,
}
impl Default for EmbeddingClient {
fn default() -> Self {
Self::new()
}
}
impl EmbeddingClient {
pub fn new() -> Self {
Self {
worker: Arc::new(Mutex::new(None)),
}
}
fn next_batch_id() -> BatchId {
BatchId::new(BATCH_COUNTER.fetch_add(1, Ordering::Relaxed))
}
fn ensure_worker(&self) -> Result<(), ClientError> {
let mut guard = self
.worker
.lock()
.map_err(|e| ClientError::Ipc(format!("failed to lock worker handle: {}", e)))?;
if guard.is_some() {
return Ok(());
}
self.spawn_worker(&mut guard)
}
fn spawn_worker(
&self,
guard: &mut std::sync::MutexGuard<'_, Option<WorkerHandle>>,
) -> Result<(), ClientError> {
let worker_path = resolve_worker_binary().map_err(|e| {
ClientError::SpawnFailed(format!("failed to resolve worker binary: {}", e))
})?;
let mut cmd = Command::new(&worker_path);
if let Ok(model_path) = std::env::var("LEINDEX_MODEL_PATH") {
cmd.env("LEINDEX_MODEL_PATH", &model_path);
}
if let Ok(provider) = std::env::var("LEINDEX_WORKER_EXECUTION_PROVIDER") {
cmd.env("LEINDEX_WORKER_EXECUTION_PROVIDER", &provider);
} else if let Some(provider) = read_execution_provider_from_config() {
cmd.env("LEINDEX_WORKER_EXECUTION_PROVIDER", &provider);
}
if std::env::var_os("ORT_DYLIB_PATH").is_none() {
if let Some(path) = read_ort_dylib_path_from_config() {
cmd.env("ORT_DYLIB_PATH", &path);
}
}
let mut child = cmd
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()
.map_err(|e| ClientError::SpawnFailed(e.to_string()))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| ClientError::SpawnFailed("failed to open worker stdin".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| ClientError::SpawnFailed("failed to open worker stdout".to_string()))?;
let (read_request_tx, read_request_rx) = mpsc::channel::<ReadRequest>();
let read_thread = thread::spawn(move || {
let mut stdout = stdout;
while let Ok(request) = read_request_rx.recv() {
match request {
ReadRequest::Read { tx } => {
let result = read_frame_with_timeout(&mut stdout);
let _ = tx.send(result);
}
ReadRequest::Shutdown => {
break;
}
}
}
});
**guard = Some(WorkerHandle {
child,
stdin: Some(stdin),
read_thread,
read_request_tx,
});
Ok(())
}
pub fn kill_worker(&self) {
if let Ok(mut guard) = self.worker.lock() {
if let Some(handle) = guard.as_mut() {
let _ = handle.read_request_tx.send(ReadRequest::Shutdown);
#[cfg(unix)]
{
let pid = handle.child.id() as libc::pid_t;
if pid > 0 {
unsafe {
libc::kill(pid, libc::SIGTERM);
}
}
}
#[cfg(not(unix))]
{
drop(handle.stdin.take());
}
}
if let Some(mut handle) = guard.take() {
#[cfg(unix)]
{
drop(handle.stdin.take());
match handle.child.try_wait() {
Ok(Some(_status)) => {
let _ = handle.read_thread.join();
let _ = handle.child.wait();
return;
}
Ok(None) => {
let deadline =
std::time::Instant::now() + std::time::Duration::from_secs(2);
loop {
match handle.child.try_wait() {
Ok(Some(_)) => break,
Ok(None) if std::time::Instant::now() < deadline => {
std::thread::sleep(std::time::Duration::from_millis(50));
}
_ => break,
}
}
}
Err(_) => {}
}
}
#[cfg(not(unix))]
{
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
loop {
match handle.child.try_wait() {
Ok(Some(_)) => break,
Ok(None) if std::time::Instant::now() < deadline => {
std::thread::sleep(std::time::Duration::from_millis(50));
}
_ => break,
}
}
}
let _ = handle.child.kill();
let _ = handle.read_thread.join();
let _ = handle.child.wait();
}
}
}
pub fn embed_with_fallback(&self, texts: &[String], expected_dim: usize) -> EmbedResult {
let batch_id = Self::next_batch_id();
match self.embed_attempt(batch_id, texts, expected_dim) {
Ok(response) => EmbedResult::Success(response),
Err(first_error) => {
tracing::warn!(
batch_id = %batch_id,
error = %first_error,
"ONNX worker failed on first attempt, retrying once"
);
self.kill_worker();
let retry_batch_id = Self::next_batch_id();
match self.embed_attempt(retry_batch_id, texts, expected_dim) {
Ok(response) => {
tracing::info!(
original_batch = %batch_id,
retry_batch = %retry_batch_id,
"ONNX worker retry succeeded"
);
EmbedResult::Success(response)
}
Err(retry_error) => {
tracing::warn!(
batch_id = %batch_id,
retry_batch_id = %retry_batch_id,
first_error = %first_error,
retry_error = %retry_error,
"ONNX worker fallback for batch {}: {} (retry exhausted, degrading to TF-IDF)",
batch_id,
retry_error
);
self.kill_worker();
EmbedResult::Fallback {
batch_id,
error: retry_error,
}
}
}
}
}
}
fn embed_attempt(
&self,
batch_id: BatchId,
texts: &[String],
expected_dim: usize,
) -> Result<EmbedResponse, ClientError> {
self.ensure_worker()?;
let request = EmbedRequest {
texts: texts.to_vec(),
expected_dim,
};
let frame = protocol::embed_request_frame(batch_id, request)
.map_err(|e| ClientError::Ipc(e.to_string()))?;
let response_frame = self.send_and_receive(frame)?;
match response_frame.header.msg_type {
MsgType::EmbedResponse => {
let response: Response = response_frame
.decode_payload()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
match response {
Response::Embed(embed_resp) => Ok(embed_resp),
_ => Err(ClientError::Protocol("expected Embed response".to_string())),
}
}
MsgType::Error => {
let response: Response = response_frame
.decode_payload()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
match response {
Response::Error(err) => Err(ClientError::Worker(err)),
_ => Err(ClientError::Protocol("expected Error response".to_string())),
}
}
other => Err(ClientError::Protocol(format!(
"unexpected response type: {:?}",
other
))),
}
}
pub fn embed(
&self,
texts: &[String],
expected_dim: usize,
) -> Result<EmbedResponse, ClientError> {
self.ensure_worker()?;
let batch_id = Self::next_batch_id();
let request = EmbedRequest {
texts: texts.to_vec(),
expected_dim,
};
let frame = protocol::embed_request_frame(batch_id, request)
.map_err(|e| ClientError::Ipc(e.to_string()))?;
let response_frame = self.send_and_receive(frame)?;
match response_frame.header.msg_type {
MsgType::EmbedResponse => {
let response: Response = response_frame
.decode_payload()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
match response {
Response::Embed(embed_resp) => Ok(embed_resp),
_ => Err(ClientError::Protocol("expected Embed response".to_string())),
}
}
MsgType::Error => {
let response: Response = response_frame
.decode_payload()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
match response {
Response::Error(err) => Err(ClientError::Worker(err)),
_ => Err(ClientError::Protocol("expected Error response".to_string())),
}
}
other => Err(ClientError::Protocol(format!(
"unexpected response type: {:?}",
other
))),
}
}
pub fn rerank(
&self,
query: &str,
documents: Vec<RerankDocument>,
) -> Result<RerankResponse, ClientError> {
self.ensure_worker()?;
let batch_id = Self::next_batch_id();
let request = RerankRequest {
query: query.to_string(),
documents,
};
let frame = protocol::rerank_request_frame(batch_id, request)
.map_err(|e| ClientError::Ipc(e.to_string()))?;
let response_frame = self.send_and_receive(frame)?;
match response_frame.header.msg_type {
MsgType::RerankResponse => {
let response: Response = response_frame
.decode_payload()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
match response {
Response::Rerank(rerank_resp) => Ok(rerank_resp),
_ => Err(ClientError::Protocol(
"expected Rerank response".to_string(),
)),
}
}
MsgType::Error => {
let response: Response = response_frame
.decode_payload()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
match response {
Response::Error(err) => Err(ClientError::Worker(err)),
_ => Err(ClientError::Protocol("expected Error response".to_string())),
}
}
other => Err(ClientError::Protocol(format!(
"unexpected response type: {:?}",
other
))),
}
}
fn send_and_receive(&self, frame: Frame) -> Result<Frame, ClientError> {
let mut guard = self
.worker
.lock()
.map_err(|e| ClientError::Ipc(format!("failed to lock worker handle: {}", e)))?;
let handle = guard
.as_mut()
.ok_or_else(|| ClientError::Ipc("worker not running".to_string()))?;
let wire = frame
.encode_wire()
.map_err(|e| ClientError::Ipc(e.to_string()))?;
let request_batch_id = frame.header.batch_id;
if let Err(e) = handle
.stdin
.as_mut()
.ok_or_else(|| ClientError::Ipc("worker stdin not available".into()))?
.write_all(&wire)
{
drop(guard);
self.kill_worker();
return Err(ClientError::Ipc(format!(
"failed to write to worker: {}",
e
)));
}
if let Err(e) = handle
.stdin
.as_mut()
.ok_or_else(|| ClientError::Ipc("worker stdin not available".into()))?
.flush()
{
drop(guard);
self.kill_worker();
return Err(ClientError::Ipc(format!(
"failed to flush worker stdin: {}",
e
)));
}
let (tx, rx) = mpsc::channel();
handle
.read_request_tx
.send(ReadRequest::Read { tx })
.map_err(|_e| ClientError::Ipc("reader thread channel closed".to_string()))?;
match rx.recv_timeout(Duration::from_secs(IPC_TIMEOUT_SECS)) {
Ok(Ok(frame_buf)) => {
let response = match Frame::from_wire_bytes(&frame_buf) {
Ok(response) => response,
Err(e) => {
drop(guard);
self.kill_worker();
return Err(ClientError::Ipc(e.to_string()));
}
};
if response.header.batch_id != request_batch_id {
drop(guard);
self.kill_worker();
return Err(ClientError::Ipc(format!(
"response batch_id mismatch: expected {}, got {}",
request_batch_id, response.header.batch_id
)));
}
Ok(response)
}
Ok(Err(e)) => {
drop(guard);
self.kill_worker();
if e.to_string().contains("too large") {
Err(ClientError::Ipc(e.to_string()))
} else {
Err(ClientError::Ipc(format!(
"failed to read from worker: {}",
e
)))
}
}
Err(mpsc::RecvTimeoutError::Timeout) => {
drop(guard);
self.kill_worker();
Err(ClientError::Timeout)
}
Err(mpsc::RecvTimeoutError::Disconnected) => {
drop(guard);
self.kill_worker();
Err(ClientError::Ipc(
"reader thread disconnected unexpectedly".to_string(),
))
}
}
}
}
impl Drop for EmbeddingClient {
fn drop(&mut self) {
let worker = match Arc::try_unwrap(std::mem::take(&mut self.worker)) {
Ok(worker) => worker,
Err(_) => return,
};
let mut guard = worker.into_inner().unwrap_or_else(|e| e.into_inner());
{
if let Some(handle) = guard.as_mut() {
let _ = handle.read_request_tx.send(ReadRequest::Shutdown);
#[cfg(unix)]
{
let pid = handle.child.id() as libc::pid_t;
if pid > 0 {
unsafe {
libc::kill(pid, libc::SIGTERM);
}
}
}
#[cfg(not(unix))]
{
drop(handle.stdin.take());
}
}
if let Some(mut handle) = guard.take() {
drop(handle.stdin.take());
#[cfg(unix)]
{
let deadline = std::time::Duration::from_secs(2);
let start = std::time::Instant::now();
loop {
match handle.child.try_wait() {
Ok(Some(_)) => break,
Ok(None) if start.elapsed() < deadline => {
std::thread::sleep(std::time::Duration::from_millis(50));
}
_ => break,
}
}
}
#[cfg(not(unix))]
{
let deadline = std::time::Duration::from_secs(2);
let start = std::time::Instant::now();
loop {
match handle.child.try_wait() {
Ok(Some(_)) => break,
Ok(None) if start.elapsed() < deadline => {
std::thread::sleep(std::time::Duration::from_millis(50));
}
_ => break,
}
}
}
let _ = handle.child.kill();
let _ = handle.read_thread.join();
let _ = handle.child.wait();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use leindex_embed::protocol::ErrorKind;
#[test]
fn test_client_creation() {
let _client = EmbeddingClient::new();
}
#[test]
fn test_client_debug_impl() {
let client = EmbeddingClient::new();
let debug_str = format!("{:?}", client);
assert!(debug_str.contains("EmbeddingClient"));
}
#[test]
fn test_client_clone_shares_worker() {
let client = EmbeddingClient::new();
let cloned = client.clone();
let _ = format!("{:?}", cloned);
}
#[test]
fn test_client_error_display() {
let err = ClientError::SpawnFailed("not found".to_string());
assert!(err.to_string().contains("not found"));
let worker_err = WorkerError {
kind: ErrorKind::ModelNotFound,
message: "missing model".to_string(),
};
let err = ClientError::Worker(worker_err);
assert!(err.to_string().contains("missing model"));
}
#[test]
fn test_embed_result_success() {
let response = EmbedResponse::new(vec![1.0, 2.0, 3.0, 4.0], 1, 4);
let result = EmbedResult::Success(response);
assert!(result.is_success());
assert!(!result.is_fallback());
assert!(result.into_success().is_some());
}
#[test]
fn test_embed_result_fallback() {
let error = ClientError::Worker(WorkerError {
kind: ErrorKind::Inference,
message: "worker crashed".to_string(),
});
let result = EmbedResult::Fallback {
batch_id: BatchId::new(42),
error,
};
assert!(!result.is_success());
assert!(result.is_fallback());
assert!(result.into_success().is_none());
}
#[test]
fn test_batch_id_monotonic() {
let id1 = EmbeddingClient::next_batch_id();
let id2 = EmbeddingClient::next_batch_id();
assert!(
id2.0 > id1.0,
"batch IDs should be monotonically increasing"
);
}
use std::sync::Mutex;
static TEST_ENV_LOCK: Mutex<()> = Mutex::new(());
#[test]
fn test_read_ort_dylib_path_from_config_returns_value() {
let _g = TEST_ENV_LOCK.lock().unwrap();
let tmp = tempfile::tempdir().unwrap();
std::env::set_var("LEINDEX_HOME", tmp.path());
let cfg_dir = tmp.path().join("config");
std::fs::create_dir_all(&cfg_dir).unwrap();
std::fs::write(
cfg_dir.join("leindex.toml"),
"[neural]\nenabled = true\nexecution_provider = \"cpu\"\nort_dylib_path = \"/opt/onnxruntime/libonnxruntime.so\"\nort_version = \"1.25.0\"\nmodel_dir = \"/models\"\n",
)
.unwrap();
let parsed = read_ort_dylib_path_from_config();
assert_eq!(
parsed.as_deref(),
Some("/opt/onnxruntime/libonnxruntime.so")
);
assert_eq!(
read_execution_provider_from_config().as_deref(),
Some("cpu")
);
std::env::remove_var("LEINDEX_HOME");
}
#[test]
fn test_read_execution_provider_from_config_skips_auto() {
let _g = TEST_ENV_LOCK.lock().unwrap();
let tmp = tempfile::tempdir().unwrap();
std::env::set_var("LEINDEX_HOME", tmp.path());
let cfg_dir = tmp.path().join("config");
std::fs::create_dir_all(&cfg_dir).unwrap();
std::fs::write(
cfg_dir.join("leindex.toml"),
"[neural]\nenabled = true\nexecution_provider = \"auto\"\n",
)
.unwrap();
assert_eq!(read_execution_provider_from_config(), None);
std::fs::write(
cfg_dir.join("leindex.toml"),
"[neural]\nenabled = true\nexecution_provider = \"migraphx\"\n",
)
.unwrap();
assert_eq!(
read_execution_provider_from_config().as_deref(),
Some("migraphx")
);
std::env::remove_var("LEINDEX_HOME");
}
#[test]
fn test_read_ort_dylib_path_from_config_returns_none_when_absent() {
let _g = TEST_ENV_LOCK.lock().unwrap();
let tmp = tempfile::tempdir().unwrap();
std::env::set_var("LEINDEX_HOME", tmp.path());
assert_eq!(read_ort_dylib_path_from_config(), None);
let cfg_dir = tmp.path().join("config");
std::fs::create_dir_all(&cfg_dir).unwrap();
std::fs::write(
cfg_dir.join("leindex.toml"),
"[neural]\nenabled = true\nmodel_dir = \"/models\"\n",
)
.unwrap();
assert_eq!(read_ort_dylib_path_from_config(), None);
std::env::remove_var("LEINDEX_HOME");
}
#[test]
fn test_read_ort_dylib_path_from_config_handles_single_quotes() {
let _g = TEST_ENV_LOCK.lock().unwrap();
let tmp = tempfile::tempdir().unwrap();
std::env::set_var("LEINDEX_HOME", tmp.path());
let cfg_dir = tmp.path().join("config");
std::fs::create_dir_all(&cfg_dir).unwrap();
std::fs::write(
cfg_dir.join("leindex.toml"),
"[neural]\nort_dylib_path = '/quote/ort.so'\n",
)
.unwrap();
assert_eq!(
read_ort_dylib_path_from_config().as_deref(),
Some("/quote/ort.so")
);
std::env::remove_var("LEINDEX_HOME");
}
#[test]
fn test_leindex_home_dir_prefers_env_override() {
let _g = TEST_ENV_LOCK.lock().unwrap();
std::env::set_var("LEINDEX_HOME", "/custom/leindex/home");
assert_eq!(
leindex_home_dir(),
Some(std::path::PathBuf::from("/custom/leindex/home"))
);
std::env::remove_var("LEINDEX_HOME");
}
#[test]
fn test_leindex_home_dir_falls_back_to_home() {
let _g = TEST_ENV_LOCK.lock().unwrap();
std::env::remove_var("LEINDEX_HOME");
std::env::set_var("HOME", "/home/testuser");
let home = leindex_home_dir();
assert_eq!(
home,
Some(std::path::PathBuf::from("/home/testuser/.leindex"))
);
std::env::remove_var("HOME");
}
#[test]
fn test_leindex_home_dir_relative_env_ignored() {
let _g = TEST_ENV_LOCK.lock().unwrap();
std::env::set_var("LEINDEX_HOME", "relative/path");
std::env::set_var("HOME", "/home/fallback");
let home = leindex_home_dir();
assert_eq!(
home,
Some(std::path::PathBuf::from("/home/fallback/.leindex"))
);
std::env::remove_var("LEINDEX_HOME");
std::env::remove_var("HOME");
}
}