1use std::path::{Path, PathBuf};
7use std::time::Duration;
8
9use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
10use tokio::net::UnixStream;
11
12use crate::protocol::{Request, Response, ResponseOutcome, WireError, WireErrorKind, encode_line};
13
14pub const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
20
21pub const ONESHOT_TIMEOUT: Duration = Duration::from_secs(30);
26
27pub struct UnixMgmtClient {
31 socket_path: PathBuf,
32}
33
34impl UnixMgmtClient {
35 pub fn new(socket_path: impl AsRef<Path>) -> Self {
36 Self { socket_path: socket_path.as_ref().to_path_buf() }
37 }
38
39 pub async fn call<A, R>(&self, verb: &str, args: &A) -> Result<R, MgmtClientError>
52 where
53 A: serde::Serialize,
54 R: for<'de> serde::Deserialize<'de>,
55 {
56 let stream = tokio::time::timeout(CONNECT_TIMEOUT, UnixStream::connect(&self.socket_path))
57 .await
58 .map_err(|_| MgmtClientError::Timeout("connect"))??;
59 let (read, mut write) = stream.into_split();
60
61 let req = Request {
62 id: 1,
63 verb: verb.to_string(),
64 args: serde_json::to_value(args).map_err(MgmtClientError::Encode)?,
65 };
66 let bytes = encode_line(&req).map_err(MgmtClientError::Encode)?;
67 write.write_all(&bytes).await?;
68 write.shutdown().await.ok();
72
73 let mut lines = BufReader::new(read).lines();
74 let line = tokio::time::timeout(ONESHOT_TIMEOUT, lines.next_line())
75 .await
76 .map_err(|_| MgmtClientError::Timeout("read"))??
77 .ok_or(MgmtClientError::EmptyResponse)?;
78 let response: Response = serde_json::from_str(&line).map_err(MgmtClientError::Decode)?;
79 match response.outcome {
80 ResponseOutcome::Result { result } => {
81 serde_json::from_value(result).map_err(MgmtClientError::Decode)
82 }
83 ResponseOutcome::Error { error } => Err(MgmtClientError::Server(error)),
84 ResponseOutcome::Event { .. } | ResponseOutcome::End { .. } => {
85 Err(MgmtClientError::Server(WireError::new(
89 WireErrorKind::Internal,
90 "received streaming frame on one-shot call",
91 )))
92 }
93 }
94 }
95
96 pub async fn call_stream<A, F>(
109 &self,
110 verb: &str,
111 args: &A,
112 mut on_event: F,
113 ) -> Result<(), MgmtClientError>
114 where
115 A: serde::Serialize,
116 F: FnMut(serde_json::Value),
117 {
118 let stream = tokio::time::timeout(CONNECT_TIMEOUT, UnixStream::connect(&self.socket_path))
119 .await
120 .map_err(|_| MgmtClientError::Timeout("connect"))??;
121 let (read, mut write) = stream.into_split();
122
123 let req = Request {
124 id: 1,
125 verb: verb.to_string(),
126 args: serde_json::to_value(args).map_err(MgmtClientError::Encode)?,
127 };
128 let bytes = encode_line(&req).map_err(MgmtClientError::Encode)?;
129 write.write_all(&bytes).await?;
130 let mut lines = BufReader::new(read).lines();
133 while let Some(line) = lines.next_line().await? {
134 if line.is_empty() {
135 continue;
136 }
137 let response: Response = serde_json::from_str(&line).map_err(MgmtClientError::Decode)?;
138 match response.outcome {
139 ResponseOutcome::Event { event } => on_event(event),
140 ResponseOutcome::End { .. } => return Ok(()),
141 ResponseOutcome::Error { error } => return Err(MgmtClientError::Server(error)),
142 ResponseOutcome::Result { .. } => {
143 return Err(MgmtClientError::Server(WireError::new(
144 WireErrorKind::Internal,
145 "received one-shot Result on streaming call",
146 )));
147 }
148 }
149 }
150 Ok(())
153 }
154}
155
156#[derive(Debug, thiserror::Error)]
157pub enum MgmtClientError {
158 #[error("io: {0}")]
159 Io(#[from] std::io::Error),
160 #[error("encode: {0}")]
161 Encode(serde_json::Error),
162 #[error("decode: {0}")]
163 Decode(serde_json::Error),
164 #[error("empty response")]
165 EmptyResponse,
166 #[error("server: {kind:?} {message}", kind = .0.kind, message = .0.message)]
167 Server(WireError),
168 #[error("http {status}: {body}")]
172 Http { status: u16, body: String },
173 #[error("timeout: {0}")]
178 Timeout(&'static str),
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::server::{DispatchOutcome, Handler, handle_conn};
185 use async_trait::async_trait;
186 use serde::Deserialize;
187 use std::sync::Arc;
188
189 #[derive(Debug, Deserialize)]
190 struct PingResult {
191 pong: bool,
192 version: String,
193 }
194
195 #[derive(serde::Serialize)]
196 struct NoArgs {}
197
198 struct StubHandler;
199
200 #[async_trait]
201 impl Handler for StubHandler {
202 async fn dispatch(&self, req: Request) -> DispatchOutcome {
203 let result: Result<serde_json::Value, crate::protocol::WireError> = match req.verb.as_str() {
204 "ping" => Ok(serde_json::json!({ "pong": true, "version": "test" })),
205 "bad_shape" => Ok(serde_json::json!({ "unrelated": 1 })),
206 _ => Err(WireError::new(WireErrorKind::UnknownVerb, format!("unknown {}", req.verb))),
207 };
208 DispatchOutcome::OneShot(result)
209 }
210 }
211
212 async fn drive_call<A, R>(verb: &str, args: A) -> Result<R, MgmtClientError>
216 where
217 A: serde::Serialize,
218 R: for<'de> serde::Deserialize<'de>,
219 {
220 let (c2s_r, mut c2s_w) = tokio::io::duplex(8192);
221 let (s2c_w, s2c_r) = tokio::io::duplex(8192);
222 let server = tokio::spawn(handle_conn(
223 c2s_r,
224 s2c_w,
225 Arc::new(StubHandler),
226 tokio_util::sync::CancellationToken::new(),
227 ));
228
229 let req = Request {
231 id: 1,
232 verb: verb.to_string(),
233 args: serde_json::to_value(&args).expect("args serialize"),
234 };
235 let bytes = encode_line(&req).expect("encode");
236 c2s_w.write_all(&bytes).await.expect("write");
237 drop(c2s_w);
238
239 let mut lines = BufReader::new(s2c_r).lines();
240 let line = lines
241 .next_line()
242 .await
243 .map_err(MgmtClientError::Io)?
244 .ok_or(MgmtClientError::EmptyResponse)?;
245 let response: Response = serde_json::from_str(&line).map_err(MgmtClientError::Decode)?;
246 let _ = server.await;
249 match response.outcome {
250 ResponseOutcome::Result { result } => {
251 serde_json::from_value(result).map_err(MgmtClientError::Decode)
252 }
253 ResponseOutcome::Error { error } => Err(MgmtClientError::Server(error)),
254 other => panic!("unexpected outcome: {other:?}"),
255 }
256 }
257
258 #[tokio::test]
259 async fn client_call_decodes_typed_result() {
260 let result: PingResult = drive_call("ping", NoArgs {}).await.expect("ok");
261 assert!(result.pong);
262 assert_eq!(result.version, "test");
263 }
264
265 #[tokio::test]
266 async fn client_surfaces_server_error_as_mgmt_client_error_server() {
267 let err = drive_call::<_, PingResult>("nope", NoArgs {}).await.expect_err("err");
268 match err {
269 MgmtClientError::Server(w) => {
270 assert_eq!(w.kind, crate::protocol::WireErrorKind::UnknownVerb);
271 }
272 other => panic!("expected Server, got {other:?}"),
273 }
274 }
275
276 #[tokio::test]
277 async fn client_decode_error_when_result_shape_mismatches() {
278 let err = drive_call::<_, PingResult>("bad_shape", NoArgs {}).await.expect_err("err");
279 assert!(matches!(err, MgmtClientError::Decode(_)), "unexpected variant: {err:?}");
280 }
281
282 struct StreamingHandler;
283
284 #[async_trait]
285 impl Handler for StreamingHandler {
286 async fn dispatch(&self, _req: Request) -> DispatchOutcome {
287 use crate::server::EventStream;
288 struct ThreeEvents {
289 remaining: Vec<serde_json::Value>,
290 }
291 #[async_trait]
292 impl EventStream for ThreeEvents {
293 async fn next_event(&mut self) -> Option<serde_json::Value> {
294 self.remaining.pop()
295 }
296 }
297 DispatchOutcome::Stream(Box::new(ThreeEvents {
298 remaining: vec![
299 serde_json::json!({ "i": 3 }),
300 serde_json::json!({ "i": 2 }),
301 serde_json::json!({ "i": 1 }),
302 ],
303 }))
304 }
305 }
306
307 #[tokio::test]
308 async fn client_call_stream_invokes_callback_per_event_until_end() {
309 let (c2s_r, mut c2s_w) = tokio::io::duplex(8192);
313 let (s2c_w, s2c_r) = tokio::io::duplex(8192);
314 let server = tokio::spawn(crate::server::handle_conn(
315 c2s_r,
316 s2c_w,
317 Arc::new(StreamingHandler),
318 tokio_util::sync::CancellationToken::new(),
319 ));
320
321 let req = Request { id: 1, verb: "tail".to_string(), args: serde_json::Value::Null };
323 let bytes = encode_line(&req).expect("encode");
324 c2s_w.write_all(&bytes).await.expect("write");
325
326 let mut events = Vec::new();
327 let mut lines = BufReader::new(s2c_r).lines();
328 loop {
329 let line = lines.next_line().await.expect("read").expect("line");
330 let resp: Response = serde_json::from_str(&line).expect("parse");
331 match resp.outcome {
332 ResponseOutcome::Event { event } => events.push(event),
333 ResponseOutcome::End { .. } => break,
334 other => panic!("unexpected: {other:?}"),
335 }
336 }
337 drop(c2s_w);
339 let _ = server.await;
340 assert_eq!(events.len(), 3, "all three events consumed");
341 assert_eq!(events[0]["i"], 1, "FIFO order from the stream");
342 assert_eq!(events[1]["i"], 2);
343 assert_eq!(events[2]["i"], 3);
344 }
345
346 #[tokio::test]
347 async fn client_io_error_on_missing_socket() {
348 let tmp = tempfile::tempdir().expect("tempdir");
349 let path = tmp.path().join("not-there.sock");
350 let client = UnixMgmtClient::new(&path);
351 let err = client
352 .call::<_, PingResult>("ping", &NoArgs {})
353 .await
354 .expect_err("must fail without a server");
355 assert!(matches!(err, MgmtClientError::Io(_)), "unexpected variant: {err:?}");
356 }
357}