use axum::body::Body;
use axum::http::{HeaderMap, StatusCode};
use axum::response::Response;
use constant_time_eq::constant_time_eq;
use dashmap::DashMap;
use deadpool_postgres::Pool;
use serde::Serialize;
use std::sync::atomic::AtomicBool;
use std::time::Instant;
use uuid::Uuid;
use crate::metrics::Metrics;
#[derive(Serialize)]
pub struct ErrorResponse {
pub error: &'static str,
pub message: String,
}
pub fn json_error(code: &'static str, message: impl Into<String>, status: StatusCode) -> Response {
let body = serde_json::to_string(&ErrorResponse {
error: code,
message: message.into(),
})
.unwrap_or_else(|_| format!(r#"{{"error":"{code}","message":"serialisation error"}}"#));
Response::builder()
.status(status)
.header("content-type", "application/json")
.body(Body::from(body))
.expect("infallible: hardcoded valid HTTP headers")
}
pub struct AppState {
pub pool: Pool,
pub auth_token: Option<String>,
pub datalog_write_token: Option<String>,
pub trust_proxy: Option<String>,
pub metrics: Metrics,
pub ever_connected: AtomicBool,
pub arrow_flight_secret: Option<String>,
pub arrow_unsigned_tickets_allowed: bool,
pub arrow_nonce_cache: DashMap<String, (Instant, u64)>,
pub arrow_nonce_cache_max: usize,
pub cors_is_permissive: bool,
}
pub fn env_or(key: &str, default: &str) -> String {
std::env::var(key).unwrap_or_else(|_| default.to_owned())
}
pub fn redacted_error(category: &str, detail: &str, status: StatusCode) -> Response {
let trace_id = Uuid::new_v4().to_string();
tracing::error!(trace_id = %trace_id, detail = %detail, "query error");
let body = serde_json::json!({
"error": category,
"trace_id": trace_id
});
Response::builder()
.status(status)
.header("content-type", "application/json")
.body(Body::from(body.to_string()))
.expect("infallible: hardcoded valid HTTP headers")
}
#[allow(clippy::result_large_err)]
pub fn check_auth(state: &AppState, headers: &HeaderMap) -> Result<(), Response> {
check_token(state.auth_token.as_deref(), headers)
}
#[allow(clippy::result_large_err)]
pub fn check_auth_write(state: &AppState, headers: &HeaderMap) -> Result<(), Response> {
let token = state
.datalog_write_token
.as_deref()
.or(state.auth_token.as_deref());
check_token(token, headers)
}
#[allow(clippy::result_large_err)]
fn check_token(expected: Option<&str>, headers: &HeaderMap) -> Result<(), Response> {
if let Some(expected) = expected {
let provided = headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let token = provided
.strip_prefix("Bearer ")
.or_else(|| provided.strip_prefix("Basic "))
.unwrap_or(provided);
if !constant_time_eq(token.as_bytes(), expected.as_bytes()) {
let body = serde_json::json!({"error": "PT401", "message": "unauthorized"}).to_string();
return Err(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("www-authenticate", "Bearer realm=\"pg_ripple\"")
.header("content-type", "application/json")
.body(Body::from(body))
.expect("infallible: hardcoded valid HTTP headers"));
}
}
Ok(())
}