1use std::str::FromStr;
2use std::sync::Arc;
3use std::time::Duration;
4
5#[cfg(unix)]
6use std::path::PathBuf;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use thiserror::Error;
11use tokio::io::{self, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
12use tokio::sync::Mutex;
13use tokio::time;
14
15#[cfg(unix)]
16use tokio::net::UnixStream;
17use tokio::net::{
18 TcpStream, tcp::OwnedReadHalf as TcpOwnedReadHalf, tcp::OwnedWriteHalf as TcpOwnedWriteHalf,
19};
20#[cfg(unix)]
21use tokio::net::{
22 unix::OwnedReadHalf as UnixOwnedReadHalf, unix::OwnedWriteHalf as UnixOwnedWriteHalf,
23};
24
25const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
26
27#[derive(Clone, Debug, PartialEq, Eq, Default)]
29pub enum CommandEndpoint {
30 #[default]
31 Stdio,
32 #[cfg(unix)]
33 UnixSocket(PathBuf),
34 Tcp(String),
35}
36
37impl FromStr for CommandEndpoint {
38 type Err = CommandEndpointParseError;
39
40 fn from_str(s: &str) -> Result<Self, Self::Err> {
41 let value = s.trim();
42 if value.eq_ignore_ascii_case("stdio") {
43 return Ok(CommandEndpoint::Stdio);
44 }
45
46 #[cfg(unix)]
47 if let Some(path) = value.strip_prefix("unix://") {
48 return Ok(CommandEndpoint::UnixSocket(PathBuf::from(path)));
49 }
50
51 if let Some(addr) = value.strip_prefix("tcp://") {
52 return Ok(CommandEndpoint::Tcp(addr.to_owned()));
53 }
54
55 Err(CommandEndpointParseError::InvalidCommandEndpoint(
56 value.to_owned(),
57 ))
58 }
59}
60
61#[derive(Debug, Error, Clone)]
63pub enum CommandEndpointParseError {
64 #[error("invalid command endpoint: {0}")]
65 InvalidCommandEndpoint(String),
66}
67
68#[derive(Clone, Debug)]
95pub struct CommandClient {
96 inner: Arc<CommandClientInner>,
97}
98
99#[derive(Debug)]
100struct CommandClientInner {
101 endpoint: CommandEndpoint,
102 writer: CommandWriter,
103 reader: CommandReader,
104 timeout: Duration,
105}
106
107impl CommandClient {
108 pub async fn connect(endpoint: CommandEndpoint) -> Result<Self, CommandError> {
123 Self::connect_with_timeout(endpoint, DEFAULT_COMMAND_TIMEOUT).await
124 }
125
126 pub async fn connect_with_timeout(
142 endpoint: CommandEndpoint,
143 timeout: Duration,
144 ) -> Result<Self, CommandError> {
145 let (writer, reader) = match &endpoint {
146 CommandEndpoint::Stdio => (
147 CommandWriter::Stdio(Mutex::new(tokio::io::stdout())),
148 CommandReader::Stdio(Mutex::new(BufReader::new(tokio::io::stdin()))),
149 ),
150 CommandEndpoint::Tcp(addr) => {
151 let stream = TcpStream::connect(addr).await?;
152 let (read_half, write_half) = stream.into_split();
153 (
154 CommandWriter::Tcp(Mutex::new(write_half)),
155 CommandReader::Tcp(Mutex::new(BufReader::new(read_half))),
156 )
157 }
158 #[cfg(unix)]
159 CommandEndpoint::UnixSocket(path) => {
160 let stream = UnixStream::connect(path).await?;
161 let (read_half, write_half) = stream.into_split();
162 (
163 CommandWriter::Unix(Mutex::new(write_half)),
164 CommandReader::Unix(Mutex::new(BufReader::new(read_half))),
165 )
166 }
167 };
168
169 Ok(Self {
170 inner: Arc::new(CommandClientInner {
171 endpoint,
172 writer,
173 reader,
174 timeout,
175 }),
176 })
177 }
178
179 pub fn endpoint(&self) -> &CommandEndpoint {
181 &self.inner.endpoint
182 }
183
184 pub async fn send(&self, request: CommandRequest) -> Result<CommandResponse, CommandError> {
199 self.inner.writer.send(&request).await?;
200
201 let response = time::timeout(self.inner.timeout, self.inner.reader.read()).await;
202 let response = match response {
203 Ok(result) => result?,
204 Err(_) => return Err(CommandError::Timeout(self.inner.timeout)),
205 };
206
207 if response.ok {
208 Ok(response)
209 } else {
210 let diagnostic = response
211 .diagnostic
212 .clone()
213 .unwrap_or_else(|| "host returned failure".to_owned());
214 Err(CommandError::CommandFailure {
215 diagnostic,
216 payload: response.payload.clone(),
217 })
218 }
219 }
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct CommandRequest {
225 pub command: String,
227 #[serde(default)]
229 pub payload: serde_json::Value,
230}
231
232impl CommandRequest {
233 pub fn new(command: impl Into<String>, payload: serde_json::Value) -> Self {
235 Self {
236 command: command.into(),
237 payload,
238 }
239 }
240
241 pub fn empty(command: impl Into<String>) -> Self {
243 Self::new(command, serde_json::Value::Null)
244 }
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct CommandResponse {
250 pub ok: bool,
252 #[serde(default)]
254 pub payload: serde_json::Value,
255 #[serde(default)]
257 pub diagnostic: Option<String>,
258}
259
260impl CommandResponse {
261 pub fn ok() -> Self {
263 Self {
264 ok: true,
265 payload: serde_json::Value::Null,
266 diagnostic: None,
267 }
268 }
269}
270#[derive(Debug, Error)]
272pub enum CommandError {
273 #[error("command failed: {diagnostic}")]
274 CommandFailure { diagnostic: String, payload: Value },
275 #[error("command transport closed")]
276 TransportClosed,
277 #[error("command timed out after {0:?}")]
278 Timeout(Duration),
279 #[error("io error: {0}")]
280 Io(#[from] io::Error),
281 #[error("invalid command payload: {0}")]
282 Serialization(#[from] serde_json::Error),
283}
284
285#[derive(Debug)]
286enum CommandWriter {
287 Stdio(Mutex<tokio::io::Stdout>),
288 Tcp(Mutex<TcpOwnedWriteHalf>),
289 #[cfg(unix)]
290 Unix(Mutex<UnixOwnedWriteHalf>),
291}
292
293#[derive(Debug)]
294enum CommandReader {
295 Stdio(Mutex<BufReader<tokio::io::Stdin>>),
296 Tcp(Mutex<BufReader<TcpOwnedReadHalf>>),
297 #[cfg(unix)]
298 Unix(Mutex<BufReader<UnixOwnedReadHalf>>),
299}
300
301impl CommandWriter {
302 async fn send(&self, request: &CommandRequest) -> Result<(), CommandError> {
303 let line = serde_json::to_string(request)?;
304 match self {
305 CommandWriter::Stdio(writer) => Self::write_line(writer, &line).await,
306 CommandWriter::Tcp(writer) => Self::write_line(writer, &line).await,
307 #[cfg(unix)]
308 CommandWriter::Unix(writer) => Self::write_line(writer, &line).await,
309 }
310 }
311
312 async fn write_line<W>(writer: &Mutex<W>, line: &str) -> Result<(), CommandError>
313 where
314 W: AsyncWrite + Unpin + Send,
315 {
316 let mut guard = writer.lock().await;
317 guard.write_all(line.as_bytes()).await?;
318 guard.write_all(b"\n").await?;
319 guard.flush().await?;
320 Ok(())
321 }
322}
323
324impl CommandReader {
325 async fn read(&self) -> Result<CommandResponse, CommandError> {
326 match self {
327 CommandReader::Stdio(reader) => Self::read_line(reader).await,
328 CommandReader::Tcp(reader) => Self::read_line(reader).await,
329 #[cfg(unix)]
330 CommandReader::Unix(reader) => Self::read_line(reader).await,
331 }
332 }
333
334 async fn read_line<R>(reader: &Mutex<BufReader<R>>) -> Result<CommandResponse, CommandError>
335 where
336 R: AsyncRead + Unpin + Send,
337 {
338 let mut guard = reader.lock().await;
339 let mut buf = String::new();
340 let read = guard.read_line(&mut buf).await?;
341 if read == 0 {
342 return Err(CommandError::TransportClosed);
343 }
344 let response = serde_json::from_str(&buf)?;
345 Ok(response)
346 }
347}