frigg 0.4.5

Local-first MCP server for code understanding.
Documentation
use std::error::Error;
use std::io;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};

use axum::Router;
use axum::extract::{Request, State};
use axum::http::{StatusCode, header};
use axum::middleware::{self, Next};
use axum::response::{IntoResponse, Response};
use frigg::mcp::FriggMcpServer;
use frigg::settings::RuntimeTransportKind;
use rmcp::transport::StreamableHttpServerConfig;
use tracing::{info, warn};

use crate::Cli;

#[derive(Debug, Clone)]
pub(super) struct HttpRuntimeConfig {
    pub bind_addr: SocketAddr,
    pub auth_token: Option<String>,
    pub allowed_authorities: Option<Vec<String>>,
}

#[derive(Clone)]
struct HttpAuthState {
    expected_bearer_header: Option<String>,
    allowed_authorities: Option<Vec<String>>,
}

impl HttpRuntimeConfig {
    pub(super) fn transport_kind(&self) -> RuntimeTransportKind {
        if self.bind_addr.ip().is_loopback() {
            RuntimeTransportKind::LoopbackHttp
        } else {
            RuntimeTransportKind::RemoteHttp
        }
    }
}

pub(super) fn resolve_http_runtime_config(
    cli: &Cli,
    serve_requested: bool,
) -> Result<Option<HttpRuntimeConfig>, Box<dyn Error>> {
    let has_http_port = cli.mcp_http_port.is_some();
    let has_http_related_flags =
        cli.mcp_http_host.is_some() || cli.allow_remote_http || cli.mcp_http_auth_token.is_some();

    if !has_http_port {
        if serve_requested {
            let bind_addr = SocketAddr::new(
                cli.mcp_http_host.unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST)),
                37_444,
            );
            return Ok(Some(HttpRuntimeConfig {
                bind_addr,
                auth_token: None,
                allowed_authorities: allowed_authorities_for_bind(bind_addr),
            }));
        }
        if has_http_related_flags {
            return Err(Box::new(io::Error::other(
                "HTTP transport flags require --mcp-http-port",
            )));
        }
        return Ok(None);
    }

    let host = cli.mcp_http_host.unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST));
    let port = cli
        .mcp_http_port
        .expect("checked: mcp_http_port is set when has_http_port is true");
    let bind_addr = SocketAddr::new(host, port);

    let auth_token = match cli.mcp_http_auth_token.as_deref() {
        Some(raw) if raw.trim().is_empty() => {
            return Err(Box::new(io::Error::other(
                "--mcp-http-auth-token must not be blank",
            )));
        }
        Some(raw) => Some(raw.trim().to_owned()),
        None => None,
    };

    if !host.is_loopback() && !cli.allow_remote_http {
        return Err(Box::new(io::Error::other(format!(
            "refusing non-loopback HTTP bind at {bind_addr}; pass --allow-remote-http and set --mcp-http-auth-token"
        ))));
    }

    if !host.is_loopback() && auth_token.is_none() {
        return Err(Box::new(io::Error::other(
            "HTTP mode requires --mcp-http-auth-token for non-loopback binds",
        )));
    }

    let allowed_authorities = allowed_authorities_for_bind(bind_addr);

    Ok(Some(HttpRuntimeConfig {
        bind_addr,
        auth_token,
        allowed_authorities,
    }))
}

pub(super) async fn serve_http(
    runtime: HttpRuntimeConfig,
    server: FriggMcpServer,
) -> Result<(), Box<dyn Error>> {
    let listener = tokio::net::TcpListener::bind(runtime.bind_addr).await?;
    let mut config = StreamableHttpServerConfig::default();
    config.stateful_mode = true;
    let shutdown = config.cancellation_token.clone();
    let service = server.streamable_http_service(config);

    info!(
        bind_addr = %runtime.bind_addr,
        "serving MCP over streamable HTTP at /mcp"
    );

    if let Some(authorities) = runtime.allowed_authorities.as_ref() {
        info!(
            ?authorities,
            "HTTP origin/host allowlist enabled for MCP endpoint"
        );
    } else {
        warn!("HTTP origin/host allowlist disabled because bind host is unspecified");
    }

    if runtime.auth_token.is_some() {
        info!("HTTP bearer token auth enabled for MCP endpoint");
    } else {
        warn!("HTTP bearer token auth disabled for loopback MCP endpoint");
    }

    let router = Router::new()
        .nest_service("/mcp", service)
        .layer(middleware::from_fn_with_state(
            HttpAuthState {
                expected_bearer_header: runtime.auth_token.map(|token| format!("Bearer {token}")),
                allowed_authorities: runtime.allowed_authorities,
            },
            bearer_auth_middleware,
        ));

    axum::serve(listener, router)
        .with_graceful_shutdown(async move {
            let _ = tokio::signal::ctrl_c().await;
            shutdown.cancel();
        })
        .await?;

    Ok(())
}

