1#![allow(missing_docs)]
7
8use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
9use serde::{Deserialize, Serialize};
10use std::fmt;
11use std::sync::Arc;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::{TcpListener, TcpStream};
14
15#[cfg(unix)]
16use tokio::net::{UnixListener, UnixStream};
17
18use crate::{ChildInfo as SupervisorChildInfo, ChildType, RestartPolicy, SupervisorHandle, Worker};
19
20#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
22#[rkyv(derive(Debug))]
23pub enum SupervisorAddress {
24 Tcp(String),
26 Unix(String),
28}
29
30#[allow(missing_docs)]
32#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
33#[rkyv(derive(Debug))]
34#[rkyv(attr(allow(missing_docs)))]
35pub enum RemoteCommand {
36 Shutdown,
38 WhichChildren,
40 TerminateChild {
42 id: String,
44 },
45 Status,
47}
48
49#[allow(missing_docs)]
51#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
52#[rkyv(derive(Debug))]
53#[rkyv(attr(allow(missing_docs)))]
54pub enum RemoteResponse {
55 Ok,
57 Children(Vec<ChildInfo>),
59 Status(SupervisorStatus),
61 Error(String),
63}
64
65#[allow(missing_docs)]
67#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
68#[rkyv(derive(Debug))]
69#[rkyv(attr(allow(missing_docs)))]
70pub struct ChildInfo {
71 pub id: String,
73 pub child_type: ChildType,
75 pub restart_policy: Option<RestartPolicy>,
77}
78
79impl From<SupervisorChildInfo> for ChildInfo {
80 fn from(info: SupervisorChildInfo) -> Self {
81 Self {
82 id: info.id,
83 child_type: info.child_type,
84 restart_policy: info.restart_policy,
85 }
86 }
87}
88
89#[allow(missing_docs)]
91#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
92#[rkyv(derive(Debug))]
93#[rkyv(attr(allow(missing_docs)))]
94pub struct SupervisorStatus {
95 pub name: String,
97 pub children_count: usize,
99 pub restart_strategy: String,
101 pub uptime_secs: u64,
103}
104
105#[derive(Clone)]
107pub struct RemoteSupervisorHandle {
108 address: SupervisorAddress,
109}
110
111impl RemoteSupervisorHandle {
112 pub fn new(address: SupervisorAddress) -> Self {
114 Self { address }
115 }
116
117 pub async fn connect_tcp(addr: impl Into<String>) -> Result<Self, DistributedError> {
119 let address = SupervisorAddress::Tcp(addr.into());
120 Ok(Self { address })
121 }
122
123 pub async fn connect_unix(path: impl Into<String>) -> Result<Self, DistributedError> {
125 let address = SupervisorAddress::Unix(path.into());
126 Ok(Self { address })
127 }
128
129 pub async fn send_command(
131 &self,
132 cmd: RemoteCommand,
133 ) -> Result<RemoteResponse, DistributedError> {
134 match &self.address {
135 SupervisorAddress::Tcp(addr) => {
136 let mut stream = TcpStream::connect(addr).await?;
137 send_message(&mut stream, &cmd).await?;
138 receive_message(&mut stream).await
139 }
140 #[cfg(unix)]
141 SupervisorAddress::Unix(path) => {
142 let mut stream = UnixStream::connect(path).await?;
143 send_message(&mut stream, &cmd).await?;
144 receive_message(&mut stream).await
145 }
146 #[cfg(not(unix))]
147 SupervisorAddress::Unix(_) => Err(DistributedError::Io(std::io::Error::new(
148 std::io::ErrorKind::Unsupported,
149 "Unix sockets are not supported on this platform",
150 ))),
151 }
152 }
153
154 pub async fn shutdown(&self) -> Result<(), DistributedError> {
156 self.send_command(RemoteCommand::Shutdown).await?;
157 Ok(())
158 }
159
160 pub async fn which_children(&self) -> Result<Vec<ChildInfo>, DistributedError> {
162 match self.send_command(RemoteCommand::WhichChildren).await? {
163 RemoteResponse::Children(children) => Ok(children),
164 RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
165 _ => Err(DistributedError::UnexpectedResponse),
166 }
167 }
168
169 pub async fn terminate_child(&self, id: &str) -> Result<(), DistributedError> {
171 match self
172 .send_command(RemoteCommand::TerminateChild { id: id.to_owned() })
173 .await?
174 {
175 RemoteResponse::Ok => Ok(()),
176 RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
177 _ => Err(DistributedError::UnexpectedResponse),
178 }
179 }
180
181 pub async fn status(&self) -> Result<SupervisorStatus, DistributedError> {
183 match self.send_command(RemoteCommand::Status).await? {
184 RemoteResponse::Status(status) => Ok(status),
185 RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
186 _ => Err(DistributedError::UnexpectedResponse),
187 }
188 }
189}
190
191pub struct SupervisorServer<W: Worker> {
193 handle: Arc<SupervisorHandle<W>>,
194}
195
196impl<W: Worker> SupervisorServer<W> {
197 pub fn new(handle: SupervisorHandle<W>) -> Self {
199 Self {
200 handle: Arc::new(handle),
201 }
202 }
203
204 #[cfg(unix)]
206 pub async fn listen_unix(
207 self,
208 path: impl AsRef<std::path::Path>,
209 ) -> Result<(), DistributedError> {
210 let path = path.as_ref();
211 let _ = std::fs::remove_file(path); let listener = UnixListener::bind(path)?;
214 tracing::info!(path = %path.display(), "server listening on unix socket");
215
216 loop {
217 let (mut stream, _) = listener.accept().await?;
218 let handle = Arc::clone(&self.handle);
219
220 tokio::spawn(async move {
221 if let Err(e) = Self::handle_connection(&mut stream, handle).await {
222 tracing::error!(error = %e, "connection error");
223 }
224 });
225 }
226 }
227
228 pub async fn listen_tcp(self, addr: impl AsRef<str>) -> Result<(), DistributedError> {
230 let listener = TcpListener::bind(addr.as_ref()).await?;
231 tracing::info!(address = addr.as_ref(), "server listening on tcp");
232
233 loop {
234 let (mut stream, peer) = listener.accept().await?;
235 tracing::debug!(peer = ?peer, "new connection");
236 let handle = Arc::clone(&self.handle);
237
238 tokio::spawn(async move {
239 if let Err(e) = Self::handle_connection(&mut stream, handle).await {
240 tracing::error!(error = %e, "Connection error");
241 }
242 });
243 }
244 }
245
246 async fn handle_connection<S>(
247 stream: &mut S,
248 handle: Arc<SupervisorHandle<W>>,
249 ) -> Result<(), DistributedError>
250 where
251 S: AsyncReadExt + AsyncWriteExt + Unpin,
252 {
253 let command: RemoteCommand = receive_message(stream).await?;
254 let response = Self::process_command(command, &handle).await;
255 send_message(stream, &response).await?;
256 Ok(())
257 }
258
259 async fn process_command(
260 command: RemoteCommand,
261 handle: &SupervisorHandle<W>,
262 ) -> RemoteResponse {
263 match command {
264 RemoteCommand::Shutdown => match handle.shutdown().await {
265 Ok(()) => RemoteResponse::Ok,
266 Err(e) => RemoteResponse::Error(e.to_string()),
267 },
268 RemoteCommand::WhichChildren => match handle.which_children().await {
269 Ok(children) => {
270 let children: Vec<ChildInfo> = children.into_iter().map(Into::into).collect();
271 RemoteResponse::Children(children)
272 }
273 Err(e) => RemoteResponse::Error(e.to_string()),
274 },
275 RemoteCommand::TerminateChild { id } => match handle.terminate_child(&id).await {
276 Ok(()) => RemoteResponse::Ok,
277 Err(e) => RemoteResponse::Error(e.to_string()),
278 },
279 RemoteCommand::Status => {
280 let restart_strategy = handle
281 .restart_strategy()
282 .await
283 .map(|s| format!("{:?}", s))
284 .unwrap_or_else(|_| "Unknown".to_owned());
285 let uptime_secs = handle.uptime().await.unwrap_or(0);
286
287 RemoteResponse::Status(SupervisorStatus {
288 name: handle.name().to_owned(),
289 children_count: handle.which_children().await.map(|c| c.len()).unwrap_or(0),
290 restart_strategy,
291 uptime_secs,
292 })
293 }
294 }
295 }
296}
297
298async fn send_message<S, T>(stream: &mut S, msg: &T) -> Result<(), DistributedError>
300where
301 S: AsyncWriteExt + Unpin,
302 T: Serialize,
303 for<'a> T: RkyvSerialize<
304 rkyv::api::high::HighSerializer<
305 rkyv::util::AlignedVec,
306 rkyv::ser::allocator::ArenaHandle<'a>,
307 rkyv::rancor::Error,
308 >,
309 >,
310{
311 let encoded = rkyv::to_bytes::<rkyv::rancor::Error>(msg)?;
312 let len = encoded.len() as u32;
313
314 stream.write_all(&len.to_be_bytes()).await?;
315 stream.write_all(&encoded).await?;
316 stream.flush().await?;
317
318 Ok(())
319}
320
321async fn receive_message<S, T>(stream: &mut S) -> Result<T, DistributedError>
323where
324 S: AsyncReadExt + Unpin,
325 T: Archive,
326 for<'a> T::Archived: RkyvDeserialize<T, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>,
327{
328 let mut len_bytes = [0u8; 4];
329 stream.read_exact(&mut len_bytes).await?;
330 let len = u32::from_be_bytes(len_bytes) as usize;
331
332 if len > 10_000_000 {
333 return Err(DistributedError::MessageTooLarge(len));
334 }
335
336 let mut buffer = vec![0u8; len];
337 stream.read_exact(&mut buffer).await?;
338
339 let decoded: T = unsafe { rkyv::from_bytes_unchecked::<T, rkyv::rancor::Error>(&buffer)? };
341 Ok(decoded)
342}
343
344#[derive(Debug)]
346pub enum DistributedError {
347 Io(std::io::Error),
349 Encode(rkyv::rancor::Error),
351 Decode(rkyv::rancor::Error),
353 RemoteError(String),
355 UnexpectedResponse,
357 MessageTooLarge(usize),
359}
360
361impl fmt::Display for DistributedError {
362 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363 match self {
364 DistributedError::Io(e) => write!(f, "IO error: {}", e),
365 DistributedError::Encode(e) => write!(f, "Encode error: {}", e),
366 DistributedError::Decode(e) => write!(f, "Decode error: {}", e),
367 DistributedError::RemoteError(e) => write!(f, "Remote error: {}", e),
368 DistributedError::UnexpectedResponse => write!(f, "Unexpected response from remote"),
369 DistributedError::MessageTooLarge(size) => {
370 write!(f, "Message too large: {} bytes", size)
371 }
372 }
373 }
374}
375
376impl std::error::Error for DistributedError {}
377
378impl From<std::io::Error> for DistributedError {
379 fn from(e: std::io::Error) -> Self {
380 DistributedError::Io(e)
381 }
382}
383
384impl From<rkyv::rancor::Error> for DistributedError {
385 fn from(e: rkyv::rancor::Error) -> Self {
386 DistributedError::Encode(e)
387 }
388}