use std::env;
use axum::{
extract::{FromRequestParts, State},
http::{request::Parts, StatusCode},
Json,
};
use serde::{Deserialize, Serialize};
use subtle::ConstantTimeEq;
use tracing::{debug, info, warn};
use crate::db::DbPool;
use crate::error::AppResult;
use crate::services::internal as svc;
const TOKEN_ENV: &str = "NOETL_INTERNAL_API_TOKEN";
#[derive(Debug)]
pub struct RequireInternalApiToken;
impl<S> FromRequestParts<S> for RequireInternalApiToken
where
S: Send + Sync,
{
type Rejection = (StatusCode, String);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let expected = match env::var(TOKEN_ENV) {
Ok(value) if !value.trim().is_empty() => value,
_ => {
warn!(
"Internal API called but {} is not set; rejecting with 503.",
TOKEN_ENV
);
return Err((
StatusCode::SERVICE_UNAVAILABLE,
format!(
"Internal API not configured: {} env var unset on the server. \
Set it to the system worker pool's ServiceAccount token before \
calling /api/internal/* endpoints.",
TOKEN_ENV
),
));
}
};
let header = match parts.headers.get("authorization") {
Some(v) => v,
None => {
return Err((
StatusCode::FORBIDDEN,
"Internal API requires Authorization header with Bearer token.".to_string(),
));
}
};
let header_value = header.to_str().map_err(|_| {
(
StatusCode::FORBIDDEN,
"Internal API Authorization header is not valid ASCII.".to_string(),
)
})?;
let mut parts_iter = header_value.splitn(2, ' ');
let scheme = parts_iter.next().unwrap_or("");
let token = parts_iter.next().unwrap_or("").trim();
if !scheme.eq_ignore_ascii_case("bearer") || token.is_empty() {
return Err((
StatusCode::FORBIDDEN,
"Internal API requires 'Bearer <token>' Authorization scheme.".to_string(),
));
}
let provided = token.as_bytes();
let expected_bytes = expected.as_bytes();
if provided.len() != expected_bytes.len()
|| !bool::from(provided.ct_eq(expected_bytes))
{
return Err((
StatusCode::FORBIDDEN,
"Invalid service-account token for /api/internal/*.".to_string(),
));
}
Ok(RequireInternalApiToken)
}
}
#[derive(Debug, Deserialize, Default)]
pub struct OutboxClaimRequest {
#[serde(default = "default_claim_limit")]
pub limit: i64,
}
fn default_claim_limit() -> i64 {
100
}
#[derive(Debug, Serialize)]
pub struct OutboxClaimResponse {
pub rows: Vec<svc::OutboxRow>,
pub claimed: i64,
}
#[derive(Debug, Deserialize)]
pub struct OutboxMarkPublishedRequest {
pub outbox_ids: Vec<i64>,
}
#[derive(Debug, Serialize)]
pub struct OutboxMarkPublishedResponse {
pub marked: i64,
}
#[derive(Debug, Deserialize)]
pub struct OutboxMarkFailedRequest {
pub outbox_id: i64,
pub error: String,
#[serde(default = "default_attempts")]
pub attempts: i32,
#[serde(default = "default_max_delay_seconds")]
pub max_delay_seconds: i32,
}
fn default_attempts() -> i32 {
1
}
fn default_max_delay_seconds() -> i32 {
300
}
#[derive(Debug, Serialize)]
pub struct OutboxMarkFailedResponse {
pub marked: bool,
pub available_at_in: i64,
}
#[derive(Debug, Serialize)]
pub struct OutboxPendingCountResponse {
pub pending: i64,
}
#[derive(Debug, Deserialize)]
pub struct EventsProjectRequest {
pub events: Vec<svc::EventEnvelope>,
}
#[derive(Debug, Serialize)]
pub struct EventsProjectResponse {
pub projected: i64,
pub duplicates: i64,
}
#[tracing::instrument(skip(pool, _token), fields(limit = request.limit))]
pub async fn outbox_claim(
State(pool): State<DbPool>,
_token: RequireInternalApiToken,
Json(request): Json<OutboxClaimRequest>,
) -> AppResult<Json<OutboxClaimResponse>> {
let rows = svc::claim_batch(&pool, request.limit).await?;
let claimed = rows.len() as i64;
debug!(claimed, "outbox/claim done");
Ok(Json(OutboxClaimResponse { rows, claimed }))
}
#[tracing::instrument(skip(pool, _token), fields(count = request.outbox_ids.len()))]
pub async fn outbox_mark_published(
State(pool): State<DbPool>,
_token: RequireInternalApiToken,
Json(request): Json<OutboxMarkPublishedRequest>,
) -> AppResult<Json<OutboxMarkPublishedResponse>> {
if request.outbox_ids.is_empty() {
return Err(crate::error::AppError::BadRequest(
"outbox_ids must not be empty".to_string(),
));
}
let marked = svc::mark_published_batch(&pool, &request.outbox_ids).await?;
debug!(marked, "outbox/mark-published done");
Ok(Json(OutboxMarkPublishedResponse { marked }))
}
#[tracing::instrument(
skip(pool, _token),
fields(outbox_id = request.outbox_id, attempts = request.attempts)
)]
pub async fn outbox_mark_failed(
State(pool): State<DbPool>,
_token: RequireInternalApiToken,
Json(request): Json<OutboxMarkFailedRequest>,
) -> AppResult<Json<OutboxMarkFailedResponse>> {
if request.error.is_empty() {
return Err(crate::error::AppError::BadRequest(
"error must not be empty".to_string(),
));
}
let delay = svc::mark_failed_row(
&pool,
request.outbox_id,
&request.error,
request.attempts,
request.max_delay_seconds,
)
.await?;
info!(delay_seconds = delay, "outbox/mark-failed applied");
Ok(Json(OutboxMarkFailedResponse {
marked: true,
available_at_in: delay,
}))
}
#[tracing::instrument(skip(pool, _token))]
pub async fn outbox_pending_count(
State(pool): State<DbPool>,
_token: RequireInternalApiToken,
) -> AppResult<Json<OutboxPendingCountResponse>> {
let pending = svc::pending_count(&pool).await?;
Ok(Json(OutboxPendingCountResponse { pending }))
}
#[tracing::instrument(skip(pool, _token), fields(batch_size = request.events.len()))]
pub async fn events_project(
State(pool): State<DbPool>,
_token: RequireInternalApiToken,
Json(request): Json<EventsProjectRequest>,
) -> AppResult<Json<EventsProjectResponse>> {
if request.events.is_empty() {
return Err(crate::error::AppError::BadRequest(
"events must not be empty".to_string(),
));
}
let (projected, duplicates) = svc::project_events(&pool, &request.events).await?;
info!(projected, duplicates, "events/project done");
Ok(Json(EventsProjectResponse {
projected,
duplicates,
}))
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::extract::FromRequestParts;
use axum::http::{Request, StatusCode};
async fn run_extractor(
env_token: Option<&str>,
header: Option<&str>,
) -> Result<RequireInternalApiToken, (StatusCode, String)> {
match env_token {
Some(v) => unsafe { env::set_var(TOKEN_ENV, v) },
None => unsafe { env::remove_var(TOKEN_ENV) },
}
let mut builder = Request::builder().method("GET").uri("/test");
if let Some(h) = header {
builder = builder.header("authorization", h);
}
let req = builder.body(Body::empty()).unwrap();
let (mut parts, _body) = req.into_parts();
let result = <RequireInternalApiToken as FromRequestParts<()>>::from_request_parts(
&mut parts,
&(),
)
.await;
unsafe { env::remove_var(TOKEN_ENV) };
result
}
#[tokio::test]
#[serial_test::serial]
async fn rejects_when_env_unset() {
let err = run_extractor(None, Some("Bearer foo")).await.unwrap_err();
assert_eq!(err.0, StatusCode::SERVICE_UNAVAILABLE);
assert!(err.1.contains(TOKEN_ENV));
}
#[tokio::test]
#[serial_test::serial]
async fn rejects_when_env_blank() {
let err = run_extractor(Some(" "), Some("Bearer foo"))
.await
.unwrap_err();
assert_eq!(err.0, StatusCode::SERVICE_UNAVAILABLE);
}
#[tokio::test]
#[serial_test::serial]
async fn rejects_when_no_authorization_header() {
let err = run_extractor(Some("secret-123"), None).await.unwrap_err();
assert_eq!(err.0, StatusCode::FORBIDDEN);
assert!(err.1.contains("Bearer"));
}
#[tokio::test]
#[serial_test::serial]
async fn rejects_non_bearer_scheme() {
let err = run_extractor(Some("secret-123"), Some("Basic secret-123"))
.await
.unwrap_err();
assert_eq!(err.0, StatusCode::FORBIDDEN);
}
#[tokio::test]
#[serial_test::serial]
async fn rejects_wrong_token() {
let err = run_extractor(Some("secret-123"), Some("Bearer wrong"))
.await
.unwrap_err();
assert_eq!(err.0, StatusCode::FORBIDDEN);
}
#[tokio::test]
#[serial_test::serial]
async fn accepts_valid_token() {
let result = run_extractor(Some("secret-123"), Some("Bearer secret-123")).await;
assert!(result.is_ok());
}
#[tokio::test]
#[serial_test::serial]
async fn accepts_valid_token_case_insensitive_scheme() {
let result = run_extractor(Some("secret-123"), Some("bearer secret-123")).await;
assert!(result.is_ok());
}
#[tokio::test]
#[serial_test::serial]
async fn rejects_empty_token_after_bearer() {
let err = run_extractor(Some("secret-123"), Some("Bearer "))
.await
.unwrap_err();
assert_eq!(err.0, StatusCode::FORBIDDEN);
}
}