gephyr 1.16.18

Gephyr is a headless local AI relay/proxy API handling OpenAI, Claude, and Gemini-compatible APIs
Documentation
use axum::extract::Request;
use std::net::{IpAddr, SocketAddr};
use std::sync::{OnceLock, RwLock};

#[derive(Debug, Clone, Default)]
struct ClientIpResolverConfig {
    trusted_proxies: Vec<String>,
}

fn resolver_config() -> &'static RwLock<ClientIpResolverConfig> {
    static CONFIG: OnceLock<RwLock<ClientIpResolverConfig>> = OnceLock::new();
    CONFIG.get_or_init(|| RwLock::new(ClientIpResolverConfig::default()))
}

fn current_trusted_proxies() -> Vec<String> {
    resolver_config()
        .read()
        .map(|cfg| cfg.trusted_proxies.clone())
        .unwrap_or_default()
}

pub(crate) fn set_trusted_proxies(trusted_proxies: Vec<String>) {
    let sanitized: Vec<String> = trusted_proxies
        .into_iter()
        .map(|entry| entry.trim().to_string())
        .filter(|entry| !entry.is_empty())
        .collect();

    if let Ok(mut cfg) = resolver_config().write() {
        cfg.trusted_proxies = sanitized;
    }
}

pub(crate) fn extract_client_ip(request: &Request) -> Option<String> {
    let socket_ip = request
        .extensions()
        .get::<axum::extract::ConnectInfo<SocketAddr>>()
        .map(|info| info.0.ip())?;

    let trusted_proxies = current_trusted_proxies();
    let resolved_ip = resolve_client_ip(socket_ip, request, &trusted_proxies);
    Some(resolved_ip.to_string())
}

fn resolve_client_ip(socket_ip: IpAddr, request: &Request, trusted_proxies: &[String]) -> IpAddr {
    if is_trusted_proxy(&socket_ip, trusted_proxies) {
        if let Some(forwarded_ip) = extract_forwarded_ip(request) {
            return forwarded_ip;
        }
    }
    socket_ip
}

fn is_trusted_proxy(socket_ip: &IpAddr, trusted_proxies: &[String]) -> bool {
    trusted_proxies
        .iter()
        .any(|pattern| ip_matches_pattern(socket_ip, pattern))
}

fn ip_matches_pattern(ip: &IpAddr, pattern: &str) -> bool {
    let pattern = pattern.trim();
    if pattern.is_empty() {
        return false;
    }

    if let Ok(exact_ip) = pattern.parse::<IpAddr>() {
        return &exact_ip == ip;
    }

    let Some((network, prefix_str)) = pattern.split_once('/') else {
        return false;
    };

    let Ok(network_ip) = network.trim().parse::<IpAddr>() else {
        return false;
    };
    let Ok(prefix_len) = prefix_str.trim().parse::<u8>() else {
        return false;
    };

    match (*ip, network_ip) {
        (IpAddr::V4(ipv4), IpAddr::V4(netv4)) if prefix_len <= 32 => {
            cidr_bytes_match(&ipv4.octets(), &netv4.octets(), prefix_len, 32)
        }
        (IpAddr::V6(ipv6), IpAddr::V6(netv6)) if prefix_len <= 128 => {
            cidr_bytes_match(&ipv6.octets(), &netv6.octets(), prefix_len, 128)
        }
        _ => false,
    }
}

fn cidr_bytes_match(ip: &[u8], network: &[u8], prefix_len: u8, total_bits: u8) -> bool {
    if prefix_len == 0 {
        return true;
    }
    if prefix_len > total_bits {
        return false;
    }

    let full_bytes = (prefix_len / 8) as usize;
    let remaining_bits = prefix_len % 8;

    if ip.get(..full_bytes) != network.get(..full_bytes) {
        return false;
    }

    if remaining_bits == 0 {
        return true;
    }

    let mask = 0xFFu8 << (8 - remaining_bits);
    let Some(ip_byte) = ip.get(full_bytes) else {
        return false;
    };
    let Some(net_byte) = network.get(full_bytes) else {
        return false;
    };
    (ip_byte & mask) == (net_byte & mask)
}

