Skip to main content

rivet_core/
application_builder.rs

1use std::path::PathBuf;
2use std::sync::Arc;
3use std::{
4    net::{SocketAddr, ToSocketAddrs},
5    str::FromStr,
6};
7
8use rivet_foundation::{
9    Config, ConfigRepository, ConfigValue, ContainerRegistrationExt, InMemoryContainer,
10    SharedConfigRepository,
11};
12use rivet_http::Method;
13use rivet_logger::LogService;
14use rivet_routing::Registry;
15
16use crate::builder::{Builder, ProviderFactory};
17use crate::error::RivetError;
18use crate::middleware::Middleware;
19use crate::module::RivetModule;
20use crate::Application;
21
22type ApiRoutes = dyn Fn(&mut Registry) -> Result<(), RivetError> + Send + Sync + 'static;
23
24#[derive(Default)]
25pub struct RoutingConfig {
26    route_files: Vec<(String, Box<ApiRoutes>)>,
27    health_path: Option<String>,
28}
29
30impl RoutingConfig {
31    pub fn web<F>(&mut self, routes: F) -> &mut Self
32    where
33        F: Fn(&mut Registry) -> Result<(), RivetError> + Send + Sync + 'static,
34    {
35        self.file("web", routes)
36    }
37
38    pub fn api<F>(&mut self, routes: F) -> &mut Self
39    where
40        F: Fn(&mut Registry) -> Result<(), RivetError> + Send + Sync + 'static,
41    {
42        self.file("api", routes)
43    }
44
45    pub fn file<F>(&mut self, name: impl Into<String>, routes: F) -> &mut Self
46    where
47        F: Fn(&mut Registry) -> Result<(), RivetError> + Send + Sync + 'static,
48    {
49        self.route_files.push((name.into(), Box::new(routes)));
50        self
51    }
52
53    pub fn health(&mut self, path: impl Into<String>) -> &mut Self {
54        self.health_path = Some(path.into());
55        self
56    }
57}
58
59#[derive(Default)]
60pub struct MiddlewareConfig {
61    entries: Vec<Arc<dyn Middleware>>,
62}
63
64impl MiddlewareConfig {
65    pub fn append<M>(&mut self, middleware: M) -> &mut Self
66    where
67        M: Middleware + 'static,
68    {
69        self.entries.push(Arc::new(middleware));
70        self
71    }
72
73    pub fn append_arc(&mut self, middleware: Arc<dyn Middleware>) -> &mut Self {
74        self.entries.push(middleware);
75        self
76    }
77}
78
79#[derive(Default, Clone)]
80pub struct ExceptionConfig {
81    reports: Vec<String>,
82}
83
84impl ExceptionConfig {
85    pub fn report(&mut self, label: impl Into<String>) -> &mut Self {
86        self.reports.push(label.into());
87        self
88    }
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum Frequency {
93    Hourly,
94}
95
96#[derive(Debug, Clone, PartialEq, Eq)]
97pub struct ScheduledCommand {
98    command: String,
99    frequency: Frequency,
100}
101
102#[derive(Default, Clone)]
103pub struct ScheduleConfig {
104    commands: Vec<ScheduledCommand>,
105}
106
107pub struct ScheduleCommandBuilder<'a> {
108    schedule: &'a mut ScheduleConfig,
109    command: String,
110}
111
112impl<'a> ScheduleCommandBuilder<'a> {
113    pub fn hourly(self) {
114        self.schedule.commands.push(ScheduledCommand {
115            command: self.command,
116            frequency: Frequency::Hourly,
117        });
118    }
119}
120
121impl ScheduleConfig {
122    pub fn command(&mut self, command: impl Into<String>) -> ScheduleCommandBuilder<'_> {
123        ScheduleCommandBuilder {
124            schedule: self,
125            command: command.into(),
126        }
127    }
128}
129
130pub struct ApplicationBuilder {
131    base_path: PathBuf,
132    config: SharedConfigRepository,
133    modules: Vec<Box<dyn RivetModule>>,
134    provider_factories: Vec<ProviderFactory>,
135    routing: RoutingConfig,
136    middleware: MiddlewareConfig,
137    exceptions: ExceptionConfig,
138    schedule: ScheduleConfig,
139}
140
141impl ApplicationBuilder {
142    pub fn configure(base_path: impl Into<PathBuf>, config: ConfigRepository) -> Self {
143        Self {
144            base_path: base_path.into(),
145            config: SharedConfigRepository::new(config),
146            modules: Vec::new(),
147            provider_factories: Vec::new(),
148            routing: RoutingConfig::default(),
149            middleware: MiddlewareConfig::default(),
150            exceptions: ExceptionConfig::default(),
151            schedule: ScheduleConfig::default(),
152        }
153    }
154
155    pub fn with_module(mut self, module: Box<dyn RivetModule>) -> Self {
156        self.modules.push(module);
157        self
158    }
159
160    pub fn with_provider(mut self, provider: ProviderFactory) -> Self {
161        self.provider_factories.push(provider);
162        self
163    }
164
165    pub fn with_providers(mut self, providers: &[ProviderFactory]) -> Self {
166        self.provider_factories.extend(providers.iter().copied());
167        self
168    }
169
170    pub fn with_routing<F>(mut self, configure: F) -> Self
171    where
172        F: FnOnce(&mut RoutingConfig),
173    {
174        configure(&mut self.routing);
175        self
176    }
177
178    pub fn with_middleware<F>(mut self, configure: F) -> Self
179    where
180        F: FnOnce(&mut MiddlewareConfig),
181    {
182        configure(&mut self.middleware);
183        self
184    }
185
186    pub fn with_exceptions<F>(mut self, configure: F) -> Self
187    where
188        F: FnOnce(&mut ExceptionConfig),
189    {
190        configure(&mut self.exceptions);
191        self
192    }
193
194    pub fn with_schedule<F>(mut self, configure: F) -> Self
195    where
196        F: FnOnce(&mut ScheduleConfig),
197    {
198        configure(&mut self.schedule);
199        self
200    }
201
202    pub fn build(self) -> Result<Application, RivetError> {
203        let container = Arc::new(InMemoryContainer::new());
204        container
205            .singleton(self.config.clone())
206            .map_err(|err| RivetError::Build(err.to_string()))?;
207        container
208            .singleton(LogService)
209            .map_err(|err| RivetError::Build(err.to_string()))?;
210        container
211            .bind_names("log", core::any::type_name::<LogService>())
212            .map_err(|err| RivetError::Build(err.to_string()))?;
213
214        let mut builder = Builder::new(
215            Arc::clone(&container) as Arc<dyn rivet_foundation::Container>,
216            Box::new(self.config),
217        );
218
219        for module in self.modules {
220            builder = builder.with_module(module);
221        }
222        builder = builder.with_providers(&self.provider_factories);
223
224        builder = builder.with_module(Box::new(BootstrapModule {
225            base_path: self.base_path,
226            routing: self.routing,
227            middleware: self.middleware.entries,
228            exceptions: self.exceptions,
229            schedule: self.schedule,
230        }));
231
232        builder.build()
233    }
234}
235
236struct BootstrapModule {
237    base_path: PathBuf,
238    routing: RoutingConfig,
239    middleware: Vec<Arc<dyn Middleware>>,
240    exceptions: ExceptionConfig,
241    schedule: ScheduleConfig,
242}
243
244impl RivetModule for BootstrapModule {
245    fn name(&self) -> &'static str {
246        "application-bootstrap"
247    }
248
249    fn configure(&self, config: &mut dyn Config) -> Result<(), RivetError> {
250        Self::ensure_server_addr(config)?;
251
252        config
253            .set(
254                "app.base_path",
255                ConfigValue::String(self.base_path.to_string_lossy().to_string()),
256            )
257            .map_err(|err| RivetError::Build(err.to_string()))?;
258        config
259            .set(
260                "app.exceptions.count",
261                ConfigValue::Integer(self.exceptions.reports.len() as i64),
262            )
263            .map_err(|err| RivetError::Build(err.to_string()))?;
264        config
265            .set(
266                "app.schedule.count",
267                ConfigValue::Integer(self.schedule.commands.len() as i64),
268            )
269            .map_err(|err| RivetError::Build(err.to_string()))?;
270
271        Ok(())
272    }
273
274    fn routes(&self, routes: &mut Registry) -> Result<(), RivetError> {
275        if let Some(path) = &self.routing.health_path {
276            routes
277                .add_route(Method::Get, path.clone())
278                .map_err(|err| RivetError::Routing(err.to_string()))?;
279        }
280
281        for (name, route_file) in &self.routing.route_files {
282            route_file(routes).map_err(|err| {
283                RivetError::Routing(format!("route file '{name}' failed: {err}"))
284            })?;
285        }
286
287        Ok(())
288    }
289
290    fn middleware(&self) -> Vec<Arc<dyn Middleware>> {
291        self.middleware.clone()
292    }
293}
294
295impl BootstrapModule {
296    fn ensure_server_addr(config: &mut dyn Config) -> Result<(), RivetError> {
297        if let Some(addr) = Self::optional_string(config, "app.server.addr")? {
298            SocketAddr::from_str(&addr).map_err(|err| {
299                RivetError::Build(format!("invalid 'app.server.addr': {err}"))
300            })?;
301            Self::set_server_url_if_missing(config, &addr)?;
302            return Ok(());
303        }
304
305        if let Some(server_url) = Self::optional_string(config, "app.server.url")? {
306            let authority = Self::url_authority(&server_url);
307            let addr = Self::resolve_addr(authority).ok_or_else(|| {
308                RivetError::Build(format!("invalid 'app.server.url' value '{server_url}'"))
309            })?;
310            config
311                .set("app.server.addr", ConfigValue::String(addr.to_string()))
312                .map_err(|err| RivetError::Build(err.to_string()))?;
313            return Ok(());
314        }
315
316        let app_url = match config.get("app.url") {
317            Some(ConfigValue::String(value)) => value,
318            Some(other) => {
319                return Err(RivetError::Build(format!(
320                    "expected 'app.url' string config, got {other:?}"
321                )))
322            }
323            None => {
324                return Err(RivetError::Build(
325                    "missing 'app.url' config".to_string(),
326                ))
327            }
328        };
329
330        if let Some(addr) = Self::resolve_addr(&app_url) {
331            config
332                .set("app.server.addr", ConfigValue::String(addr.to_string()))
333                .map_err(|err| RivetError::Build(err.to_string()))?;
334            Self::set_server_url_if_missing(config, &addr.to_string())?;
335            return Ok(());
336        }
337
338        let authority = Self::url_authority(&app_url);
339        if let Some(addr) = Self::resolve_addr(authority) {
340            config
341                .set("app.server.addr", ConfigValue::String(addr.to_string()))
342                .map_err(|err| RivetError::Build(err.to_string()))?;
343            Self::set_server_url_if_missing(config, &addr.to_string())?;
344            return Ok(());
345        }
346
347        let port = match config.get("app.port") {
348            Some(ConfigValue::Integer(value)) => u16::try_from(value).map_err(|_| {
349                RivetError::Build(format!(
350                    "expected 'app.port' integer in range 0..=65535, got {value}"
351                ))
352            })?,
353            Some(ConfigValue::String(value)) => value.parse::<u16>().map_err(|err| {
354                RivetError::Build(format!("invalid 'app.port' string value '{value}': {err}"))
355            })?,
356            Some(other) => {
357                return Err(RivetError::Build(format!(
358                    "expected 'app.port' integer/string config, got {other:?}"
359                )))
360            }
361            None => 7057,
362        };
363
364        let host = Self::strip_port(authority);
365        let candidate = format!("{host}:{port}");
366        let addr = Self::resolve_addr(&candidate).ok_or_else(|| {
367            RivetError::Build(format!(
368                "invalid 'app.url'/'app.port' combination: '{candidate}'"
369            ))
370        })?;
371
372        config
373            .set("app.server.addr", ConfigValue::String(addr.to_string()))
374            .map_err(|err| RivetError::Build(err.to_string()))?;
375        Self::set_server_url_if_missing(config, &addr.to_string())
376    }
377
378    fn resolve_addr(value: &str) -> Option<SocketAddr> {
379        SocketAddr::from_str(value)
380            .ok()
381            .or_else(|| value.to_socket_addrs().ok()?.next())
382    }
383
384    fn url_authority(url: &str) -> &str {
385        let without_scheme = url.split_once("://").map_or(url, |(_, rest)| rest);
386        without_scheme
387            .split_once(['/', '?', '#'])
388            .map_or(without_scheme, |(authority, _)| authority)
389    }
390
391    fn strip_port(authority: &str) -> &str {
392        if authority.starts_with('[') {
393            if let Some(end) = authority.find(']') {
394                return &authority[..=end];
395            }
396            return authority;
397        }
398
399        if let Some((host, port)) = authority.rsplit_once(':') {
400            if port.parse::<u16>().is_ok() {
401                return host;
402            }
403        }
404
405        authority
406    }
407
408    fn optional_string(config: &dyn Config, key: &str) -> Result<Option<String>, RivetError> {
409        match config.get(key) {
410            Some(ConfigValue::String(value)) => Ok(Some(value)),
411            Some(other) => Err(RivetError::Build(format!(
412                "expected '{key}' string config, got {other:?}"
413            ))),
414            None => Ok(None),
415        }
416    }
417
418    fn set_server_url_if_missing(
419        config: &mut dyn Config,
420        addr: &str,
421    ) -> Result<(), RivetError> {
422        match config.get("app.server.url") {
423            Some(ConfigValue::String(_)) => Ok(()),
424            Some(other) => Err(RivetError::Build(format!(
425                "expected 'app.server.url' string config, got {other:?}"
426            ))),
427            None => config
428                .set(
429                    "app.server.url",
430                    ConfigValue::String(format!("http://{addr}")),
431                )
432                .map_err(|err| RivetError::Build(err.to_string())),
433        }
434    }
435}