1#![allow(missing_docs)]
7
8use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
9use serde::{Deserialize, Serialize};
10use std::fmt;
11use std::sync::Arc;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::{TcpListener, TcpStream};
14
15#[cfg(unix)]
16use tokio::net::{UnixListener, UnixStream};
17
18use crate::{ChildInfo as SupervisorChildInfo, ChildType, RestartPolicy, SupervisorHandle, Worker};
19
20#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
22#[rkyv(derive(Debug))]
23pub enum SupervisorAddress {
24 Tcp(String),
26 Unix(String),
28}
29
30#[allow(missing_docs)]
32#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
33#[rkyv(derive(Debug))]
34#[rkyv(attr(allow(missing_docs)))]
35pub enum RemoteCommand {
36 Shutdown,
38 WhichChildren,
40 TerminateChild {
42 id: String,
44 },
45 Status,
47}
48
49#[allow(missing_docs)]
51#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
52#[rkyv(derive(Debug))]
53#[rkyv(attr(allow(missing_docs)))]
54pub enum RemoteResponse {
55 Ok,
57 Children(Vec<ChildInfo>),
59 Status(SupervisorStatus),
61 Error(String),
63}
64
65#[allow(missing_docs)]
67#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
68#[rkyv(derive(Debug))]
69#[rkyv(attr(allow(missing_docs)))]
70pub struct ChildInfo {
71 pub id: String,
73 pub child_type: ChildType,
75 pub restart_policy: Option<RestartPolicy>,
77}
78
79impl From<SupervisorChildInfo> for ChildInfo {
80 fn from(info: SupervisorChildInfo) -> Self {
81 Self {
82 id: info.id,
83 child_type: info.child_type,
84 restart_policy: info.restart_policy,
85 }
86 }
87}
88
89#[allow(missing_docs)]
91#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
92#[rkyv(derive(Debug))]
93#[rkyv(attr(allow(missing_docs)))]
94pub struct SupervisorStatus {
95 pub name: String,
97 pub children_count: usize,
99 pub restart_strategy: String,
101 pub uptime_secs: u64,
103}
104
105#[derive(Clone)]
107pub struct RemoteSupervisorHandle {
108 address: SupervisorAddress,
109}
110
111impl RemoteSupervisorHandle {
112 #[must_use]
114 pub fn new(address: SupervisorAddress) -> Self {
115 Self { address }
116 }
117
118 #[allow(clippy::unused_async)]
124 pub async fn connect_tcp(addr: impl Into<String>) -> Result<Self, DistributedError> {
125 let address = SupervisorAddress::Tcp(addr.into());
126 Ok(Self { address })
127 }
128
129 #[allow(clippy::unused_async)]
135 pub async fn connect_unix(path: impl Into<String>) -> Result<Self, DistributedError> {
136 let address = SupervisorAddress::Unix(path.into());
137 Ok(Self { address })
138 }
139
140 pub async fn send_command(
146 &self,
147 cmd: RemoteCommand,
148 ) -> Result<RemoteResponse, DistributedError> {
149 match &self.address {
150 SupervisorAddress::Tcp(addr) => {
151 let mut stream = TcpStream::connect(addr).await?;
152 send_message(&mut stream, &cmd).await?;
153 receive_message(&mut stream).await
154 }
155 #[cfg(unix)]
156 SupervisorAddress::Unix(path) => {
157 let mut stream = UnixStream::connect(path).await?;
158 send_message(&mut stream, &cmd).await?;
159 receive_message(&mut stream).await
160 }
161 #[cfg(not(unix))]
162 SupervisorAddress::Unix(_) => Err(DistributedError::Io(std::io::Error::new(
163 std::io::ErrorKind::Unsupported,
164 "Unix sockets are not supported on this platform",
165 ))),
166 }
167 }
168
169 pub async fn shutdown(&self) -> Result<(), DistributedError> {
175 self.send_command(RemoteCommand::Shutdown).await?;
176 Ok(())
177 }
178
179 pub async fn which_children(&self) -> Result<Vec<ChildInfo>, DistributedError> {
185 match self.send_command(RemoteCommand::WhichChildren).await? {
186 RemoteResponse::Children(children) => Ok(children),
187 RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
188 _ => Err(DistributedError::UnexpectedResponse),
189 }
190 }
191
192 pub async fn terminate_child(&self, id: &str) -> Result<(), DistributedError> {
198 match self
199 .send_command(RemoteCommand::TerminateChild { id: id.to_owned() })
200 .await?
201 {
202 RemoteResponse::Ok => Ok(()),
203 RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
204 _ => Err(DistributedError::UnexpectedResponse),
205 }
206 }
207
208 pub async fn status(&self) -> Result<SupervisorStatus, DistributedError> {
214 match self.send_command(RemoteCommand::Status).await? {
215 RemoteResponse::Status(status) => Ok(status),
216 RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
217 _ => Err(DistributedError::UnexpectedResponse),
218 }
219 }
220}
221
222pub struct SupervisorServer<W: Worker> {
224 handle: Arc<SupervisorHandle<W>>,
225}
226
227impl<W: Worker> SupervisorServer<W> {
228 #[must_use]
230 pub fn new(handle: SupervisorHandle<W>) -> Self {
231 Self {
232 handle: Arc::new(handle),
233 }
234 }
235
236 #[cfg(unix)]
242 pub async fn listen_unix(
243 self,
244 path: impl AsRef<std::path::Path>,
245 ) -> Result<(), DistributedError> {
246 let socket_path = path.as_ref();
247 let _remove_result = std::fs::remove_file(socket_path); let listener = UnixListener::bind(socket_path)?;
250 tracing::info!(path = %socket_path.display(), "server listening on unix socket");
251
252 loop {
253 let (mut stream, _) = listener.accept().await?;
254 let handle = Arc::clone(&self.handle);
255
256 tokio::spawn(async move {
257 if let Err(e) = Self::handle_connection(&mut stream, handle).await {
258 tracing::error!(error = %e, "connection error");
259 }
260 });
261 }
262 }
263
264 pub async fn listen_tcp(self, addr: impl AsRef<str>) -> Result<(), DistributedError> {
270 let listener = TcpListener::bind(addr.as_ref()).await?;
271 tracing::info!(address = addr.as_ref(), "server listening on tcp");
272
273 loop {
274 let (mut stream, peer) = listener.accept().await?;
275 tracing::debug!(peer = ?peer, "new connection");
276 let handle = Arc::clone(&self.handle);
277
278 tokio::spawn(async move {
279 if let Err(e) = Self::handle_connection(&mut stream, handle).await {
280 tracing::error!(error = %e, "Connection error");
281 }
282 });
283 }
284 }
285
286 async fn handle_connection<S>(
287 stream: &mut S,
288 handle: Arc<SupervisorHandle<W>>,
289 ) -> Result<(), DistributedError>
290 where
291 S: AsyncReadExt + AsyncWriteExt + Unpin,
292 {
293 let command: RemoteCommand = receive_message(stream).await?;
294 let response = Self::process_command(command, &handle).await;
295 send_message(stream, &response).await?;
296 Ok(())
297 }
298
299 async fn process_command(
300 command: RemoteCommand,
301 handle: &SupervisorHandle<W>,
302 ) -> RemoteResponse {
303 match command {
304 RemoteCommand::Shutdown => match handle.shutdown().await {
305 Ok(()) => RemoteResponse::Ok,
306 Err(e) => RemoteResponse::Error(e.to_string()),
307 },
308 RemoteCommand::WhichChildren => match handle.which_children().await {
309 Ok(children) => {
310 let child_list: Vec<ChildInfo> = children.into_iter().map(Into::into).collect();
311 RemoteResponse::Children(child_list)
312 }
313 Err(e) => RemoteResponse::Error(e.to_string()),
314 },
315 RemoteCommand::TerminateChild { id } => match handle.terminate_child(&id).await {
316 Ok(()) => RemoteResponse::Ok,
317 Err(e) => RemoteResponse::Error(e.to_string()),
318 },
319 RemoteCommand::Status => {
320 let restart_strategy = handle
321 .restart_strategy()
322 .await
323 .map_or_else(|_| "Unknown".to_owned(), |s| format!("{s:?}"));
324 let uptime_secs = handle.uptime().await.unwrap_or(0);
325
326 RemoteResponse::Status(SupervisorStatus {
327 name: handle.name().to_owned(),
328 children_count: handle.which_children().await.map(|c| c.len()).unwrap_or(0),
329 restart_strategy,
330 uptime_secs,
331 })
332 }
333 }
334 }
335}
336
337async fn send_message<S, T>(stream: &mut S, msg: &T) -> Result<(), DistributedError>
339where
340 S: AsyncWriteExt + Unpin,
341 T: Serialize,
342 for<'a> T: RkyvSerialize<
343 rkyv::api::high::HighSerializer<
344 rkyv::util::AlignedVec,
345 rkyv::ser::allocator::ArenaHandle<'a>,
346 rkyv::rancor::Error,
347 >,
348 >,
349{
350 let encoded = rkyv::to_bytes::<rkyv::rancor::Error>(msg)?;
351 let len = u32::try_from(encoded.len())
352 .map_err(|_| DistributedError::MessageTooLarge(encoded.len()))?;
353
354 stream.write_all(&len.to_be_bytes()).await?;
355 stream.write_all(&encoded).await?;
356 stream.flush().await?;
357
358 Ok(())
359}
360
361#[allow(clippy::as_conversions)]
363async fn receive_message<S, T>(stream: &mut S) -> Result<T, DistributedError>
364where
365 S: AsyncReadExt + Unpin,
366 T: Archive,
367 for<'a> T::Archived: RkyvDeserialize<T, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>,
368{
369 let mut len_bytes = [0u8; 4];
370 stream.read_exact(&mut len_bytes).await?;
371 let len = u32::from_be_bytes(len_bytes) as usize;
372
373 if len > 10_000_000 {
374 return Err(DistributedError::MessageTooLarge(len));
375 }
376
377 let mut buffer = vec![0u8; len];
378 stream.read_exact(&mut buffer).await?;
379
380 let decoded: T = unsafe { rkyv::from_bytes_unchecked::<T, rkyv::rancor::Error>(&buffer)? };
382 Ok(decoded)
383}
384
385#[derive(Debug)]
387pub enum DistributedError {
388 Io(std::io::Error),
390 Encode(rkyv::rancor::Error),
392 Decode(rkyv::rancor::Error),
394 RemoteError(String),
396 UnexpectedResponse,
398 MessageTooLarge(usize),
400}
401
402impl fmt::Display for DistributedError {
403 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404 match self {
405 DistributedError::Io(e) => write!(f, "IO error: {e}"),
406 DistributedError::Encode(e) => write!(f, "Encode error: {e}"),
407 DistributedError::Decode(e) => write!(f, "Decode error: {e}"),
408 DistributedError::RemoteError(e) => write!(f, "Remote error: {e}"),
409 DistributedError::UnexpectedResponse => write!(f, "Unexpected response from remote"),
410 DistributedError::MessageTooLarge(size) => {
411 write!(f, "Message too large: {size} bytes")
412 }
413 }
414 }
415}
416
417impl std::error::Error for DistributedError {}
418
419impl From<std::io::Error> for DistributedError {
420 fn from(e: std::io::Error) -> Self {
421 DistributedError::Io(e)
422 }
423}
424
425impl From<rkyv::rancor::Error> for DistributedError {
426 fn from(e: rkyv::rancor::Error) -> Self {
427 DistributedError::Encode(e)
428 }
429}