fn extract_forwarded_ip(request: &Request) -> Option<IpAddr> {
    if let Some(header_value) = request
        .headers()
        .get("x-forwarded-for")
        .and_then(|value| value.to_str().ok())
    {
        for candidate in header_value.split(',') {
            if let Some(ip) = parse_ip_candidate(candidate) {
                return Some(ip);
            }
        }
    }

    request
        .headers()
        .get("x-real-ip")
        .and_then(|value| value.to_str().ok())
        .and_then(parse_ip_candidate)
}

fn parse_ip_candidate(candidate: &str) -> Option<IpAddr> {
    let trimmed = candidate.trim();
    if trimmed.is_empty() {
        return None;
    }

    if let Ok(ip) = trimmed.parse::<IpAddr>() {
        return Some(ip);
    }

    if let Ok(socket_addr) = trimmed.parse::<SocketAddr>() {
        return Some(socket_addr.ip());
    }

    None
}

#[cfg(test)]
mod tests {
    use super::{current_trusted_proxies, extract_client_ip, set_trusted_proxies};
    use axum::{body::Body, extract::ConnectInfo, http::Request};
    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
    use std::sync::{Mutex, OnceLock};

    fn with_trusted_proxies<F>(trusted_proxies: Vec<String>, test_fn: F)
    where
        F: FnOnce(),
    {
        static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
        let _guard = TEST_LOCK
            .get_or_init(|| Mutex::new(()))
            .lock()
            .expect("test lock");

        let original = current_trusted_proxies();
        set_trusted_proxies(trusted_proxies);
        test_fn();
        set_trusted_proxies(original);
    }

    #[test]
    fn returns_socket_ip_when_connect_info_present() {
        with_trusted_proxies(vec![], || {
            let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 1, 2, 3)), 8080);
            let mut req = Request::builder()
                .uri("/v1/messages")
                .body(Body::empty())
                .expect("request build");
            req.extensions_mut().insert(ConnectInfo(socket));

            assert_eq!(extract_client_ip(&req), Some("10.1.2.3".to_string()));
        });
    }

    #[test]
    fn untrusted_proxy_ignores_forwarded_headers() {
        with_trusted_proxies(vec!["10.0.0.0/8".to_string()], || {
            let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 10, 20)), 9000);
            let mut req = Request::builder()
                .uri("/v1/messages")
                .header("x-forwarded-for", "203.0.113.10, 198.51.100.4")
                .header("x-real-ip", "203.0.113.11")
                .body(Body::empty())
                .expect("request build");
            req.extensions_mut().insert(ConnectInfo(socket));

            assert_eq!(extract_client_ip(&req), Some("192.168.10.20".to_string()));
        });
    }

    #[test]
    fn trusted_proxy_uses_forwarded_for_first_valid_ip() {
        with_trusted_proxies(vec!["192.168.10.20".to_string()], || {
            let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 10, 20)), 9000);
            let mut req = Request::builder()
                .uri("/v1/messages")
                .header("x-forwarded-for", "203.0.113.10, 198.51.100.4")
                .header("x-real-ip", "203.0.113.11")
                .body(Body::empty())
                .expect("request build");
            req.extensions_mut().insert(ConnectInfo(socket));

            assert_eq!(extract_client_ip(&req), Some("203.0.113.10".to_string()));
        });
    }

    #[test]
    fn trusted_proxy_cidr_uses_real_ip_header_when_forwarded_for_invalid() {
        with_trusted_proxies(vec!["192.168.10.0/24".to_string()], || {
            let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 10, 20)), 9000);
            let mut req = Request::builder()
                .uri("/v1/messages")
                .header("x-forwarded-for", "not-an-ip")
                .header("x-real-ip", "203.0.113.11")
                .body(Body::empty())
                .expect("request build");
            req.extensions_mut().insert(ConnectInfo(socket));

            assert_eq!(extract_client_ip(&req), Some("203.0.113.11".to_string()));
        });
    }

    #[test]
    fn returns_none_without_connect_info() {
        with_trusted_proxies(vec!["192.168.10.20".to_string()], || {
            let req = Request::builder()
                .uri("/v1/messages")
                .header("x-forwarded-for", "203.0.113.10")
                .body(Body::empty())
                .expect("request build");

            assert_eq!(extract_client_ip(&req), None);
        });
    }
}