1#![allow(missing_docs)]
7
8use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
9use serde::{Deserialize, Serialize};
10use std::fmt;
11use std::sync::Arc;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::{TcpListener, TcpStream};
14
15#[cfg(unix)]
16use tokio::net::{UnixListener, UnixStream};
17
18use crate::{ChildInfo as SupervisorChildInfo, ChildType, RestartPolicy, SupervisorHandle, Worker};
19
20#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
22#[rkyv(derive(Debug))]
23pub enum SupervisorAddress {
24 Tcp(String),
26 Unix(String),
28}
29
30#[allow(missing_docs)]
32#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
33#[rkyv(derive(Debug))]
34#[rkyv(attr(allow(missing_docs)))]
35pub enum RemoteCommand {
36 Shutdown,
38 WhichChildren,
40 TerminateChild {
42 id: String,
44 },
45 Status,
47}
48
49#[allow(missing_docs)]
51#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
52#[rkyv(derive(Debug))]
53#[rkyv(attr(allow(missing_docs)))]
54pub enum RemoteResponse {
55 Ok,
57 Children(Vec<ChildInfo>),
59 Status(SupervisorStatus),
61 Error(String),
63}
64
65#[allow(missing_docs)]
67#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
68#[rkyv(derive(Debug))]
69#[rkyv(attr(allow(missing_docs)))]
70pub struct ChildInfo {
71 pub id: String,
73 pub child_type: ChildType,
75 pub restart_policy: Option<RestartPolicy>,
77}
78
79impl From<SupervisorChildInfo> for ChildInfo {
80 fn from(info: SupervisorChildInfo) -> Self {
81 Self {
82 id: info.id,
83 child_type: info.child_type,
84 restart_policy: info.restart_policy,
85 }
86 }
87}
88
89#[allow(missing_docs)]
91#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
92#[rkyv(derive(Debug))]
93#[rkyv(attr(allow(missing_docs)))]
94pub struct SupervisorStatus {
95 pub name: String,
97 pub children_count: usize,
99 pub restart_strategy: String,
101 pub uptime_secs: u64,
103}
104
105pub struct RemoteSupervisorHandle {
107 address: SupervisorAddress,
108}
109
110impl RemoteSupervisorHandle {
111 pub fn new(address: SupervisorAddress) -> Self {
113 Self { address }
114 }
115
116 pub async fn connect_tcp(addr: impl Into<String>) -> Result<Self, DistributedError> {
118 let address = SupervisorAddress::Tcp(addr.into());
119 Ok(Self { address })
120 }
121
122 pub async fn connect_unix(path: impl Into<String>) -> Result<Self, DistributedError> {
124 let address = SupervisorAddress::Unix(path.into());
125 Ok(Self { address })
126 }
127
128 pub async fn send_command(
130 &self,
131 cmd: RemoteCommand,
132 ) -> Result<RemoteResponse, DistributedError> {
133 match &self.address {
134 SupervisorAddress::Tcp(addr) => {
135 let mut stream = TcpStream::connect(addr).await?;
136 send_message(&mut stream, &cmd).await?;
137 receive_message(&mut stream).await
138 }
139 #[cfg(unix)]
140 SupervisorAddress::Unix(path) => {
141 let mut stream = UnixStream::connect(path).await?;
142 send_message(&mut stream, &cmd).await?;
143 receive_message(&mut stream).await
144 }
145 #[cfg(not(unix))]
146 SupervisorAddress::Unix(_) => Err(DistributedError::Io(std::io::Error::new(
147 std::io::ErrorKind::Unsupported,
148 "Unix sockets are not supported on this platform",
149 ))),
150 }
151 }
152
153 pub async fn shutdown(&self) -> Result<(), DistributedError> {
155 self.send_command(RemoteCommand::Shutdown).await?;
156 Ok(())
157 }
158
159 pub async fn which_children(&self) -> Result<Vec<ChildInfo>, DistributedError> {
161 match self.send_command(RemoteCommand::WhichChildren).await? {
162 RemoteResponse::Children(children) => Ok(children),
163 RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
164 _ => Err(DistributedError::UnexpectedResponse),
165 }
166 }
167
168 pub async fn terminate_child(&self, id: &str) -> Result<(), DistributedError> {
170 match self
171 .send_command(RemoteCommand::TerminateChild { id: id.to_owned() })
172 .await?
173 {
174 RemoteResponse::Ok => Ok(()),
175 RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
176 _ => Err(DistributedError::UnexpectedResponse),
177 }
178 }
179
180 pub async fn status(&self) -> Result<SupervisorStatus, DistributedError> {
182 match self.send_command(RemoteCommand::Status).await? {
183 RemoteResponse::Status(status) => Ok(status),
184 RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
185 _ => Err(DistributedError::UnexpectedResponse),
186 }
187 }
188}
189
190pub struct SupervisorServer<W: Worker> {
192 handle: Arc<SupervisorHandle<W>>,
193}
194
195impl<W: Worker> SupervisorServer<W> {
196 pub fn new(handle: SupervisorHandle<W>) -> Self {
198 Self {
199 handle: Arc::new(handle),
200 }
201 }
202
203 #[cfg(unix)]
205 pub async fn listen_unix(
206 self,
207 path: impl AsRef<std::path::Path>,
208 ) -> Result<(), DistributedError> {
209 let path = path.as_ref();
210 let _ = std::fs::remove_file(path); let listener = UnixListener::bind(path)?;
213 slog::info!(slog_scope::logger(), "server listening on unix socket";
214 "path" => %path.display()
215 );
216
217 loop {
218 let (mut stream, _) = listener.accept().await?;
219 let handle = Arc::clone(&self.handle);
220
221 tokio::spawn(async move {
222 if let Err(e) = Self::handle_connection(&mut stream, handle).await {
223 slog::error!(slog_scope::logger(), "connection error";
224 "error" => %e
225 );
226 }
227 });
228 }
229 }
230
231 pub async fn listen_tcp(self, addr: impl AsRef<str>) -> Result<(), DistributedError> {
233 let listener = TcpListener::bind(addr.as_ref()).await?;
234 slog::info!(slog_scope::logger(), "server listening on tcp";
235 "address" => addr.as_ref()
236 );
237
238 loop {
239 let (mut stream, peer) = listener.accept().await?;
240 slog::debug!(slog_scope::logger(), "new connection";
241 "peer" => ?peer
242 );
243 let handle = Arc::clone(&self.handle);
244
245 tokio::spawn(async move {
246 if let Err(e) = Self::handle_connection(&mut stream, handle).await {
247 slog::error!(slog_scope::logger(), "Connection error";
248 "error" => %e
249 );
250 }
251 });
252 }
253 }
254
255 async fn handle_connection<S>(
256 stream: &mut S,
257 handle: Arc<SupervisorHandle<W>>,
258 ) -> Result<(), DistributedError>
259 where
260 S: AsyncReadExt + AsyncWriteExt + Unpin,
261 {
262 let command: RemoteCommand = receive_message(stream).await?;
263 let response = Self::process_command(command, &handle).await;
264 send_message(stream, &response).await?;
265 Ok(())
266 }
267
268 async fn process_command(
269 command: RemoteCommand,
270 handle: &SupervisorHandle<W>,
271 ) -> RemoteResponse {
272 match command {
273 RemoteCommand::Shutdown => match handle.shutdown().await {
274 Ok(()) => RemoteResponse::Ok,
275 Err(e) => RemoteResponse::Error(e.to_string()),
276 },
277 RemoteCommand::WhichChildren => match handle.which_children().await {
278 Ok(children) => {
279 let children: Vec<ChildInfo> = children.into_iter().map(Into::into).collect();
280 RemoteResponse::Children(children)
281 }
282 Err(e) => RemoteResponse::Error(e.to_string()),
283 },
284 RemoteCommand::TerminateChild { id } => match handle.terminate_child(&id).await {
285 Ok(()) => RemoteResponse::Ok,
286 Err(e) => RemoteResponse::Error(e.to_string()),
287 },
288 RemoteCommand::Status => {
289 let restart_strategy = handle
290 .restart_strategy()
291 .await
292 .map(|s| format!("{:?}", s))
293 .unwrap_or_else(|_| "Unknown".to_owned());
294 let uptime_secs = handle.uptime().await.unwrap_or(0);
295
296 RemoteResponse::Status(SupervisorStatus {
297 name: handle.name().to_owned(),
298 children_count: handle.which_children().await.map(|c| c.len()).unwrap_or(0),
299 restart_strategy,
300 uptime_secs,
301 })
302 }
303 }
304 }
305}
306
307async fn send_message<S, T>(stream: &mut S, msg: &T) -> Result<(), DistributedError>
309where
310 S: AsyncWriteExt + Unpin,
311 T: Serialize,
312 for<'a> T: RkyvSerialize<
313 rkyv::api::high::HighSerializer<
314 rkyv::util::AlignedVec,
315 rkyv::ser::allocator::ArenaHandle<'a>,
316 rkyv::rancor::Error,
317 >,
318 >,
319{
320 let encoded = rkyv::to_bytes::<rkyv::rancor::Error>(msg)?;
321 let len = encoded.len() as u32;
322
323 stream.write_all(&len.to_be_bytes()).await?;
324 stream.write_all(&encoded).await?;
325 stream.flush().await?;
326
327 Ok(())
328}
329
330async fn receive_message<S, T>(stream: &mut S) -> Result<T, DistributedError>
332where
333 S: AsyncReadExt + Unpin,
334 T: Archive,
335 for<'a> T::Archived: RkyvDeserialize<T, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>,
336{
337 let mut len_bytes = [0u8; 4];
338 stream.read_exact(&mut len_bytes).await?;
339 let len = u32::from_be_bytes(len_bytes) as usize;
340
341 if len > 10_000_000 {
342 return Err(DistributedError::MessageTooLarge(len));
343 }
344
345 let mut buffer = vec![0u8; len];
346 stream.read_exact(&mut buffer).await?;
347
348 let decoded: T = unsafe { rkyv::from_bytes_unchecked::<T, rkyv::rancor::Error>(&buffer)? };
350 Ok(decoded)
351}
352
353#[derive(Debug)]
355pub enum DistributedError {
356 Io(std::io::Error),
358 Encode(rkyv::rancor::Error),
360 Decode(rkyv::rancor::Error),
362 RemoteError(String),
364 UnexpectedResponse,
366 MessageTooLarge(usize),
368}
369
370impl fmt::Display for DistributedError {
371 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372 match self {
373 DistributedError::Io(e) => write!(f, "IO error: {}", e),
374 DistributedError::Encode(e) => write!(f, "Encode error: {}", e),
375 DistributedError::Decode(e) => write!(f, "Decode error: {}", e),
376 DistributedError::RemoteError(e) => write!(f, "Remote error: {}", e),
377 DistributedError::UnexpectedResponse => write!(f, "Unexpected response from remote"),
378 DistributedError::MessageTooLarge(size) => {
379 write!(f, "Message too large: {} bytes", size)
380 }
381 }
382 }
383}
384
385impl std::error::Error for DistributedError {}
386
387impl From<std::io::Error> for DistributedError {
388 fn from(e: std::io::Error) -> Self {
389 DistributedError::Io(e)
390 }
391}
392
393impl From<rkyv::rancor::Error> for DistributedError {
394 fn from(e: rkyv::rancor::Error) -> Self {
395 DistributedError::Encode(e)
396 }
397}