hydra/
application.rs

1use std::future::Future;
2
3use serde::Deserialize;
4use serde::Serialize;
5
6use tokio::runtime::Builder;
7use tokio::runtime::Runtime;
8use tokio::sync::oneshot;
9
10use crate::ApplicationConfig;
11use crate::ExitReason;
12use crate::Message;
13use crate::Pid;
14use crate::Process;
15use crate::ProcessFlags;
16use crate::SystemMessage;
17
18#[cfg(feature = "console")]
19use crate::ConsoleServer;
20
21/// Messages used internally by [Application].
22#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
23enum ApplicationMessage {
24    ShutdownTimeout,
25}
26
27/// Main application logic and entry point for a hydra program.
28///
29/// [Application] provides graceful shutdown by allowing you to link a process inside the call to `start`.
30/// The `run` call will only return once that process has terminated. It's recommended to link a supervisor.
31pub trait Application: Sized + Send + 'static {
32    /// Override to change the application configuration defaults.
33    fn config() -> ApplicationConfig {
34        ApplicationConfig::default()
35    }
36
37    /// Called when an application is starting. You should link a process here and return it's [Pid].
38    ///
39    /// The [Application] will wait for that process to exit before returning from `run`.
40    fn start(&self) -> impl Future<Output = Result<Pid, ExitReason>> + Send;
41
42    /// Runs the [Application] to completion.
43    ///
44    /// This method will return when the linked process created in `start` has exited.
45    fn run(self) {
46        use ApplicationMessage::*;
47
48        let config = Self::config();
49
50        #[cfg(feature = "tracing")]
51        if config.tracing_subscribe {
52            use std::sync::Once;
53
54            static TRACING_SUBSCRIBE_ONCE: Once = Once::new();
55
56            TRACING_SUBSCRIBE_ONCE.call_once(|| {
57                tracing_subscriber::fmt::init();
58            });
59        }
60
61        #[allow(unused_mut)]
62        let mut prev_hook: Option<_> = None;
63
64        #[cfg(feature = "tracing")]
65        if config.tracing_panics {
66            prev_hook = Some(std::panic::take_hook());
67
68            std::panic::set_hook(Box::new(panic_hook));
69        }
70
71        let rt = Runtime::new().unwrap();
72
73        rt.block_on(async move {
74            let (tx, rx) = oneshot::channel();
75
76            Process::spawn(async move {
77                Process::set_flags(ProcessFlags::TRAP_EXIT);
78
79                #[cfg(feature="console")]
80                let mut cpid = ConsoleServer::new()
81                                        .start_link()
82                                        .await
83                                        .expect("Failed to start console server!");
84
85                match self.start().await {
86                    Ok(pid) => {
87                        #[cfg(feature = "tracing")]
88                        tracing::info!(supervisor = ?pid, "Application supervisor has started");
89
90                        let spid = if config.graceful_shutdown {
91                            Some(Process::spawn_link(signal_handler()))
92                        } else {
93                            None
94                        };
95
96                        loop {
97                            let message = Process::receive::<ApplicationMessage>().await;
98
99                            match message {
100                                Message::User(ShutdownTimeout) => {
101                                    #[cfg(feature = "tracing")]
102                                    tracing::error!(timeout = ?config.graceful_shutdown_timeout, "Application failed to shutdown gracefully");
103
104                                    Process::exit(pid, ExitReason::Kill);
105                                }
106                                Message::System(SystemMessage::Exit(epid, ereason)) => {
107                                    if epid == pid {
108                                        if ereason.is_custom() && ereason != "shutdown" {
109                                            #[cfg(feature = "tracing")]
110                                            tracing::error!(reason = ?ereason, supervisor = ?pid, "Application supervisor has terminated");
111                                        } else {
112                                            #[cfg(feature = "tracing")]
113                                            tracing::info!(reason = ?ereason, supervisor = ?pid, "Application supervisor has exited");
114                                        }
115                                        break;
116                                    } else if spid.is_some_and(|spid| spid == epid) {
117                                        #[cfg(feature = "tracing")]
118                                        tracing::info!(reason = ?ereason, supervisor = ?pid, timeout = ?config.graceful_shutdown_timeout, "Application starting graceful shutdown");
119
120                                        Process::exit(pid, ExitReason::from("shutdown"));
121                                        Process::send_after(Process::current(), ShutdownTimeout, config.graceful_shutdown_timeout);
122                                    }
123
124                                    #[cfg(feature = "console")]
125                                    if cpid == epid && ereason != "shutdown" {
126                                        cpid = ConsoleServer::new()
127                                                        .start_link()
128                                                        .await
129                                                        .expect("Failed to restart console server!");
130                                    }
131                                }
132                                _ => continue,
133                            }
134                        }
135                    }
136                    Err(reason) => {
137                        #[cfg(feature = "tracing")]
138                        tracing::error!(reason = ?reason, "Application supervisor failed to start");
139
140                        #[cfg(not(feature = "tracing"))]
141                        let _ = reason;
142                    }
143                }
144
145                tx.send(()).unwrap();
146            });
147
148            let _ = rx.await;
149        });
150
151        if let Some(prev_hook) = prev_hook {
152            std::panic::set_hook(prev_hook);
153        }
154    }
155
156    /// Runs the [Application] to completion for tests.
157    ///
158    /// This method will panic if the process doesn't cleanly exit with `normal` or `shutdown` reasons.
159    fn test(self) {
160        let rt = Builder::new_current_thread().enable_all().build().unwrap();
161
162        rt.block_on(async move {
163            let (tx, rx) = oneshot::channel();
164
165            Process::spawn(async move {
166                Process::set_flags(ProcessFlags::TRAP_EXIT);
167
168                match self.start().await {
169                    Ok(pid) => loop {
170                        let message = Process::receive::<()>().await;
171
172                        match message {
173                            Message::System(SystemMessage::Exit(epid, ereason)) => {
174                                if epid == pid {
175                                    tx.send(Some(ereason)).unwrap();
176                                    break;
177                                }
178                            }
179                            _ => continue,
180                        }
181                    },
182                    Err(reason) => {
183                        tx.send(Some(reason)).unwrap();
184                    }
185                }
186            });
187
188            if let Ok(Some(reason)) = rx.await
189                && reason.is_custom()
190                && reason != "shutdown"
191            {
192                panic!("Exited: {:?}", reason);
193            }
194        });
195    }
196}
197
198/// Handles SIGTERM and ctrl+c signals on unix-like platforms.
199#[cfg(unix)]
200async fn signal_handler() {
201    use tokio::signal::unix;
202
203    let mut sigterm =
204        unix::signal(unix::SignalKind::terminate()).expect("Failed to register SIGTERM handler");
205
206    tokio::select! {
207        _ = sigterm.recv() => {
208            Process::exit(Process::current(), ExitReason::from("sigterm"));
209        }
210        _ = tokio::signal::ctrl_c() => {
211            Process::exit(Process::current(), ExitReason::from("ctrl_c"));
212        }
213    }
214}
215
216/// Handles ctrl+c signals on non-unix-like platforms.
217#[cfg(not(unix))]
218async fn signal_handler() {
219    let _ = tokio::signal::ctrl_c().await;
220
221    Process::exit(Process::current(), ExitReason::from("ctrl_c"));
222}
223
224/// Handles forwarding panic messages through tracing when enabled.
225#[cfg(feature = "tracing")]
226fn panic_hook(panic_info: &std::panic::PanicHookInfo) {
227    use std::backtrace::Backtrace;
228    use std::backtrace::BacktraceStatus;
229
230    use tracing::*;
231
232    let payload = panic_info.payload();
233
234    #[allow(clippy::manual_map)]
235    let payload = if let Some(s) = payload.downcast_ref::<&str>() {
236        Some(&**s)
237    } else if let Some(s) = payload.downcast_ref::<String>() {
238        Some(s.as_str())
239    } else {
240        None
241    };
242
243    let location = panic_info.location().map(|location| location.to_string());
244
245    let backtrace = Backtrace::capture();
246    let backtrace = if backtrace.status() == BacktraceStatus::Disabled {
247        String::from("run with RUST_BACKTRACE=1 environment variable to display a backtrace")
248    } else {
249        field::display(backtrace).to_string()
250    };
251
252    event!(
253        target: "hydra",
254        Level::ERROR,
255        payload = payload,
256        location = location,
257        backtrace = ?backtrace,
258        "A process has panicked",
259    );
260}