use std::path::{Path, PathBuf};
use std::time::Duration;
use anyhow::{Context, Result};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::UnixStream;
use tokio::time::timeout;
use super::protocol::{HealthInfo, Operation, Request, Response, PROTOCOL_VERSION};
use super::get_socket_path;
pub struct DaemonClient {
socket_path: PathBuf,
timeout: Duration,
}
impl DaemonClient {
pub fn new(beads_dir: impl AsRef<Path>) -> Self {
Self {
socket_path: get_socket_path(beads_dir.as_ref()),
timeout: Duration::from_secs(30),
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub async fn is_available(&self) -> bool {
self.health().await.is_ok()
}
async fn connect(&self) -> Result<UnixStream> {
let connect_timeout = Duration::from_millis(500);
timeout(connect_timeout, UnixStream::connect(&self.socket_path))
.await
.context("Connection timeout")?
.context("Failed to connect to daemon")
}
pub async fn send(&self, request: &Request) -> Result<Response> {
let stream = self.connect().await?;
let (reader, mut writer) = stream.into_split();
let mut reader = BufReader::new(reader);
let json = serde_json::to_string(request)?;
writer.write_all(json.as_bytes()).await?;
writer.write_all(b"\n").await?;
writer.flush().await?;
let mut line = String::new();
timeout(self.timeout, reader.read_line(&mut line))
.await
.context("Response timeout")?
.context("Failed to read response")?;
let response: Response = serde_json::from_str(&line)
.context("Failed to parse response")?;
Ok(response)
}
pub async fn health(&self) -> Result<HealthInfo> {
let request = Request::new(Operation::Health, "client");
let response = self.send(&request).await?;
if response.success {
response.parse_data()
.context("Failed to parse health info")
} else {
anyhow::bail!(response.error.unwrap_or_else(|| "Unknown error".to_string()))
}
}
pub async fn ping(&self) -> Result<()> {
let request = Request::new(Operation::Ping, "client");
let response = self.send(&request).await?;
if response.success {
Ok(())
} else {
anyhow::bail!(response.error.unwrap_or_else(|| "Ping failed".to_string()))
}
}
pub async fn shutdown(&self) -> Result<()> {
let request = Request::new(Operation::Shutdown, "client");
let _ = self.send(&request).await;
Ok(())
}
pub async fn check_compatibility(&self) -> Result<bool> {
let health = self.health().await?;
Ok(is_compatible(&health.protocol_version, PROTOCOL_VERSION))
}
}
fn is_compatible(server_version: &str, client_version: &str) -> bool {
let server_major = server_version.split('.').next().and_then(|s| s.parse::<u32>().ok());
let client_major = client_version.split('.').next().and_then(|s| s.parse::<u32>().ok());
match (server_major, client_major) {
(Some(s), Some(c)) => s == c,
_ => true, }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_version_compatibility() {
assert!(is_compatible("1.0.0", "1.0.0"));
assert!(is_compatible("1.0.0", "1.1.0"));
assert!(is_compatible("1.2.3", "1.0.0"));
assert!(!is_compatible("2.0.0", "1.0.0"));
assert!(!is_compatible("1.0.0", "2.0.0"));
}
}