use crate::{
environment::net::{
ip::{header, transport, Category, Header, Packet, Segments},
monitor::List as Monitors,
socket::{self, RecvOptions, RecvResult, SendOptions},
},
ext::*,
net::{
monitor::{SocketRead, SocketWrite},
SocketAddr,
},
queue::Pushable,
sync::channel,
};
use core::task::{Context, Poll};
use std::{io, sync::Mutex};
pub struct Socket {
sender: Mutex<Sender>,
receiver: Mutex<Receiver>,
local_addr: SocketAddr,
peer_addr: Mutex<Option<SocketAddr>>,
monitors: Monitors,
}
macro_rules! lock {
($lock:expr) => {
$lock
.lock()
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{e}")))?
};
}
impl Socket {
pub fn new(
sender: channel::Sender<Segments>,
receiver: channel::Receiver<Packet>,
local_addr: SocketAddr,
monitors: Monitors,
) -> Self {
let sender = Mutex::new(Sender::new(sender));
let receiver = Mutex::new(Receiver::new(receiver));
Self {
sender,
receiver,
local_addr,
peer_addr: Mutex::new(None),
monitors,
}
}
}
impl socket::Socket for Socket {
fn poll_connect(&self, _cx: &mut Context, peer_addr: SocketAddr) -> Poll<io::Result<()>> {
*lock!(self.peer_addr) = Some(peer_addr);
Poll::Ready(Ok(()))
}
fn peer_addr(&self) -> io::Result<SocketAddr> {
if let Some(peer_addr) = *lock!(self.peer_addr) {
Ok(peer_addr)
} else {
Err(io::Error::new(
io::ErrorKind::NotConnected,
"Socket not connected",
))
}
}
fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self.local_addr)
}
fn poll_writable(&self, _cx: &mut Context) -> Poll<io::Result<()>> {
Err(io::Error::new(
io::ErrorKind::Unsupported,
"poll_writable isn't implemented",
))
.into()
}
fn poll_readable(&self, _cx: &mut Context) -> Poll<io::Result<()>> {
Err(io::Error::new(
io::ErrorKind::Unsupported,
"poll_readable isn't implemented",
))
.into()
}
fn set_sockopt(&self, _option: super::Sockopt, _value: usize) -> io::Result<()> {
Err(io::Error::new(
io::ErrorKind::Unsupported,
"set_sockopt isn't implemented",
))
}
fn get_sockopt(&self, _option: super::Sockopt) -> io::Result<usize> {
Err(io::Error::new(
io::ErrorKind::Unsupported,
"get_sockopt isn't implemented",
))
}
fn sendmsg(
&self,
cx: Option<&mut Context>,
destination: &SocketAddr,
payload: &[io::IoSlice],
opts: SendOptions,
) -> io::Result<usize> {
let peer_addr = *lock!(self.peer_addr);
lock!(self.sender).sendmsg(
cx,
&self.local_addr,
peer_addr,
destination,
payload,
opts,
&self.monitors,
)
}
fn recvmsg(
&self,
cx: Option<&mut Context>,
payload: &mut [io::IoSliceMut],
opts: RecvOptions,
) -> io::Result<RecvResult> {
let peer_addr = *lock!(self.peer_addr);
lock!(self.receiver).recvmsg(cx, peer_addr, payload, opts, &self.monitors)
}
fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
let _ = how;
Ok(())
}
}
struct Sender {
channel: channel::Sender<Segments>,
id: u16,
ttl: u8,
}
impl Sender {
fn new(channel: channel::Sender<Segments>) -> Self {
Self {
channel,
id: 0,
ttl: 64,
}
}
fn sendmsg(
&mut self,
cx: Option<&mut Context>,
local_addr: &SocketAddr,
peer_addr: Option<SocketAddr>,
destination: &SocketAddr,
payload: &[io::IoSlice],
opts: super::SendOptions,
monitors: &Monitors,
) -> io::Result<usize> {
let destination = if destination.is_unspecified() {
peer_addr.as_ref().ok_or_else(|| {
io::Error::new(io::ErrorKind::NotConnected, "Socket not connected")
})?
} else {
destination
};
let id = self.id;
self.id = self.id.wrapping_add(1);
let ttl = self.ttl;
if opts.source.is_some() {
todo!()
}
let header: Header = match (local_addr, destination) {
(SocketAddr::V4(src), SocketAddr::V4(dst)) => header::V4 {
source: *src.ip(),
destination: *dst.ip(),
dscp: 0,
ecn: opts.ecn,
df: true,
id,
ttl,
}
.into(),
(SocketAddr::V6(src), SocketAddr::V4(dst)) => header::V6 {
source: *src.ip(),
destination: dst.ip().to_ipv6_mapped(),
dscp: 0,
ecn: opts.ecn,
flow_label: 0,
hop_limit: ttl,
}
.into(),
(SocketAddr::V6(src), SocketAddr::V6(dst)) => header::V6 {
source: *src.ip(),
destination: *dst.ip(),
dscp: 0,
ecn: opts.ecn,
flow_label: 0,
hop_limit: ttl,
}
.into(),
(SocketAddr::V4(_), SocketAddr::V6(_)) => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"cannot send IPv6 packet on IPv4 socket",
))
}
};
let mut socket_write = SocketWrite {
local_addr,
peer_addr: destination,
transport: transport::Kind::Udp,
payload,
opts: &opts,
};
monitors.on_socket_write(&mut socket_write)?;
let transport = transport::Udp {
source: local_addr.port(),
destination: destination.port(),
payload: Default::default(),
checksum: 0,
};
let mut packet = SendablePacket {
header,
transport,
payload,
len: None,
segment_len: opts.segment_len,
};
if let Some(cx) = cx {
if self.channel.poll_push(cx, &mut packet)?.is_pending() {
return Err(io::ErrorKind::WouldBlock.into());
}
} else {
let mut channel = self.channel.clone();
let packet = packet.produce();
async move {
let _ = channel.push(packet).await;
}
.spawn();
}
Ok(packet.len.unwrap_or(0))
}
}
struct SendablePacket<'a> {
header: Header,
transport: transport::Udp,
payload: &'a [io::IoSlice<'a>],
len: Option<usize>,
segment_len: Option<usize>,
}
impl Pushable<Segments> for SendablePacket<'_> {
fn produce(&mut self) -> Segments {
let len = if let Some(len) = self.len {
len
} else {
let len = self.payload.iter().map(|p| p.len()).sum();
self.len = Some(len);
len
};
let mut payload = Vec::with_capacity(len);
for chunk in self.payload {
payload.extend_from_slice(chunk);
}
let mut transport = self.transport.clone();
transport.payload = payload.into();
let packet = Packet {
header: self.header,
transport: transport.into(),
};
let segment_len = self.segment_len.unwrap_or(len).min(len);
Segments {
packet,
segment_len,
}
}
}
struct Receiver {
channel: channel::Receiver<Packet>,
}
impl Receiver {
fn new(channel: channel::Receiver<Packet>) -> Self {
Self { channel }
}
fn recvmsg(
&mut self,
mut cx: Option<&mut Context>,
peer_addr: Option<SocketAddr>,
payload: &mut [io::IoSliceMut],
opts: RecvOptions,
monitors: &Monitors,
) -> io::Result<RecvResult> {
if opts.peek {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"peek is not currently implemented",
));
}
loop {
let packet = if let Some(cx) = cx.as_mut() {
let res = self.channel.poll_pop(cx)?;
let Poll::Ready(v) = res else {
return Err(io::ErrorKind::WouldBlock.into());
};
v
} else {
return Err(io::Error::new(
io::ErrorKind::WouldBlock,
"recvmsg without context is not currently implemented",
));
};
let destination = packet.destination();
let source = packet.source();
if let Some(peer_addr) = peer_addr {
if source != peer_addr {
count!("peer_mismatch");
continue;
}
}
let (copied_len, truncation_len) = packet.transport.copy_payload_into(payload);
let mut res = RecvResult {
peer_addr: source,
local_addr: destination,
ecn: packet.header.ecn(),
len: copied_len,
segment_len: copied_len,
truncation_len,
};
monitors.on_socket_read(&mut SocketRead {
result: &mut res,
payload: &mut *payload,
})?;
return Ok(res);
}
}
}