cronback_api_srv/
lib.rs

1pub(crate) mod auth;
2pub(crate) mod auth_middleware;
3pub(crate) mod auth_store;
4pub mod errors;
5pub(crate) mod extractors;
6mod handlers;
7mod logging;
8mod model;
9mod paginated;
10
11use std::sync::Arc;
12use std::time::Instant;
13
14use auth::Authenticator;
15use axum::extract::MatchedPath;
16use axum::http::{Request, StatusCode};
17use axum::middleware::{self, Next};
18use axum::response::IntoResponse;
19use axum::routing::get;
20use axum::Router;
21use lib::clients::dispatcher_client::ScopedDispatcherClient;
22use lib::clients::scheduler_client::ScopedSchedulerClient;
23use lib::config::Config;
24use lib::database::Database;
25use lib::grpc_client_provider::{GrpcClientFactory, GrpcClientProvider};
26use lib::model::ValidShardedId;
27use lib::prelude::*;
28use lib::types::{ProjectId, RequestId};
29use lib::{netutils, service};
30use metrics::{histogram, increment_counter};
31use thiserror::Error;
32use tokio::select;
33use tower_http::cors::{AllowOrigin, CorsLayer};
34use tower_http::trace::TraceLayer;
35use tracing::{error, info, warn};
36
37use crate::auth_store::SqlAuthStore;
38use crate::logging::{trace_request_response, ApiMakeSpan};
39
40#[derive(Debug, Error)]
41pub enum AppStateError {
42    #[error(transparent)]
43    ConnectError(#[from] tonic::transport::Error),
44    #[error("Internal data routing error: {0}")]
45    RoutingError(String),
46    #[error("Database error: {0}")]
47    DatabaseError(String),
48}
49
50pub struct AppState {
51    pub _context: service::ServiceContext,
52    pub config: Config,
53    pub authenicator: Authenticator,
54    pub scheduler_clients:
55        Box<dyn GrpcClientFactory<ClientType = ScopedSchedulerClient>>,
56    pub dispatcher_clients:
57        Box<dyn GrpcClientFactory<ClientType = ScopedDispatcherClient>>,
58}
59
60async fn fallback() -> (StatusCode, &'static str) {
61    (StatusCode::NOT_FOUND, "Not Found")
62}
63
64#[tracing::instrument(skip_all, fields(service = context.service_name()))]
65pub async fn start_api_server(
66    mut context: service::ServiceContext,
67) -> anyhow::Result<()> {
68    let config = context.load_config();
69    let addr =
70        netutils::parse_addr(&config.api.address, config.api.port).unwrap();
71
72    let db = Database::connect(&config.api.database_uri).await?;
73    db.migrate().await?;
74
75    let shared_state = Arc::new(AppState {
76        _context: context.clone(),
77        config: config.clone(),
78        authenicator: Authenticator::new(Box::new(SqlAuthStore::new(
79            db.clone(),
80        ))),
81        scheduler_clients: Box::new(GrpcClientProvider::new(context.clone())),
82        dispatcher_clients: Box::new(GrpcClientProvider::new(context.clone())),
83    });
84
85    let service_name = context.service_name().to_string();
86    // build our application with a route
87    let app = Router::new()
88        // `GET /` goes to `root`
89        .route("/", get(root))
90        .nest("/v1", handlers::routes(shared_state.clone()))
91        .layer(middleware::from_fn_with_state(
92            Arc::new(config.clone()),
93            trace_request_response,
94        ))
95        .layer(
96            CorsLayer::new()
97                .allow_origin(AllowOrigin::any())
98                .allow_headers([
99                    axum::http::header::CONTENT_TYPE,
100                    axum::http::header::AUTHORIZATION,
101                ]),
102        )
103        .layer(
104            TraceLayer::new_for_http()
105                .make_span_with(ApiMakeSpan::new(service_name)),
106        )
107        .route_layer(middleware::from_fn(inject_request_id))
108        .route_layer(middleware::from_fn(track_metrics))
109        .fallback(fallback);
110
111    // Handle 404
112    let app = app.fallback(handler_404);
113
114    let mut context_clone = context.clone();
115    info!("Starting '{}' on {:?}", context.service_name(), addr);
116    let server = axum::Server::try_bind(&addr)?;
117
118    let server = server
119        .serve(app.into_make_service())
120        .with_graceful_shutdown(context.recv_shutdown_signal());
121
122    // Waiting for shutdown signal
123    select! {
124        _ = context_clone.recv_shutdown_signal() => {
125            warn!("Received shutdown signal!");
126        },
127        res = server => {
128        if let Err(e) = res {
129            error!(
130                "Service '{}' failed and will trigger system shutdown: {e}",
131                context.service_name()
132            );
133            context.broadcast_shutdown();
134        }
135        }
136    };
137    Ok(())
138}
139
140// basic handler that responds with a static string
141async fn root() -> &'static str {
142    "Hey, better visit https://cronback.me"
143}
144
145async fn inject_request_id<B>(
146    mut req: Request<B>,
147    next: Next<B>,
148) -> impl IntoResponse {
149    let request_id = RequestId::new();
150    // Inject RequestId into extensions. Can be useful if someone wants to
151    // log the request_id
152    req.extensions_mut().insert(request_id.clone());
153    // Run the next layer
154    let mut response = next.run(req).await;
155    // Inject request_id into response headers
156    response
157        .headers_mut()
158        .insert(REQUEST_ID_HEADER, request_id.to_string().parse().unwrap());
159
160    // Inject project_id into response headers
161    if let Some(project_id) = response
162        .extensions()
163        .get::<ValidShardedId<ProjectId>>()
164        .cloned()
165    {
166        response
167            .headers_mut()
168            .insert(PROJECT_ID_HEADER, project_id.to_string().parse().unwrap());
169    }
170    response
171}
172
173async fn track_metrics<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
174    let start = Instant::now();
175    let path = if let Some(matched_path) = req.extensions().get::<MatchedPath>()
176    {
177        matched_path.as_str().to_owned()
178    } else {
179        req.uri().path().to_owned()
180    };
181    let method = req.method().clone();
182
183    let response = next.run(req).await;
184
185    let latency = start.elapsed().as_secs_f64();
186    let status = response.status().as_u16().to_string();
187
188    let labels = [
189        ("method", method.to_string()),
190        ("path", path),
191        ("status", status),
192    ];
193
194    increment_counter!("cronback.api.http_requests_total", &labels);
195    histogram!(
196        "cronback.api.http_requests_duration_seconds",
197        latency,
198        &labels
199    );
200
201    response
202}
203
204// handle 404
205async fn handler_404() -> impl IntoResponse {
206    (StatusCode::NOT_FOUND, "Are you lost, mate?")
207}