1use std::str::FromStr;
2use std::sync::Arc;
3use std::time::Duration;
4
5#[cfg(unix)]
6use std::path::PathBuf;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use thiserror::Error;
11use tokio::io::{self, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
12use tokio::sync::Mutex;
13use tokio::time;
14
15#[cfg(unix)]
16use tokio::net::UnixStream;
17use tokio::net::{
18 TcpStream, tcp::OwnedReadHalf as TcpOwnedReadHalf, tcp::OwnedWriteHalf as TcpOwnedWriteHalf,
19};
20#[cfg(unix)]
21use tokio::net::{
22 unix::OwnedReadHalf as UnixOwnedReadHalf, unix::OwnedWriteHalf as UnixOwnedWriteHalf,
23};
24
25const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
26
27#[derive(Clone, Debug, PartialEq, Eq, Default)]
29pub enum CommandEndpoint {
30 #[default]
31 Stdio,
32 #[cfg(unix)]
33 UnixSocket(PathBuf),
34 Tcp(String),
35 Unavailable,
37}
38
39impl FromStr for CommandEndpoint {
40 type Err = CommandEndpointParseError;
41
42 fn from_str(s: &str) -> Result<Self, Self::Err> {
43 let value = s.trim();
44 if value.eq_ignore_ascii_case("stdio") {
45 return Ok(CommandEndpoint::Stdio);
46 }
47
48 if value.eq_ignore_ascii_case("disabled") || value.eq_ignore_ascii_case("unavailable") {
49 return Ok(CommandEndpoint::Unavailable);
50 }
51
52 #[cfg(unix)]
53 if let Some(path) = value.strip_prefix("unix://") {
54 return Ok(CommandEndpoint::UnixSocket(PathBuf::from(path)));
55 }
56
57 if let Some(addr) = value.strip_prefix("tcp://") {
58 return Ok(CommandEndpoint::Tcp(addr.to_owned()));
59 }
60
61 Err(CommandEndpointParseError::InvalidCommandEndpoint(
62 value.to_owned(),
63 ))
64 }
65}
66
67#[derive(Debug, Error, Clone)]
69pub enum CommandEndpointParseError {
70 #[error("invalid command endpoint: {0}")]
71 InvalidCommandEndpoint(String),
72}
73
74#[derive(Clone, Debug)]
101pub struct CommandClient {
102 inner: Arc<CommandClientInner>,
103}
104
105#[derive(Debug)]
106struct CommandClientInner {
107 endpoint: CommandEndpoint,
108 writer: CommandWriter,
109 reader: CommandReader,
110 timeout: Duration,
111}
112
113impl CommandClient {
114 pub async fn connect(endpoint: CommandEndpoint) -> Result<Self, CommandError> {
129 Self::connect_with_timeout(endpoint, DEFAULT_COMMAND_TIMEOUT).await
130 }
131
132 pub async fn connect_with_timeout(
148 endpoint: CommandEndpoint,
149 timeout: Duration,
150 ) -> Result<Self, CommandError> {
151 let (writer, reader) = match &endpoint {
152 CommandEndpoint::Stdio => (
153 CommandWriter::Stdio(Mutex::new(tokio::io::stdout())),
154 CommandReader::Stdio(Mutex::new(BufReader::new(tokio::io::stdin()))),
155 ),
156 CommandEndpoint::Tcp(addr) => {
157 let stream = TcpStream::connect(addr).await?;
158 let (read_half, write_half) = stream.into_split();
159 (
160 CommandWriter::Tcp(Mutex::new(write_half)),
161 CommandReader::Tcp(Mutex::new(BufReader::new(read_half))),
162 )
163 }
164 #[cfg(unix)]
165 CommandEndpoint::UnixSocket(path) => {
166 let stream = UnixStream::connect(path).await?;
167 let (read_half, write_half) = stream.into_split();
168 (
169 CommandWriter::Unix(Mutex::new(write_half)),
170 CommandReader::Unix(Mutex::new(BufReader::new(read_half))),
171 )
172 }
173 CommandEndpoint::Unavailable => {
174 return Err(CommandError::Unavailable(
175 "command endpoint marked unavailable".into(),
176 ));
177 }
178 };
179
180 Ok(Self {
181 inner: Arc::new(CommandClientInner {
182 endpoint,
183 writer,
184 reader,
185 timeout,
186 }),
187 })
188 }
189
190 pub fn unavailable(reason: impl Into<String>) -> Self {
195 let reason = reason.into();
196 let shared = Arc::new(reason);
197 Self {
198 inner: Arc::new(CommandClientInner {
199 endpoint: CommandEndpoint::Unavailable,
200 writer: CommandWriter::Unavailable(shared.clone()),
201 reader: CommandReader::Unavailable(shared),
202 timeout: DEFAULT_COMMAND_TIMEOUT,
203 }),
204 }
205 }
206
207 pub fn endpoint(&self) -> &CommandEndpoint {
209 &self.inner.endpoint
210 }
211
212 pub async fn send(&self, request: CommandRequest) -> Result<CommandResponse, CommandError> {
227 self.inner.writer.send(&request).await?;
228
229 let response = time::timeout(self.inner.timeout, self.inner.reader.read()).await;
230 let response = match response {
231 Ok(result) => result?,
232 Err(_) => return Err(CommandError::Timeout(self.inner.timeout)),
233 };
234
235 if response.ok {
236 Ok(response)
237 } else {
238 let diagnostic = response
239 .diagnostic
240 .clone()
241 .unwrap_or_else(|| "host returned failure".to_owned());
242 Err(CommandError::CommandFailure {
243 diagnostic,
244 payload: response.payload.clone(),
245 })
246 }
247 }
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct CommandRequest {
253 pub command: String,
255 #[serde(default)]
257 pub payload: serde_json::Value,
258}
259
260impl CommandRequest {
261 pub fn new(command: impl Into<String>, payload: serde_json::Value) -> Self {
263 Self {
264 command: command.into(),
265 payload,
266 }
267 }
268
269 pub fn empty(command: impl Into<String>) -> Self {
271 Self::new(command, serde_json::Value::Null)
272 }
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct CommandResponse {
278 pub ok: bool,
280 #[serde(default)]
282 pub payload: serde_json::Value,
283 #[serde(default)]
285 pub diagnostic: Option<String>,
286}
287
288impl CommandResponse {
289 pub fn ok() -> Self {
291 Self {
292 ok: true,
293 payload: serde_json::Value::Null,
294 diagnostic: None,
295 }
296 }
297}
298#[derive(Debug, Error)]
300pub enum CommandError {
301 #[error("command failed: {diagnostic}")]
302 CommandFailure { diagnostic: String, payload: Value },
303 #[error("command transport closed")]
304 TransportClosed,
305 #[error("command timed out after {0:?}")]
306 Timeout(Duration),
307 #[error("io error: {0}")]
308 Io(#[from] io::Error),
309 #[error("invalid command payload: {0}")]
310 Serialization(#[from] serde_json::Error),
311 #[error("command channel unavailable: {0}")]
312 Unavailable(String),
313}
314
315#[derive(Debug)]
316enum CommandWriter {
317 Stdio(Mutex<tokio::io::Stdout>),
318 Tcp(Mutex<TcpOwnedWriteHalf>),
319 #[cfg(unix)]
320 Unix(Mutex<UnixOwnedWriteHalf>),
321 Unavailable(Arc<String>),
322}
323
324#[derive(Debug)]
325enum CommandReader {
326 Stdio(Mutex<BufReader<tokio::io::Stdin>>),
327 Tcp(Mutex<BufReader<TcpOwnedReadHalf>>),
328 #[cfg(unix)]
329 Unix(Mutex<BufReader<UnixOwnedReadHalf>>),
330 Unavailable(Arc<String>),
331}
332
333impl CommandWriter {
334 async fn send(&self, request: &CommandRequest) -> Result<(), CommandError> {
335 let line = serde_json::to_string(request)?;
336 match self {
337 CommandWriter::Stdio(writer) => Self::write_line(writer, &line).await,
338 CommandWriter::Tcp(writer) => Self::write_line(writer, &line).await,
339 #[cfg(unix)]
340 CommandWriter::Unix(writer) => Self::write_line(writer, &line).await,
341 CommandWriter::Unavailable(reason) => {
342 Err(CommandError::Unavailable(reason.as_ref().clone()))
343 }
344 }
345 }
346
347 async fn write_line<W>(writer: &Mutex<W>, line: &str) -> Result<(), CommandError>
348 where
349 W: AsyncWrite + Unpin + Send,
350 {
351 let mut guard = writer.lock().await;
352 guard.write_all(line.as_bytes()).await?;
353 guard.write_all(b"\n").await?;
354 guard.flush().await?;
355 Ok(())
356 }
357}
358
359impl CommandReader {
360 async fn read(&self) -> Result<CommandResponse, CommandError> {
361 match self {
362 CommandReader::Stdio(reader) => Self::read_line(reader).await,
363 CommandReader::Tcp(reader) => Self::read_line(reader).await,
364 #[cfg(unix)]
365 CommandReader::Unix(reader) => Self::read_line(reader).await,
366 CommandReader::Unavailable(reason) => {
367 Err(CommandError::Unavailable(reason.as_ref().clone()))
368 }
369 }
370 }
371
372 async fn read_line<R>(reader: &Mutex<BufReader<R>>) -> Result<CommandResponse, CommandError>
373 where
374 R: AsyncRead + Unpin + Send,
375 {
376 let mut guard = reader.lock().await;
377 let mut buf = String::new();
378 let read = guard.read_line(&mut buf).await?;
379 if read == 0 {
380 return Err(CommandError::TransportClosed);
381 }
382 let response = serde_json::from_str(&buf)?;
383 Ok(response)
384 }
385}