use std::path::PathBuf;
use std::process::Stdio;
use std::sync::{Arc, OnceLock};
use anyhow::{Context, Result, anyhow, bail};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout};
use tokio::sync::Mutex;
static GLOBAL: OnceLock<Arc<Mutex<HebbClient>>> = OnceLock::new();
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VoiceMatch {
pub voice_id: String,
pub name: Option<String>,
pub similarity: f32,
}
pub struct HebbClient {
_child: Child,
stdin: ChildStdin,
stdout: BufReader<ChildStdout>,
next_id: u64,
}
impl HebbClient {
async fn spawn() -> Result<Self> {
let binary = locate_hebb_binary().ok_or_else(|| anyhow!("hebb-mcp not found in PATH"))?;
let mut child = tokio::process::Command::new(&binary)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null()) .spawn()
.with_context(|| format!("failed to spawn hebb-mcp at {}", binary.display()))?;
let stdin = child.stdin.take().expect("stdin piped");
let stdout = BufReader::new(child.stdout.take().expect("stdout piped"));
let mut client = Self {
_child: child,
stdin,
stdout,
next_id: 0,
};
client.handshake().await?;
Ok(client)
}
async fn handshake(&mut self) -> Result<()> {
let init_response = self
.send_request(
"initialize",
json!({
"protocolVersion": "2025-11-25",
"capabilities": { "sampling": {} },
"clientInfo": { "name": "nab-mcp", "version": env!("CARGO_PKG_VERSION") }
}),
)
.await
.context("MCP initialize handshake failed")?;
if init_response.get("error").is_some() {
bail!("hebb-mcp initialize returned error: {init_response}");
}
self.send_notification("notifications/initialized", json!({}))
.await
.context("hebb-mcp initialized notification failed")?;
Ok(())
}
pub async fn global() -> Result<Arc<Mutex<Self>>> {
if let Some(client) = GLOBAL.get() {
return Ok(Arc::clone(client));
}
let client = Self::spawn().await?;
let arc = Arc::new(Mutex::new(client));
let _ = GLOBAL.set(Arc::clone(&arc));
Ok(Arc::clone(GLOBAL.get().expect("just set")))
}
pub fn is_available() -> bool {
locate_hebb_binary().is_some()
}
async fn send_request(&mut self, method: &str, params: Value) -> Result<Value> {
let id = self.next_id;
self.next_id += 1;
let msg = json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
});
self.write_line(&msg).await?;
loop {
let response = self.read_response_line().await?;
if response.get("id") == Some(&json!(id)) {
return Ok(response);
}
}
}
async fn send_notification(&mut self, method: &str, params: Value) -> Result<()> {
let msg = json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
});
self.write_line(&msg).await
}
async fn write_line(&mut self, value: &Value) -> Result<()> {
let line = serde_json::to_string(value).context("failed to serialize JSON-RPC message")?;
self.stdin
.write_all(format!("{line}\n").as_bytes())
.await
.context("write to hebb-mcp stdin failed")?;
self.stdin.flush().await.context("flush hebb-mcp stdin")?;
Ok(())
}
async fn read_response_line(&mut self) -> Result<Value> {
let mut line = String::new();
loop {
line.clear();
let n = self
.stdout
.read_line(&mut line)
.await
.context("read from hebb-mcp stdout failed")?;
if n == 0 {
bail!("hebb-mcp subprocess closed stdout unexpectedly");
}
let trimmed = line.trim();
if !trimmed.is_empty() {
return serde_json::from_str(trimmed)
.with_context(|| format!("hebb-mcp response is not valid JSON: {trimmed}"));
}
}
}
pub async fn call_tool(&mut self, name: &str, arguments: Value) -> Result<Value> {
let response = self
.send_request(
"tools/call",
json!({ "name": name, "arguments": arguments }),
)
.await
.with_context(|| format!("tools/call '{name}' failed"))?;
if let Some(err) = response.get("error") {
bail!("hebb tool '{name}' RPC error: {err}");
}
let result = response
.get("result")
.ok_or_else(|| anyhow!("hebb response missing 'result' field"))?;
if result
.get("isError")
.and_then(Value::as_bool)
.unwrap_or(false)
{
let msg = extract_first_text_content(result)
.unwrap_or_else(|| format!("tool '{name}' returned isError=true"));
bail!("{msg}");
}
if let Some(structured) = result.get("structuredContent") {
return Ok(structured.clone());
}
let text = extract_first_text_content(result)
.ok_or_else(|| anyhow!("hebb tool '{name}' returned no content"))?;
serde_json::from_str(&text)
.with_context(|| format!("hebb tool '{name}' text content is not valid JSON: {text}"))
}
pub async fn voice_match(
&mut self,
embedding: &[f32],
threshold: f32,
limit: u32,
) -> Result<Vec<VoiceMatch>> {
let embedding_values: Vec<Value> = embedding.iter().map(|&f| json!(f)).collect();
let result = self
.call_tool(
"voice_match",
json!({
"embedding": embedding_values,
"threshold": threshold,
"limit": limit,
}),
)
.await?;
parse_voice_matches(&result)
}
pub async fn voice_remember(
&mut self,
embedding: &[f32],
source: &str,
name: Option<&str>,
) -> Result<String> {
let embedding_values: Vec<Value> = embedding.iter().map(|&f| json!(f)).collect();
let mut args = json!({
"embedding": embedding_values,
"source": source,
});
if let Some(n) = name {
args["name"] = json!(n);
}
let result = self.call_tool("voice_remember", args).await?;
result
.get("voice_id")
.and_then(Value::as_str)
.map(String::from)
.ok_or_else(|| anyhow!("voice_remember response missing 'voice_id'"))
}
pub async fn kv_set(
&mut self,
namespace: &str,
key: &str,
value: Value,
content_text: Option<&str>,
) -> Result<()> {
let mut args = json!({
"namespace": namespace,
"key": key,
"value": value,
});
if let Some(text) = content_text {
args["content_text"] = json!(text);
}
self.call_tool("kv_set", args).await?;
Ok(())
}
pub async fn kv_get(&mut self, namespace: &str, key: &str) -> Result<Option<Value>> {
let result = self
.call_tool("kv_get", json!({ "namespace": namespace, "key": key }))
.await;
match result {
Ok(v) => Ok(Some(v)),
Err(e) if is_not_found_error(&e) => Ok(None),
Err(e) => Err(e),
}
}
}
fn locate_hebb_binary() -> Option<PathBuf> {
if let Ok(path) = which::which("hebb-mcp") {
return Some(path);
}
let managed = dirs::data_local_dir()?.join("hebb/bin/hebb-mcp");
managed.exists().then_some(managed)
}
fn extract_first_text_content(result: &Value) -> Option<String> {
result
.get("content")?
.as_array()?
.iter()
.find(|c| c.get("type").and_then(Value::as_str) == Some("text"))
.and_then(|c| c.get("text"))
.and_then(Value::as_str)
.map(String::from)
}
fn parse_voice_matches(result: &Value) -> Result<Vec<VoiceMatch>> {
let matches = result
.get("matches")
.and_then(Value::as_array)
.ok_or_else(|| anyhow!("voice_match result missing 'matches' array"))?;
matches
.iter()
.map(|m| {
Ok(VoiceMatch {
voice_id: m
.get("voice_id")
.and_then(Value::as_str)
.ok_or_else(|| anyhow!("voice_match entry missing 'voice_id'"))?
.to_string(),
name: m.get("name").and_then(Value::as_str).map(String::from),
similarity: parse_similarity_f32(m)?,
})
})
.collect()
}
#[allow(clippy::cast_precision_loss)] fn parse_similarity_f32(match_value: &Value) -> Result<f32> {
let similarity = match_value
.get("similarity")
.and_then(Value::as_f64)
.ok_or_else(|| anyhow!("voice_match entry missing 'similarity'"))?;
if !similarity.is_finite()
|| similarity < f64::from(f32::MIN)
|| similarity > f64::from(f32::MAX)
{
bail!("voice_match similarity out of f32 range");
}
#[allow(clippy::cast_possible_truncation)]
let similarity = similarity as f32;
Ok(similarity)
}
fn is_not_found_error(err: &anyhow::Error) -> bool {
let msg = err.to_string().to_ascii_lowercase();
msg.contains("not found") || msg.contains("no such key") || msg.contains("key_not_found")
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn is_available_returns_bool_without_panic() {
let result = HebbClient::is_available();
let _ = result; }
#[test]
fn parse_voice_matches_well_formed() {
let result = json!({
"matches": [
{ "voice_id": "v_abc123", "name": "Alice", "similarity": 0.92 },
{ "voice_id": "v_def456", "name": null, "similarity": 0.75 }
]
});
let matches = parse_voice_matches(&result).expect("should parse");
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].voice_id, "v_abc123");
assert_eq!(matches[0].name.as_deref(), Some("Alice"));
assert!((matches[0].similarity - 0.92).abs() < 1e-4);
assert_eq!(matches[1].voice_id, "v_def456");
assert!(matches[1].name.is_none());
}
#[test]
fn parse_voice_matches_missing_key_returns_err() {
let result = json!({ "something_else": [] });
let err = parse_voice_matches(&result);
assert!(err.is_err(), "expected Err, got Ok");
}
#[test]
fn parse_voice_matches_empty_array() {
let result = json!({ "matches": [] });
let matches = parse_voice_matches(&result).expect("should parse");
assert!(matches.is_empty());
}
#[test]
fn extract_first_text_content_picks_text_type() {
let result = json!({
"content": [
{ "type": "image", "data": "base64..." },
{ "type": "text", "text": "{\"voice_id\": \"v_x\"}" },
]
});
let text = extract_first_text_content(&result);
assert_eq!(text.as_deref(), Some("{\"voice_id\": \"v_x\"}"));
}
#[test]
fn extract_first_text_content_none_when_no_text() {
let result = json!({ "content": [{ "type": "image", "data": "..." }] });
let text = extract_first_text_content(&result);
assert!(text.is_none());
}
#[test]
fn is_not_found_error_matches_known_phrases() {
let cases = [
(anyhow!("key not found in namespace"), true),
(anyhow!("no such key: urls/abc"), true),
(anyhow!("key_not_found"), true),
(anyhow!("connection refused"), false),
(anyhow!("internal server error"), false),
];
for (err, expected) in cases {
assert_eq!(is_not_found_error(&err), expected, "mismatch for: {err}");
}
}
}