use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr};
use std::time::Duration;
use jkipsec::api::{AuthDecision, JkispecConfig, JkispecServer, PortRole};
use jkipsec::crypto::derive_auth_key;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{info, warn};
#[tokio::main]
async fn main() {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info,jkipsec=debug")),
)
.init();
let binds = parse_binds(
&std::env::var("JKIPSEC_BINDS")
.unwrap_or_else(|_| "0.0.0.0:8500/ike500,0.0.0.0:4501/ike4500".into()),
);
let public_ip: IpAddr = std::env::var("JKIPSEC_PUBLIC_IP")
.as_deref()
.unwrap_or("0.0.0.0")
.parse()
.expect("JKIPSEC_PUBLIC_IP");
let public_port: u16 = std::env::var("JKIPSEC_PUBLIC_PORT")
.as_deref()
.unwrap_or("500")
.parse()
.expect("JKIPSEC_PUBLIC_PORT");
let identity = std::env::var("JKIPSEC_IDENTITY").unwrap_or_else(|_| "vpn@local".into());
let virtual_ip: Ipv4Addr = std::env::var("JKIPSEC_VIRTUAL_IP")
.as_deref()
.unwrap_or("10.8.0.2")
.parse()
.expect("JKIPSEC_VIRTUAL_IP");
let virtual_dns: Ipv4Addr = std::env::var("JKIPSEC_VIRTUAL_DNS")
.as_deref()
.unwrap_or("1.1.1.1")
.parse()
.expect("JKIPSEC_VIRTUAL_DNS");
let gateway_ip: Ipv4Addr = std::env::var("JKIPSEC_GATEWAY_IP")
.as_deref()
.unwrap_or("10.8.0.1")
.parse()
.expect("JKIPSEC_GATEWAY_IP");
let mut auth_key_map: HashMap<Vec<u8>, [u8; 32]> = HashMap::new();
for entry in std::env::var("JKIPSEC_PSK_MAP")
.unwrap_or_default()
.split(',')
{
if let Some((id, psk)) = entry.split_once('=') {
auth_key_map.insert(
id.trim().as_bytes().to_vec(),
derive_auth_key(psk.trim().as_bytes()),
);
}
}
let fallback_auth_key = std::env::var("JKIPSEC_PSK")
.ok()
.map(|s| derive_auth_key(s.as_bytes()));
let server = JkispecServer::start(JkispecConfig {
binds,
public_ip,
public_port,
identity,
virtual_ip,
gateway_ip,
virtual_dns,
auth: Box::new(move |challenge| {
let auth_key_map = auth_key_map.clone();
let fallback_auth_key = fallback_auth_key;
Box::pin(async move {
let key = auth_key_map
.get(challenge.identity())
.copied()
.or(fallback_auth_key);
match key {
Some(k) => challenge.approve_with(&k),
None => AuthDecision::Reject,
}
})
}),
})
.await;
let ports: Vec<u16> = std::env::var("JKIPSEC_PROBE_PORTS")
.unwrap_or_else(|_| "49152,62078,12345".into())
.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
let probe_msg = std::env::var("JKIPSEC_PROBE_MSG").unwrap_or_else(|_| "hello\n".into());
while let Some(mut client) = server.accept().await {
info!(
id = client.identity_str(),
peer = %client.peer(),
"client authenticated"
);
let ports = ports.clone();
let probe_msg = probe_msg.clone();
tokio::spawn(async move {
for port in ports {
info!(id = client.identity_str(), port, "attempting TCP connect");
match client.connect(port).await {
Ok(mut stream) => {
info!(port, "TCP connected");
if let Err(e) = stream.write_all(probe_msg.as_bytes()).await {
warn!(port, "write: {e}");
continue;
}
let mut buf = [0u8; 256];
match tokio::time::timeout(Duration::from_secs(3), stream.read(&mut buf))
.await
{
Ok(Ok(0)) => info!(port, "peer closed"),
Ok(Ok(n)) => {
info!(port, n, "got {:?}", String::from_utf8_lossy(&buf[..n]))
}
Ok(Err(e)) => warn!(port, "read: {e}"),
Err(_) => info!(port, "connected, no response in 3s"),
}
tokio::spawn(async move {
let _keep = stream;
std::future::pending::<()>().await;
});
}
Err(e) => warn!(port, "connect failed: {e}"),
}
}
std::future::pending::<()>().await;
drop(client);
});
}
}
fn parse_binds(s: &str) -> Vec<(String, PortRole)> {
s.split(',')
.filter_map(|entry| {
let entry = entry.trim();
let (addr, role) = entry.split_once('/').unwrap_or((entry, "ike500"));
let role = match role {
"ike500" => PortRole::Ike500,
"ike4500" => PortRole::Ike4500,
_ => return None,
};
Some((addr.to_string(), role))
})
.collect()
}