use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use ya_smoltcp::iface::{Route, SocketHandle};
use ya_smoltcp::socket::*;
use ya_smoltcp::time::Instant;
use ya_smoltcp::wire::{IpAddress, IpCidr, IpEndpoint, IpProtocol, IpVersion};
use crate::connection::{Connect, Connection, ConnectionMeta, Disconnect, Send};
use crate::interface::*;
use crate::metrics::ChannelMetrics;
use crate::patch_smoltcp::GetSocketSafe;
use crate::protocol::Protocol;
use crate::socket::*;
use crate::{port, StackConfig};
use crate::{Error, Result};
use ya_relay_util::Payload;
#[derive(Clone)]
pub struct Stack<'a> {
iface: Rc<RefCell<CaptureInterface<'a>>>,
metrics: Rc<RefCell<HashMap<SocketDesc, ChannelMetrics>>>,
ports: Rc<RefCell<port::Allocator>>,
config: Rc<StackConfig>,
}
impl<'a> Stack<'a> {
pub fn new(iface: CaptureInterface<'a>, config: Rc<StackConfig>) -> Self {
Self {
iface: Rc::new(RefCell::new(iface)),
metrics: Default::default(),
ports: Default::default(),
config,
}
}
pub fn address(&self) -> Result<IpCidr> {
{
let iface = self.iface.borrow();
iface.ip_addrs().iter().next().cloned()
}
.ok_or(Error::NetEmpty)
}
pub fn addresses(&self) -> Vec<IpCidr> {
self.iface.borrow().ip_addrs().to_vec()
}
pub fn add_address(&self, address: IpCidr) {
let mut iface = self.iface.borrow_mut();
add_iface_address(&mut iface, address);
}
pub fn add_route(&self, net_ip: IpCidr, route: Route) {
let mut iface = self.iface.borrow_mut();
add_iface_route(&mut iface, net_ip, route);
}
#[inline]
pub(crate) fn iface(&self) -> Rc<RefCell<CaptureInterface<'a>>> {
self.iface.clone()
}
#[inline]
pub(crate) fn metrics(&self) -> Rc<RefCell<HashMap<SocketDesc, ChannelMetrics>>> {
self.metrics.clone()
}
pub(crate) fn on_sent(&self, desc: &SocketDesc, size: usize) {
let mut metrics = self.metrics.borrow_mut();
metrics.entry(*desc).or_default().tx.push(size as f32);
}
pub(crate) fn on_received(&self, desc: &SocketDesc, size: usize) {
let mut metrics = self.metrics.borrow_mut();
metrics.entry(*desc).or_default().rx.push(size as f32);
}
}
impl<'a> Stack<'a> {
pub fn bind(
&self,
protocol: Protocol,
endpoint: impl Into<SocketEndpoint>,
) -> Result<SocketHandle> {
let endpoint = endpoint.into();
let mut iface = self.iface.borrow_mut();
let handle = match protocol {
Protocol::Tcp => {
if let SocketEndpoint::Ip(ep) = endpoint {
let mut socket = tcp_socket(self.config.tcp_mem.rx, self.config.tcp_mem.tx);
socket.listen(ep).map_err(|e| Error::Other(e.to_string()))?;
socket.set_defaults();
iface.add_socket(socket)
} else {
return Err(Error::Other("Expected an IP endpoint".to_string()));
}
}
Protocol::Udp => {
if let SocketEndpoint::Ip(ep) = endpoint {
let mut socket = udp_socket(self.config.udp_mem.rx, self.config.udp_mem.tx);
socket.bind(ep).map_err(|e| Error::Other(e.to_string()))?;
iface.add_socket(socket)
} else {
return Err(Error::Other("Expected an IP endpoint".to_string()));
}
}
Protocol::Icmp | Protocol::Ipv6Icmp => {
if let SocketEndpoint::Icmp(e) = endpoint {
let mut socket = icmp_socket(self.config.icmp_mem.rx, self.config.icmp_mem.tx);
socket.bind(e).map_err(|e| Error::Other(e.to_string()))?;
iface.add_socket(socket)
} else {
return Err(Error::Other("Expected an ICMP endpoint".to_string()));
}
}
_ => {
let ip_version = {
match endpoint {
SocketEndpoint::Ip(ep) => match ep.addr {
IpAddress::Ipv4(_) => IpVersion::Ipv4,
IpAddress::Ipv6(_) => IpVersion::Ipv6,
_ => return Err(Error::Other(format!("Invalid address: {}", ep.addr))),
},
_ => return Err(Error::Other("Expected an IP endpoint".to_string())),
}
};
let socket = raw_socket(
ip_version,
map_protocol(protocol)?,
self.config.raw_mem.rx,
self.config.raw_mem.tx,
);
iface.add_socket(socket)
}
};
Ok(handle)
}
pub fn unbind(
&self,
protocol: Protocol,
endpoint: impl Into<SocketEndpoint>,
) -> Result<SocketHandle> {
let endpoint = endpoint.into();
let mut iface = self.iface.borrow_mut();
let mut sockets = iface.sockets_mut();
let handle = sockets
.find(|(_, s)| s.local_endpoint() == endpoint)
.and_then(|(h, _)| match protocol {
Protocol::Tcp | Protocol::Udp | Protocol::Icmp | Protocol::Ipv6Icmp => Some(h),
_ => None,
})
.ok_or(Error::SocketClosed)?;
let _ = endpoint.ip_endpoint().map(|e| {
log::trace!("unbinding {} ({})", e, protocol);
let mut ports = self.ports.borrow_mut();
ports.free(protocol, e.port);
});
drop(sockets);
iface.remove_socket(handle);
Ok(handle)
}
pub fn connect(&self, remote: IpEndpoint) -> Result<Connect<'a>> {
let ip = self.address()?.address();
let mut iface = self.iface.borrow_mut();
let mut ports = self.ports.borrow_mut();
let protocol = Protocol::Tcp;
let handle = iface.add_socket(tcp_socket(self.config.tcp_mem.rx, self.config.tcp_mem.tx));
let port = ports.next(protocol)?;
let local: IpEndpoint = (ip, port).into();
log::trace!("connecting to {} ({})", remote, protocol);
match {
let (socket, ctx) = iface.get_socket_and_context::<TcpSocket>(handle);
socket.connect(ctx, remote, local).map(|_| socket)
} {
Ok(socket) => socket.set_defaults(),
Err(e) => {
iface.remove_socket(handle);
ports.free(Protocol::Tcp, port);
return Err(Error::ConnectionError(e.to_string()));
}
}
let meta = ConnectionMeta {
protocol,
local,
remote,
};
Ok(Connect::new(
Connection { handle, meta },
self.iface.clone(),
))
}
pub fn disconnect(&self, handle: SocketHandle) -> Disconnect<'a> {
let mut iface = self.iface.borrow_mut();
if let Ok(sock) = iface.get_socket_safe::<TcpSocket>(handle) {
sock.close();
}
Disconnect::new(handle, self.iface.clone())
}
pub(crate) fn abort(&self, handle: SocketHandle) {
let mut iface = self.iface.borrow_mut();
if let Ok(sock) = iface.get_socket_safe::<TcpSocket>(handle) {
sock.abort();
}
}
pub(crate) fn remove(&self, meta: &ConnectionMeta, handle: SocketHandle) {
let mut iface = self.iface.borrow_mut();
let mut metrics = self.metrics.borrow_mut();
let mut ports = self.ports.borrow_mut();
if let Some((handle, socket)) = {
let mut sockets = iface.sockets();
sockets.find(|(h, _)| h == &handle)
} {
log::trace!(
"removing connection: {}:{}:{}",
meta.protocol,
meta.local,
meta.remote,
);
metrics.remove(&socket.desc());
iface.remove_socket(handle);
ports.free(meta.protocol, meta.local.port);
}
}
#[inline]
pub fn send<B: Into<Payload>, F: Fn() + 'static>(
&self,
data: B,
conn: Connection,
f: F,
) -> Send<'a> {
Send::new(data.into(), conn, self.iface.clone(), f)
}
#[inline]
pub fn receive<B: Into<Payload>>(&self, data: B) {
let mut iface = self.iface.borrow_mut();
iface.device_mut().phy_rx(data.into());
}
#[inline]
pub fn poll(&self) -> Result<bool> {
match {
let mut iface = self.iface.borrow_mut();
iface.poll(Instant::now())
} {
Err(ya_smoltcp::Error::Unrecognized) => self.poll(),
Err(err) => Err(Error::Other(err.to_string())),
Ok(val) => Ok(val),
}
}
}
fn map_protocol(protocol: Protocol) -> Result<IpProtocol> {
match protocol {
Protocol::HopByHop => Ok(IpProtocol::HopByHop),
Protocol::Icmp => Ok(IpProtocol::Icmp),
Protocol::Igmp => Ok(IpProtocol::Igmp),
Protocol::Tcp => Ok(IpProtocol::Tcp),
Protocol::Udp => Ok(IpProtocol::Udp),
Protocol::Ipv6Route => Ok(IpProtocol::Ipv6Route),
Protocol::Ipv6Frag => Ok(IpProtocol::Ipv6Frag),
Protocol::Ipv6Icmp => Ok(IpProtocol::Icmpv6),
Protocol::Ipv6NoNxt => Ok(IpProtocol::Ipv6NoNxt),
Protocol::Ipv6Opts => Ok(IpProtocol::Ipv6Opts),
_ => Err(Error::ProtocolNotSupported(protocol.to_string())),
}
}