use nu_protocol::{Record, Span, Value};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::IpAddr;
pub fn resolve_trusted_ip(
headers: &http::header::HeaderMap,
remote_ip: Option<IpAddr>,
trusted_proxies: &[ipnet::IpNet],
) -> Option<IpAddr> {
if trusted_proxies.is_empty() {
return remote_ip;
}
let remote_is_trusted = remote_ip
.map(|ip| trusted_proxies.iter().any(|net| net.contains(&ip)))
.unwrap_or(true);
if !remote_is_trusted {
return remote_ip;
}
let xff = match headers.get("x-forwarded-for") {
Some(v) => v.to_str().ok()?,
None => return remote_ip,
};
let ips: Vec<&str> = xff.split(',').map(|s| s.trim()).collect();
let mut leftmost_ip = None;
for ip_str in ips.into_iter().rev() {
if let Ok(ip) = ip_str.parse::<IpAddr>() {
leftmost_ip = Some(ip);
if !trusted_proxies.iter().any(|net| net.contains(&ip)) {
return Some(ip);
}
}
}
leftmost_ip.or(remote_ip)
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Request {
pub proto: String,
#[serde(with = "http_serde::method")]
pub method: http::method::Method,
#[serde(skip_serializing_if = "Option::is_none")]
pub authority: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub remote_ip: Option<std::net::IpAddr>,
#[serde(skip_serializing_if = "Option::is_none")]
pub remote_port: Option<u16>,
#[serde(skip_serializing_if = "Option::is_none")]
pub trusted_ip: Option<std::net::IpAddr>,
#[serde(with = "http_serde::header_map")]
pub headers: http::header::HeaderMap,
#[serde(with = "http_serde::uri")]
pub uri: http::Uri,
pub path: String,
pub query: HashMap<String, String>,
}
pub fn request_to_value(request: &Request, span: Span) -> Value {
let mut record = Record::new();
record.push("proto", Value::string(request.proto.clone(), span));
record.push("method", Value::string(request.method.to_string(), span));
record.push("uri", Value::string(request.uri.to_string(), span));
record.push("path", Value::string(request.path.clone(), span));
if let Some(authority) = &request.authority {
record.push("authority", Value::string(authority.clone(), span));
}
if let Some(remote_ip) = &request.remote_ip {
record.push("remote_ip", Value::string(remote_ip.to_string(), span));
}
if let Some(remote_port) = &request.remote_port {
record.push("remote_port", Value::int(*remote_port as i64, span));
}
if let Some(trusted_ip) = &request.trusted_ip {
record.push("trusted_ip", Value::string(trusted_ip.to_string(), span));
}
let mut headers_record = Record::new();
for (key, value) in request.headers.iter() {
headers_record.push(
key.to_string(),
Value::string(value.to_str().unwrap_or_default().to_string(), span),
);
}
record.push("headers", Value::record(headers_record, span));
let mut query_record = Record::new();
for (key, value) in &request.query {
query_record.push(key.clone(), Value::string(value.clone(), span));
}
record.push("query", Value::record(query_record, span));
Value::record(record, span)
}
#[cfg(test)]
mod tests {
use super::*;
fn headers_with_xff(xff: &str) -> http::header::HeaderMap {
let mut headers = http::header::HeaderMap::new();
headers.insert("x-forwarded-for", xff.parse().unwrap());
headers
}
fn parse_cidr(s: &str) -> ipnet::IpNet {
s.parse().unwrap()
}
#[test]
fn test_no_trusted_proxies_returns_remote_ip() {
let headers = http::header::HeaderMap::new();
let remote: IpAddr = "1.2.3.4".parse().unwrap();
let result = resolve_trusted_ip(&headers, Some(remote), &[]);
assert_eq!(result, Some(remote));
}
#[test]
fn test_remote_not_trusted_returns_remote_ip() {
let headers = headers_with_xff("5.6.7.8");
let remote: IpAddr = "1.2.3.4".parse().unwrap();
let trusted = vec![parse_cidr("10.0.0.0/8")];
let result = resolve_trusted_ip(&headers, Some(remote), &trusted);
assert_eq!(result, Some(remote));
}
#[test]
fn test_xff_extracts_client_ip() {
let headers = headers_with_xff("5.6.7.8, 10.0.0.1");
let remote: IpAddr = "10.0.0.2".parse().unwrap();
let trusted = vec![parse_cidr("10.0.0.0/8")];
let result = resolve_trusted_ip(&headers, Some(remote), &trusted);
assert_eq!(result, Some("5.6.7.8".parse().unwrap()));
}
#[test]
fn test_xff_stops_at_first_untrusted() {
let headers = headers_with_xff("1.1.1.1, 5.6.7.8, 10.0.0.1");
let remote: IpAddr = "10.0.0.2".parse().unwrap();
let trusted = vec![parse_cidr("10.0.0.0/8")];
let result = resolve_trusted_ip(&headers, Some(remote), &trusted);
assert_eq!(result, Some("5.6.7.8".parse().unwrap()));
}
#[test]
fn test_all_xff_trusted_returns_leftmost() {
let headers = headers_with_xff("10.0.0.5, 10.0.0.1");
let remote: IpAddr = "10.0.0.2".parse().unwrap();
let trusted = vec![parse_cidr("10.0.0.0/8")];
let result = resolve_trusted_ip(&headers, Some(remote), &trusted);
assert_eq!(result, Some("10.0.0.5".parse().unwrap()));
}
#[test]
fn test_no_xff_header_returns_remote() {
let headers = http::header::HeaderMap::new();
let remote: IpAddr = "10.0.0.2".parse().unwrap();
let trusted = vec![parse_cidr("10.0.0.0/8")];
let result = resolve_trusted_ip(&headers, Some(remote), &trusted);
assert_eq!(result, Some(remote));
}
#[test]
fn test_multiple_trusted_cidrs() {
let headers = headers_with_xff("5.6.7.8, 192.168.1.1");
let remote: IpAddr = "10.0.0.2".parse().unwrap();
let trusted = vec![parse_cidr("10.0.0.0/8"), parse_cidr("192.168.0.0/16")];
let result = resolve_trusted_ip(&headers, Some(remote), &trusted);
assert_eq!(result, Some("5.6.7.8".parse().unwrap()));
}
#[test]
fn test_ipv6_support() {
let headers = headers_with_xff("2001:db8::1, ::ffff:10.0.0.1");
let remote: IpAddr = "::ffff:10.0.0.2".parse().unwrap();
let trusted = vec![parse_cidr("::ffff:10.0.0.0/104")];
let result = resolve_trusted_ip(&headers, Some(remote), &trusted);
assert_eq!(result, Some("2001:db8::1".parse().unwrap()));
}
#[test]
fn test_unix_socket_with_xff() {
let headers = headers_with_xff("5.6.7.8, 10.0.0.1");
let trusted = vec![parse_cidr("10.0.0.0/8")];
let result = resolve_trusted_ip(&headers, None, &trusted);
assert_eq!(result, Some("5.6.7.8".parse().unwrap()));
}
#[test]
fn test_unix_socket_no_xff() {
let headers = http::header::HeaderMap::new();
let trusted = vec![parse_cidr("10.0.0.0/8")];
let result = resolve_trusted_ip(&headers, None, &trusted);
assert_eq!(result, None);
}
#[test]
fn test_trust_all_uses_leftmost_xff() {
let headers = headers_with_xff("38.147.250.103");
let trusted = vec![parse_cidr("0.0.0.0/0")];
let result = resolve_trusted_ip(&headers, None, &trusted);
assert_eq!(result, Some("38.147.250.103".parse().unwrap()));
}
}