1pub(crate) mod auth;
2pub(crate) mod auth_middleware;
3pub(crate) mod auth_store;
4pub mod errors;
5pub(crate) mod extractors;
6mod handlers;
7mod logging;
8mod model;
9mod paginated;
10
11use std::sync::Arc;
12use std::time::Instant;
13
14use auth::Authenticator;
15use axum::extract::MatchedPath;
16use axum::http::{Request, StatusCode};
17use axum::middleware::{self, Next};
18use axum::response::IntoResponse;
19use axum::routing::get;
20use axum::Router;
21use lib::clients::dispatcher_client::ScopedDispatcherClient;
22use lib::clients::scheduler_client::ScopedSchedulerClient;
23use lib::config::Config;
24use lib::database::Database;
25use lib::grpc_client_provider::{GrpcClientFactory, GrpcClientProvider};
26use lib::model::ValidShardedId;
27use lib::prelude::*;
28use lib::types::{ProjectId, RequestId};
29use lib::{netutils, service};
30use metrics::{histogram, increment_counter};
31use thiserror::Error;
32use tokio::select;
33use tower_http::cors::{AllowOrigin, CorsLayer};
34use tower_http::trace::TraceLayer;
35use tracing::{error, info, warn};
36
37use crate::auth_store::SqlAuthStore;
38use crate::logging::{trace_request_response, ApiMakeSpan};
39
40#[derive(Debug, Error)]
41pub enum AppStateError {
42 #[error(transparent)]
43 ConnectError(#[from] tonic::transport::Error),
44 #[error("Internal data routing error: {0}")]
45 RoutingError(String),
46 #[error("Database error: {0}")]
47 DatabaseError(String),
48}
49
50pub struct AppState {
51 pub _context: service::ServiceContext,
52 pub config: Config,
53 pub authenicator: Authenticator,
54 pub scheduler_clients:
55 Box<dyn GrpcClientFactory<ClientType = ScopedSchedulerClient>>,
56 pub dispatcher_clients:
57 Box<dyn GrpcClientFactory<ClientType = ScopedDispatcherClient>>,
58}
59
60async fn fallback() -> (StatusCode, &'static str) {
61 (StatusCode::NOT_FOUND, "Not Found")
62}
63
64#[tracing::instrument(skip_all, fields(service = context.service_name()))]
65pub async fn start_api_server(
66 mut context: service::ServiceContext,
67) -> anyhow::Result<()> {
68 let config = context.load_config();
69 let addr =
70 netutils::parse_addr(&config.api.address, config.api.port).unwrap();
71
72 let db = Database::connect(&config.api.database_uri).await?;
73 db.migrate().await?;
74
75 let shared_state = Arc::new(AppState {
76 _context: context.clone(),
77 config: config.clone(),
78 authenicator: Authenticator::new(Box::new(SqlAuthStore::new(
79 db.clone(),
80 ))),
81 scheduler_clients: Box::new(GrpcClientProvider::new(context.clone())),
82 dispatcher_clients: Box::new(GrpcClientProvider::new(context.clone())),
83 });
84
85 let service_name = context.service_name().to_string();
86 let app = Router::new()
88 .route("/", get(root))
90 .nest("/v1", handlers::routes(shared_state.clone()))
91 .layer(middleware::from_fn_with_state(
92 Arc::new(config.clone()),
93 trace_request_response,
94 ))
95 .layer(
96 CorsLayer::new()
97 .allow_origin(AllowOrigin::any())
98 .allow_headers([
99 axum::http::header::CONTENT_TYPE,
100 axum::http::header::AUTHORIZATION,
101 ]),
102 )
103 .layer(
104 TraceLayer::new_for_http()
105 .make_span_with(ApiMakeSpan::new(service_name)),
106 )
107 .route_layer(middleware::from_fn(inject_request_id))
108 .route_layer(middleware::from_fn(track_metrics))
109 .fallback(fallback);
110
111 let app = app.fallback(handler_404);
113
114 let mut context_clone = context.clone();
115 info!("Starting '{}' on {:?}", context.service_name(), addr);
116 let server = axum::Server::try_bind(&addr)?;
117
118 let server = server
119 .serve(app.into_make_service())
120 .with_graceful_shutdown(context.recv_shutdown_signal());
121
122 select! {
124 _ = context_clone.recv_shutdown_signal() => {
125 warn!("Received shutdown signal!");
126 },
127 res = server => {
128 if let Err(e) = res {
129 error!(
130 "Service '{}' failed and will trigger system shutdown: {e}",
131 context.service_name()
132 );
133 context.broadcast_shutdown();
134 }
135 }
136 };
137 Ok(())
138}
139
140async fn root() -> &'static str {
142 "Hey, better visit https://cronback.me"
143}
144
145async fn inject_request_id<B>(
146 mut req: Request<B>,
147 next: Next<B>,
148) -> impl IntoResponse {
149 let request_id = RequestId::new();
150 req.extensions_mut().insert(request_id.clone());
153 let mut response = next.run(req).await;
155 response
157 .headers_mut()
158 .insert(REQUEST_ID_HEADER, request_id.to_string().parse().unwrap());
159
160 if let Some(project_id) = response
162 .extensions()
163 .get::<ValidShardedId<ProjectId>>()
164 .cloned()
165 {
166 response
167 .headers_mut()
168 .insert(PROJECT_ID_HEADER, project_id.to_string().parse().unwrap());
169 }
170 response
171}
172
173async fn track_metrics<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
174 let start = Instant::now();
175 let path = if let Some(matched_path) = req.extensions().get::<MatchedPath>()
176 {
177 matched_path.as_str().to_owned()
178 } else {
179 req.uri().path().to_owned()
180 };
181 let method = req.method().clone();
182
183 let response = next.run(req).await;
184
185 let latency = start.elapsed().as_secs_f64();
186 let status = response.status().as_u16().to_string();
187
188 let labels = [
189 ("method", method.to_string()),
190 ("path", path),
191 ("status", status),
192 ];
193
194 increment_counter!("cronback.api.http_requests_total", &labels);
195 histogram!(
196 "cronback.api.http_requests_duration_seconds",
197 latency,
198 &labels
199 );
200
201 response
202}
203
204async fn handler_404() -> impl IntoResponse {
206 (StatusCode::NOT_FOUND, "Are you lost, mate?")
207}