use std::{net, sync::Arc};
use ana_gotatun::packet::{Packet, PacketBufPool};
use scion_proto::{
address::{Isd, IsdAsn, ScionAddr, SocketAddr},
wire_encoding::WireEncodeVec,
};
use scion_sdk_reqwest_connect_rpc::token_source::TokenSource;
use snap_tun::client::{PACKET_BUF_POOL_SIZE, SnapTunEndpoint};
use socket2::{Domain, Protocol, Socket, Type};
use tokio::net::UdpSocket;
use url::Url;
use x25519_dalek::StaticSecret;
use crate::{
scionstack::{
AsyncUdpUnderlaySocket, DynUnderlayStack, InvalidBindAddressError, ScionSocketBindError,
SnapConnectionError, UnderlaySocket, builder::PreferredUnderlay, scmp_handler::ScmpHandler,
},
underlays::{
discovery::{UnderlayDiscovery, UnderlayInfo},
udp::{LocalIpResolver, UdpAsyncUdpUnderlaySocket, UdpUnderlaySocket},
},
};
pub mod discovery;
pub mod snap;
pub mod udp;
pub struct SnapSocketConfig {
pub snap_token_source: Option<Arc<dyn TokenSource>>,
}
pub struct UnderlayStack {
preferred_underlay: PreferredUnderlay,
underlay_discovery: Arc<dyn UnderlayDiscovery>,
local_ip_resolver: Arc<dyn LocalIpResolver>,
snap_socket_config: SnapSocketConfig,
snap_tunnel_manager: Option<SnapTunEndpoint>,
pool: PacketBufPool<PACKET_BUF_POOL_SIZE>,
}
impl UnderlayStack {
pub fn new(
preferred_underlay: PreferredUnderlay,
underlay_discovery: Arc<dyn UnderlayDiscovery>,
local_ip_resolver: Arc<dyn LocalIpResolver>,
static_identity: StaticSecret,
default_snap_socket_config: SnapSocketConfig,
) -> Self {
let snap_tunnel_manager = default_snap_socket_config
.snap_token_source
.as_ref()
.map(|token_source| SnapTunEndpoint::new(token_source.clone(), static_identity));
Self {
preferred_underlay,
underlay_discovery,
local_ip_resolver,
snap_socket_config: default_snap_socket_config,
snap_tunnel_manager,
pool: PacketBufPool::new(64),
}
}
fn select_underlay(&self, requested_isd_as: IsdAsn) -> Option<(IsdAsn, UnderlayInfo)> {
let underlays = self.underlay_discovery.underlays(requested_isd_as);
match self.preferred_underlay {
PreferredUnderlay::Snap => {
if let Some(underlay) = underlays
.iter()
.find(|(_, underlay)| matches!(underlay, UnderlayInfo::Snap(_)))
{
return Some(underlay.clone());
}
}
PreferredUnderlay::Udp => {
if let Some(underlay) = underlays
.iter()
.find(|(_, underlay)| matches!(underlay, UnderlayInfo::Udp(_)))
{
return Some(underlay.clone());
}
}
}
underlays.into_iter().next()
}
async fn bind_snap_socket(
&self,
requested_addr: Option<scion_proto::address::SocketAddr>,
isd_as: IsdAsn,
cp_url: Url,
) -> Result<snap::SnapUnderlaySocket, ScionSocketBindError> {
let (Some(token_source), Some(snap_tunnel_manager)) = (
self.snap_socket_config.snap_token_source.as_ref(),
self.snap_tunnel_manager.as_ref(),
) else {
return Err(ScionSocketBindError::SnapConnectionError(
SnapConnectionError::SnapTokenSourceMissing,
))?;
};
let local_addr = match requested_addr {
Some(addr) => {
addr.local_address()
.ok_or(ScionSocketBindError::InvalidBindAddress(
InvalidBindAddressError::ServiceAddress(addr),
))?
}
None => {
if let Some(cp_addr) = cp_url
.socket_addrs(|| None)
.ok()
.and_then(|addrs| addrs.first().cloned())
&& let Some(ip) = source_ip_towards(cp_addr).await
{
Ok(net::SocketAddr::new(ip, 0))
} else {
Err(ScionSocketBindError::InvalidBindAddress(
InvalidBindAddressError::NoLocalIpAddressFound,
))
}?
}
};
let bind_addr = SocketAddr::from_std(isd_as, local_addr);
let udp_socket = bind_udp_underlay_socket(local_addr)?;
let socket = snap::SnapUnderlaySocket::new(
bind_addr,
cp_url,
udp_socket,
snap_tunnel_manager,
token_source.clone(),
1024,
self.pool.clone(),
)
.await?;
let assigned_addr = socket.local_addr();
if let Some(requested_addr) = requested_addr
&& requested_addr.isd_asn().matches(assigned_addr.isd_asn())
&& let Some(requested_socket_addr) = requested_addr.local_address()
&& let Some(assigned_socket_addr) = assigned_addr.local_address()
&& ((!requested_socket_addr.ip().is_unspecified() && assigned_socket_addr.ip() != requested_socket_addr.ip())
|| (requested_socket_addr.port() != 0 && assigned_socket_addr.port() != requested_socket_addr.port()))
{
return Err(crate::scionstack::ScionSocketBindError::InvalidBindAddress(
crate::scionstack::InvalidBindAddressError::AddressMismatch {
assigned_addr: SocketAddr::from_std(bind_addr.isd_asn(), requested_socket_addr),
bind_addr,
},
));
}
Ok(socket)
}
async fn resolve_udp_bind_addr(
&self,
isd_as: IsdAsn,
bind_addr: Option<SocketAddr>,
) -> Result<SocketAddr, ScionSocketBindError> {
let bind_addr = match bind_addr {
Some(addr) => {
if addr.is_service() {
return Err(ScionSocketBindError::InvalidBindAddress(
InvalidBindAddressError::ServiceAddress(addr),
));
}
addr
}
None => {
let local_address = *self.local_ip_resolver.local_ips().await.first().ok_or(
ScionSocketBindError::InvalidBindAddress(
InvalidBindAddressError::NoLocalIpAddressFound,
),
)?;
SocketAddr::new(ScionAddr::new(isd_as, local_address.into()), 0)
}
};
Ok(bind_addr)
}
async fn bind_udp_socket(
&self,
isd_as: IsdAsn,
bind_addr: Option<SocketAddr>,
) -> Result<(SocketAddr, UdpSocket), ScionSocketBindError> {
let bind_addr = self.resolve_udp_bind_addr(isd_as, bind_addr).await?;
let local_addr: net::SocketAddr =
bind_addr
.local_address()
.ok_or(ScionSocketBindError::InvalidBindAddress(
InvalidBindAddressError::ServiceAddress(bind_addr),
))?;
let socket = bind_udp_underlay_socket(local_addr)?;
let local_addr = socket.local_addr().map_err(|e| {
ScionSocketBindError::Other(
anyhow::anyhow!("failed to get local address: {e}").into_boxed_dyn_error(),
)
})?;
let bind_addr = SocketAddr::new(
ScionAddr::new(bind_addr.isd_asn(), local_addr.ip().into()),
local_addr.port(),
);
Ok((bind_addr, socket))
}
}
impl DynUnderlayStack for UnderlayStack {
fn bind_socket(
&self,
_kind: crate::scionstack::SocketKind,
bind_addr: Option<scion_proto::address::SocketAddr>,
) -> futures::future::BoxFuture<
'_,
Result<Box<dyn crate::scionstack::UnderlaySocket>, crate::scionstack::ScionSocketBindError>,
> {
Box::pin(async move {
let requested_isd_as = bind_addr
.map(|addr| addr.isd_asn())
.unwrap_or(IsdAsn::WILDCARD);
match self.select_underlay(requested_isd_as) {
Some((isd_as, UnderlayInfo::Snap(cp_url))) => {
Ok(
Box::new(self.bind_snap_socket(bind_addr, isd_as, cp_url).await?)
as Box<dyn UnderlaySocket>,
)
}
Some((isd_as, UnderlayInfo::Udp(_))) => {
let (bind_addr, socket) = self.bind_udp_socket(isd_as, bind_addr).await?;
Ok(Box::new(UdpUnderlaySocket::new(
socket,
bind_addr,
self.underlay_discovery.clone(),
)) as Box<dyn UnderlaySocket>)
}
None => {
Err(
crate::scionstack::ScionSocketBindError::NoUnderlayAvailable(
requested_isd_as.isd(),
),
)
}
}
})
}
fn bind_async_udp_socket(
&self,
bind_addr: Option<scion_proto::address::SocketAddr>,
scmp_handlers: Vec<Box<dyn ScmpHandler>>,
) -> futures::future::BoxFuture<
'_,
Result<
std::sync::Arc<dyn crate::scionstack::AsyncUdpUnderlaySocket>,
crate::scionstack::ScionSocketBindError,
>,
> {
Box::pin(async move {
match self.select_underlay(
bind_addr
.map(|addr| addr.isd_asn())
.unwrap_or(IsdAsn::WILDCARD),
) {
Some((isd_as, UnderlayInfo::Snap(cp_url))) => {
let socket = self.bind_snap_socket(bind_addr, isd_as, cp_url).await?;
let async_udp_socket = snap::SnapAsyncUdpSocket::new(socket, scmp_handlers);
Ok(Arc::new(async_udp_socket) as Arc<dyn AsyncUdpUnderlaySocket + 'static>)
}
Some((isd_as, UnderlayInfo::Udp(_))) => {
let (bind_addr, socket) = self.bind_udp_socket(isd_as, bind_addr).await?;
let async_udp_socket = UdpAsyncUdpUnderlaySocket::new(
bind_addr,
self.underlay_discovery.clone(),
socket,
scmp_handlers,
);
Ok(Arc::new(async_udp_socket) as Arc<dyn AsyncUdpUnderlaySocket + 'static>)
}
None => {
Err(
crate::scionstack::ScionSocketBindError::NoUnderlayAvailable(
bind_addr
.map(|addr| addr.isd_asn().isd())
.unwrap_or(Isd::WILDCARD),
),
)
}
}
})
}
fn local_ases(&self) -> Vec<IsdAsn> {
let mut isd_ases: Vec<IsdAsn> = self.underlay_discovery.isd_ases().into_iter().collect();
isd_ases.sort();
isd_ases
}
}
#[cfg(windows)]
fn set_exclusive_addr_use(sock: &Socket, enable: bool) -> std::io::Result<()> {
use std::{mem, os::windows::io::AsRawSocket};
use windows_sys::Win32::Networking::WinSock;
let val: u32 = if enable { 1 } else { 0 };
let rc = unsafe {
WinSock::setsockopt(
sock.as_raw_socket() as usize,
WinSock::SOL_SOCKET,
WinSock::SO_EXCLUSIVEADDRUSE,
&val as *const _ as *const _,
mem::size_of_val(&val) as _,
)
};
if rc == 0 {
Ok(())
} else {
Err(std::io::Error::last_os_error())
}
}
fn bind_udp_underlay_socket(
addr: net::SocketAddr,
) -> Result<tokio::net::UdpSocket, ScionSocketBindError> {
let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))
.map_err(|e| ScionSocketBindError::Other(Box::new(e)))?;
socket
.set_nonblocking(true)
.map_err(|e| ScionSocketBindError::Other(Box::new(e)))?;
if addr.is_ipv6()
&& let Err(e) = socket.set_only_v6(false)
{
tracing::debug!(%e, "unable to make socket dual-stack");
}
#[cfg(windows)]
set_exclusive_addr_use(&socket, true).map_err(|e| ScionSocketBindError::Other(Box::new(e)))?;
socket.bind(&addr.into()).map_err(|e| {
match e.kind() {
std::io::ErrorKind::AddrInUse => ScionSocketBindError::PortAlreadyInUse(addr.port()),
std::io::ErrorKind::AddrNotAvailable | std::io::ErrorKind::InvalidInput => {
ScionSocketBindError::InvalidBindAddress(
InvalidBindAddressError::CannotBindToRequestedAddress(
SocketAddr::from_std(IsdAsn::WILDCARD, addr),
format!("Failed to bind socket: {e:#}").into(),
),
)
}
#[cfg(windows)]
std::io::ErrorKind::PermissionDenied => {
ScionSocketBindError::PortAlreadyInUse(addr.port())
}
_ => ScionSocketBindError::Other(Box::new(e)),
}
})?;
tokio::net::UdpSocket::from_std(std::net::UdpSocket::from(socket))
.map_err(|e| ScionSocketBindError::Other(Box::new(e)))
}
#[inline]
pub(crate) fn wire_encode<W, const N: usize>(
packet: &W,
temp_buf: &mut Packet,
target_buf: &mut Packet,
) -> Result<(), W::Error>
where
W: WireEncodeVec<N>,
{
temp_buf.truncate(0);
let parts = packet.encode_with(temp_buf.buf_mut())?;
let mut n = 0;
parts.iter().for_each(|x| {
target_buf.as_mut()[n..(n + x.len())].copy_from_slice(x);
n += x.len();
});
target_buf.truncate(n);
Ok(())
}
pub(crate) async fn source_ip_towards(dst: net::SocketAddr) -> Option<net::IpAddr> {
let bind_addr = match dst.ip() {
net::IpAddr::V4(_) => net::Ipv4Addr::UNSPECIFIED.into(),
net::IpAddr::V6(_) => net::Ipv6Addr::UNSPECIFIED.into(),
};
if let Ok(socket) = tokio::net::UdpSocket::bind(net::SocketAddr::new(bind_addr, 0)).await
&& socket.connect(dst).await.is_ok()
&& let Ok(addr) = socket.local_addr()
{
return Some(addr.ip());
}
None
}