Skip to main content

ash_flare/
distributed.rs

1//! Distributed supervision via TCP/Unix sockets
2//!
3//! Allows supervisors to run in separate processes (local) or on remote machines (network).
4//! Commands are serialized with rkyv for minimal overhead.
5
6#![allow(missing_docs)]
7
8use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
9use serde::{Deserialize, Serialize};
10use std::fmt;
11use std::sync::Arc;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::{TcpListener, TcpStream};
14
15#[cfg(unix)]
16use tokio::net::{UnixListener, UnixStream};
17
18use crate::{ChildInfo as SupervisorChildInfo, ChildType, RestartPolicy, SupervisorHandle, Worker};
19
20/// Remote supervisor address
21#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
22#[rkyv(derive(Debug))]
23pub enum SupervisorAddress {
24    /// TCP socket address (host:port)
25    Tcp(String),
26    /// Unix domain socket path
27    Unix(String),
28}
29
30/// Commands that can be sent to a remote supervisor
31#[allow(missing_docs)]
32#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
33#[rkyv(derive(Debug))]
34#[rkyv(attr(allow(missing_docs)))]
35pub enum RemoteCommand {
36    /// Request supervisor to shut down gracefully
37    Shutdown,
38    /// Get list of all children
39    WhichChildren,
40    /// Terminate a specific child
41    TerminateChild {
42        /// ID of the child to terminate
43        id: String,
44    },
45    /// Get supervisor status
46    Status,
47}
48
49/// Responses from remote supervisor
50#[allow(missing_docs)]
51#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
52#[rkyv(derive(Debug))]
53#[rkyv(attr(allow(missing_docs)))]
54pub enum RemoteResponse {
55    /// Command executed successfully
56    Ok,
57    /// List of child IDs
58    Children(Vec<ChildInfo>),
59    /// Supervisor status information
60    Status(SupervisorStatus),
61    /// Error occurred
62    Error(String),
63}
64
65/// Information about a child process (simplified for serialization)
66#[allow(missing_docs)]
67#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
68#[rkyv(derive(Debug))]
69#[rkyv(attr(allow(missing_docs)))]
70pub struct ChildInfo {
71    /// Unique identifier for the child
72    pub id: String,
73    /// Type of child (Worker or Supervisor)
74    pub child_type: ChildType,
75    /// Restart policy for the child (None for supervisors)
76    pub restart_policy: Option<RestartPolicy>,
77}
78
79impl From<SupervisorChildInfo> for ChildInfo {
80    fn from(info: SupervisorChildInfo) -> Self {
81        Self {
82            id: info.id,
83            child_type: info.child_type,
84            restart_policy: info.restart_policy,
85        }
86    }
87}
88
89/// Status information about a supervisor
90#[allow(missing_docs)]
91#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
92#[rkyv(derive(Debug))]
93#[rkyv(attr(allow(missing_docs)))]
94pub struct SupervisorStatus {
95    /// Name of the supervisor
96    pub name: String,
97    /// Number of children currently running
98    pub children_count: usize,
99    /// Restart strategy being used
100    pub restart_strategy: String,
101    /// Uptime in seconds
102    pub uptime_secs: u64,
103}
104
105/// Handle to communicate with a remote supervisor
106#[derive(Clone)]
107pub struct RemoteSupervisorHandle {
108    address: SupervisorAddress,
109}
110
111impl RemoteSupervisorHandle {
112    /// Create a new handle to a remote supervisor
113    #[must_use]
114    pub fn new(address: SupervisorAddress) -> Self {
115        Self { address }
116    }
117
118    /// Connect to a TCP supervisor
119    ///
120    /// # Errors
121    ///
122    /// Returns an error if the address is invalid.
123    #[allow(clippy::unused_async)]
124    pub async fn connect_tcp(addr: impl Into<String>) -> Result<Self, DistributedError> {
125        let address = SupervisorAddress::Tcp(addr.into());
126        Ok(Self { address })
127    }
128
129    /// Connect to a Unix socket supervisor
130    ///
131    /// # Errors
132    ///
133    /// Returns an error if the path is invalid.
134    #[allow(clippy::unused_async)]
135    pub async fn connect_unix(path: impl Into<String>) -> Result<Self, DistributedError> {
136        let address = SupervisorAddress::Unix(path.into());
137        Ok(Self { address })
138    }
139
140    /// Send a command to the remote supervisor
141    ///
142    /// # Errors
143    ///
144    /// Returns an error if the connection fails or the command cannot be serialized.
145    pub async fn send_command(
146        &self,
147        cmd: RemoteCommand,
148    ) -> Result<RemoteResponse, DistributedError> {
149        match &self.address {
150            SupervisorAddress::Tcp(addr) => {
151                let mut stream = TcpStream::connect(addr).await?;
152                send_message(&mut stream, &cmd).await?;
153                receive_message(&mut stream).await
154            }
155            #[cfg(unix)]
156            SupervisorAddress::Unix(path) => {
157                let mut stream = UnixStream::connect(path).await?;
158                send_message(&mut stream, &cmd).await?;
159                receive_message(&mut stream).await
160            }
161            #[cfg(not(unix))]
162            SupervisorAddress::Unix(_) => Err(DistributedError::Io(std::io::Error::new(
163                std::io::ErrorKind::Unsupported,
164                "Unix sockets are not supported on this platform",
165            ))),
166        }
167    }
168
169    /// Shutdown the remote supervisor
170    ///
171    /// # Errors
172    ///
173    /// Returns an error if the remote connection fails.
174    pub async fn shutdown(&self) -> Result<(), DistributedError> {
175        self.send_command(RemoteCommand::Shutdown).await?;
176        Ok(())
177    }
178
179    /// Get list of children from remote supervisor
180    ///
181    /// # Errors
182    ///
183    /// Returns an error if the remote connection fails or returns an unexpected response.
184    pub async fn which_children(&self) -> Result<Vec<ChildInfo>, DistributedError> {
185        match self.send_command(RemoteCommand::WhichChildren).await? {
186            RemoteResponse::Children(children) => Ok(children),
187            RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
188            _ => Err(DistributedError::UnexpectedResponse),
189        }
190    }
191
192    /// Terminate a child on the remote supervisor
193    ///
194    /// # Errors
195    ///
196    /// Returns an error if the remote connection fails or returns an unexpected response.
197    pub async fn terminate_child(&self, id: &str) -> Result<(), DistributedError> {
198        match self
199            .send_command(RemoteCommand::TerminateChild { id: id.to_owned() })
200            .await?
201        {
202            RemoteResponse::Ok => Ok(()),
203            RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
204            _ => Err(DistributedError::UnexpectedResponse),
205        }
206    }
207
208    /// Get status from remote supervisor
209    ///
210    /// # Errors
211    ///
212    /// Returns an error if the remote connection fails or returns an unexpected response.
213    pub async fn status(&self) -> Result<SupervisorStatus, DistributedError> {
214        match self.send_command(RemoteCommand::Status).await? {
215            RemoteResponse::Status(status) => Ok(status),
216            RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
217            _ => Err(DistributedError::UnexpectedResponse),
218        }
219    }
220}
221
222/// Server that wraps a `SupervisorHandle` and accepts remote commands
223pub struct SupervisorServer<W: Worker> {
224    handle: Arc<SupervisorHandle<W>>,
225}
226
227impl<W: Worker> SupervisorServer<W> {
228    /// Create a new supervisor server wrapping a `SupervisorHandle`
229    #[must_use]
230    pub fn new(handle: SupervisorHandle<W>) -> Self {
231        Self {
232            handle: Arc::new(handle),
233        }
234    }
235
236    /// Start listening on a Unix socket (Unix only)
237    ///
238    /// # Errors
239    ///
240    /// Returns an error if the socket cannot be bound or a connection fails.
241    #[cfg(unix)]
242    pub async fn listen_unix(
243        self,
244        path: impl AsRef<std::path::Path>,
245    ) -> Result<(), DistributedError> {
246        let socket_path = path.as_ref();
247        let _remove_result = std::fs::remove_file(socket_path); // Clean up old socket
248
249        let listener = UnixListener::bind(socket_path)?;
250        tracing::info!(path = %socket_path.display(), "server listening on unix socket");
251
252        loop {
253            let (mut stream, _) = listener.accept().await?;
254            let handle = Arc::clone(&self.handle);
255
256            tokio::spawn(async move {
257                if let Err(e) = Self::handle_connection(&mut stream, handle).await {
258                    tracing::error!(error = %e, "connection error");
259                }
260            });
261        }
262    }
263
264    /// Start listening on a TCP socket
265    ///
266    /// # Errors
267    ///
268    /// Returns an error if the socket cannot be bound or a connection fails.
269    pub async fn listen_tcp(self, addr: impl AsRef<str>) -> Result<(), DistributedError> {
270        let listener = TcpListener::bind(addr.as_ref()).await?;
271        tracing::info!(address = addr.as_ref(), "server listening on tcp");
272
273        loop {
274            let (mut stream, peer) = listener.accept().await?;
275            tracing::debug!(peer = ?peer, "new connection");
276            let handle = Arc::clone(&self.handle);
277
278            tokio::spawn(async move {
279                if let Err(e) = Self::handle_connection(&mut stream, handle).await {
280                    tracing::error!(error = %e, "Connection error");
281                }
282            });
283        }
284    }
285
286    async fn handle_connection<S>(
287        stream: &mut S,
288        handle: Arc<SupervisorHandle<W>>,
289    ) -> Result<(), DistributedError>
290    where
291        S: AsyncReadExt + AsyncWriteExt + Unpin,
292    {
293        let command: RemoteCommand = receive_message(stream).await?;
294        let response = Self::process_command(command, &handle).await;
295        send_message(stream, &response).await?;
296        Ok(())
297    }
298
299    async fn process_command(
300        command: RemoteCommand,
301        handle: &SupervisorHandle<W>,
302    ) -> RemoteResponse {
303        match command {
304            RemoteCommand::Shutdown => match handle.shutdown().await {
305                Ok(()) => RemoteResponse::Ok,
306                Err(e) => RemoteResponse::Error(e.to_string()),
307            },
308            RemoteCommand::WhichChildren => match handle.which_children().await {
309                Ok(children) => {
310                    let child_list: Vec<ChildInfo> = children.into_iter().map(Into::into).collect();
311                    RemoteResponse::Children(child_list)
312                }
313                Err(e) => RemoteResponse::Error(e.to_string()),
314            },
315            RemoteCommand::TerminateChild { id } => match handle.terminate_child(&id).await {
316                Ok(()) => RemoteResponse::Ok,
317                Err(e) => RemoteResponse::Error(e.to_string()),
318            },
319            RemoteCommand::Status => {
320                let restart_strategy = handle
321                    .restart_strategy()
322                    .await
323                    .map_or_else(|_| "Unknown".to_owned(), |s| format!("{s:?}"));
324                let uptime_secs = handle.uptime().await.unwrap_or(0);
325
326                RemoteResponse::Status(SupervisorStatus {
327                    name: handle.name().to_owned(),
328                    children_count: handle.which_children().await.map(|c| c.len()).unwrap_or(0),
329                    restart_strategy,
330                    uptime_secs,
331                })
332            }
333        }
334    }
335}
336
337/// Send a message over a stream (length-prefixed rkyv)
338async fn send_message<S, T>(stream: &mut S, msg: &T) -> Result<(), DistributedError>
339where
340    S: AsyncWriteExt + Unpin,
341    T: Serialize,
342    for<'a> T: RkyvSerialize<
343        rkyv::api::high::HighSerializer<
344            rkyv::util::AlignedVec,
345            rkyv::ser::allocator::ArenaHandle<'a>,
346            rkyv::rancor::Error,
347        >,
348    >,
349{
350    let encoded = rkyv::to_bytes::<rkyv::rancor::Error>(msg)?;
351    let len = u32::try_from(encoded.len())
352        .map_err(|_| DistributedError::MessageTooLarge(encoded.len()))?;
353
354    stream.write_all(&len.to_be_bytes()).await?;
355    stream.write_all(&encoded).await?;
356    stream.flush().await?;
357
358    Ok(())
359}
360
361/// Receive a message from a stream (length-prefixed rkyv)
362#[allow(clippy::as_conversions)]
363async fn receive_message<S, T>(stream: &mut S) -> Result<T, DistributedError>
364where
365    S: AsyncReadExt + Unpin,
366    T: Archive,
367    for<'a> T::Archived: RkyvDeserialize<T, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>,
368{
369    let mut len_bytes = [0u8; 4];
370    stream.read_exact(&mut len_bytes).await?;
371    let len = u32::from_be_bytes(len_bytes) as usize;
372
373    if len > 10_000_000 {
374        return Err(DistributedError::MessageTooLarge(len));
375    }
376
377    let mut buffer = vec![0u8; len];
378    stream.read_exact(&mut buffer).await?;
379
380    // SAFETY: The buffer contains serialized rkyv data from a trusted source (our own code)
381    let decoded: T = unsafe { rkyv::from_bytes_unchecked::<T, rkyv::rancor::Error>(&buffer)? };
382    Ok(decoded)
383}
384
385/// Errors that can occur in distributed operations
386#[derive(Debug)]
387pub enum DistributedError {
388    /// I/O error occurred
389    Io(std::io::Error),
390    /// Serialization error
391    Encode(rkyv::rancor::Error),
392    /// Deserialization error
393    Decode(rkyv::rancor::Error),
394    /// Error from remote supervisor
395    RemoteError(String),
396    /// Received unexpected response type
397    UnexpectedResponse,
398    /// Message size exceeds maximum allowed
399    MessageTooLarge(usize),
400}
401
402impl fmt::Display for DistributedError {
403    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404        match self {
405            DistributedError::Io(e) => write!(f, "IO error: {e}"),
406            DistributedError::Encode(e) => write!(f, "Encode error: {e}"),
407            DistributedError::Decode(e) => write!(f, "Decode error: {e}"),
408            DistributedError::RemoteError(e) => write!(f, "Remote error: {e}"),
409            DistributedError::UnexpectedResponse => write!(f, "Unexpected response from remote"),
410            DistributedError::MessageTooLarge(size) => {
411                write!(f, "Message too large: {size} bytes")
412            }
413        }
414    }
415}
416
417impl std::error::Error for DistributedError {}
418
419impl From<std::io::Error> for DistributedError {
420    fn from(e: std::io::Error) -> Self {
421        DistributedError::Io(e)
422    }
423}
424
425impl From<rkyv::rancor::Error> for DistributedError {
426    fn from(e: rkyv::rancor::Error) -> Self {
427        DistributedError::Encode(e)
428    }
429}