Skip to main content

rivet_core/
application.rs

1use std::any::Any;
2use std::collections::BTreeMap;
3use std::net::SocketAddr;
4use std::path::PathBuf;
5use std::sync::{Arc, OnceLock};
6
7use rivet_foundation::{ConfigValue, ContainerResolveExt, FromConfigValue};
8use rivet_http::{Method, Request, Response};
9use rivet_logger::handlers::build_handler_from_config;
10use rivet_logger::{
11    init_default_tracing, set_channel_handler_build_config, set_channel_handlers, set_handler, Log,
12};
13use rivet_routing::Match;
14use tracing::{debug, warn};
15
16use crate::dispatch::Dispatcher;
17use crate::error::DispatchError;
18use crate::middleware::Middleware;
19use crate::state::Runtime;
20
21/// Built rivet runtime that implements the dispatch contract.
22pub struct Application {
23    state: Runtime,
24}
25
26static CURRENT_APP: OnceLock<Arc<dyn Dispatcher>> = OnceLock::new();
27
28impl Application {
29    pub(crate) fn from_state(state: Runtime) -> Self {
30        Self { state }
31    }
32
33    /// Number of routes registered in the runtime.
34    pub fn route_count(&self) -> usize {
35        self.state.route_count()
36    }
37
38    /// Number of middleware instances in the runtime chain.
39    pub fn middleware_count(&self) -> usize {
40        self.state.middleware_count()
41    }
42
43    /// Resolve a typed service instance from the application container.
44    pub fn make<T: Any + Send + Sync + 'static>(
45        &self,
46    ) -> Result<Arc<T>, crate::error::RivetError> {
47        self.state.container().resolve::<T>().map_err(Into::into)
48    }
49
50    /// Read a configuration value by dot-notation key.
51    pub fn config(&self, key: &str) -> Option<ConfigValue> {
52        self.state.config().get(key)
53    }
54
55    /// Read and convert a configuration value by dot-notation key.
56    pub fn config_typed<T: FromConfigValue>(&self, key: &str) -> Option<T> {
57        self.config(key).and_then(T::from_config_value)
58    }
59
60    /// Initialize rivet logging runtime and configured handlers.
61    pub fn init_logging(&self) {
62        init_default_tracing();
63        self.init_log_handler();
64    }
65
66    /// Log that the application is starting to serve requests.
67    pub fn log_serving(&self, addr: SocketAddr) {
68        self.init_logging();
69        Log::channel("stdout").info(format!("Listening on http://{addr}"));
70    }
71
72    fn init_log_handler(&self) {
73        let Some(log_config) = self.config_typed::<ConfigValue>("log") else {
74            return;
75        };
76
77        let base_path = self
78            .config_typed::<String>("app.base_path")
79            .map(PathBuf::from)
80            .unwrap_or_else(|| PathBuf::from("."));
81
82        match build_handler_from_config(&log_config, &base_path) {
83            Ok(handler) => set_handler(handler),
84            Err(err) => warn!(
85                target: "rivet",
86                error = %err,
87                "failed to initialize configured log handler"
88            ),
89        }
90
91        set_channel_handlers(BTreeMap::new());
92        set_channel_handler_build_config(log_config, base_path);
93    }
94}
95
96pub fn set_current(app: Arc<dyn Dispatcher>) -> Result<(), crate::error::RivetError> {
97    CURRENT_APP.set(app).map_err(|_| {
98        crate::error::RivetError::Build("application global is already initialized".to_string())
99    })
100}
101
102pub fn app() -> Arc<dyn Dispatcher> {
103    CURRENT_APP
104        .get()
105        .cloned()
106        .expect("application global not initialized")
107}
108
109pub fn try_app() -> Option<Arc<dyn Dispatcher>> {
110    CURRENT_APP.get().cloned()
111}
112
113pub fn config<T: FromConfigValue>(key: &str) -> Option<T> {
114    app().config_value(key).and_then(T::from_config_value)
115}
116
117impl Dispatcher for Application {
118    fn dispatch(&self, req: Request) -> Response {
119        let _ = self.state.container();
120        let _ = self.state.config();
121        let method = req.method.clone();
122        let route = req.path.clone();
123
124        debug!(
125            method = ?req.method,
126            path = %req.path,
127            middleware_count = self.state.middleware_count(),
128            "dispatch started"
129        );
130
131        let terminal = self.route_terminal(&req);
132
133        let response = Self::run_middleware_chain(self.state.middleware(), req, terminal.as_ref());
134        Log::info(format!(
135            "{} {} {}",
136            method_label(&method),
137            route,
138            response.status
139        ));
140        debug!(status = response.status, "dispatch completed");
141        response
142    }
143
144    fn config_value(&self, key: &str) -> Option<ConfigValue> {
145        self.config(key)
146    }
147
148    fn resolve_any(
149        &self,
150        type_name: &'static str,
151    ) -> Result<Arc<dyn Any + Send + Sync>, crate::error::RivetError> {
152        self.state
153            .container()
154            .resolve_any(type_name)
155            .map_err(Into::into)
156    }
157}
158
159fn method_label(method: &Method) -> &'static str {
160    match method {
161        Method::Get => "GET",
162        Method::Post => "POST",
163        Method::Put => "PUT",
164        Method::Patch => "PATCH",
165        Method::Delete => "DELETE",
166        Method::Head => "HEAD",
167        Method::Options => "OPTIONS",
168    }
169}
170
171impl Application {
172    fn route_terminal(
173        &self,
174        req: &Request,
175    ) -> Box<dyn Fn(Request) -> Response + Send + Sync + 'static> {
176        match self.state.routes().match_request(&req.method, &req.path) {
177            Match::Matched {
178                route,
179                head_fallback,
180            } => {
181                debug!(
182                    route_method = ?route.method,
183                    route_path = %route.path,
184                    head_fallback,
185                    "route matched"
186                );
187
188                let route = route.clone();
189                Box::new(move |req: Request| {
190                    let mut response = route.invoke(req);
191                    if head_fallback {
192                        response.body.clear();
193                    }
194                    response
195                })
196            }
197            Match::MethodNotAllowed { allow } => {
198                warn!(allow = ?allow, "route method not allowed");
199                let err = DispatchError::MethodNotAllowed { allow };
200                Box::new(move |_req: Request| {
201                    warn!(error = %err, "dispatch terminal returned error");
202                    err.clone().into_response()
203                })
204            }
205            Match::NotFound => {
206                let err = DispatchError::NotFound;
207                debug!("route not found");
208                Box::new(move |_req: Request| {
209                    warn!(error = %err, "dispatch terminal returned error");
210                    err.clone().into_response()
211                })
212            }
213        }
214    }
215
216    fn run_middleware_chain(
217        middlewares: &[Arc<dyn Middleware>],
218        req: Request,
219        terminal: &(dyn Fn(Request) -> Response + Send + Sync),
220    ) -> Response {
221        debug!(
222            middleware_count = middlewares.len(),
223            "executing middleware chain"
224        );
225        Self::run_middleware_at(0, middlewares, req, terminal)
226    }
227
228    fn run_middleware_at(
229        index: usize,
230        middlewares: &[Arc<dyn Middleware>],
231        req: Request,
232        terminal: &(dyn Fn(Request) -> Response + Send + Sync),
233    ) -> Response {
234        if index >= middlewares.len() {
235            debug!("reached terminal middleware boundary");
236            return terminal(req);
237        }
238
239        let middleware = Arc::clone(&middlewares[index]);
240        debug!(middleware_index = index, "running middleware");
241        let next =
242            |next_req: Request| Self::run_middleware_at(index + 1, middlewares, next_req, terminal);
243
244        middleware.handle(req, &next)
245    }
246}