collet 0.1.0

Relentless agentic coding orchestrator with zero-drop agent loops
Documentation
//! Token-based auth middleware for the web server.
//!
//! Flow:
//! 1. Client POSTs `{ "password": "..." }` to `/api/auth`
//! 2. Server validates and returns `{ "token": "<random>" }`
//! 3. Client sends `Authorization: Bearer <token>` on subsequent requests
//! 4. SSE endpoint accepts `?token=<token>` query param (EventSource limitation)
//!
//! When no password is configured, all endpoints are open.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};

use axum::extract::{Query, State};
use axum::http::{Request, StatusCode, header};
use axum::middleware::Next;
use axum::response::{IntoResponse, Json, Response};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;

/// Token time-to-live: 24 hours.
const TOKEN_TTL: Duration = Duration::from_secs(24 * 60 * 60);
/// Maximum login attempts within the rate-limit window.
const MAX_LOGIN_ATTEMPTS: u32 = 5;
/// Rate-limit window: 1 minute.
const LOGIN_WINDOW: Duration = Duration::from_secs(60);

/// Shared auth state: the configured password + set of valid session tokens.
pub struct AuthState {
    pub username: String,
    pub password: String,
    /// Maps token → issued-at timestamp for TTL enforcement.
    pub tokens: RwLock<HashMap<String, Instant>>,
    /// Global login attempt timestamps for rate limiting.
    login_attempts: RwLock<Vec<Instant>>,
}

impl AuthState {
    pub fn new(username: String, password: String) -> Self {
        Self {
            username,
            password,
            tokens: RwLock::new(HashMap::new()),
            login_attempts: RwLock::new(Vec::new()),
        }
    }

    /// Generate a cryptographically random token, store it with a timestamp,
    /// and lazily purge expired tokens.
    pub async fn issue_token(&self) -> String {
        let bytes: [u8; 32] = rand::random();
        let token =
            base64::Engine::encode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, bytes);
        let now = Instant::now();
        let mut tokens = self.tokens.write().await;
        // Lazy purge: remove expired tokens on each issue.
        tokens.retain(|_, issued_at| now.duration_since(*issued_at) < TOKEN_TTL);
        tokens.insert(token.clone(), now);
        token
    }

    /// Validate a token: must exist and not be expired.
    pub async fn validate_token(&self, token: &str) -> bool {
        let tokens = self.tokens.read().await;
        match tokens.get(token) {
            Some(issued_at) => issued_at.elapsed() < TOKEN_TTL,
            None => false,
        }
    }

    /// Check and record a login attempt. Returns `true` if within rate limit.
    async fn check_rate_limit(&self) -> bool {
        let now = Instant::now();
        let mut attempts = self.login_attempts.write().await;
        // Remove attempts outside the current window.
        attempts.retain(|t| now.duration_since(*t) < LOGIN_WINDOW);
        if attempts.len() >= MAX_LOGIN_ATTEMPTS as usize {
            return false;
        }
        attempts.push(now);
        true
    }
}

// ── Login endpoint ──────────────────────────────────────────────────────

#[derive(Deserialize)]
pub struct LoginRequest {
    username: String,
    password: String,
}

#[derive(Serialize)]
pub struct LoginResponse {
    token: Option<String>,
    error: Option<String>,
}

pub async fn login(
    State(auth): State<Arc<AuthState>>,
    Json(req): Json<LoginRequest>,
) -> (StatusCode, Json<LoginResponse>) {
    // Rate-limit check before credential validation.
    if !auth.check_rate_limit().await {
        return (
            StatusCode::TOO_MANY_REQUESTS,
            Json(LoginResponse {
                token: None,
                error: Some("Too many login attempts. Try again later.".into()),
            }),
        );
    }

    if constant_time_eq(req.username.as_bytes(), auth.username.as_bytes())
        && constant_time_eq(req.password.as_bytes(), auth.password.as_bytes())
    {
        let token = auth.issue_token().await;
        (
            StatusCode::OK,
            Json(LoginResponse {
                token: Some(token),
                error: None,
            }),
        )
    } else {
        (
            StatusCode::UNAUTHORIZED,
            Json(LoginResponse {
                token: None,
                error: Some("Invalid password".into()),
            }),
        )
    }
}

// ── Constant-time comparison ─────────────────────────────────────────────────

/// Compare two byte slices in constant time to prevent timing attacks.
///
/// Returns false immediately if lengths differ (length itself is not secret),
/// but the per-byte comparison runs in constant time regardless of content.
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
    if a.len() != b.len() {
        return false;
    }
    a.iter()
        .zip(b.iter())
        .fold(0u8, |acc, (x, y)| acc | (x ^ y))
        == 0
}

// ── Auth check endpoint (no password needed when auth is disabled) ──────

#[derive(Clone, Serialize)]
pub struct AuthStatusResponse {
    pub auth_required: bool,
}

// ── Middleware ───────────────────────────────────────────────────────────

#[derive(Deserialize)]
pub struct TokenQuery {
    pub token: Option<String>,
}

/// Middleware: validate Bearer token or `?token=` query param.
/// Allows `/api/health` and `/api/auth/status` without auth.
pub async fn require_auth(
    State(auth): State<Arc<AuthState>>,
    Query(query): Query<TokenQuery>,
    request: Request<axum::body::Body>,
    next: Next,
) -> Response {
    let path = request.uri().path();

    // Public endpoints — no auth needed
    if path == "/api/health" || path == "/api/auth/status" || path == "/api/auth/login" {
        return next.run(request).await;
    }

    // Check Bearer header first
    let token = request
        .headers()
        .get(header::AUTHORIZATION)
        .and_then(|v| v.to_str().ok())
        .and_then(|v| v.strip_prefix("Bearer "))
        .map(|t| t.to_string())
        // Fall back to query param (for EventSource)
        .or(query.token);

    match token {
        Some(t) if auth.validate_token(&t).await => next.run(request).await,
        _ => (
            StatusCode::UNAUTHORIZED,
            Json(LoginResponse {
                token: None,
                error: Some("Invalid or missing token".into()),
            }),
        )
            .into_response(),
    }
}