pub(crate) mod app_packaging;
pub(crate) mod client_api;
pub(crate) mod errors;
mod home_page;
pub(crate) mod path_handlers;
use std::collections::HashSet;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use dashmap::DashMap;
use freenet_stdlib::{
client_api::{ClientError, ClientRequest, HostResponse},
prelude::*,
};
use axum::{Extension, response::IntoResponse};
use client_api::HttpClientApi;
use tower_http::trace::TraceLayer;
use crate::{
client_events::{AuthToken, BoxedClient, ClientId, HostResult, websocket::WebSocketProxy},
config::{GlobalExecutor, WebsocketApiConfig},
};
pub use app_packaging::WebApp;
pub use client_api::{OriginContract, OriginContractMap};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub(crate) enum ApiVersion {
#[default]
V1,
V2,
}
impl ApiVersion {
pub fn prefix(self) -> &'static str {
match self {
Self::V1 => "v1",
Self::V2 => "v2",
}
}
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub(crate) enum ClientConnection {
NewConnection {
callbacks: tokio::sync::mpsc::UnboundedSender<HostCallbackResult>,
assigned_token: Option<(AuthToken, ContractInstanceId)>,
},
Request {
client_id: ClientId,
req: Box<ClientRequest<'static>>,
auth_token: Option<AuthToken>,
origin_contract: Option<ContractInstanceId>,
#[allow(dead_code)]
api_version: ApiVersion,
},
}
#[derive(Debug)]
pub(crate) enum HostCallbackResult {
NewId {
id: ClientId,
},
Result {
id: ClientId,
result: Result<HostResponse, ClientError>,
},
SubscriptionChannel {
id: ClientId,
key: ContractInstanceId,
callback: tokio::sync::mpsc::Receiver<HostResult>,
},
}
async fn serve(socket: SocketAddr, router: axum::Router) -> std::io::Result<()> {
serve_with_listener(socket, router, None).await
}
async fn serve_with_listener(
socket: SocketAddr,
router: axum::Router,
pre_bound: Option<std::net::TcpListener>,
) -> std::io::Result<()> {
let listener = match pre_bound {
Some(std_listener) => {
std_listener.set_nonblocking(true)?;
tokio::net::TcpListener::from_std(std_listener)?
}
None => {
let std_listener = {
let is_ipv6 = socket.is_ipv6();
let sock = socket2::Socket::new(
if is_ipv6 {
socket2::Domain::IPV6
} else {
socket2::Domain::IPV4
},
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)
.map_err(|e| {
std::io::Error::new(e.kind(), format!("Failed to create socket: {e}"))
})?;
if is_ipv6 {
sock.set_only_v6(false)?;
}
sock.set_reuse_address(true)?;
sock.set_nonblocking(true)?;
sock.bind(&socket.into()).map_err(|e| {
if e.kind() == std::io::ErrorKind::AddrInUse {
std::io::Error::new(
std::io::ErrorKind::AddrInUse,
format!(
"Port {} is already in use. Another freenet process may be running. \
Use 'pkill freenet' to stop it, or specify a different port with --ws-api-port.",
socket.port()
),
)
} else {
e
}
})?;
sock.listen(128)?;
std::net::TcpListener::from(sock)
};
tokio::net::TcpListener::from_std(std_listener)?
}
};
tracing::info!("HTTP client API listening on {}", socket);
GlobalExecutor::spawn(async move {
axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.map_err(|e| {
tracing::error!("Error while running HTTP client API server: {e}");
})
});
Ok(())
}
pub fn is_private_ip(ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
}
IpAddr::V6(v6) => {
if let Some(v4) = v6.to_ipv4_mapped() {
return is_private_ip(&IpAddr::V4(v4));
}
v6.is_loopback() || v6.is_unspecified() || is_ipv6_link_local(v6) || is_ipv6_ula(v6)
}
}
}
fn is_ipv6_link_local(addr: &std::net::Ipv6Addr) -> bool {
(addr.segments()[0] & 0xffc0) == 0xfe80
}
fn is_ipv6_ula(addr: &std::net::Ipv6Addr) -> bool {
(addr.segments()[0] & 0xfe00) == 0xfc00
}
pub mod local_node {
use freenet_stdlib::client_api::{ClientRequest, ErrorKind};
use std::net::SocketAddr;
use tower_http::trace::TraceLayer;
use crate::{
client_events::{ClientEventsProxy, OpenRequest, websocket::WebSocketProxy},
contract::{Executor, ExecutorError},
};
use super::{client_api::HttpClientApi, serve};
pub async fn run_local_node(mut executor: Executor, socket: SocketAddr) -> anyhow::Result<()> {
if !super::is_private_ip(&socket.ip()) {
anyhow::bail!(
"invalid ip: {}, only loopback and private network addresses are allowed",
socket.ip()
)
}
let (mut gw, gw_router) = HttpClientApi::as_router(&socket);
let (mut ws_proxy, ws_router) = WebSocketProxy::create_router(gw_router);
serve(socket, ws_router.layer(TraceLayer::new_for_http())).await?;
enum Receiver {
Ws,
Gw,
}
let mut receiver;
loop {
let req = crate::deterministic_select! {
req = ws_proxy.recv() => {
receiver = Receiver::Ws;
req?
},
req = gw.recv() => {
receiver = Receiver::Gw;
req?
},
};
let OpenRequest {
client_id: id,
request,
notification_channel,
token,
..
} = req;
tracing::trace!(cli_id = %id, "got request -> {request}");
let res = match *request {
ClientRequest::ContractOp(op) => {
executor
.contract_requests(op, id, notification_channel)
.await
}
ClientRequest::DelegateOp(op) => {
let origin_contract = token.and_then(|token| {
gw.origin_contracts
.get(&token)
.map(|entry| entry.contract_id)
});
executor.delegate_request(op, origin_contract.as_ref(), None)
}
ClientRequest::Disconnect { cause } => {
if let Some(cause) = cause {
tracing::info!("disconnecting cause: {cause}");
}
if let Some(rm_token) = gw.origin_contracts.iter().find_map(|entry| {
let (k, origin) = entry.pair();
(origin.client_id == id).then(|| k.clone())
}) {
gw.origin_contracts.remove(&rm_token);
}
continue;
}
ClientRequest::Authenticate { .. }
| ClientRequest::NodeQueries(_)
| ClientRequest::Close
| _ => Err(ExecutorError::other(anyhow::anyhow!("not supported"))),
};
match res {
Ok(res) => {
match receiver {
Receiver::Ws => ws_proxy.send(id, Ok(res)).await?,
Receiver::Gw => gw.send(id, Ok(res)).await?,
};
}
Err(err) if err.is_request() => {
let err = ErrorKind::RequestError(err.unwrap_request());
match receiver {
Receiver::Ws => {
ws_proxy.send(id, Err(err.into())).await?;
}
Receiver::Gw => {
gw.send(id, Err(err.into())).await?;
}
};
}
Err(err) => {
tracing::error!("{err}");
let err = Err(ErrorKind::Unhandled {
cause: format!("{err}").into(),
}
.into());
match receiver {
Receiver::Ws => {
ws_proxy.send(id, err).await?;
}
Receiver::Gw => {
gw.send(id, err).await?;
}
};
}
}
}
}
}
pub async fn serve_client_api(config: WebsocketApiConfig) -> std::io::Result<[BoxedClient; 2]> {
let (gw, ws_proxy) = serve_client_api_in_impl(config, None).await?;
Ok([Box::new(gw), Box::new(ws_proxy)])
}
pub async fn serve_client_api_with_listener(
config: WebsocketApiConfig,
listener: std::net::TcpListener,
) -> std::io::Result<[BoxedClient; 2]> {
let (gw, ws_proxy) = serve_client_api_in_impl(config, Some(listener)).await?;
Ok([Box::new(gw), Box::new(ws_proxy)])
}
pub async fn serve_client_api_with_listener_and_contracts(
config: WebsocketApiConfig,
listener: std::net::TcpListener,
) -> std::io::Result<([BoxedClient; 2], OriginContractMap)> {
let (gw, ws_proxy) = serve_client_api_in_impl(config, Some(listener)).await?;
let origin_contracts = gw.origin_contracts.clone();
Ok(([Box::new(gw), Box::new(ws_proxy)], origin_contracts))
}
pub async fn serve_client_api_for_test(
config: WebsocketApiConfig,
) -> std::io::Result<(
client_api::HttpClientApi,
crate::client_events::websocket::WebSocketProxy,
)> {
serve_client_api_in_impl(config, None).await
}
pub(crate) async fn serve_client_api_in(
config: WebsocketApiConfig,
) -> std::io::Result<(HttpClientApi, WebSocketProxy)> {
serve_client_api_in_impl(config, None).await
}
pub(crate) type AllowedHosts = Arc<HashSet<String>>;
#[derive(Clone, Default)]
pub(crate) struct AllowedSourceCidrs(pub Arc<Vec<ipnet::IpNet>>);
impl AllowedSourceCidrs {
pub(crate) fn contains_ip(&self, ip: &IpAddr) -> bool {
self.0.iter().any(|net| net.contains(ip))
}
pub(crate) fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
const MIN_IPV4_PREFIX_LEN: u8 = 8;
const MIN_IPV6_PREFIX_LEN: u8 = 16;
pub fn validate_source_cidr(net: &ipnet::IpNet) -> Result<(), String> {
let (prefix, min) = match net {
ipnet::IpNet::V4(v4) => (v4.prefix_len(), MIN_IPV4_PREFIX_LEN),
ipnet::IpNet::V6(v6) => (v6.prefix_len(), MIN_IPV6_PREFIX_LEN),
};
if prefix < min {
return Err(format!(
"CIDR `{net}` has prefix /{prefix}; minimum accepted is /{min}. \
Shorter prefixes would trust too large a range for a \
fully-privileged local API."
));
}
Ok(())
}
pub(crate) fn is_source_allowed(ip: IpAddr, allowed: &AllowedSourceCidrs) -> bool {
if is_private_ip(&ip) {
return true;
}
let match_ip = match ip {
IpAddr::V6(v6) => v6
.to_ipv4_mapped()
.map(IpAddr::V4)
.unwrap_or(IpAddr::V6(v6)),
v4 => v4,
};
allowed.contains_ip(&match_ip)
}
fn build_allowed_hosts(
bind_addr: IpAddr,
port: u16,
extra_allowed_hosts: &[String],
) -> HashSet<String> {
let mut hosts = HostAllowlistBuilder::new(port);
hosts.add_localhost();
hosts.add_machine_hostname();
if !bind_addr.is_unspecified() {
hosts.add(&bind_addr.to_string());
}
for host in extra_allowed_hosts {
hosts.add(host);
}
hosts.build()
}
struct HostAllowlistBuilder {
hosts: HashSet<String>,
port: u16,
}
impl HostAllowlistBuilder {
fn new(port: u16) -> Self {
Self {
hosts: HashSet::new(),
port,
}
}
fn add(&mut self, host: &str) {
let host_lower = host.to_lowercase();
self.hosts.insert(format!("{host_lower}:{}", self.port));
self.hosts.insert(host_lower);
}
fn add_localhost(&mut self) {
self.add("localhost");
self.add("127.0.0.1");
self.add("[::1]");
}
fn add_machine_hostname(&mut self) {
let Ok(name) = hostname::get() else { return };
let Some(name_str) = name.to_str() else {
return;
};
self.add(name_str);
self.resolve_hostname_ips(name_str);
}
fn resolve_hostname_ips(&mut self, hostname: &str) {
let Ok(addrs) = std::net::ToSocketAddrs::to_socket_addrs(&(hostname, self.port)) else {
return;
};
for addr in addrs {
self.add(&addr.ip().to_string());
}
}
fn build(self) -> HashSet<String> {
self.hosts
}
}
async fn serve_client_api_in_impl(
config: WebsocketApiConfig,
pre_bound: Option<std::net::TcpListener>,
) -> std::io::Result<(HttpClientApi, WebSocketProxy)> {
let ws_socket = (config.address, config.port).into();
let origin_contracts: OriginContractMap = Arc::new(DashMap::new());
spawn_token_cleanup_task(
origin_contracts.clone(),
config.token_ttl_seconds,
config.token_cleanup_interval_seconds,
);
let (gw, gw_router) = HttpClientApi::as_router_with_origin_contracts(
&ws_socket,
origin_contracts.clone(),
crate::contract::user_input::pending_prompts(),
);
let (ws_proxy, ws_router) =
WebSocketProxy::create_router_with_origin_contracts(gw_router, origin_contracts);
let allowed_hosts: AllowedHosts = Arc::new(build_allowed_hosts(
config.address,
config.port,
&config.allowed_hosts,
));
tracing::info!(?allowed_hosts, "WebSocket Host header allowlist built");
let allowed_source_cidrs = AllowedSourceCidrs(Arc::new(config.allowed_source_cidrs.clone()));
for cidr in allowed_source_cidrs.0.iter() {
tracing::warn!(
%cidr,
"Local API source CIDR enabled: ensure this range is fully under \
your control. Anything reachable in it can access contract \
state and keys."
);
}
let needs_lan_filter = !config.address.is_loopback();
let router = if needs_lan_filter {
ws_router
.layer(Extension(allowed_hosts))
.layer(axum::middleware::from_fn(private_network_filter))
.layer(Extension(allowed_source_cidrs))
.layer(TraceLayer::new_for_http())
} else {
ws_router
.layer(Extension(allowed_hosts))
.layer(Extension(allowed_source_cidrs))
.layer(TraceLayer::new_for_http())
};
serve_with_listener(ws_socket, router, pre_bound).await?;
Ok((gw, ws_proxy))
}
async fn private_network_filter(
connect_info: axum::extract::ConnectInfo<SocketAddr>,
Extension(allowed_source_cidrs): Extension<AllowedSourceCidrs>,
req: axum::http::Request<axum::body::Body>,
next: axum::middleware::Next,
) -> axum::response::Response {
let ip = connect_info.0.ip();
if !is_source_allowed(ip, &allowed_source_cidrs) {
tracing::warn!(
remote_ip = %ip,
"Rejected connection from non-private IP"
);
return (
axum::http::StatusCode::FORBIDDEN,
"Only local network connections are allowed",
)
.into_response();
}
next.run(req).await
}
fn spawn_token_cleanup_task(
origin_contracts: OriginContractMap,
token_ttl_seconds: u64,
cleanup_interval_seconds: u64,
) {
let token_ttl = Duration::from_secs(token_ttl_seconds);
let cleanup_interval = Duration::from_secs(cleanup_interval_seconds);
GlobalExecutor::spawn(async move {
let mut interval = tokio::time::interval(cleanup_interval);
interval.tick().await;
loop {
interval.tick().await;
let now = Instant::now();
let initial_count = origin_contracts.len();
origin_contracts.retain(|token, origin| {
let elapsed = now.duration_since(origin.last_accessed);
let should_keep = elapsed < token_ttl;
if !should_keep {
tracing::info!(
?token,
contract_id = ?origin.contract_id,
client_id = ?origin.client_id,
elapsed_hours = elapsed.as_secs() / 3600,
"Removing expired authentication token"
);
}
should_keep
});
let removed_count = initial_count - origin_contracts.len();
if removed_count > 0 {
tracing::debug!(
removed_count,
remaining_count = origin_contracts.len(),
"Token cleanup completed"
);
}
}
});
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn test_is_private_ip_v4() {
assert!(is_private_ip(&IpAddr::V4(Ipv4Addr::LOCALHOST)));
assert!(is_private_ip(&IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2))));
assert!(is_private_ip(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
assert!(is_private_ip(&IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1))));
assert!(is_private_ip(&IpAddr::V4(Ipv4Addr::new(172, 31, 255, 255))));
assert!(is_private_ip(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2))));
assert!(is_private_ip(&IpAddr::V4(Ipv4Addr::new(169, 254, 1, 1))));
assert!(is_private_ip(&IpAddr::V4(Ipv4Addr::UNSPECIFIED)));
assert!(!is_private_ip(&IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))));
assert!(!is_private_ip(&IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))));
assert!(!is_private_ip(&IpAddr::V4(Ipv4Addr::new(172, 32, 0, 1))));
assert!(!is_private_ip(&IpAddr::V4(Ipv4Addr::new(192, 169, 0, 1))));
}
#[test]
fn test_is_private_ip_v6() {
assert!(is_private_ip(&IpAddr::V6(Ipv6Addr::LOCALHOST)));
assert!(is_private_ip(&IpAddr::V6(Ipv6Addr::UNSPECIFIED)));
assert!(is_private_ip(&IpAddr::V6(Ipv6Addr::new(
0xfe80, 0, 0, 0, 0, 0, 0, 1
))));
assert!(is_private_ip(&IpAddr::V6(Ipv6Addr::new(
0xfebf, 0xffff, 0, 0, 0, 0, 0, 1
))));
assert!(!is_private_ip(&IpAddr::V6(Ipv6Addr::new(
0xfe40, 0, 0, 0, 0, 0, 0, 1
))));
assert!(is_private_ip(&IpAddr::V6(Ipv6Addr::new(
0xfd00, 0, 0, 0, 0, 0, 0, 1
))));
assert!(is_private_ip(&IpAddr::V6(Ipv6Addr::new(
0xfc00, 0, 0, 0, 0, 0, 0, 1
))));
assert!(is_private_ip(&IpAddr::V6(Ipv6Addr::new(
0xfdff, 0xffff, 0, 0, 0, 0, 0, 1
))));
assert!(!is_private_ip(&IpAddr::V6(Ipv6Addr::new(
0x2001, 0xdb8, 0, 0, 0, 0, 0, 1
))));
assert!(!is_private_ip(&IpAddr::V6(Ipv6Addr::new(
0x2607, 0xf8b0, 0, 0, 0, 0, 0, 1
))));
assert!(is_private_ip(
&"::ffff:127.0.0.1".parse::<IpAddr>().unwrap()
)); assert!(is_private_ip(&"::ffff:10.0.0.1".parse::<IpAddr>().unwrap())); assert!(is_private_ip(
&"::ffff:192.168.1.1".parse::<IpAddr>().unwrap()
)); assert!(!is_private_ip(&"::ffff:8.8.8.8".parse::<IpAddr>().unwrap())); }
#[test]
fn test_build_allowed_hosts_always_includes_localhost() {
let hosts = build_allowed_hosts(IpAddr::V4(Ipv4Addr::LOCALHOST), 7509, &[]);
assert!(hosts.contains("localhost"));
assert!(hosts.contains("localhost:7509"));
assert!(hosts.contains("127.0.0.1"));
assert!(hosts.contains("127.0.0.1:7509"));
assert!(hosts.contains("[::1]"));
assert!(hosts.contains("[::1]:7509"));
}
#[test]
fn test_build_allowed_hosts_includes_machine_hostname() {
let hosts = build_allowed_hosts(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 7509, &[]);
if let Ok(name) = hostname::get() {
if let Some(name_str) = name.to_str() {
assert!(hosts.contains(&name_str.to_lowercase()));
}
}
}
#[test]
fn test_build_allowed_hosts_custom_hostname() {
let hosts = build_allowed_hosts(
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
7509,
&["mynode.example.com".to_string()],
);
assert!(hosts.contains("mynode.example.com"));
assert!(hosts.contains("mynode.example.com:7509"));
assert!(hosts.contains("localhost"));
}
#[test]
fn test_build_allowed_hosts_specific_bind_addr() {
let hosts = build_allowed_hosts(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 50)), 7509, &[]);
assert!(hosts.contains("192.168.1.50"));
assert!(hosts.contains("192.168.1.50:7509"));
assert!(hosts.contains("localhost"));
}
#[test]
fn test_build_allowed_hosts_excludes_unspecified() {
let hosts = build_allowed_hosts(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 7509, &[]);
assert!(!hosts.contains("0.0.0.0"));
assert!(!hosts.contains("0.0.0.0:7509"));
}
fn cidrs(list: &[&str]) -> AllowedSourceCidrs {
AllowedSourceCidrs(Arc::new(list.iter().map(|s| s.parse().unwrap()).collect()))
}
#[tokio::test]
async fn middleware_layer_stack_allows_configured_source() {
use axum::{Router, routing::get};
use tower::ServiceExt;
async fn handler() -> &'static str {
"ok"
}
let allowed = cidrs(&["100.64.0.0/10"]);
let app: Router = Router::new()
.route("/", get(handler))
.layer(axum::middleware::from_fn(private_network_filter))
.layer(Extension(allowed));
let req = axum::http::Request::builder()
.uri("/")
.extension(axum::extract::ConnectInfo(SocketAddr::from((
[100, 64, 0, 1],
12345,
))))
.body(axum::body::Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(
resp.status(),
axum::http::StatusCode::OK,
"allowlisted CGNAT source must reach the handler; \
500 here means the middleware failed to extract the Extension \
(check layer ordering in serve_client_api_in_impl)"
);
let req = axum::http::Request::builder()
.uri("/")
.extension(axum::extract::ConnectInfo(SocketAddr::from((
[8, 8, 8, 8],
12345,
))))
.body(axum::body::Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), axum::http::StatusCode::FORBIDDEN);
let empty_app: Router = Router::new()
.route("/", get(handler))
.layer(axum::middleware::from_fn(private_network_filter))
.layer(Extension(AllowedSourceCidrs::default()));
let req = axum::http::Request::builder()
.uri("/")
.extension(axum::extract::ConnectInfo(SocketAddr::from((
[127, 0, 0, 1],
12345,
))))
.body(axum::body::Body::empty())
.unwrap();
let resp = empty_app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), axum::http::StatusCode::OK);
}
#[test]
fn allowed_source_cidrs_empty_rejects_public() {
let allow = AllowedSourceCidrs::default();
assert!(!allow.contains_ip(&IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))));
assert!(!allow.contains_ip(&IpAddr::V4(Ipv4Addr::LOCALHOST)));
}
#[test]
fn allowed_source_cidrs_tailscale_cgnat() {
let allow = cidrs(&["100.64.0.0/10"]);
assert!(allow.contains_ip(&IpAddr::V4(Ipv4Addr::new(100, 64, 0, 1))));
assert!(allow.contains_ip(&IpAddr::V4(Ipv4Addr::new(100, 100, 50, 1))));
assert!(allow.contains_ip(&IpAddr::V4(Ipv4Addr::new(100, 127, 255, 254))));
assert!(!allow.contains_ip(&IpAddr::V4(Ipv4Addr::new(100, 128, 0, 1))));
assert!(!allow.contains_ip(&IpAddr::V4(Ipv4Addr::new(100, 63, 255, 255))));
assert!(!allow.contains_ip(&IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))));
}
#[test]
fn allowed_source_cidrs_narrow_tailnet() {
let allow = cidrs(&["100.64.1.0/24"]);
assert!(allow.contains_ip(&IpAddr::V4(Ipv4Addr::new(100, 64, 1, 5))));
assert!(!allow.contains_ip(&IpAddr::V4(Ipv4Addr::new(100, 64, 2, 5))));
}
#[test]
fn allowed_source_cidrs_ipv6() {
let allow = cidrs(&["fd7a:115c:a1e0::/48"]);
assert!(allow.contains_ip(&IpAddr::V6(Ipv6Addr::new(
0xfd7a, 0x115c, 0xa1e0, 0, 0, 0, 0, 1
))));
assert!(!allow.contains_ip(&IpAddr::V6(Ipv6Addr::new(
0xfd7a, 0x115c, 0xa1e1, 0, 0, 0, 0, 1
))));
}
#[test]
fn is_source_allowed_accepts_private_with_empty_allowlist() {
let empty = AllowedSourceCidrs::default();
assert!(is_source_allowed(IpAddr::V4(Ipv4Addr::LOCALHOST), &empty));
assert!(is_source_allowed(
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
&empty
));
assert!(is_source_allowed(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 5)),
&empty
));
assert!(is_source_allowed(IpAddr::V6(Ipv6Addr::LOCALHOST), &empty));
}
#[test]
fn is_source_allowed_rejects_public_with_empty_allowlist() {
let empty = AllowedSourceCidrs::default();
assert!(!is_source_allowed(
IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
&empty
));
assert!(!is_source_allowed(
IpAddr::V4(Ipv4Addr::new(100, 64, 0, 1)),
&empty
));
assert!(!is_source_allowed(
IpAddr::V4(Ipv4Addr::new(203, 0, 113, 42)),
&empty
));
}
#[test]
fn is_source_allowed_accepts_configured_tailscale_range() {
let tailnet = cidrs(&["100.64.0.0/10"]);
assert!(is_source_allowed(
IpAddr::V4(Ipv4Addr::new(100, 64, 0, 1)),
&tailnet
));
assert!(is_source_allowed(
IpAddr::V4(Ipv4Addr::new(100, 127, 0, 1)),
&tailnet
));
assert!(!is_source_allowed(
IpAddr::V4(Ipv4Addr::new(100, 128, 0, 1)),
&tailnet
));
assert!(!is_source_allowed(
IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
&tailnet
));
assert!(is_source_allowed(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 5)),
&tailnet
));
}
#[test]
fn is_source_allowed_normalizes_ipv4_mapped_ipv6_for_cidr_match() {
let tailnet = cidrs(&["100.64.0.0/10"]);
let mapped = IpAddr::V6(Ipv4Addr::new(100, 64, 0, 1).to_ipv6_mapped());
assert!(is_source_allowed(mapped, &tailnet));
let public_mapped = IpAddr::V6(Ipv4Addr::new(8, 8, 8, 8).to_ipv6_mapped());
assert!(!is_source_allowed(
public_mapped,
&AllowedSourceCidrs::default()
));
assert!(!is_source_allowed(public_mapped, &tailnet));
}
#[test]
fn is_source_allowed_accepts_configured_ipv6_range() {
let tailnet_v6 = cidrs(&["fd7a:115c:a1e0::/48"]);
let inside = IpAddr::V6(Ipv6Addr::new(0xfd7a, 0x115c, 0xa1e0, 0x0001, 0, 0, 0, 1));
assert!(is_source_allowed(inside, &tailnet_v6));
}
#[test]
fn validate_source_cidr_rejects_overly_broad_ipv4() {
assert!(validate_source_cidr(&"0.0.0.0/0".parse().unwrap()).is_err());
assert!(validate_source_cidr(&"0.0.0.0/7".parse().unwrap()).is_err());
assert!(validate_source_cidr(&"10.0.0.0/8".parse().unwrap()).is_ok());
assert!(validate_source_cidr(&"100.64.0.0/10".parse().unwrap()).is_ok());
assert!(validate_source_cidr(&"203.0.113.5/32".parse().unwrap()).is_ok());
}
#[test]
fn validate_source_cidr_rejects_overly_broad_ipv6() {
assert!(validate_source_cidr(&"::/0".parse().unwrap()).is_err());
assert!(validate_source_cidr(&"::/15".parse().unwrap()).is_err());
assert!(validate_source_cidr(&"::/16".parse().unwrap()).is_ok());
assert!(validate_source_cidr(&"fd7a:115c:a1e0::/48".parse().unwrap()).is_ok());
assert!(validate_source_cidr(&"::1/128".parse().unwrap()).is_ok());
}
#[test]
fn validate_source_cidr_error_message_is_actionable() {
let err = validate_source_cidr(&"0.0.0.0/0".parse().unwrap()).unwrap_err();
assert!(err.contains("0.0.0.0/0"), "should quote the offending CIDR");
assert!(err.contains("/0"), "should state the offending prefix");
assert!(err.contains("/8"), "should state the minimum accepted");
}
#[test]
fn allowed_source_cidrs_multiple_ranges() {
let allow = cidrs(&["100.64.0.0/10", "10.100.0.0/16"]);
assert!(allow.contains_ip(&IpAddr::V4(Ipv4Addr::new(100, 64, 1, 1))));
assert!(allow.contains_ip(&IpAddr::V4(Ipv4Addr::new(10, 100, 5, 5))));
assert!(!allow.contains_ip(&IpAddr::V4(Ipv4Addr::new(10, 101, 0, 1))));
}
#[test]
fn allowed_source_cidrs_does_not_accept_public_by_default() {
let allow = AllowedSourceCidrs::default();
for ip in [
Ipv4Addr::new(100, 64, 0, 1), Ipv4Addr::new(8, 8, 8, 8), Ipv4Addr::new(203, 0, 113, 42), ] {
assert!(
!allow.contains_ip(&IpAddr::V4(ip)),
"{ip} must not be trusted by default",
);
}
}
#[test]
fn test_build_allowed_hosts_excludes_ipv6_unspecified() {
let hosts = build_allowed_hosts(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 7509, &[]);
assert!(!hosts.contains("::"));
assert!(!hosts.contains("[::]:7509"));
assert!(hosts.contains("localhost"));
assert!(hosts.contains("[::1]"));
assert!(hosts.contains("127.0.0.1"));
}
}