Skip to main content

ash_flare/
distributed.rs

1//! Distributed supervision via TCP/Unix sockets
2//!
3//! Allows supervisors to run in separate processes (local) or on remote machines (network).
4//! Commands are serialized with rkyv for minimal overhead.
5
6#![allow(missing_docs)]
7
8use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
9use serde::{Deserialize, Serialize};
10use std::fmt;
11use std::sync::Arc;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::{TcpListener, TcpStream};
14
15#[cfg(unix)]
16use tokio::net::{UnixListener, UnixStream};
17
18use crate::{ChildInfo as SupervisorChildInfo, ChildType, RestartPolicy, SupervisorHandle, Worker};
19
20/// Remote supervisor address
21#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
22#[rkyv(derive(Debug))]
23pub enum SupervisorAddress {
24    /// TCP socket address (host:port)
25    Tcp(String),
26    /// Unix domain socket path
27    Unix(String),
28}
29
30/// Commands that can be sent to a remote supervisor
31#[allow(missing_docs)]
32#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
33#[rkyv(derive(Debug))]
34#[rkyv(attr(allow(missing_docs)))]
35pub enum RemoteCommand {
36    /// Request supervisor to shut down gracefully
37    Shutdown,
38    /// Get list of all children
39    WhichChildren,
40    /// Terminate a specific child
41    TerminateChild {
42        /// ID of the child to terminate
43        id: String,
44    },
45    /// Get supervisor status
46    Status,
47}
48
49/// Responses from remote supervisor
50#[allow(missing_docs)]
51#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
52#[rkyv(derive(Debug))]
53#[rkyv(attr(allow(missing_docs)))]
54pub enum RemoteResponse {
55    /// Command executed successfully
56    Ok,
57    /// List of child IDs
58    Children(Vec<ChildInfo>),
59    /// Supervisor status information
60    Status(SupervisorStatus),
61    /// Error occurred
62    Error(String),
63}
64
65/// Information about a child process (simplified for serialization)
66#[allow(missing_docs)]
67#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
68#[rkyv(derive(Debug))]
69#[rkyv(attr(allow(missing_docs)))]
70pub struct ChildInfo {
71    /// Unique identifier for the child
72    pub id: String,
73    /// Type of child (Worker or Supervisor)
74    pub child_type: ChildType,
75    /// Restart policy for the child (None for supervisors)
76    pub restart_policy: Option<RestartPolicy>,
77}
78
79impl From<SupervisorChildInfo> for ChildInfo {
80    fn from(info: SupervisorChildInfo) -> Self {
81        Self {
82            id: info.id,
83            child_type: info.child_type,
84            restart_policy: info.restart_policy,
85        }
86    }
87}
88
89/// Status information about a supervisor
90#[allow(missing_docs)]
91#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
92#[rkyv(derive(Debug))]
93#[rkyv(attr(allow(missing_docs)))]
94pub struct SupervisorStatus {
95    /// Name of the supervisor
96    pub name: String,
97    /// Number of children currently running
98    pub children_count: usize,
99    /// Restart strategy being used
100    pub restart_strategy: String,
101    /// Uptime in seconds
102    pub uptime_secs: u64,
103}
104
105/// Handle to communicate with a remote supervisor
106pub struct RemoteSupervisorHandle {
107    address: SupervisorAddress,
108}
109
110impl RemoteSupervisorHandle {
111    /// Create a new handle to a remote supervisor
112    pub fn new(address: SupervisorAddress) -> Self {
113        Self { address }
114    }
115
116    /// Connect to a TCP supervisor
117    pub async fn connect_tcp(addr: impl Into<String>) -> Result<Self, DistributedError> {
118        let address = SupervisorAddress::Tcp(addr.into());
119        Ok(Self { address })
120    }
121
122    /// Connect to a Unix socket supervisor
123    pub async fn connect_unix(path: impl Into<String>) -> Result<Self, DistributedError> {
124        let address = SupervisorAddress::Unix(path.into());
125        Ok(Self { address })
126    }
127
128    /// Send a command to the remote supervisor
129    pub async fn send_command(
130        &self,
131        cmd: RemoteCommand,
132    ) -> Result<RemoteResponse, DistributedError> {
133        match &self.address {
134            SupervisorAddress::Tcp(addr) => {
135                let mut stream = TcpStream::connect(addr).await?;
136                send_message(&mut stream, &cmd).await?;
137                receive_message(&mut stream).await
138            }
139            #[cfg(unix)]
140            SupervisorAddress::Unix(path) => {
141                let mut stream = UnixStream::connect(path).await?;
142                send_message(&mut stream, &cmd).await?;
143                receive_message(&mut stream).await
144            }
145            #[cfg(not(unix))]
146            SupervisorAddress::Unix(_) => Err(DistributedError::Io(std::io::Error::new(
147                std::io::ErrorKind::Unsupported,
148                "Unix sockets are not supported on this platform",
149            ))),
150        }
151    }
152
153    /// Shutdown the remote supervisor
154    pub async fn shutdown(&self) -> Result<(), DistributedError> {
155        self.send_command(RemoteCommand::Shutdown).await?;
156        Ok(())
157    }
158
159    /// Get list of children from remote supervisor
160    pub async fn which_children(&self) -> Result<Vec<ChildInfo>, DistributedError> {
161        match self.send_command(RemoteCommand::WhichChildren).await? {
162            RemoteResponse::Children(children) => Ok(children),
163            RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
164            _ => Err(DistributedError::UnexpectedResponse),
165        }
166    }
167
168    /// Terminate a child on the remote supervisor
169    pub async fn terminate_child(&self, id: &str) -> Result<(), DistributedError> {
170        match self
171            .send_command(RemoteCommand::TerminateChild { id: id.to_owned() })
172            .await?
173        {
174            RemoteResponse::Ok => Ok(()),
175            RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
176            _ => Err(DistributedError::UnexpectedResponse),
177        }
178    }
179
180    /// Get status from remote supervisor
181    pub async fn status(&self) -> Result<SupervisorStatus, DistributedError> {
182        match self.send_command(RemoteCommand::Status).await? {
183            RemoteResponse::Status(status) => Ok(status),
184            RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
185            _ => Err(DistributedError::UnexpectedResponse),
186        }
187    }
188}
189
190/// Server that wraps a SupervisorHandle and accepts remote commands
191pub struct SupervisorServer<W: Worker> {
192    handle: Arc<SupervisorHandle<W>>,
193}
194
195impl<W: Worker> SupervisorServer<W> {
196    /// Create a new supervisor server wrapping a SupervisorHandle
197    pub fn new(handle: SupervisorHandle<W>) -> Self {
198        Self {
199            handle: Arc::new(handle),
200        }
201    }
202
203    /// Start listening on a Unix socket (Unix only)
204    #[cfg(unix)]
205    pub async fn listen_unix(
206        self,
207        path: impl AsRef<std::path::Path>,
208    ) -> Result<(), DistributedError> {
209        let path = path.as_ref();
210        let _ = std::fs::remove_file(path); // Clean up old socket
211
212        let listener = UnixListener::bind(path)?;
213        slog::info!(slog_scope::logger(), "server listening on unix socket";
214            "path" => %path.display()
215        );
216
217        loop {
218            let (mut stream, _) = listener.accept().await?;
219            let handle = Arc::clone(&self.handle);
220
221            tokio::spawn(async move {
222                if let Err(e) = Self::handle_connection(&mut stream, handle).await {
223                    slog::error!(slog_scope::logger(), "connection error";
224                        "error" => %e
225                    );
226                }
227            });
228        }
229    }
230
231    /// Start listening on a TCP socket
232    pub async fn listen_tcp(self, addr: impl AsRef<str>) -> Result<(), DistributedError> {
233        let listener = TcpListener::bind(addr.as_ref()).await?;
234        slog::info!(slog_scope::logger(), "server listening on tcp";
235            "address" => addr.as_ref()
236        );
237
238        loop {
239            let (mut stream, peer) = listener.accept().await?;
240            slog::debug!(slog_scope::logger(), "new connection";
241                "peer" => ?peer
242            );
243            let handle = Arc::clone(&self.handle);
244
245            tokio::spawn(async move {
246                if let Err(e) = Self::handle_connection(&mut stream, handle).await {
247                    slog::error!(slog_scope::logger(), "Connection error";
248                        "error" => %e
249                    );
250                }
251            });
252        }
253    }
254
255    async fn handle_connection<S>(
256        stream: &mut S,
257        handle: Arc<SupervisorHandle<W>>,
258    ) -> Result<(), DistributedError>
259    where
260        S: AsyncReadExt + AsyncWriteExt + Unpin,
261    {
262        let command: RemoteCommand = receive_message(stream).await?;
263        let response = Self::process_command(command, &handle).await;
264        send_message(stream, &response).await?;
265        Ok(())
266    }
267
268    async fn process_command(
269        command: RemoteCommand,
270        handle: &SupervisorHandle<W>,
271    ) -> RemoteResponse {
272        match command {
273            RemoteCommand::Shutdown => match handle.shutdown().await {
274                Ok(()) => RemoteResponse::Ok,
275                Err(e) => RemoteResponse::Error(e.to_string()),
276            },
277            RemoteCommand::WhichChildren => match handle.which_children().await {
278                Ok(children) => {
279                    let children: Vec<ChildInfo> = children.into_iter().map(Into::into).collect();
280                    RemoteResponse::Children(children)
281                }
282                Err(e) => RemoteResponse::Error(e.to_string()),
283            },
284            RemoteCommand::TerminateChild { id } => match handle.terminate_child(&id).await {
285                Ok(()) => RemoteResponse::Ok,
286                Err(e) => RemoteResponse::Error(e.to_string()),
287            },
288            RemoteCommand::Status => {
289                let restart_strategy = handle
290                    .restart_strategy()
291                    .await
292                    .map(|s| format!("{:?}", s))
293                    .unwrap_or_else(|_| "Unknown".to_owned());
294                let uptime_secs = handle.uptime().await.unwrap_or(0);
295
296                RemoteResponse::Status(SupervisorStatus {
297                    name: handle.name().to_owned(),
298                    children_count: handle.which_children().await.map(|c| c.len()).unwrap_or(0),
299                    restart_strategy,
300                    uptime_secs,
301                })
302            }
303        }
304    }
305}
306
307/// Send a message over a stream (length-prefixed rkyv)
308async fn send_message<S, T>(stream: &mut S, msg: &T) -> Result<(), DistributedError>
309where
310    S: AsyncWriteExt + Unpin,
311    T: Serialize,
312    for<'a> T: RkyvSerialize<
313        rkyv::api::high::HighSerializer<
314            rkyv::util::AlignedVec,
315            rkyv::ser::allocator::ArenaHandle<'a>,
316            rkyv::rancor::Error,
317        >,
318    >,
319{
320    let encoded = rkyv::to_bytes::<rkyv::rancor::Error>(msg)?;
321    let len = encoded.len() as u32;
322
323    stream.write_all(&len.to_be_bytes()).await?;
324    stream.write_all(&encoded).await?;
325    stream.flush().await?;
326
327    Ok(())
328}
329
330/// Receive a message from a stream (length-prefixed rkyv)
331async fn receive_message<S, T>(stream: &mut S) -> Result<T, DistributedError>
332where
333    S: AsyncReadExt + Unpin,
334    T: Archive,
335    for<'a> T::Archived: RkyvDeserialize<T, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>,
336{
337    let mut len_bytes = [0u8; 4];
338    stream.read_exact(&mut len_bytes).await?;
339    let len = u32::from_be_bytes(len_bytes) as usize;
340
341    if len > 10_000_000 {
342        return Err(DistributedError::MessageTooLarge(len));
343    }
344
345    let mut buffer = vec![0u8; len];
346    stream.read_exact(&mut buffer).await?;
347
348    // SAFETY: The buffer contains serialized rkyv data from a trusted source (our own code)
349    let decoded: T = unsafe { rkyv::from_bytes_unchecked::<T, rkyv::rancor::Error>(&buffer)? };
350    Ok(decoded)
351}
352
353/// Errors that can occur in distributed operations
354#[derive(Debug)]
355pub enum DistributedError {
356    /// I/O error occurred
357    Io(std::io::Error),
358    /// Serialization error
359    Encode(rkyv::rancor::Error),
360    /// Deserialization error
361    Decode(rkyv::rancor::Error),
362    /// Error from remote supervisor
363    RemoteError(String),
364    /// Received unexpected response type
365    UnexpectedResponse,
366    /// Message size exceeds maximum allowed
367    MessageTooLarge(usize),
368}
369
370impl fmt::Display for DistributedError {
371    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372        match self {
373            DistributedError::Io(e) => write!(f, "IO error: {}", e),
374            DistributedError::Encode(e) => write!(f, "Encode error: {}", e),
375            DistributedError::Decode(e) => write!(f, "Decode error: {}", e),
376            DistributedError::RemoteError(e) => write!(f, "Remote error: {}", e),
377            DistributedError::UnexpectedResponse => write!(f, "Unexpected response from remote"),
378            DistributedError::MessageTooLarge(size) => {
379                write!(f, "Message too large: {} bytes", size)
380            }
381        }
382    }
383}
384
385impl std::error::Error for DistributedError {}
386
387impl From<std::io::Error> for DistributedError {
388    fn from(e: std::io::Error) -> Self {
389        DistributedError::Io(e)
390    }
391}
392
393impl From<rkyv::rancor::Error> for DistributedError {
394    fn from(e: rkyv::rancor::Error) -> Self {
395        DistributedError::Encode(e)
396    }
397}