harn-cli 0.7.24

CLI for the Harn programming language — run, test, REPL, format, and lint
use std::collections::BTreeSet;
use std::sync::Arc;

use axum::body::Body;
use axum::extract::State;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};

#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct OriginAllowList {
    wildcard: bool,
    exact: BTreeSet<String>,
}

impl OriginAllowList {
    pub(crate) fn from_manifest(origins: &[String]) -> Self {
        if origins.is_empty() {
            return Self::wildcard();
        }

        let mut wildcard = false;
        let mut exact = BTreeSet::new();
        for origin in origins {
            let trimmed = origin.trim();
            if trimmed.is_empty() {
                continue;
            }
            if trimmed == "*" {
                wildcard = true;
                continue;
            }
            exact.insert(trimmed.to_string());
        }

        if wildcard || exact.is_empty() {
            Self::wildcard()
        } else {
            Self {
                wildcard: false,
                exact,
            }
        }
    }

    pub(crate) fn wildcard() -> Self {
        Self {
            wildcard: true,
            exact: BTreeSet::new(),
        }
    }

    pub(crate) fn allows(&self, origin: &str) -> bool {
        self.wildcard || self.exact.contains(origin)
    }
}

pub(crate) async fn enforce_allowed_origin(
    State(allow_list): State<Arc<OriginAllowList>>,
    request: Request<Body>,
    next: Next,
) -> Response {
    let origin = request
        .headers()
        .get(axum::http::header::ORIGIN)
        .and_then(|value| value.to_str().ok());

    match origin {
        Some(origin) if !allow_list.allows(origin) => (
            StatusCode::FORBIDDEN,
            format!("origin '{origin}' is not allowed"),
        )
            .into_response(),
        _ => next.run(request).await,
    }
}

#[cfg(test)]
mod tests {
    use super::OriginAllowList;

    #[test]
    fn empty_manifest_defaults_to_wildcard() {
        let allow_list = OriginAllowList::from_manifest(&[]);
        assert!(allow_list.allows("https://example.com"));
    }

    #[test]
    fn explicit_allow_list_restricts_origins() {
        let allow_list = OriginAllowList::from_manifest(&[
            "https://allowed.example".to_string(),
            "https://other.example".to_string(),
        ]);
        assert!(allow_list.allows("https://allowed.example"));
        assert!(!allow_list.allows("https://blocked.example"));
    }
}