use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use tokio::net::UdpSocket;
use tracing::{debug, info, warn};
use crate::esp::{self, OutboundFlavor};
pub use crate::session::{AuthChallenge, AuthDecision, AuthFn};
use crate::session::{
Config, EstablishedSession, InboundChanFlavor, InboundMsg, OutboundMsg, ReplyFlavor, Server,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PortRole {
Ike500,
Ike4500,
}
pub struct JkispecConfig {
pub binds: Vec<(String, PortRole)>,
pub public_ip: IpAddr,
pub public_port: u16,
pub identity: String,
pub virtual_ip: Ipv4Addr,
pub gateway_ip: Ipv4Addr,
pub virtual_dns: Ipv4Addr,
pub auth: AuthFn,
}
pub struct JkispecServer {
client_rx: crossfire::AsyncRx<crate::session::SessionChanFlavor>,
inbound_tx: crossfire::MTx<InboundChanFlavor>,
reply_tx_4500: crossfire::MTx<ReplyFlavor>,
esp_socket: Arc<UdpSocket>,
_tasks: Vec<tokio::task::JoinHandle<()>>,
}
pub struct Client {
identity_raw: Vec<u8>,
peer: SocketAddr,
initiator_spi: u64,
adapter_handle: jktcp::handle::AdapterHandle,
inbound_tx: crossfire::MTx<InboundChanFlavor>,
reply_tx: crossfire::MTx<ReplyFlavor>,
_outbound_task: tokio::task::JoinHandle<()>,
}
impl JkispecServer {
pub async fn start(config: JkispecConfig) -> Self {
crossfire::detect_backoff_cfg();
let (mut server, inbound_tx, client_rx) = Server::new(Config {
auth: config.auth,
local_ip: config.public_ip,
local_port: config.public_port,
our_identity: config.identity,
virtual_ip: config.virtual_ip,
gateway_ip: config.gateway_ip,
virtual_dns: config.virtual_dns,
});
let mut tasks = Vec::new();
let mut esp_socket: Option<Arc<UdpSocket>> = None;
let mut reply_tx_4500: Option<crossfire::MTx<ReplyFlavor>> = None;
for (addr, role) in &config.binds {
let socket = UdpSocket::bind(addr)
.await
.unwrap_or_else(|e| panic!("bind {addr} failed: {e}"));
info!(addr, ?role, "listening");
let socket = Arc::new(socket);
let (rtx, rrx) = crossfire::mpsc::unbounded_async::<OutboundMsg>();
if *role == PortRole::Ike4500 && esp_socket.is_none() {
esp_socket = Some(Arc::clone(&socket));
reply_tx_4500 = Some(rtx.clone());
}
let role_internal = match role {
PortRole::Ike500 => InternalRole::Ike500,
PortRole::Ike4500 => InternalRole::Ike4500,
};
tasks.push(tokio::spawn(read_udp(
role_internal,
Arc::clone(&socket),
inbound_tx.clone(),
rtx,
)));
tasks.push(tokio::spawn(write_udp(
role_internal,
Arc::clone(&socket),
rrx,
)));
}
let esp_socket = esp_socket.expect("need at least one Ike4500 bind for ESP");
let reply_tx_4500 = reply_tx_4500.unwrap();
tasks.push(tokio::spawn(async move {
server.run().await;
}));
Self {
client_rx,
inbound_tx,
reply_tx_4500,
esp_socket,
_tasks: tasks,
}
}
pub async fn accept(&self) -> Option<Client> {
let session: EstablishedSession =
crossfire::AsyncRxTrait::recv(&self.client_rx).await.ok()?;
info!(
peer = %session.peer,
ispi = format_args!("{:016x}", session.initiator_spi),
"client authenticated"
);
let adapter = jktcp::adapter::Adapter::new(
Box::new(session.tunnel),
IpAddr::V4(session.gateway_ip),
IpAddr::V4(session.virtual_ip),
);
let adapter_handle = adapter.to_async_handle();
let outbound_task = tokio::spawn(run_outbound(
Arc::clone(&self.esp_socket),
session.peer,
session.suite,
session.outbound_spi,
session.outbound_key,
session.outbound_salt,
session.outbound_integ,
session.outbound_rx,
));
Some(Client {
identity_raw: session.peer_identity,
peer: session.peer,
initiator_spi: session.initiator_spi,
adapter_handle,
inbound_tx: self.inbound_tx.clone(),
reply_tx: self.reply_tx_4500.clone(),
_outbound_task: outbound_task,
})
}
}
impl Client {
pub fn identity_raw(&self) -> &[u8] {
&self.identity_raw
}
pub fn identity(&self) -> &[u8] {
if self.identity_raw.len() > 4 {
&self.identity_raw[4..]
} else {
&self.identity_raw
}
}
pub fn identity_str(&self) -> String {
String::from_utf8_lossy(self.identity()).into()
}
pub fn peer(&self) -> SocketAddr {
self.peer
}
pub async fn connect(
&mut self,
port: u16,
) -> Result<jktcp::handle::StreamHandle, std::io::Error> {
self.adapter_handle.connect(port).await
}
pub fn disconnect(self) {
use crossfire::BlockingTxTrait;
let _ = self.inbound_tx.send(InboundMsg::Disconnect {
initiator_spi: self.initiator_spi,
reply_tx: self.reply_tx.clone(),
});
}
}
impl Drop for Client {
fn drop(&mut self) {
use crossfire::BlockingTxTrait;
let _ = self.inbound_tx.send(InboundMsg::Disconnect {
initiator_spi: self.initiator_spi,
reply_tx: self.reply_tx.clone(),
});
self._outbound_task.abort();
}
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("identity", &self.identity_str())
.field("peer", &self.peer)
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum InternalRole {
Ike500,
Ike4500,
}
async fn read_udp(
role: InternalRole,
socket: Arc<UdpSocket>,
inbound_tx: crossfire::MTx<InboundChanFlavor>,
reply_tx: crossfire::MTx<ReplyFlavor>,
) {
use crossfire::BlockingTxTrait;
let mut buf = vec![0u8; 65_535];
loop {
let (n, peer) = match socket.recv_from(&mut buf).await {
Ok(v) => v,
Err(e) => {
warn!(?role, "recv_from failed: {e}");
continue;
}
};
let datagram = &buf[..n];
if role == InternalRole::Ike4500 && datagram == [0xFFu8] {
debug!(%peer, "NAT-T keepalive");
continue;
}
let (kind, ike_bytes) = classify(role, datagram);
match kind {
DgKind::Ike => {
let _ = inbound_tx.send(InboundMsg::Ike {
data: ike_bytes.to_vec(),
peer,
reply_tx: reply_tx.clone(),
});
}
DgKind::Esp => {
let _ = inbound_tx.send(InboundMsg::Esp {
data: datagram.to_vec(),
});
}
DgKind::Garbage => {
warn!(?role, %peer, len = n, "too short to classify");
}
}
}
}
async fn write_udp(
role: InternalRole,
socket: Arc<UdpSocket>,
rx: crossfire::AsyncRx<ReplyFlavor>,
) {
while let Ok(msg) = crossfire::AsyncRxTrait::recv(&rx).await {
let payload = match role {
InternalRole::Ike500 => msg.data,
InternalRole::Ike4500 => {
let mut out = Vec::with_capacity(4 + msg.data.len());
out.extend_from_slice(&[0, 0, 0, 0]);
out.extend_from_slice(&msg.data);
out
}
};
let _ = socket.send_to(&payload, msg.peer).await;
}
}
#[allow(clippy::too_many_arguments)]
async fn run_outbound(
socket: Arc<UdpSocket>,
peer: SocketAddr,
suite: crate::crypto::Suite,
spi: u32,
key: Vec<u8>,
salt: Vec<u8>,
integ: Vec<u8>,
rx: crossfire::AsyncRx<OutboundFlavor>,
) {
let mut seq: u32 = 1;
while let Ok(ip_packet) = crossfire::AsyncRxTrait::recv(&rx).await {
if ip_packet.is_empty() {
continue;
}
let next_header: u8 = match ip_packet[0] >> 4 {
4 => 4,
6 => 41,
_ => continue,
};
let esp_pkt = esp::encrypt(
suite,
&key,
&salt,
&integ,
spi,
seq,
&ip_packet,
next_header,
);
if let Err(e) = socket.send_to(&esp_pkt, peer).await {
warn!(%peer, "outbound ESP failed: {e}");
break;
}
seq = seq.wrapping_add(1);
}
}
#[derive(Debug)]
enum DgKind {
Ike,
Esp,
Garbage,
}
fn classify(role: InternalRole, datagram: &[u8]) -> (DgKind, &[u8]) {
match role {
InternalRole::Ike500 => (DgKind::Ike, datagram),
InternalRole::Ike4500 => {
if datagram.len() < 4 {
return (DgKind::Garbage, datagram);
}
if datagram[..4] == [0u8; 4] {
(DgKind::Ike, &datagram[4..])
} else {
(DgKind::Esp, datagram)
}
}
}
}