1use std::path::{Path, PathBuf};
8
9use lilo_rm_core::{
10 ErrorCode, ProtocolError, RuntimeResponse, RuntimeRpc, read_json_line, write_json_line,
11};
12use thiserror::Error;
13use tokio::io::BufReader;
14use tokio::net::UnixStream;
15
16#[derive(Clone, Debug)]
17pub struct RuntimeClient {
18 socket_path: PathBuf,
19}
20
21impl RuntimeClient {
22 pub fn new(socket_path: impl Into<PathBuf>) -> Self {
23 Self {
24 socket_path: socket_path.into(),
25 }
26 }
27
28 pub fn socket_path(&self) -> &Path {
29 &self.socket_path
30 }
31
32 pub async fn request(&self, rpc: RuntimeRpc) -> Result<RuntimeResponse, ClientError> {
33 request(&self.socket_path, rpc).await
34 }
35}
36
37#[derive(Debug, Error)]
38pub enum ClientError {
39 #[error("rtmd unavailable at {socket_path}: {source}")]
40 DaemonUnavailable {
41 socket_path: PathBuf,
42 #[source]
43 source: std::io::Error,
44 },
45 #[error("rtmd protocol error: {source}")]
46 Protocol {
47 #[from]
48 source: ProtocolError,
49 },
50 #[error("rtmd returned {code}: {message}")]
51 ErrorResponse { code: ErrorCode, message: String },
52}
53
54impl ClientError {
55 pub const fn code(&self) -> ErrorCode {
56 match self {
57 Self::DaemonUnavailable { .. } => ErrorCode::RuntimeUnavailable,
58 Self::Protocol { .. } => ErrorCode::ProtocolMismatch,
59 Self::ErrorResponse { code, .. } => *code,
60 }
61 }
62}
63
64pub async fn request(
65 socket_path: impl AsRef<Path>,
66 rpc: RuntimeRpc,
67) -> Result<RuntimeResponse, ClientError> {
68 let socket_path = socket_path.as_ref();
69 let stream = UnixStream::connect(socket_path).await.map_err(|source| {
70 ClientError::DaemonUnavailable {
71 socket_path: socket_path.to_path_buf(),
72 source,
73 }
74 })?;
75 request_on_stream(stream, rpc).await
76}
77
78async fn request_on_stream(
79 stream: UnixStream,
80 rpc: RuntimeRpc,
81) -> Result<RuntimeResponse, ClientError> {
82 let (read_half, mut write_half) = stream.into_split();
83 write_json_line(&mut write_half, &rpc).await?;
84
85 let mut reader = BufReader::new(read_half);
86 match read_json_line(&mut reader).await? {
87 RuntimeResponse::Error { code, message } => {
88 Err(ClientError::ErrorResponse { code, message })
89 }
90 response => Ok(response),
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use tokio::net::UnixListener;
98
99 fn temp_socket_path() -> (tempfile::TempDir, PathBuf) {
100 let tempdir = tempfile::tempdir().expect("tempdir");
101 let socket_path = tempdir.path().join("rtmd.sock");
102 (tempdir, socket_path)
103 }
104
105 #[tokio::test]
106 async fn missing_socket_reports_daemon_unavailable() {
107 let (_tempdir, socket_path) = temp_socket_path();
108
109 let error = request(&socket_path, RuntimeRpc::Version)
110 .await
111 .expect_err("missing socket should fail");
112
113 match error {
114 ClientError::DaemonUnavailable {
115 socket_path: actual,
116 ..
117 } => assert_eq!(actual, socket_path),
118 other => panic!("unexpected client error: {other:?}"),
119 }
120 }
121
122 #[tokio::test]
123 async fn daemon_error_response_preserves_code() {
124 let (_tempdir, socket_path) = temp_socket_path();
125 let listener = UnixListener::bind(&socket_path).expect("bind test socket");
126 let server = tokio::spawn(async move {
127 let (stream, _) = listener.accept().await.expect("accept client");
128 let (read_half, mut write_half) = stream.into_split();
129 let mut reader = BufReader::new(read_half);
130 let rpc: RuntimeRpc = read_json_line(&mut reader).await.expect("read rpc");
131 assert_eq!(rpc, RuntimeRpc::Version);
132
133 write_json_line(
134 &mut write_half,
135 &RuntimeResponse::error(ErrorCode::SessionNotFound, "missing session"),
136 )
137 .await
138 .expect("write response");
139 });
140
141 let error = RuntimeClient::new(&socket_path)
142 .request(RuntimeRpc::Version)
143 .await
144 .expect_err("daemon error response should fail");
145
146 match error {
147 ClientError::ErrorResponse { code, message } => {
148 assert_eq!(code, ErrorCode::SessionNotFound);
149 assert_eq!(message, "missing session");
150 }
151 other => panic!("unexpected client error: {other:?}"),
152 }
153 server.await.expect("server task");
154 }
155}