1use std::path::PathBuf;
2use std::sync::Arc;
3use std::{
4 net::{SocketAddr, ToSocketAddrs},
5 str::FromStr,
6};
7
8use rivet_foundation::{
9 Config, ConfigRepository, ConfigValue, ContainerRegistrationExt, InMemoryContainer,
10 SharedConfigRepository,
11};
12use rivet_http::Method;
13use rivet_logger::LogService;
14use rivet_routing::Registry;
15
16use crate::builder::{Builder, ProviderFactory};
17use crate::error::RivetError;
18use crate::middleware::Middleware;
19use crate::module::RivetModule;
20use crate::Application;
21
22type ApiRoutes = dyn Fn(&mut Registry) -> Result<(), RivetError> + Send + Sync + 'static;
23
24#[derive(Default)]
25pub struct RoutingConfig {
26 route_files: Vec<(String, Box<ApiRoutes>)>,
27 health_path: Option<String>,
28}
29
30impl RoutingConfig {
31 pub fn web<F>(&mut self, routes: F) -> &mut Self
32 where
33 F: Fn(&mut Registry) -> Result<(), RivetError> + Send + Sync + 'static,
34 {
35 self.file("web", routes)
36 }
37
38 pub fn api<F>(&mut self, routes: F) -> &mut Self
39 where
40 F: Fn(&mut Registry) -> Result<(), RivetError> + Send + Sync + 'static,
41 {
42 self.file("api", routes)
43 }
44
45 pub fn file<F>(&mut self, name: impl Into<String>, routes: F) -> &mut Self
46 where
47 F: Fn(&mut Registry) -> Result<(), RivetError> + Send + Sync + 'static,
48 {
49 self.route_files.push((name.into(), Box::new(routes)));
50 self
51 }
52
53 pub fn health(&mut self, path: impl Into<String>) -> &mut Self {
54 self.health_path = Some(path.into());
55 self
56 }
57}
58
59#[derive(Default)]
60pub struct MiddlewareConfig {
61 entries: Vec<Arc<dyn Middleware>>,
62}
63
64impl MiddlewareConfig {
65 pub fn append<M>(&mut self, middleware: M) -> &mut Self
66 where
67 M: Middleware + 'static,
68 {
69 self.entries.push(Arc::new(middleware));
70 self
71 }
72
73 pub fn append_arc(&mut self, middleware: Arc<dyn Middleware>) -> &mut Self {
74 self.entries.push(middleware);
75 self
76 }
77}
78
79#[derive(Default, Clone)]
80pub struct ExceptionConfig {
81 reports: Vec<String>,
82}
83
84impl ExceptionConfig {
85 pub fn report(&mut self, label: impl Into<String>) -> &mut Self {
86 self.reports.push(label.into());
87 self
88 }
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum Frequency {
93 Hourly,
94}
95
96#[derive(Debug, Clone, PartialEq, Eq)]
97pub struct ScheduledCommand {
98 command: String,
99 frequency: Frequency,
100}
101
102#[derive(Default, Clone)]
103pub struct ScheduleConfig {
104 commands: Vec<ScheduledCommand>,
105}
106
107pub struct ScheduleCommandBuilder<'a> {
108 schedule: &'a mut ScheduleConfig,
109 command: String,
110}
111
112impl<'a> ScheduleCommandBuilder<'a> {
113 pub fn hourly(self) {
114 self.schedule.commands.push(ScheduledCommand {
115 command: self.command,
116 frequency: Frequency::Hourly,
117 });
118 }
119}
120
121impl ScheduleConfig {
122 pub fn command(&mut self, command: impl Into<String>) -> ScheduleCommandBuilder<'_> {
123 ScheduleCommandBuilder {
124 schedule: self,
125 command: command.into(),
126 }
127 }
128}
129
130pub struct ApplicationBuilder {
131 base_path: PathBuf,
132 config: SharedConfigRepository,
133 modules: Vec<Box<dyn RivetModule>>,
134 provider_factories: Vec<ProviderFactory>,
135 routing: RoutingConfig,
136 middleware: MiddlewareConfig,
137 exceptions: ExceptionConfig,
138 schedule: ScheduleConfig,
139}
140
141impl ApplicationBuilder {
142 pub fn configure(base_path: impl Into<PathBuf>, config: ConfigRepository) -> Self {
143 Self {
144 base_path: base_path.into(),
145 config: SharedConfigRepository::new(config),
146 modules: Vec::new(),
147 provider_factories: Vec::new(),
148 routing: RoutingConfig::default(),
149 middleware: MiddlewareConfig::default(),
150 exceptions: ExceptionConfig::default(),
151 schedule: ScheduleConfig::default(),
152 }
153 }
154
155 pub fn with_module(mut self, module: Box<dyn RivetModule>) -> Self {
156 self.modules.push(module);
157 self
158 }
159
160 pub fn with_provider(mut self, provider: ProviderFactory) -> Self {
161 self.provider_factories.push(provider);
162 self
163 }
164
165 pub fn with_providers(mut self, providers: &[ProviderFactory]) -> Self {
166 self.provider_factories.extend(providers.iter().copied());
167 self
168 }
169
170 pub fn with_routing<F>(mut self, configure: F) -> Self
171 where
172 F: FnOnce(&mut RoutingConfig),
173 {
174 configure(&mut self.routing);
175 self
176 }
177
178 pub fn with_middleware<F>(mut self, configure: F) -> Self
179 where
180 F: FnOnce(&mut MiddlewareConfig),
181 {
182 configure(&mut self.middleware);
183 self
184 }
185
186 pub fn with_exceptions<F>(mut self, configure: F) -> Self
187 where
188 F: FnOnce(&mut ExceptionConfig),
189 {
190 configure(&mut self.exceptions);
191 self
192 }
193
194 pub fn with_schedule<F>(mut self, configure: F) -> Self
195 where
196 F: FnOnce(&mut ScheduleConfig),
197 {
198 configure(&mut self.schedule);
199 self
200 }
201
202 pub fn build(self) -> Result<Application, RivetError> {
203 let container = Arc::new(InMemoryContainer::new());
204 container
205 .singleton(self.config.clone())
206 .map_err(|err| RivetError::Build(err.to_string()))?;
207 container
208 .singleton(LogService)
209 .map_err(|err| RivetError::Build(err.to_string()))?;
210 container
211 .bind_names("log", core::any::type_name::<LogService>())
212 .map_err(|err| RivetError::Build(err.to_string()))?;
213
214 let mut builder = Builder::new(
215 Arc::clone(&container) as Arc<dyn rivet_foundation::Container>,
216 Box::new(self.config),
217 );
218
219 for module in self.modules {
220 builder = builder.with_module(module);
221 }
222 builder = builder.with_providers(&self.provider_factories);
223
224 builder = builder.with_module(Box::new(BootstrapModule {
225 base_path: self.base_path,
226 routing: self.routing,
227 middleware: self.middleware.entries,
228 exceptions: self.exceptions,
229 schedule: self.schedule,
230 }));
231
232 builder.build()
233 }
234}
235
236struct BootstrapModule {
237 base_path: PathBuf,
238 routing: RoutingConfig,
239 middleware: Vec<Arc<dyn Middleware>>,
240 exceptions: ExceptionConfig,
241 schedule: ScheduleConfig,
242}
243
244impl RivetModule for BootstrapModule {
245 fn name(&self) -> &'static str {
246 "application-bootstrap"
247 }
248
249 fn configure(&self, config: &mut dyn Config) -> Result<(), RivetError> {
250 Self::ensure_server_addr(config)?;
251
252 config
253 .set(
254 "app.base_path",
255 ConfigValue::String(self.base_path.to_string_lossy().to_string()),
256 )
257 .map_err(|err| RivetError::Build(err.to_string()))?;
258 config
259 .set(
260 "app.exceptions.count",
261 ConfigValue::Integer(self.exceptions.reports.len() as i64),
262 )
263 .map_err(|err| RivetError::Build(err.to_string()))?;
264 config
265 .set(
266 "app.schedule.count",
267 ConfigValue::Integer(self.schedule.commands.len() as i64),
268 )
269 .map_err(|err| RivetError::Build(err.to_string()))?;
270
271 Ok(())
272 }
273
274 fn routes(&self, routes: &mut Registry) -> Result<(), RivetError> {
275 if let Some(path) = &self.routing.health_path {
276 routes
277 .add_route(Method::Get, path.clone())
278 .map_err(|err| RivetError::Routing(err.to_string()))?;
279 }
280
281 for (name, route_file) in &self.routing.route_files {
282 route_file(routes).map_err(|err| {
283 RivetError::Routing(format!("route file '{name}' failed: {err}"))
284 })?;
285 }
286
287 Ok(())
288 }
289
290 fn middleware(&self) -> Vec<Arc<dyn Middleware>> {
291 self.middleware.clone()
292 }
293}
294
295impl BootstrapModule {
296 fn ensure_server_addr(config: &mut dyn Config) -> Result<(), RivetError> {
297 if let Some(addr) = Self::optional_string(config, "app.server.addr")? {
298 SocketAddr::from_str(&addr).map_err(|err| {
299 RivetError::Build(format!("invalid 'app.server.addr': {err}"))
300 })?;
301 Self::set_server_url_if_missing(config, &addr)?;
302 return Ok(());
303 }
304
305 if let Some(server_url) = Self::optional_string(config, "app.server.url")? {
306 let authority = Self::url_authority(&server_url);
307 let addr = Self::resolve_addr(authority).ok_or_else(|| {
308 RivetError::Build(format!("invalid 'app.server.url' value '{server_url}'"))
309 })?;
310 config
311 .set("app.server.addr", ConfigValue::String(addr.to_string()))
312 .map_err(|err| RivetError::Build(err.to_string()))?;
313 return Ok(());
314 }
315
316 let app_url = match config.get("app.url") {
317 Some(ConfigValue::String(value)) => value,
318 Some(other) => {
319 return Err(RivetError::Build(format!(
320 "expected 'app.url' string config, got {other:?}"
321 )))
322 }
323 None => {
324 return Err(RivetError::Build(
325 "missing 'app.url' config".to_string(),
326 ))
327 }
328 };
329
330 if let Some(addr) = Self::resolve_addr(&app_url) {
331 config
332 .set("app.server.addr", ConfigValue::String(addr.to_string()))
333 .map_err(|err| RivetError::Build(err.to_string()))?;
334 Self::set_server_url_if_missing(config, &addr.to_string())?;
335 return Ok(());
336 }
337
338 let authority = Self::url_authority(&app_url);
339 if let Some(addr) = Self::resolve_addr(authority) {
340 config
341 .set("app.server.addr", ConfigValue::String(addr.to_string()))
342 .map_err(|err| RivetError::Build(err.to_string()))?;
343 Self::set_server_url_if_missing(config, &addr.to_string())?;
344 return Ok(());
345 }
346
347 let port = match config.get("app.port") {
348 Some(ConfigValue::Integer(value)) => u16::try_from(value).map_err(|_| {
349 RivetError::Build(format!(
350 "expected 'app.port' integer in range 0..=65535, got {value}"
351 ))
352 })?,
353 Some(ConfigValue::String(value)) => value.parse::<u16>().map_err(|err| {
354 RivetError::Build(format!("invalid 'app.port' string value '{value}': {err}"))
355 })?,
356 Some(other) => {
357 return Err(RivetError::Build(format!(
358 "expected 'app.port' integer/string config, got {other:?}"
359 )))
360 }
361 None => 7057,
362 };
363
364 let host = Self::strip_port(authority);
365 let candidate = format!("{host}:{port}");
366 let addr = Self::resolve_addr(&candidate).ok_or_else(|| {
367 RivetError::Build(format!(
368 "invalid 'app.url'/'app.port' combination: '{candidate}'"
369 ))
370 })?;
371
372 config
373 .set("app.server.addr", ConfigValue::String(addr.to_string()))
374 .map_err(|err| RivetError::Build(err.to_string()))?;
375 Self::set_server_url_if_missing(config, &addr.to_string())
376 }
377
378 fn resolve_addr(value: &str) -> Option<SocketAddr> {
379 SocketAddr::from_str(value)
380 .ok()
381 .or_else(|| value.to_socket_addrs().ok()?.next())
382 }
383
384 fn url_authority(url: &str) -> &str {
385 let without_scheme = url.split_once("://").map_or(url, |(_, rest)| rest);
386 without_scheme
387 .split_once(['/', '?', '#'])
388 .map_or(without_scheme, |(authority, _)| authority)
389 }
390
391 fn strip_port(authority: &str) -> &str {
392 if authority.starts_with('[') {
393 if let Some(end) = authority.find(']') {
394 return &authority[..=end];
395 }
396 return authority;
397 }
398
399 if let Some((host, port)) = authority.rsplit_once(':') {
400 if port.parse::<u16>().is_ok() {
401 return host;
402 }
403 }
404
405 authority
406 }
407
408 fn optional_string(config: &dyn Config, key: &str) -> Result<Option<String>, RivetError> {
409 match config.get(key) {
410 Some(ConfigValue::String(value)) => Ok(Some(value)),
411 Some(other) => Err(RivetError::Build(format!(
412 "expected '{key}' string config, got {other:?}"
413 ))),
414 None => Ok(None),
415 }
416 }
417
418 fn set_server_url_if_missing(
419 config: &mut dyn Config,
420 addr: &str,
421 ) -> Result<(), RivetError> {
422 match config.get("app.server.url") {
423 Some(ConfigValue::String(_)) => Ok(()),
424 Some(other) => Err(RivetError::Build(format!(
425 "expected 'app.server.url' string config, got {other:?}"
426 ))),
427 None => config
428 .set(
429 "app.server.url",
430 ConfigValue::String(format!("http://{addr}")),
431 )
432 .map_err(|err| RivetError::Build(err.to_string())),
433 }
434 }
435}