containerflare_command/
lib.rs

1use std::str::FromStr;
2use std::sync::Arc;
3use std::time::Duration;
4
5#[cfg(unix)]
6use std::path::PathBuf;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use thiserror::Error;
11use tokio::io::{self, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
12use tokio::sync::Mutex;
13use tokio::time;
14
15#[cfg(unix)]
16use tokio::net::UnixStream;
17use tokio::net::{
18    TcpStream, tcp::OwnedReadHalf as TcpOwnedReadHalf, tcp::OwnedWriteHalf as TcpOwnedWriteHalf,
19};
20#[cfg(unix)]
21use tokio::net::{
22    unix::OwnedReadHalf as UnixOwnedReadHalf, unix::OwnedWriteHalf as UnixOwnedWriteHalf,
23};
24
25const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
26
27/// Describes how the container establishes the host command channel transport.
28#[derive(Clone, Debug, PartialEq, Eq, Default)]
29pub enum CommandEndpoint {
30    #[default]
31    Stdio,
32    #[cfg(unix)]
33    UnixSocket(PathBuf),
34    Tcp(String),
35    /// Marker used when the runtime intentionally disables the command channel.
36    Unavailable,
37}
38
39impl FromStr for CommandEndpoint {
40    type Err = CommandEndpointParseError;
41
42    fn from_str(s: &str) -> Result<Self, Self::Err> {
43        let value = s.trim();
44        if value.eq_ignore_ascii_case("stdio") {
45            return Ok(CommandEndpoint::Stdio);
46        }
47
48        if value.eq_ignore_ascii_case("disabled") || value.eq_ignore_ascii_case("unavailable") {
49            return Ok(CommandEndpoint::Unavailable);
50        }
51
52        #[cfg(unix)]
53        if let Some(path) = value.strip_prefix("unix://") {
54            return Ok(CommandEndpoint::UnixSocket(PathBuf::from(path)));
55        }
56
57        if let Some(addr) = value.strip_prefix("tcp://") {
58            return Ok(CommandEndpoint::Tcp(addr.to_owned()));
59        }
60
61        Err(CommandEndpointParseError::InvalidCommandEndpoint(
62            value.to_owned(),
63        ))
64    }
65}
66
67/// Errors encountered while parsing a [`CommandEndpoint`] from a string.
68#[derive(Debug, Error, Clone)]
69pub enum CommandEndpointParseError {
70    #[error("invalid command endpoint: {0}")]
71    InvalidCommandEndpoint(String),
72}
73
74/// High-level client that talks to Cloudflare's host-managed command channel (Cloud Run does not expose one).
75///
76/// Commands are framed as JSON lines and travel over stdin/stdout (default), TCP, or
77/// Unix sockets (when enabled). Responses are deserialized back into [`CommandResponse`]
78/// instances and surfaced through async APIs.
79///
80/// # Transport Modes
81/// - `stdio`: bidirectional pipes that the Workers container sidecar keeps open.
82/// - `tcp://host:port`: an explicit TCP socket managed by the sidecar.
83/// - `unix://path` *(Unix only)*: a Unix domain socket exposed by the sidecar.
84///
85/// # Errors
86/// All async constructors and [`CommandClient::send`] return [`CommandError`] when the transport
87/// cannot be established, the host drops the channel, or the host reports a failure.
88///
89/// # Examples
90/// ```ignore
91/// use containerflare_command::{CommandClient, CommandEndpoint, CommandRequest};
92///
93/// # async fn demo() -> Result<(), Box<dyn std::error::Error>> {
94/// let client = CommandClient::connect(CommandEndpoint::Stdio).await?;
95/// let response = client.send(CommandRequest::empty("health_check")).await?;
96/// assert!(response.ok);
97/// # Ok(())
98/// # }
99/// ```
100#[derive(Clone, Debug)]
101pub struct CommandClient {
102    inner: Arc<CommandClientInner>,
103}
104
105#[derive(Debug)]
106struct CommandClientInner {
107    endpoint: CommandEndpoint,
108    writer: CommandWriter,
109    reader: CommandReader,
110    timeout: Duration,
111}
112
113impl CommandClient {
114    /// Connects to the configured endpoint using the default timeout.
115    ///
116    /// # Parameters
117    /// * `endpoint` - Transport descriptor (stdio, TCP, or Unix socket).
118    ///
119    /// # Returns
120    /// A connected [`CommandClient`] ready to issue commands.
121    ///
122    /// # Errors
123    /// Returns [`CommandError`] if the underlying transport cannot be opened or is closed
124    /// before the connection is established.
125    ///
126    /// # Panics
127    /// Does not panic.
128    pub async fn connect(endpoint: CommandEndpoint) -> Result<Self, CommandError> {
129        Self::connect_with_timeout(endpoint, DEFAULT_COMMAND_TIMEOUT).await
130    }
131
132    /// Connects to the endpoint and enforces a custom read timeout.
133    ///
134    /// # Parameters
135    /// * `endpoint` - Transport descriptor (stdio, TCP, or Unix socket).
136    /// * `timeout` - Maximum duration to wait for each response before failing.
137    ///
138    /// # Returns
139    /// A connected [`CommandClient`] that enforces the provided timeout for every command.
140    ///
141    /// # Errors
142    /// Returns [`CommandError`] if the underlying transport cannot be opened or the timeout
143    /// elapses while establishing the connection.
144    ///
145    /// # Panics
146    /// Does not panic.
147    pub async fn connect_with_timeout(
148        endpoint: CommandEndpoint,
149        timeout: Duration,
150    ) -> Result<Self, CommandError> {
151        let (writer, reader) = match &endpoint {
152            CommandEndpoint::Stdio => (
153                CommandWriter::Stdio(Mutex::new(tokio::io::stdout())),
154                CommandReader::Stdio(Mutex::new(BufReader::new(tokio::io::stdin()))),
155            ),
156            CommandEndpoint::Tcp(addr) => {
157                let stream = TcpStream::connect(addr).await?;
158                let (read_half, write_half) = stream.into_split();
159                (
160                    CommandWriter::Tcp(Mutex::new(write_half)),
161                    CommandReader::Tcp(Mutex::new(BufReader::new(read_half))),
162                )
163            }
164            #[cfg(unix)]
165            CommandEndpoint::UnixSocket(path) => {
166                let stream = UnixStream::connect(path).await?;
167                let (read_half, write_half) = stream.into_split();
168                (
169                    CommandWriter::Unix(Mutex::new(write_half)),
170                    CommandReader::Unix(Mutex::new(BufReader::new(read_half))),
171                )
172            }
173            CommandEndpoint::Unavailable => {
174                return Err(CommandError::Unavailable(
175                    "command endpoint marked unavailable".into(),
176                ));
177            }
178        };
179
180        Ok(Self {
181            inner: Arc::new(CommandClientInner {
182                endpoint,
183                writer,
184                reader,
185                timeout,
186            }),
187        })
188    }
189
190    /// Creates a [`CommandClient`] that always reports an unavailable channel.
191    ///
192    /// This is useful for runtimes (Google Cloud Run, local testing, etc.) that do not expose
193    /// a host-managed command bus but still want to share the API surface.
194    pub fn unavailable(reason: impl Into<String>) -> Self {
195        let reason = reason.into();
196        let shared = Arc::new(reason);
197        Self {
198            inner: Arc::new(CommandClientInner {
199                endpoint: CommandEndpoint::Unavailable,
200                writer: CommandWriter::Unavailable(shared.clone()),
201                reader: CommandReader::Unavailable(shared),
202                timeout: DEFAULT_COMMAND_TIMEOUT,
203            }),
204        }
205    }
206
207    /// Returns the endpoint backing this client.
208    pub fn endpoint(&self) -> &CommandEndpoint {
209        &self.inner.endpoint
210    }
211
212    /// Sends a command request and waits for a response (or timeout).
213    ///
214    /// # Parameters
215    /// * `request` - Structured command for the Workers sidecar.
216    ///
217    /// # Returns
218    /// The [`CommandResponse`] emitted by the sidecar.
219    ///
220    /// # Errors
221    /// Returns [`CommandError`] if the channel closes, the response payload cannot be
222    /// deserialized, the command reports a failure, or the read timeout elapses.
223    ///
224    /// # Panics
225    /// Does not panic.
226    pub async fn send(&self, request: CommandRequest) -> Result<CommandResponse, CommandError> {
227        self.inner.writer.send(&request).await?;
228
229        let response = time::timeout(self.inner.timeout, self.inner.reader.read()).await;
230        let response = match response {
231            Ok(result) => result?,
232            Err(_) => return Err(CommandError::Timeout(self.inner.timeout)),
233        };
234
235        if response.ok {
236            Ok(response)
237        } else {
238            let diagnostic = response
239                .diagnostic
240                .clone()
241                .unwrap_or_else(|| "host returned failure".to_owned());
242            Err(CommandError::CommandFailure {
243                diagnostic,
244                payload: response.payload.clone(),
245            })
246        }
247    }
248}
249
250/// JSON payload describing a command issued to the host.
251#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct CommandRequest {
253    /// Command verb recognized by the Workers sidecar.
254    pub command: String,
255    /// Structured JSON payload to accompany the command (defaults to `null`).
256    #[serde(default)]
257    pub payload: serde_json::Value,
258}
259
260impl CommandRequest {
261    /// Creates a new request with the provided command name and payload.
262    pub fn new(command: impl Into<String>, payload: serde_json::Value) -> Self {
263        Self {
264            command: command.into(),
265            payload,
266        }
267    }
268
269    /// Creates a request whose payload is `null`.
270    pub fn empty(command: impl Into<String>) -> Self {
271        Self::new(command, serde_json::Value::Null)
272    }
273}
274
275/// Response returned by the host for a previously issued command.
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct CommandResponse {
278    /// Indicates whether the host executed the command successfully.
279    pub ok: bool,
280    /// JSON payload returned by the host (defaults to `null`).
281    #[serde(default)]
282    pub payload: serde_json::Value,
283    /// Optional diagnostic string supplied by the host when `ok == false`.
284    #[serde(default)]
285    pub diagnostic: Option<String>,
286}
287
288impl CommandResponse {
289    /// Constructs a success response with an empty payload.
290    pub fn ok() -> Self {
291        Self {
292            ok: true,
293            payload: serde_json::Value::Null,
294            diagnostic: None,
295        }
296    }
297}
298/// Errors emitted by [`CommandClient`] when transport or payload handling fails.
299#[derive(Debug, Error)]
300pub enum CommandError {
301    #[error("command failed: {diagnostic}")]
302    CommandFailure { diagnostic: String, payload: Value },
303    #[error("command transport closed")]
304    TransportClosed,
305    #[error("command timed out after {0:?}")]
306    Timeout(Duration),
307    #[error("io error: {0}")]
308    Io(#[from] io::Error),
309    #[error("invalid command payload: {0}")]
310    Serialization(#[from] serde_json::Error),
311    #[error("command channel unavailable: {0}")]
312    Unavailable(String),
313}
314
315#[derive(Debug)]
316enum CommandWriter {
317    Stdio(Mutex<tokio::io::Stdout>),
318    Tcp(Mutex<TcpOwnedWriteHalf>),
319    #[cfg(unix)]
320    Unix(Mutex<UnixOwnedWriteHalf>),
321    Unavailable(Arc<String>),
322}
323
324#[derive(Debug)]
325enum CommandReader {
326    Stdio(Mutex<BufReader<tokio::io::Stdin>>),
327    Tcp(Mutex<BufReader<TcpOwnedReadHalf>>),
328    #[cfg(unix)]
329    Unix(Mutex<BufReader<UnixOwnedReadHalf>>),
330    Unavailable(Arc<String>),
331}
332
333impl CommandWriter {
334    async fn send(&self, request: &CommandRequest) -> Result<(), CommandError> {
335        let line = serde_json::to_string(request)?;
336        match self {
337            CommandWriter::Stdio(writer) => Self::write_line(writer, &line).await,
338            CommandWriter::Tcp(writer) => Self::write_line(writer, &line).await,
339            #[cfg(unix)]
340            CommandWriter::Unix(writer) => Self::write_line(writer, &line).await,
341            CommandWriter::Unavailable(reason) => {
342                Err(CommandError::Unavailable(reason.as_ref().clone()))
343            }
344        }
345    }
346
347    async fn write_line<W>(writer: &Mutex<W>, line: &str) -> Result<(), CommandError>
348    where
349        W: AsyncWrite + Unpin + Send,
350    {
351        let mut guard = writer.lock().await;
352        guard.write_all(line.as_bytes()).await?;
353        guard.write_all(b"\n").await?;
354        guard.flush().await?;
355        Ok(())
356    }
357}
358
359impl CommandReader {
360    async fn read(&self) -> Result<CommandResponse, CommandError> {
361        match self {
362            CommandReader::Stdio(reader) => Self::read_line(reader).await,
363            CommandReader::Tcp(reader) => Self::read_line(reader).await,
364            #[cfg(unix)]
365            CommandReader::Unix(reader) => Self::read_line(reader).await,
366            CommandReader::Unavailable(reason) => {
367                Err(CommandError::Unavailable(reason.as_ref().clone()))
368            }
369        }
370    }
371
372    async fn read_line<R>(reader: &Mutex<BufReader<R>>) -> Result<CommandResponse, CommandError>
373    where
374        R: AsyncRead + Unpin + Send,
375    {
376        let mut guard = reader.lock().await;
377        let mut buf = String::new();
378        let read = guard.read_line(&mut buf).await?;
379        if read == 0 {
380            return Err(CommandError::TransportClosed);
381        }
382        let response = serde_json::from_str(&buf)?;
383        Ok(response)
384    }
385}