containerflare_command/
lib.rs

1use std::str::FromStr;
2use std::sync::Arc;
3use std::time::Duration;
4
5#[cfg(unix)]
6use std::path::PathBuf;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use thiserror::Error;
11use tokio::io::{self, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
12use tokio::sync::Mutex;
13use tokio::time;
14
15#[cfg(unix)]
16use tokio::net::UnixStream;
17use tokio::net::{
18    TcpStream, tcp::OwnedReadHalf as TcpOwnedReadHalf, tcp::OwnedWriteHalf as TcpOwnedWriteHalf,
19};
20#[cfg(unix)]
21use tokio::net::{
22    unix::OwnedReadHalf as UnixOwnedReadHalf, unix::OwnedWriteHalf as UnixOwnedWriteHalf,
23};
24
25const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
26
27/// Describes how the container establishes the host command channel transport.
28#[derive(Clone, Debug, PartialEq, Eq, Default)]
29pub enum CommandEndpoint {
30    #[default]
31    Stdio,
32    #[cfg(unix)]
33    UnixSocket(PathBuf),
34    Tcp(String),
35}
36
37impl FromStr for CommandEndpoint {
38    type Err = CommandEndpointParseError;
39
40    fn from_str(s: &str) -> Result<Self, Self::Err> {
41        let value = s.trim();
42        if value.eq_ignore_ascii_case("stdio") {
43            return Ok(CommandEndpoint::Stdio);
44        }
45
46        #[cfg(unix)]
47        if let Some(path) = value.strip_prefix("unix://") {
48            return Ok(CommandEndpoint::UnixSocket(PathBuf::from(path)));
49        }
50
51        if let Some(addr) = value.strip_prefix("tcp://") {
52            return Ok(CommandEndpoint::Tcp(addr.to_owned()));
53        }
54
55        Err(CommandEndpointParseError::InvalidCommandEndpoint(
56            value.to_owned(),
57        ))
58    }
59}
60
61/// Errors encountered while parsing a [`CommandEndpoint`] from a string.
62#[derive(Debug, Error, Clone)]
63pub enum CommandEndpointParseError {
64    #[error("invalid command endpoint: {0}")]
65    InvalidCommandEndpoint(String),
66}
67
68/// High-level client that talks to Cloudflare's host-managed command channel.
69///
70/// Commands are framed as JSON lines and travel over stdin/stdout (default), TCP, or
71/// Unix sockets (when enabled). Responses are deserialized back into [`CommandResponse`]
72/// instances and surfaced through async APIs.
73///
74/// # Transport Modes
75/// - `stdio`: bidirectional pipes that the Workers container sidecar keeps open.
76/// - `tcp://host:port`: an explicit TCP socket managed by the sidecar.
77/// - `unix://path` *(Unix only)*: a Unix domain socket exposed by the sidecar.
78///
79/// # Errors
80/// All async constructors and [`CommandClient::send`] return [`CommandError`] when the transport
81/// cannot be established, the host drops the channel, or the host reports a failure.
82///
83/// # Examples
84/// ```ignore
85/// use containerflare_command::{CommandClient, CommandEndpoint, CommandRequest};
86///
87/// # async fn demo() -> Result<(), Box<dyn std::error::Error>> {
88/// let client = CommandClient::connect(CommandEndpoint::Stdio).await?;
89/// let response = client.send(CommandRequest::empty("health_check")).await?;
90/// assert!(response.ok);
91/// # Ok(())
92/// # }
93/// ```
94#[derive(Clone, Debug)]
95pub struct CommandClient {
96    inner: Arc<CommandClientInner>,
97}
98
99#[derive(Debug)]
100struct CommandClientInner {
101    endpoint: CommandEndpoint,
102    writer: CommandWriter,
103    reader: CommandReader,
104    timeout: Duration,
105}
106
107impl CommandClient {
108    /// Connects to the configured endpoint using the default timeout.
109    ///
110    /// # Parameters
111    /// * `endpoint` - Transport descriptor (stdio, TCP, or Unix socket).
112    ///
113    /// # Returns
114    /// A connected [`CommandClient`] ready to issue commands.
115    ///
116    /// # Errors
117    /// Returns [`CommandError`] if the underlying transport cannot be opened or is closed
118    /// before the connection is established.
119    ///
120    /// # Panics
121    /// Does not panic.
122    pub async fn connect(endpoint: CommandEndpoint) -> Result<Self, CommandError> {
123        Self::connect_with_timeout(endpoint, DEFAULT_COMMAND_TIMEOUT).await
124    }
125
126    /// Connects to the endpoint and enforces a custom read timeout.
127    ///
128    /// # Parameters
129    /// * `endpoint` - Transport descriptor (stdio, TCP, or Unix socket).
130    /// * `timeout` - Maximum duration to wait for each response before failing.
131    ///
132    /// # Returns
133    /// A connected [`CommandClient`] that enforces the provided timeout for every command.
134    ///
135    /// # Errors
136    /// Returns [`CommandError`] if the underlying transport cannot be opened or the timeout
137    /// elapses while establishing the connection.
138    ///
139    /// # Panics
140    /// Does not panic.
141    pub async fn connect_with_timeout(
142        endpoint: CommandEndpoint,
143        timeout: Duration,
144    ) -> Result<Self, CommandError> {
145        let (writer, reader) = match &endpoint {
146            CommandEndpoint::Stdio => (
147                CommandWriter::Stdio(Mutex::new(tokio::io::stdout())),
148                CommandReader::Stdio(Mutex::new(BufReader::new(tokio::io::stdin()))),
149            ),
150            CommandEndpoint::Tcp(addr) => {
151                let stream = TcpStream::connect(addr).await?;
152                let (read_half, write_half) = stream.into_split();
153                (
154                    CommandWriter::Tcp(Mutex::new(write_half)),
155                    CommandReader::Tcp(Mutex::new(BufReader::new(read_half))),
156                )
157            }
158            #[cfg(unix)]
159            CommandEndpoint::UnixSocket(path) => {
160                let stream = UnixStream::connect(path).await?;
161                let (read_half, write_half) = stream.into_split();
162                (
163                    CommandWriter::Unix(Mutex::new(write_half)),
164                    CommandReader::Unix(Mutex::new(BufReader::new(read_half))),
165                )
166            }
167        };
168
169        Ok(Self {
170            inner: Arc::new(CommandClientInner {
171                endpoint,
172                writer,
173                reader,
174                timeout,
175            }),
176        })
177    }
178
179    /// Returns the endpoint backing this client.
180    pub fn endpoint(&self) -> &CommandEndpoint {
181        &self.inner.endpoint
182    }
183
184    /// Sends a command request and waits for a response (or timeout).
185    ///
186    /// # Parameters
187    /// * `request` - Structured command for the Workers sidecar.
188    ///
189    /// # Returns
190    /// The [`CommandResponse`] emitted by the sidecar.
191    ///
192    /// # Errors
193    /// Returns [`CommandError`] if the channel closes, the response payload cannot be
194    /// deserialized, the command reports a failure, or the read timeout elapses.
195    ///
196    /// # Panics
197    /// Does not panic.
198    pub async fn send(&self, request: CommandRequest) -> Result<CommandResponse, CommandError> {
199        self.inner.writer.send(&request).await?;
200
201        let response = time::timeout(self.inner.timeout, self.inner.reader.read()).await;
202        let response = match response {
203            Ok(result) => result?,
204            Err(_) => return Err(CommandError::Timeout(self.inner.timeout)),
205        };
206
207        if response.ok {
208            Ok(response)
209        } else {
210            let diagnostic = response
211                .diagnostic
212                .clone()
213                .unwrap_or_else(|| "host returned failure".to_owned());
214            Err(CommandError::CommandFailure {
215                diagnostic,
216                payload: response.payload.clone(),
217            })
218        }
219    }
220}
221
222/// JSON payload describing a command issued to the host.
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct CommandRequest {
225    /// Command verb recognized by the Workers sidecar.
226    pub command: String,
227    /// Structured JSON payload to accompany the command (defaults to `null`).
228    #[serde(default)]
229    pub payload: serde_json::Value,
230}
231
232impl CommandRequest {
233    /// Creates a new request with the provided command name and payload.
234    pub fn new(command: impl Into<String>, payload: serde_json::Value) -> Self {
235        Self {
236            command: command.into(),
237            payload,
238        }
239    }
240
241    /// Creates a request whose payload is `null`.
242    pub fn empty(command: impl Into<String>) -> Self {
243        Self::new(command, serde_json::Value::Null)
244    }
245}
246
247/// Response returned by the host for a previously issued command.
248#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct CommandResponse {
250    /// Indicates whether the host executed the command successfully.
251    pub ok: bool,
252    /// JSON payload returned by the host (defaults to `null`).
253    #[serde(default)]
254    pub payload: serde_json::Value,
255    /// Optional diagnostic string supplied by the host when `ok == false`.
256    #[serde(default)]
257    pub diagnostic: Option<String>,
258}
259
260impl CommandResponse {
261    /// Constructs a success response with an empty payload.
262    pub fn ok() -> Self {
263        Self {
264            ok: true,
265            payload: serde_json::Value::Null,
266            diagnostic: None,
267        }
268    }
269}
270/// Errors emitted by [`CommandClient`] when transport or payload handling fails.
271#[derive(Debug, Error)]
272pub enum CommandError {
273    #[error("command failed: {diagnostic}")]
274    CommandFailure { diagnostic: String, payload: Value },
275    #[error("command transport closed")]
276    TransportClosed,
277    #[error("command timed out after {0:?}")]
278    Timeout(Duration),
279    #[error("io error: {0}")]
280    Io(#[from] io::Error),
281    #[error("invalid command payload: {0}")]
282    Serialization(#[from] serde_json::Error),
283}
284
285#[derive(Debug)]
286enum CommandWriter {
287    Stdio(Mutex<tokio::io::Stdout>),
288    Tcp(Mutex<TcpOwnedWriteHalf>),
289    #[cfg(unix)]
290    Unix(Mutex<UnixOwnedWriteHalf>),
291}
292
293#[derive(Debug)]
294enum CommandReader {
295    Stdio(Mutex<BufReader<tokio::io::Stdin>>),
296    Tcp(Mutex<BufReader<TcpOwnedReadHalf>>),
297    #[cfg(unix)]
298    Unix(Mutex<BufReader<UnixOwnedReadHalf>>),
299}
300
301impl CommandWriter {
302    async fn send(&self, request: &CommandRequest) -> Result<(), CommandError> {
303        let line = serde_json::to_string(request)?;
304        match self {
305            CommandWriter::Stdio(writer) => Self::write_line(writer, &line).await,
306            CommandWriter::Tcp(writer) => Self::write_line(writer, &line).await,
307            #[cfg(unix)]
308            CommandWriter::Unix(writer) => Self::write_line(writer, &line).await,
309        }
310    }
311
312    async fn write_line<W>(writer: &Mutex<W>, line: &str) -> Result<(), CommandError>
313    where
314        W: AsyncWrite + Unpin + Send,
315    {
316        let mut guard = writer.lock().await;
317        guard.write_all(line.as_bytes()).await?;
318        guard.write_all(b"\n").await?;
319        guard.flush().await?;
320        Ok(())
321    }
322}
323
324impl CommandReader {
325    async fn read(&self) -> Result<CommandResponse, CommandError> {
326        match self {
327            CommandReader::Stdio(reader) => Self::read_line(reader).await,
328            CommandReader::Tcp(reader) => Self::read_line(reader).await,
329            #[cfg(unix)]
330            CommandReader::Unix(reader) => Self::read_line(reader).await,
331        }
332    }
333
334    async fn read_line<R>(reader: &Mutex<BufReader<R>>) -> Result<CommandResponse, CommandError>
335    where
336        R: AsyncRead + Unpin + Send,
337    {
338        let mut guard = reader.lock().await;
339        let mut buf = String::new();
340        let read = guard.read_line(&mut buf).await?;
341        if read == 0 {
342            return Err(CommandError::TransportClosed);
343        }
344        let response = serde_json::from_str(&buf)?;
345        Ok(response)
346    }
347}