Skip to main content

ash_flare/
distributed.rs

1//! Distributed supervision via TCP/Unix sockets
2//!
3//! Allows supervisors to run in separate processes (local) or on remote machines (network).
4//! Commands are serialized with rkyv for minimal overhead.
5
6#![allow(missing_docs)]
7
8use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
9use serde::{Deserialize, Serialize};
10use std::fmt;
11use std::sync::Arc;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::{TcpListener, TcpStream};
14
15#[cfg(unix)]
16use tokio::net::{UnixListener, UnixStream};
17
18use crate::{ChildInfo as SupervisorChildInfo, ChildType, RestartPolicy, SupervisorHandle, Worker};
19
20/// Remote supervisor address
21#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
22#[rkyv(derive(Debug))]
23pub enum SupervisorAddress {
24    /// TCP socket address (host:port)
25    Tcp(String),
26    /// Unix domain socket path
27    Unix(String),
28}
29
30/// Commands that can be sent to a remote supervisor
31#[allow(missing_docs)]
32#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
33#[rkyv(derive(Debug))]
34#[rkyv(attr(allow(missing_docs)))]
35pub enum RemoteCommand {
36    /// Request supervisor to shut down gracefully
37    Shutdown,
38    /// Get list of all children
39    WhichChildren,
40    /// Terminate a specific child
41    TerminateChild {
42        /// ID of the child to terminate
43        id: String,
44    },
45    /// Get supervisor status
46    Status,
47}
48
49/// Responses from remote supervisor
50#[allow(missing_docs)]
51#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
52#[rkyv(derive(Debug))]
53#[rkyv(attr(allow(missing_docs)))]
54pub enum RemoteResponse {
55    /// Command executed successfully
56    Ok,
57    /// List of child IDs
58    Children(Vec<ChildInfo>),
59    /// Supervisor status information
60    Status(SupervisorStatus),
61    /// Error occurred
62    Error(String),
63}
64
65/// Information about a child process (simplified for serialization)
66#[allow(missing_docs)]
67#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
68#[rkyv(derive(Debug))]
69#[rkyv(attr(allow(missing_docs)))]
70pub struct ChildInfo {
71    /// Unique identifier for the child
72    pub id: String,
73    /// Type of child (Worker or Supervisor)
74    pub child_type: ChildType,
75    /// Restart policy for the child (None for supervisors)
76    pub restart_policy: Option<RestartPolicy>,
77}
78
79impl From<SupervisorChildInfo> for ChildInfo {
80    fn from(info: SupervisorChildInfo) -> Self {
81        Self {
82            id: info.id,
83            child_type: info.child_type,
84            restart_policy: info.restart_policy,
85        }
86    }
87}
88
89/// Status information about a supervisor
90#[allow(missing_docs)]
91#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
92#[rkyv(derive(Debug))]
93#[rkyv(attr(allow(missing_docs)))]
94pub struct SupervisorStatus {
95    /// Name of the supervisor
96    pub name: String,
97    /// Number of children currently running
98    pub children_count: usize,
99    /// Restart strategy being used
100    pub restart_strategy: String,
101    /// Uptime in seconds
102    pub uptime_secs: u64,
103}
104
105/// Handle to communicate with a remote supervisor
106#[derive(Clone)]
107pub struct RemoteSupervisorHandle {
108    address: SupervisorAddress,
109}
110
111impl RemoteSupervisorHandle {
112    /// Create a new handle to a remote supervisor
113    pub fn new(address: SupervisorAddress) -> Self {
114        Self { address }
115    }
116
117    /// Connect to a TCP supervisor
118    pub async fn connect_tcp(addr: impl Into<String>) -> Result<Self, DistributedError> {
119        let address = SupervisorAddress::Tcp(addr.into());
120        Ok(Self { address })
121    }
122
123    /// Connect to a Unix socket supervisor
124    pub async fn connect_unix(path: impl Into<String>) -> Result<Self, DistributedError> {
125        let address = SupervisorAddress::Unix(path.into());
126        Ok(Self { address })
127    }
128
129    /// Send a command to the remote supervisor
130    pub async fn send_command(
131        &self,
132        cmd: RemoteCommand,
133    ) -> Result<RemoteResponse, DistributedError> {
134        match &self.address {
135            SupervisorAddress::Tcp(addr) => {
136                let mut stream = TcpStream::connect(addr).await?;
137                send_message(&mut stream, &cmd).await?;
138                receive_message(&mut stream).await
139            }
140            #[cfg(unix)]
141            SupervisorAddress::Unix(path) => {
142                let mut stream = UnixStream::connect(path).await?;
143                send_message(&mut stream, &cmd).await?;
144                receive_message(&mut stream).await
145            }
146            #[cfg(not(unix))]
147            SupervisorAddress::Unix(_) => Err(DistributedError::Io(std::io::Error::new(
148                std::io::ErrorKind::Unsupported,
149                "Unix sockets are not supported on this platform",
150            ))),
151        }
152    }
153
154    /// Shutdown the remote supervisor
155    pub async fn shutdown(&self) -> Result<(), DistributedError> {
156        self.send_command(RemoteCommand::Shutdown).await?;
157        Ok(())
158    }
159
160    /// Get list of children from remote supervisor
161    pub async fn which_children(&self) -> Result<Vec<ChildInfo>, DistributedError> {
162        match self.send_command(RemoteCommand::WhichChildren).await? {
163            RemoteResponse::Children(children) => Ok(children),
164            RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
165            _ => Err(DistributedError::UnexpectedResponse),
166        }
167    }
168
169    /// Terminate a child on the remote supervisor
170    pub async fn terminate_child(&self, id: &str) -> Result<(), DistributedError> {
171        match self
172            .send_command(RemoteCommand::TerminateChild { id: id.to_owned() })
173            .await?
174        {
175            RemoteResponse::Ok => Ok(()),
176            RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
177            _ => Err(DistributedError::UnexpectedResponse),
178        }
179    }
180
181    /// Get status from remote supervisor
182    pub async fn status(&self) -> Result<SupervisorStatus, DistributedError> {
183        match self.send_command(RemoteCommand::Status).await? {
184            RemoteResponse::Status(status) => Ok(status),
185            RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
186            _ => Err(DistributedError::UnexpectedResponse),
187        }
188    }
189}
190
191/// Server that wraps a SupervisorHandle and accepts remote commands
192pub struct SupervisorServer<W: Worker> {
193    handle: Arc<SupervisorHandle<W>>,
194}
195
196impl<W: Worker> SupervisorServer<W> {
197    /// Create a new supervisor server wrapping a SupervisorHandle
198    pub fn new(handle: SupervisorHandle<W>) -> Self {
199        Self {
200            handle: Arc::new(handle),
201        }
202    }
203
204    /// Start listening on a Unix socket (Unix only)
205    #[cfg(unix)]
206    pub async fn listen_unix(
207        self,
208        path: impl AsRef<std::path::Path>,
209    ) -> Result<(), DistributedError> {
210        let path = path.as_ref();
211        let _ = std::fs::remove_file(path); // Clean up old socket
212
213        let listener = UnixListener::bind(path)?;
214        tracing::info!(path = %path.display(), "server listening on unix socket");
215
216        loop {
217            let (mut stream, _) = listener.accept().await?;
218            let handle = Arc::clone(&self.handle);
219
220            tokio::spawn(async move {
221                if let Err(e) = Self::handle_connection(&mut stream, handle).await {
222                    tracing::error!(error = %e, "connection error");
223                }
224            });
225        }
226    }
227
228    /// Start listening on a TCP socket
229    pub async fn listen_tcp(self, addr: impl AsRef<str>) -> Result<(), DistributedError> {
230        let listener = TcpListener::bind(addr.as_ref()).await?;
231        tracing::info!(address = addr.as_ref(), "server listening on tcp");
232
233        loop {
234            let (mut stream, peer) = listener.accept().await?;
235            tracing::debug!(peer = ?peer, "new connection");
236            let handle = Arc::clone(&self.handle);
237
238            tokio::spawn(async move {
239                if let Err(e) = Self::handle_connection(&mut stream, handle).await {
240                    tracing::error!(error = %e, "Connection error");
241                }
242            });
243        }
244    }
245
246    async fn handle_connection<S>(
247        stream: &mut S,
248        handle: Arc<SupervisorHandle<W>>,
249    ) -> Result<(), DistributedError>
250    where
251        S: AsyncReadExt + AsyncWriteExt + Unpin,
252    {
253        let command: RemoteCommand = receive_message(stream).await?;
254        let response = Self::process_command(command, &handle).await;
255        send_message(stream, &response).await?;
256        Ok(())
257    }
258
259    async fn process_command(
260        command: RemoteCommand,
261        handle: &SupervisorHandle<W>,
262    ) -> RemoteResponse {
263        match command {
264            RemoteCommand::Shutdown => match handle.shutdown().await {
265                Ok(()) => RemoteResponse::Ok,
266                Err(e) => RemoteResponse::Error(e.to_string()),
267            },
268            RemoteCommand::WhichChildren => match handle.which_children().await {
269                Ok(children) => {
270                    let children: Vec<ChildInfo> = children.into_iter().map(Into::into).collect();
271                    RemoteResponse::Children(children)
272                }
273                Err(e) => RemoteResponse::Error(e.to_string()),
274            },
275            RemoteCommand::TerminateChild { id } => match handle.terminate_child(&id).await {
276                Ok(()) => RemoteResponse::Ok,
277                Err(e) => RemoteResponse::Error(e.to_string()),
278            },
279            RemoteCommand::Status => {
280                let restart_strategy = handle
281                    .restart_strategy()
282                    .await
283                    .map(|s| format!("{:?}", s))
284                    .unwrap_or_else(|_| "Unknown".to_owned());
285                let uptime_secs = handle.uptime().await.unwrap_or(0);
286
287                RemoteResponse::Status(SupervisorStatus {
288                    name: handle.name().to_owned(),
289                    children_count: handle.which_children().await.map(|c| c.len()).unwrap_or(0),
290                    restart_strategy,
291                    uptime_secs,
292                })
293            }
294        }
295    }
296}
297
298/// Send a message over a stream (length-prefixed rkyv)
299async fn send_message<S, T>(stream: &mut S, msg: &T) -> Result<(), DistributedError>
300where
301    S: AsyncWriteExt + Unpin,
302    T: Serialize,
303    for<'a> T: RkyvSerialize<
304        rkyv::api::high::HighSerializer<
305            rkyv::util::AlignedVec,
306            rkyv::ser::allocator::ArenaHandle<'a>,
307            rkyv::rancor::Error,
308        >,
309    >,
310{
311    let encoded = rkyv::to_bytes::<rkyv::rancor::Error>(msg)?;
312    let len = encoded.len() as u32;
313
314    stream.write_all(&len.to_be_bytes()).await?;
315    stream.write_all(&encoded).await?;
316    stream.flush().await?;
317
318    Ok(())
319}
320
321/// Receive a message from a stream (length-prefixed rkyv)
322async fn receive_message<S, T>(stream: &mut S) -> Result<T, DistributedError>
323where
324    S: AsyncReadExt + Unpin,
325    T: Archive,
326    for<'a> T::Archived: RkyvDeserialize<T, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>,
327{
328    let mut len_bytes = [0u8; 4];
329    stream.read_exact(&mut len_bytes).await?;
330    let len = u32::from_be_bytes(len_bytes) as usize;
331
332    if len > 10_000_000 {
333        return Err(DistributedError::MessageTooLarge(len));
334    }
335
336    let mut buffer = vec![0u8; len];
337    stream.read_exact(&mut buffer).await?;
338
339    // SAFETY: The buffer contains serialized rkyv data from a trusted source (our own code)
340    let decoded: T = unsafe { rkyv::from_bytes_unchecked::<T, rkyv::rancor::Error>(&buffer)? };
341    Ok(decoded)
342}
343
344/// Errors that can occur in distributed operations
345#[derive(Debug)]
346pub enum DistributedError {
347    /// I/O error occurred
348    Io(std::io::Error),
349    /// Serialization error
350    Encode(rkyv::rancor::Error),
351    /// Deserialization error
352    Decode(rkyv::rancor::Error),
353    /// Error from remote supervisor
354    RemoteError(String),
355    /// Received unexpected response type
356    UnexpectedResponse,
357    /// Message size exceeds maximum allowed
358    MessageTooLarge(usize),
359}
360
361impl fmt::Display for DistributedError {
362    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363        match self {
364            DistributedError::Io(e) => write!(f, "IO error: {}", e),
365            DistributedError::Encode(e) => write!(f, "Encode error: {}", e),
366            DistributedError::Decode(e) => write!(f, "Decode error: {}", e),
367            DistributedError::RemoteError(e) => write!(f, "Remote error: {}", e),
368            DistributedError::UnexpectedResponse => write!(f, "Unexpected response from remote"),
369            DistributedError::MessageTooLarge(size) => {
370                write!(f, "Message too large: {} bytes", size)
371            }
372        }
373    }
374}
375
376impl std::error::Error for DistributedError {}
377
378impl From<std::io::Error> for DistributedError {
379    fn from(e: std::io::Error) -> Self {
380        DistributedError::Io(e)
381    }
382}
383
384impl From<rkyv::rancor::Error> for DistributedError {
385    fn from(e: rkyv::rancor::Error) -> Self {
386        DistributedError::Encode(e)
387    }
388}