artur 0.2.0

Universal config-driven Rust HTTP gateway and package orchestrator
Documentation
use crate::{
    AppConfig,
    config::{EndpointConfig, SecurityTaskConfig},
    error::{ArturError, Result},
    process::{RequestContext, lookup_template_json_value, render_template, run_task},
};
use serde_json::Value;
use std::{collections::BTreeMap, sync::Arc, time::Instant};
use tokio::sync::RwLock;

#[derive(Debug, Clone, Default)]
pub struct SecurityState {
    failures: Arc<RwLock<BTreeMap<String, FailureRecord>>>,
}

#[derive(Debug, Clone)]
struct FailureRecord {
    first_seen: Instant,
    failures: u32,
    blocked_until: Option<Instant>,
}

pub async fn authorize_endpoint(
    config: Arc<AppConfig>,
    security: SecurityState,
    endpoint: &EndpointConfig,
    request: &RequestContext,
) -> Result<()> {
    check_block(&security, endpoint, request).await?;

    let result = async {
        check_api_key(endpoint, request)?;
        check_task_guard(
            config.clone(),
            endpoint.security.challenge.as_ref(),
            request,
            false,
        )
        .await?;
        check_task_guard(config, endpoint.security.x402.as_ref(), request, true).await?;
        Ok(())
    }
    .await;

    match result {
        Ok(()) => {
            clear_failure(&security, endpoint, request).await;
            Ok(())
        }
        Err(err) => {
            record_failure(&security, endpoint, request).await;
            Err(err)
        }
    }
}

fn check_api_key(endpoint: &EndpointConfig, request: &RequestContext) -> Result<()> {
    let Some(api_key) = &endpoint.security.api_key else {
        return Ok(());
    };
    let header_name = api_key.header.to_ascii_lowercase();
    let expected = render_template(&api_key.value, request)?;
    let actual = request
        .headers
        .get(&header_name)
        .cloned()
        .unwrap_or_default();
    let expected = match &api_key.scheme {
        Some(scheme) if !scheme.trim().is_empty() => format!("{} {}", scheme.trim(), expected),
        _ => expected,
    };
    if !constant_time_eq(actual.as_bytes(), expected.as_bytes()) {
        return Err(ArturError::Forbidden(format!(
            "endpoint {} rejected request: invalid api key",
            endpoint.name
        )));
    }
    Ok(())
}

async fn check_task_guard(
    config: Arc<AppConfig>,
    guard: Option<&SecurityTaskConfig>,
    request: &RequestContext,
    payment: bool,
) -> Result<()> {
    let Some(guard) = guard else {
        return Ok(());
    };
    let task = config
        .task_by_name(&guard.task)
        .cloned()
        .ok_or_else(|| ArturError::Config(format!("unknown security task {}", guard.task)))?;
    let output = run_task(&task, request).await?;
    if output.ok && guard_output_allows(guard, output.json.as_ref(), payment) {
        return Ok(());
    }
    if payment {
        Err(ArturError::PaymentRequired(format!(
            "x402 payment verification failed for task {}",
            guard.task
        )))
    } else {
        Err(ArturError::Forbidden(format!(
            "challenge verification failed for task {}",
            guard.task
        )))
    }
}

fn guard_output_allows(guard: &SecurityTaskConfig, json: Option<&Value>, payment: bool) -> bool {
    let Some(json) = json else {
        return false;
    };
    if let Some(path) = &guard.success_path {
        let request = RequestContext {
            method: String::new(),
            uri: String::new(),
            path: String::new(),
            params: BTreeMap::new(),
            query: BTreeMap::new(),
            headers: BTreeMap::new(),
            body: String::new(),
            body_json: None,
            steps: BTreeMap::from([("guard".to_string(), json.clone())]),
        };
        return lookup_template_json_value(&format!("steps.guard.{path}"), &request)
            .and_then(|value| value.as_bool())
            .unwrap_or(false);
    }
    json.get("ok")
        .and_then(Value::as_bool)
        .or_else(|| json.get("allowed").and_then(Value::as_bool))
        .or_else(|| json.get("verified").and_then(Value::as_bool))
        .or_else(|| {
            if payment {
                json.get("paid").and_then(Value::as_bool)
            } else {
                None
            }
        })
        .unwrap_or(false)
}

fn constant_time_eq(left: &[u8], right: &[u8]) -> bool {
    let max_len = left.len().max(right.len());
    let mut diff = left.len() ^ right.len();
    for index in 0..max_len {
        let a = left.get(index).copied().unwrap_or(0);
        let b = right.get(index).copied().unwrap_or(0);
        diff |= usize::from(a ^ b);
    }
    diff == 0
}

async fn check_block(
    security: &SecurityState,
    endpoint: &EndpointConfig,
    request: &RequestContext,
) -> Result<()> {
    let Some(block) = &endpoint.security.failure_block else {
        return Ok(());
    };
    let key = failure_key(endpoint, request)?;
    let now = Instant::now();
    let failures = security.failures.read().await;
    if let Some(record) = failures.get(&key)
        && let Some(blocked_until) = record.blocked_until
        && now < blocked_until
    {
        return Err(ArturError::TooManyRequests(format!(
            "endpoint {} temporarily blocked this key after {} failed requests",
            endpoint.name, block.max_failures
        )));
    }
    Ok(())
}

async fn record_failure(
    security: &SecurityState,
    endpoint: &EndpointConfig,
    request: &RequestContext,
) {
    let Some(block) = &endpoint.security.failure_block else {
        return;
    };
    let Ok(key) = failure_key(endpoint, request) else {
        return;
    };
    let now = Instant::now();
    let window = std::time::Duration::from_secs(block.window_secs);
    let mut failures = security.failures.write().await;
    let record = failures.entry(key).or_insert(FailureRecord {
        first_seen: now,
        failures: 0,
        blocked_until: None,
    });
    if now.duration_since(record.first_seen) > window {
        record.first_seen = now;
        record.failures = 0;
        record.blocked_until = None;
    }
    record.failures += 1;
    if record.failures >= block.max_failures {
        record.blocked_until = Some(now + std::time::Duration::from_secs(block.block_secs));
    }
}

async fn clear_failure(
    security: &SecurityState,
    endpoint: &EndpointConfig,
    request: &RequestContext,
) {
    if endpoint.security.failure_block.is_none() {
        return;
    }
    if let Ok(key) = failure_key(endpoint, request) {
        security.failures.write().await.remove(&key);
    }
}

fn failure_key(endpoint: &EndpointConfig, request: &RequestContext) -> Result<String> {
    let Some(block) = &endpoint.security.failure_block else {
        return Ok(String::new());
    };
    let rendered = render_template(&block.key, request)?;
    if rendered.trim().is_empty() {
        Ok(format!("{}:anonymous", endpoint.name))
    } else {
        Ok(format!("{}:{rendered}", endpoint.name))
    }
}