use std::path::{Path, PathBuf};
use anyhow::{Context, Result, anyhow};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::UnixStream;
const JSONRPC_VERSION: &str = "2.0";
const METHOD_INDEX: &str = "index";
const METHOD_SEARCH: &str = "search";
const METHOD_DELETE: &str = "delete";
pub fn socket_path_for_palace(palace: &str) -> PathBuf {
let dir = match std::env::var("TMPDIR") {
Ok(p) if !p.trim().is_empty() => PathBuf::from(p),
_ => PathBuf::from("/tmp"),
};
dir.join(format!("trusty-bm25-{palace}.sock"))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BM25Hit {
pub doc_id: String,
pub score: f32,
}
#[derive(Debug, Serialize)]
struct RpcRequest<'a, P: Serialize> {
jsonrpc: &'a str,
method: &'a str,
params: P,
id: u64,
}
#[derive(Debug, Serialize)]
struct IndexParams<'a> {
doc_id: &'a str,
text: &'a str,
}
#[derive(Debug, Serialize)]
struct SearchParams<'a> {
query: &'a str,
top_k: usize,
}
#[derive(Debug, Serialize)]
struct DeleteParams<'a> {
doc_id: &'a str,
}
#[derive(Debug, Deserialize)]
struct RpcResponse<T> {
#[serde(default = "Option::default")]
result: Option<T>,
#[serde(default = "Option::default")]
error: Option<RpcError>,
}
#[derive(Debug, Deserialize)]
struct RpcError {
code: i32,
message: String,
}
#[derive(Debug, Deserialize)]
struct IndexResult {
#[serde(default)]
indexed: bool,
}
#[derive(Debug, Deserialize)]
struct DeleteResult {
#[serde(default)]
deleted: bool,
}
#[derive(Debug, Deserialize)]
struct SearchResult {
#[serde(default)]
hits: Vec<BM25Hit>,
}
#[derive(Debug, Clone)]
pub struct Bm25Client {
socket_path: PathBuf,
}
impl Bm25Client {
pub fn for_palace(palace: impl Into<String>) -> Self {
let palace = palace.into();
Self {
socket_path: socket_path_for_palace(&palace),
}
}
pub fn new(socket_path: PathBuf) -> Self {
Self { socket_path }
}
pub fn socket_path(&self) -> &Path {
&self.socket_path
}
pub async fn index(&self, doc_id: &str, text: &str) -> Result<()> {
let params = IndexParams { doc_id, text };
let res: IndexResult = self.call(METHOD_INDEX, ¶ms).await?;
if !res.indexed {
anyhow::bail!("bm25 daemon reported indexed=false for doc_id={doc_id}");
}
Ok(())
}
pub async fn search(&self, query: &str, top_k: usize) -> Result<Vec<BM25Hit>> {
let params = SearchParams { query, top_k };
let res: SearchResult = self.call(METHOD_SEARCH, ¶ms).await?;
Ok(res.hits)
}
pub async fn delete(&self, doc_id: &str) -> Result<()> {
let params = DeleteParams { doc_id };
let res: DeleteResult = self.call(METHOD_DELETE, ¶ms).await?;
let _ = res.deleted;
Ok(())
}
async fn call<P: Serialize, R: serde::de::DeserializeOwned>(
&self,
method: &'static str,
params: &P,
) -> Result<R> {
let stream = UnixStream::connect(&self.socket_path)
.await
.with_context(|| {
format!(
"connect to bm25 daemon at {} (method={method})",
self.socket_path.display()
)
})?;
let (read_half, mut write_half) = stream.into_split();
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method,
params,
id: 1,
};
let mut payload = serde_json::to_vec(&req).context("serialise bm25 JSON-RPC request")?;
payload.push(b'\n');
write_half
.write_all(&payload)
.await
.context("write bm25 JSON-RPC request to daemon")?;
write_half
.shutdown()
.await
.context("half-close write side of bm25 daemon socket")?;
let mut reader = BufReader::new(read_half);
let mut line = String::new();
let n = reader
.read_line(&mut line)
.await
.context("read bm25 JSON-RPC response from daemon")?;
if n == 0 {
anyhow::bail!("bm25 daemon closed connection before responding (method={method})");
}
let resp: RpcResponse<R> = serde_json::from_str(line.trim()).with_context(|| {
format!(
"decode bm25 JSON-RPC response (method={method}, raw={})",
line.trim()
)
})?;
if let Some(err) = resp.error {
anyhow::bail!("bm25 daemon error {}: {}", err.code, err.message);
}
resp.result
.ok_or_else(|| anyhow!("bm25 daemon response missing both result and error"))
}
}
pub fn locate_bm25_daemon_binary() -> anyhow::Result<std::path::PathBuf> {
if let Ok(explicit) = std::env::var("TRUSTY_BM25_DAEMON_BIN") {
let p = std::path::PathBuf::from(&explicit);
if p.is_file() {
return Ok(p);
}
anyhow::bail!("TRUSTY_BM25_DAEMON_BIN={explicit:?} does not point to an existing file");
}
if let Ok(exe) = std::env::current_exe()
&& let Some(dir) = exe.parent()
{
let sibling = dir.join("trusty-bm25-daemon");
if sibling.is_file() {
return Ok(sibling);
}
let sibling_exe = dir.join("trusty-bm25-daemon.exe");
if sibling_exe.is_file() {
return Ok(sibling_exe);
}
}
if let Ok(found) = which_bm25_daemon() {
return Ok(found);
}
anyhow::bail!(
"could not locate trusty-bm25-daemon binary. \
Set TRUSTY_BM25_DAEMON_BIN=/path/to/trusty-bm25-daemon or ensure \
it is on PATH (or install via `cargo install trusty-memory`)."
)
}
fn which_bm25_daemon() -> anyhow::Result<std::path::PathBuf> {
let path_var = std::env::var("PATH").unwrap_or_default();
let sep = if cfg!(windows) { ';' } else { ':' };
for dir in path_var.split(sep) {
let candidate = std::path::PathBuf::from(dir).join("trusty-bm25-daemon");
if candidate.is_file() {
return Ok(candidate);
}
#[cfg(windows)]
{
let candidate_exe = std::path::PathBuf::from(dir).join("trusty-bm25-daemon.exe");
if candidate_exe.is_file() {
return Ok(candidate_exe);
}
}
}
anyhow::bail!("trusty-bm25-daemon not found on PATH")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn socket_path_uses_tmpdir_and_palace_name() {
let p = socket_path_for_palace("my-palace");
let fname = p.file_name().and_then(|s| s.to_str()).unwrap_or("");
assert!(
fname.starts_with("trusty-bm25-"),
"filename must start with trusty-bm25-: {fname}"
);
assert!(
fname.contains("my-palace"),
"filename must include palace name: {fname}"
);
assert!(
fname.ends_with(".sock"),
"filename must end with .sock: {fname}"
);
assert!(p.parent().is_some());
}
#[test]
fn for_palace_uses_palace_specific_path() {
let c = Bm25Client::for_palace("alpha");
let fname = c
.socket_path()
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("");
assert!(fname.contains("alpha"), "got: {fname}");
}
#[test]
fn index_request_serialises_as_jsonrpc_2_0() {
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: METHOD_INDEX,
params: IndexParams {
doc_id: "doc-1",
text: "hello world",
},
id: 1,
};
let s = serde_json::to_string(&req).unwrap();
assert!(s.contains("\"jsonrpc\":\"2.0\""));
assert!(s.contains("\"method\":\"index\""));
assert!(s.contains("\"doc_id\":\"doc-1\""));
assert!(s.contains("\"text\":\"hello world\""));
}
#[test]
fn search_request_serialises_with_query_and_top_k() {
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: METHOD_SEARCH,
params: SearchParams {
query: "cargo test",
top_k: 5,
},
id: 1,
};
let s = serde_json::to_string(&req).unwrap();
assert!(s.contains("\"method\":\"search\""));
assert!(s.contains("\"query\":\"cargo test\""));
assert!(s.contains("\"top_k\":5"));
}
#[test]
fn delete_request_serialises_with_doc_id() {
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: METHOD_DELETE,
params: DeleteParams { doc_id: "x" },
id: 1,
};
let s = serde_json::to_string(&req).unwrap();
assert!(s.contains("\"method\":\"delete\""));
assert!(s.contains("\"doc_id\":\"x\""));
}
#[test]
fn bm25_hit_round_trips() {
let h = BM25Hit {
doc_id: "drawer-1".into(),
score: 0.42,
};
let s = serde_json::to_string(&h).unwrap();
let back: BM25Hit = serde_json::from_str(&s).unwrap();
assert_eq!(back.doc_id, "drawer-1");
assert!((back.score - 0.42).abs() < 1e-6);
}
#[test]
fn locate_bm25_daemon_binary_prefers_env_override() {
let exe = std::env::current_exe().expect("current_exe");
let key = "TRUSTY_BM25_DAEMON_BIN";
let prev = std::env::var(key).ok();
unsafe { std::env::set_var(key, &exe) };
let result = locate_bm25_daemon_binary();
match prev {
Some(v) => unsafe { std::env::set_var(key, v) },
None => unsafe { std::env::remove_var(key) },
}
assert_eq!(result.expect("must find via env var"), exe);
}
#[test]
fn locate_bm25_daemon_binary_env_override_nonexistent_errors() {
let key = "TRUSTY_BM25_DAEMON_BIN";
let prev = std::env::var(key).ok();
unsafe { std::env::set_var(key, "/nonexistent/trusty-bm25-daemon") };
let result = locate_bm25_daemon_binary();
match prev {
Some(v) => unsafe { std::env::set_var(key, v) },
None => unsafe { std::env::remove_var(key) },
}
assert!(
result.is_err(),
"expected error for non-existent path, got: {result:?}"
);
}
}