async fn bearer_auth_middleware(
    State(state): State<HttpAuthState>,
    request: Request,
    next: Next,
) -> Response {
    if !host_header_allowed(request.headers(), &state.allowed_authorities) {
        return typed_access_denied_response(StatusCode::FORBIDDEN, "unauthorized host header");
    }

    if !origin_header_allowed(request.headers(), &state.allowed_authorities) {
        return typed_access_denied_response(StatusCode::FORBIDDEN, "unauthorized origin header");
    }

    let Some(expected_bearer_header) = state.expected_bearer_header.as_deref() else {
        return next.run(request).await;
    };

    let provided = request
        .headers()
        .get(header::AUTHORIZATION)
        .and_then(|value| value.to_str().ok())
        .unwrap_or("");
    let authorized = constant_time_equals(provided, expected_bearer_header);

    if !authorized {
        return typed_access_denied_response(
            StatusCode::UNAUTHORIZED,
            "missing or invalid bearer authorization",
        )
        .into_response();
    }

    next.run(request).await
}

pub(super) fn allowed_authorities_for_bind(bind_addr: SocketAddr) -> Option<Vec<String>> {
    if bind_addr.ip().is_unspecified() {
        return None;
    }

    let mut authorities = Vec::new();
    let port = bind_addr.port();

    match bind_addr {
        SocketAddr::V4(addr) => {
            push_authority_variants(&mut authorities, &addr.ip().to_string(), port);
            if addr.ip().is_loopback() {
                push_authority_variants(&mut authorities, "localhost", port);
            }
        }
        SocketAddr::V6(addr) => {
            push_authority_variants(&mut authorities, &format!("[{}]", addr.ip()), port);
            if addr.ip().is_loopback() {
                push_authority_variants(&mut authorities, "localhost", port);
            }
        }
    }

    authorities.sort();
    authorities.dedup();
    Some(authorities)
}

fn push_authority_variants(authorities: &mut Vec<String>, host: &str, port: u16) {
    authorities.push(host.to_ascii_lowercase());
    authorities.push(format!("{host}:{port}").to_ascii_lowercase());
}

pub(super) fn host_header_allowed(
    headers: &axum::http::HeaderMap,
    allowed_authorities: &Option<Vec<String>>,
) -> bool {
    let Some(authority) = headers
        .get(header::HOST)
        .and_then(|value| value.to_str().ok())
        .and_then(parse_host_authority)
    else {
        return false;
    };

    authority_allowed(&authority, allowed_authorities)
}

pub(super) fn origin_header_allowed(
    headers: &axum::http::HeaderMap,
    allowed_authorities: &Option<Vec<String>>,
) -> bool {
    let Some(raw_origin) = headers.get(header::ORIGIN) else {
        return true;
    };
    let Some(authority) = raw_origin.to_str().ok().and_then(parse_origin_authority) else {
        return false;
    };

    authority_allowed(&authority, allowed_authorities)
}

pub(super) fn parse_host_authority(raw: &str) -> Option<String> {
    let authority = raw.trim().trim_end_matches('.');
    if authority.is_empty() {
        return None;
    }
    Some(authority.to_ascii_lowercase())
}

pub(super) fn parse_origin_authority(raw: &str) -> Option<String> {
    let origin = raw.trim();
    if origin.is_empty() || origin.eq_ignore_ascii_case("null") {
        return None;
    }
    let (_scheme, rest) = origin.split_once("://")?;
    let authority = rest.split('/').next()?.trim().trim_end_matches('.');
    if authority.is_empty() {
        return None;
    }
    Some(authority.to_ascii_lowercase())
}

pub(super) fn authority_allowed(
    authority: &str,
    allowed_authorities: &Option<Vec<String>>,
) -> bool {
    match allowed_authorities {
        None => true,
        Some(allowlist) => allowlist
            .iter()
            .any(|candidate| constant_time_equals(candidate, authority)),
    }
}

pub(super) fn constant_time_equals(left: &str, right: &str) -> bool {
    let left_bytes = left.as_bytes();
    let right_bytes = right.as_bytes();
    let max_len = left_bytes.len().max(right_bytes.len());
    let mut diff = left_bytes.len() ^ right_bytes.len();

    for idx in 0..max_len {
        let lhs = *left_bytes.get(idx).unwrap_or(&0);
        let rhs = *right_bytes.get(idx).unwrap_or(&0);
        diff |= (lhs ^ rhs) as usize;
    }

    diff == 0
}

pub(super) fn typed_access_denied_response(status: StatusCode, message: &str) -> Response {
    let escaped_message = message
        .replace('\\', "\\\\")
        .replace('"', "\\\"")
        .replace('\n', "\\n")
        .replace('\r', "\\r")
        .replace('\t', "\\t");
    (
        status,
        [(header::CONTENT_TYPE, "application/json")],
        format!(
            r#"{{"error_code":"access_denied","retryable":false,"message":"{escaped_message}"}}"#
        ),
    )
        .into_response()
}