1use std::path::{Path, PathBuf};
7
8use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
9use tokio::net::UnixStream;
10
11use crate::protocol::{Request, Response, ResponseOutcome, WireError, WireErrorKind, encode_line};
12
13pub struct UnixMgmtClient {
17 socket_path: PathBuf,
18}
19
20impl UnixMgmtClient {
21 pub fn new(socket_path: impl AsRef<Path>) -> Self {
22 Self { socket_path: socket_path.as_ref().to_path_buf() }
23 }
24
25 pub async fn call<A, R>(&self, verb: &str, args: &A) -> Result<R, MgmtClientError>
38 where
39 A: serde::Serialize,
40 R: for<'de> serde::Deserialize<'de>,
41 {
42 let stream = UnixStream::connect(&self.socket_path).await?;
43 let (read, mut write) = stream.into_split();
44
45 let req = Request {
46 id: 1,
47 verb: verb.to_string(),
48 args: serde_json::to_value(args).map_err(MgmtClientError::Encode)?,
49 };
50 let bytes = encode_line(&req).map_err(MgmtClientError::Encode)?;
51 write.write_all(&bytes).await?;
52 write.shutdown().await.ok();
56
57 let mut lines = BufReader::new(read).lines();
58 let line = lines.next_line().await?.ok_or(MgmtClientError::EmptyResponse)?;
59 let response: Response = serde_json::from_str(&line).map_err(MgmtClientError::Decode)?;
60 match response.outcome {
61 ResponseOutcome::Result { result } => {
62 serde_json::from_value(result).map_err(MgmtClientError::Decode)
63 }
64 ResponseOutcome::Error { error } => Err(MgmtClientError::Server(error)),
65 ResponseOutcome::Event { .. } | ResponseOutcome::End { .. } => {
66 Err(MgmtClientError::Server(WireError {
70 kind: WireErrorKind::Internal,
71 message: "received streaming frame on one-shot call".to_string(),
72 }))
73 }
74 }
75 }
76
77 pub async fn call_stream<A, F>(
90 &self,
91 verb: &str,
92 args: &A,
93 mut on_event: F,
94 ) -> Result<(), MgmtClientError>
95 where
96 A: serde::Serialize,
97 F: FnMut(serde_json::Value),
98 {
99 let stream = UnixStream::connect(&self.socket_path).await?;
100 let (read, mut write) = stream.into_split();
101
102 let req = Request {
103 id: 1,
104 verb: verb.to_string(),
105 args: serde_json::to_value(args).map_err(MgmtClientError::Encode)?,
106 };
107 let bytes = encode_line(&req).map_err(MgmtClientError::Encode)?;
108 write.write_all(&bytes).await?;
109 let mut lines = BufReader::new(read).lines();
112 while let Some(line) = lines.next_line().await? {
113 if line.is_empty() {
114 continue;
115 }
116 let response: Response = serde_json::from_str(&line).map_err(MgmtClientError::Decode)?;
117 match response.outcome {
118 ResponseOutcome::Event { event } => on_event(event),
119 ResponseOutcome::End { .. } => return Ok(()),
120 ResponseOutcome::Error { error } => return Err(MgmtClientError::Server(error)),
121 ResponseOutcome::Result { .. } => {
122 return Err(MgmtClientError::Server(WireError {
123 kind: WireErrorKind::Internal,
124 message: "received one-shot Result on streaming call".to_string(),
125 }));
126 }
127 }
128 }
129 Ok(())
132 }
133}
134
135#[derive(Debug, thiserror::Error)]
136pub enum MgmtClientError {
137 #[error("io: {0}")]
138 Io(#[from] std::io::Error),
139 #[error("encode: {0}")]
140 Encode(serde_json::Error),
141 #[error("decode: {0}")]
142 Decode(serde_json::Error),
143 #[error("empty response")]
144 EmptyResponse,
145 #[error("server: {kind:?} {message}", kind = .0.kind, message = .0.message)]
146 Server(WireError),
147 #[error("http {status}: {body}")]
151 Http { status: u16, body: String },
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::server::{DispatchOutcome, Handler, handle_conn};
158 use async_trait::async_trait;
159 use serde::Deserialize;
160 use std::sync::Arc;
161
162 #[derive(Debug, Deserialize)]
163 struct PingResult {
164 pong: bool,
165 version: String,
166 }
167
168 #[derive(serde::Serialize)]
169 struct NoArgs {}
170
171 struct StubHandler;
172
173 #[async_trait]
174 impl Handler for StubHandler {
175 async fn dispatch(&self, req: Request) -> DispatchOutcome {
176 let result: Result<serde_json::Value, crate::protocol::WireError> = match req.verb.as_str() {
177 "ping" => Ok(serde_json::json!({ "pong": true, "version": "test" })),
178 "bad_shape" => Ok(serde_json::json!({ "unrelated": 1 })),
179 _ => Err(WireError {
180 kind: WireErrorKind::UnknownVerb,
181 message: format!("unknown {}", req.verb),
182 }),
183 };
184 DispatchOutcome::OneShot(result)
185 }
186 }
187
188 async fn drive_call<A, R>(verb: &str, args: A) -> Result<R, MgmtClientError>
192 where
193 A: serde::Serialize,
194 R: for<'de> serde::Deserialize<'de>,
195 {
196 let (c2s_r, mut c2s_w) = tokio::io::duplex(8192);
197 let (s2c_w, s2c_r) = tokio::io::duplex(8192);
198 let server = tokio::spawn(handle_conn(c2s_r, s2c_w, Arc::new(StubHandler)));
199
200 let req = Request {
202 id: 1,
203 verb: verb.to_string(),
204 args: serde_json::to_value(&args).expect("args serialize"),
205 };
206 let bytes = encode_line(&req).expect("encode");
207 c2s_w.write_all(&bytes).await.expect("write");
208 drop(c2s_w);
209
210 let mut lines = BufReader::new(s2c_r).lines();
211 let line = lines
212 .next_line()
213 .await
214 .map_err(MgmtClientError::Io)?
215 .ok_or(MgmtClientError::EmptyResponse)?;
216 let response: Response = serde_json::from_str(&line).map_err(MgmtClientError::Decode)?;
217 let _ = server.await;
220 match response.outcome {
221 ResponseOutcome::Result { result } => {
222 serde_json::from_value(result).map_err(MgmtClientError::Decode)
223 }
224 ResponseOutcome::Error { error } => Err(MgmtClientError::Server(error)),
225 other => panic!("unexpected outcome: {other:?}"),
226 }
227 }
228
229 #[tokio::test]
230 async fn client_call_decodes_typed_result() {
231 let result: PingResult = drive_call("ping", NoArgs {}).await.expect("ok");
232 assert!(result.pong);
233 assert_eq!(result.version, "test");
234 }
235
236 #[tokio::test]
237 async fn client_surfaces_server_error_as_mgmt_client_error_server() {
238 let err = drive_call::<_, PingResult>("nope", NoArgs {}).await.expect_err("err");
239 match err {
240 MgmtClientError::Server(w) => {
241 assert_eq!(w.kind, crate::protocol::WireErrorKind::UnknownVerb);
242 }
243 other => panic!("expected Server, got {other:?}"),
244 }
245 }
246
247 #[tokio::test]
248 async fn client_decode_error_when_result_shape_mismatches() {
249 let err = drive_call::<_, PingResult>("bad_shape", NoArgs {}).await.expect_err("err");
250 assert!(matches!(err, MgmtClientError::Decode(_)), "unexpected variant: {err:?}");
251 }
252
253 struct StreamingHandler;
254
255 #[async_trait]
256 impl Handler for StreamingHandler {
257 async fn dispatch(&self, _req: Request) -> DispatchOutcome {
258 use crate::server::EventStream;
259 struct ThreeEvents {
260 remaining: Vec<serde_json::Value>,
261 }
262 #[async_trait]
263 impl EventStream for ThreeEvents {
264 async fn next_event(&mut self) -> Option<serde_json::Value> {
265 self.remaining.pop()
266 }
267 }
268 DispatchOutcome::Stream(Box::new(ThreeEvents {
269 remaining: vec![
270 serde_json::json!({ "i": 3 }),
271 serde_json::json!({ "i": 2 }),
272 serde_json::json!({ "i": 1 }),
273 ],
274 }))
275 }
276 }
277
278 #[tokio::test]
279 async fn client_call_stream_invokes_callback_per_event_until_end() {
280 let (c2s_r, mut c2s_w) = tokio::io::duplex(8192);
284 let (s2c_w, s2c_r) = tokio::io::duplex(8192);
285 let server = tokio::spawn(crate::server::handle_conn(c2s_r, s2c_w, Arc::new(StreamingHandler)));
286
287 let req = Request { id: 1, verb: "tail".to_string(), args: serde_json::Value::Null };
289 let bytes = encode_line(&req).expect("encode");
290 c2s_w.write_all(&bytes).await.expect("write");
291
292 let mut events = Vec::new();
293 let mut lines = BufReader::new(s2c_r).lines();
294 loop {
295 let line = lines.next_line().await.expect("read").expect("line");
296 let resp: Response = serde_json::from_str(&line).expect("parse");
297 match resp.outcome {
298 ResponseOutcome::Event { event } => events.push(event),
299 ResponseOutcome::End { .. } => break,
300 other => panic!("unexpected: {other:?}"),
301 }
302 }
303 drop(c2s_w);
305 let _ = server.await;
306 assert_eq!(events.len(), 3, "all three events consumed");
307 assert_eq!(events[0]["i"], 1, "FIFO order from the stream");
308 assert_eq!(events[1]["i"], 2);
309 assert_eq!(events[2]["i"], 3);
310 }
311
312 #[tokio::test]
313 async fn client_io_error_on_missing_socket() {
314 let tmp = tempfile::tempdir().expect("tempdir");
315 let path = tmp.path().join("not-there.sock");
316 let client = UnixMgmtClient::new(&path);
317 let err = client
318 .call::<_, PingResult>("ping", &NoArgs {})
319 .await
320 .expect_err("must fail without a server");
321 assert!(matches!(err, MgmtClientError::Io(_)), "unexpected variant: {err:?}");
322 }
323}