use crate::{
error::Error, ffi::get_last_error, ip::NrfSockAddr, lte_link::LteLink, CancellationToken,
};
use core::{
cell::RefCell,
ops::{BitOr, BitOrAssign, Deref, Neg},
sync::atomic::{AtomicU8, Ordering},
task::{Poll, Waker},
};
use critical_section::Mutex;
use no_std_net::SocketAddr;
use num_enum::{IntoPrimitive, TryFromPrimitive};
const WAKER_SLOTS: usize = (nrfxlib_sys::NRF_MODEM_MAX_SOCKET_COUNT * 2) as usize;
const WAKER_INIT: Option<(Waker, i32, SocketDirection)> = None;
#[allow(clippy::type_complexity)]
static SOCKET_WAKERS: Mutex<RefCell<[Option<(Waker, i32, SocketDirection)>; WAKER_SLOTS]>> =
Mutex::new(RefCell::new([WAKER_INIT; WAKER_SLOTS]));
fn wake_sockets(socket_fd: i32, socket_dir: SocketDirection) {
critical_section::with(|cs| {
SOCKET_WAKERS
.borrow_ref_mut(cs)
.iter_mut()
.filter(|slot| {
if let Some((_, fd, dir)) = slot {
*fd == socket_fd && dir.same_direction(socket_dir)
} else {
false
}
})
.for_each(|slot| {
let (waker, _, _) = slot.take().unwrap();
waker.wake();
});
});
}
fn register_socket_waker(waker: Waker, socket_fd: i32, socket_dir: SocketDirection) {
critical_section::with(|cs| {
let mut wakers = SOCKET_WAKERS.borrow_ref_mut(cs);
let empty_waker = wakers.iter_mut().find(|waker| {
waker.is_none()
|| waker.as_ref().map(|(_, fd, dir)| (*fd, *dir)) == Some((socket_fd, socket_dir))
});
if let Some(empty_waker) = empty_waker {
*empty_waker = Some((waker, socket_fd, socket_dir));
} else {
wakers
.first_mut()
.unwrap()
.replace((waker, socket_fd, socket_dir))
.unwrap()
.0
.wake();
}
});
}
unsafe extern "C" fn socket_poll_callback(pollfd: *mut nrfxlib_sys::nrf_pollfd) {
let pollfd = *pollfd;
let mut direction = SocketDirection::Neither;
if pollfd.revents as u32 & nrfxlib_sys::NRF_POLLIN != 0 {
direction |= SocketDirection::In;
}
if pollfd.revents as u32 & nrfxlib_sys::NRF_POLLOUT != 0 {
direction |= SocketDirection::Out;
}
if pollfd.revents as u32
& (nrfxlib_sys::NRF_POLLERR | nrfxlib_sys::NRF_POLLHUP | nrfxlib_sys::NRF_POLLNVAL)
!= 0
{
direction |= SocketDirection::Either;
}
#[cfg(feature = "defmt")]
defmt::trace!(
"Socket poll callback. fd: {}, revents: {:X}, direction: {}",
pollfd.fd,
pollfd.revents,
direction
);
wake_sockets(pollfd.fd, direction);
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
enum SocketDirection {
Neither,
In,
Out,
Either,
}
impl BitOrAssign for SocketDirection {
fn bitor_assign(&mut self, rhs: Self) {
*self = *self | rhs;
}
}
impl BitOr for SocketDirection {
type Output = SocketDirection;
fn bitor(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(SocketDirection::Neither, rhs) => rhs,
(lhs, SocketDirection::Neither) => lhs,
(SocketDirection::In, SocketDirection::In) => SocketDirection::In,
(SocketDirection::Out, SocketDirection::Out) => SocketDirection::Out,
(SocketDirection::In, SocketDirection::Out) => SocketDirection::Either,
(SocketDirection::Out, SocketDirection::In) => SocketDirection::Either,
(SocketDirection::Either, _) => SocketDirection::Either,
(_, SocketDirection::Either) => SocketDirection::Either,
}
}
}
impl SocketDirection {
fn same_direction(&self, other: Self) -> bool {
match (self, other) {
(SocketDirection::Neither, _) => false,
(_, SocketDirection::Neither) => false,
(SocketDirection::In, SocketDirection::In) => true,
(SocketDirection::Out, SocketDirection::Out) => true,
(SocketDirection::In, SocketDirection::Out) => false,
(SocketDirection::Out, SocketDirection::In) => false,
(_, SocketDirection::Either) => true,
(SocketDirection::Either, _) => true,
}
}
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Socket {
fd: i32,
family: SocketFamily,
link: Option<LteLink>,
split: bool,
}
impl Socket {
pub async fn create(
family: SocketFamily,
s_type: SocketType,
protocol: SocketProtocol,
) -> Result<Self, Error> {
#[cfg(feature = "defmt")]
defmt::debug!(
"Creating socket with family: {}, type: {}, protocol: {}",
family as u32 as i32,
s_type as u32 as i32,
protocol as u32 as i32
);
if unsafe { !nrfxlib_sys::nrf_modem_is_initialized() } {
return Err(Error::ModemNotInitialized);
}
let link = LteLink::new().await?;
let fd = unsafe {
nrfxlib_sys::nrf_socket(
family as u32 as i32,
s_type as u32 as i32,
protocol as u32 as i32,
)
};
if fd == -1 {
return Err(Error::NrfError(get_last_error()));
}
unsafe {
let result = nrfxlib_sys::nrf_fcntl(
fd,
nrfxlib_sys::NRF_F_SETFL as _,
nrfxlib_sys::NRF_O_NONBLOCK as _,
);
if result == -1 {
return Err(Error::NrfError(get_last_error()));
}
}
let poll_callback = nrfxlib_sys::nrf_modem_pollcb {
callback: Some(socket_poll_callback),
events: (nrfxlib_sys::NRF_POLLIN | nrfxlib_sys::NRF_POLLOUT) as _, oneshot: false,
};
unsafe {
let result = nrfxlib_sys::nrf_setsockopt(
fd,
nrfxlib_sys::NRF_SOL_SOCKET as _,
nrfxlib_sys::NRF_SO_POLLCB as _,
(&poll_callback as *const nrfxlib_sys::nrf_modem_pollcb).cast(),
core::mem::size_of::<nrfxlib_sys::nrf_modem_pollcb>() as u32,
);
if result == -1 {
return Err(Error::NrfError(get_last_error()));
}
}
Ok(Socket {
fd,
family,
link: Some(link),
split: false,
})
}
pub fn as_raw_fd(&self) -> i32 {
self.fd
}
pub async fn split(mut self) -> Result<(SplitSocketHandle, SplitSocketHandle), Error> {
let index = SplitSocketHandle::get_new_spot();
self.split = true;
Ok((
SplitSocketHandle {
inner: Some(Socket {
fd: self.fd,
family: self.family,
link: Some(LteLink::new().await?),
split: true,
}),
index,
},
SplitSocketHandle {
inner: Some(self),
index,
},
))
}
pub async unsafe fn connect(
&self,
address: SocketAddr,
token: &CancellationToken,
) -> Result<(), Error> {
#[cfg(feature = "defmt")]
defmt::debug!(
"Connecting socket {} to {:?}",
self.fd,
defmt::Debug2Format(&address)
);
token.bind_to_current_task().await;
self.link
.as_ref()
.unwrap()
.wait_for_link_with_cancellation(token)
.await?;
core::future::poll_fn(|cx| {
#[cfg(feature = "defmt")]
defmt::trace!("Connecting socket {}", self.fd);
if token.is_cancelled() {
return Poll::Ready(Err(Error::OperationCancelled));
}
let address = NrfSockAddr::from(address);
register_socket_waker(cx.waker().clone(), self.fd, SocketDirection::Either);
let mut connect_result = unsafe {
nrfxlib_sys::nrf_connect(self.fd, address.as_ptr(), address.size() as u32)
} as isize;
const NRF_EINPROGRESS: isize = nrfxlib_sys::NRF_EINPROGRESS as isize;
const NRF_EALREADY: isize = nrfxlib_sys::NRF_EALREADY as isize;
const NRF_EISCONN: isize = nrfxlib_sys::NRF_EISCONN as isize;
if connect_result == -1 {
connect_result = get_last_error();
}
#[cfg(feature = "defmt")]
defmt::trace!("Connect result {}", connect_result);
match connect_result {
0 => Poll::Ready(Ok(())),
NRF_EISCONN => Poll::Ready(Ok(())),
NRF_EINPROGRESS | NRF_EALREADY => Poll::Pending,
error => Poll::Ready(Err(Error::NrfError(error))),
}
})
.await?;
Ok(())
}
pub async unsafe fn bind(
&self,
address: SocketAddr,
token: &CancellationToken,
) -> Result<(), Error> {
#[cfg(feature = "defmt")]
defmt::debug!(
"Binding socket {} to {:?}",
self.fd,
defmt::Debug2Format(&address)
);
token.bind_to_current_task().await;
self.link
.as_ref()
.unwrap()
.wait_for_link_with_cancellation(token)
.await?;
core::future::poll_fn(|cx| {
#[cfg(feature = "defmt")]
defmt::trace!("Binding socket {}", self.fd);
if token.is_cancelled() {
return Poll::Ready(Err(Error::OperationCancelled));
}
let address = NrfSockAddr::from(address);
register_socket_waker(cx.waker().clone(), self.fd, SocketDirection::Either);
let mut bind_result =
unsafe { nrfxlib_sys::nrf_bind(self.fd, address.as_ptr(), address.size() as u32) }
as isize;
const NRF_EINPROGRESS: isize = nrfxlib_sys::NRF_EINPROGRESS as isize;
const NRF_EALREADY: isize = nrfxlib_sys::NRF_EALREADY as isize;
const NRF_EISCONN: isize = nrfxlib_sys::NRF_EISCONN as isize;
if bind_result == -1 {
bind_result = get_last_error();
}
#[cfg(feature = "defmt")]
defmt::trace!("Bind result {}", bind_result);
match bind_result {
0 => Poll::Ready(Ok(())),
NRF_EISCONN => Poll::Ready(Ok(())),
NRF_EINPROGRESS | NRF_EALREADY => Poll::Pending,
error => Poll::Ready(Err(Error::NrfError(error))),
}
})
.await?;
Ok(())
}
pub async fn write(&self, buffer: &[u8], token: &CancellationToken) -> Result<usize, Error> {
token.bind_to_current_task().await;
core::future::poll_fn(|cx| {
#[cfg(feature = "defmt")]
defmt::trace!("Sending with socket {}", self.fd);
if token.is_cancelled() {
return Poll::Ready(Err(Error::OperationCancelled));
}
register_socket_waker(cx.waker().clone(), self.fd, SocketDirection::Out);
let mut send_result = unsafe {
nrfxlib_sys::nrf_send(self.fd, buffer.as_ptr() as *const _, buffer.len(), 0)
};
if send_result == -1 {
send_result = get_last_error().abs().neg();
}
#[cfg(feature = "defmt")]
defmt::trace!("Send result {}", send_result);
const NRF_EWOULDBLOCK: isize = -(nrfxlib_sys::NRF_EWOULDBLOCK as isize);
const NRF_ENOTCONN: isize = -(nrfxlib_sys::NRF_ENOTCONN as isize);
match send_result {
0 if !buffer.is_empty() => Poll::Ready(Err(Error::Disconnected)),
NRF_ENOTCONN => Poll::Ready(Err(Error::Disconnected)),
bytes_sent @ 0.. => Poll::Ready(Ok(bytes_sent as usize)),
NRF_EWOULDBLOCK => Poll::Pending,
error => Poll::Ready(Err(Error::NrfError(error))),
}
})
.await
}
pub async fn receive(
&self,
buffer: &mut [u8],
token: &CancellationToken,
) -> Result<usize, Error> {
token.bind_to_current_task().await;
core::future::poll_fn(|cx| {
#[cfg(feature = "defmt")]
defmt::trace!("Receiving with socket {}", self.fd);
if token.is_cancelled() {
return Poll::Ready(Err(Error::OperationCancelled));
}
register_socket_waker(cx.waker().clone(), self.fd, SocketDirection::In);
let mut receive_result = unsafe {
nrfxlib_sys::nrf_recv(self.fd, buffer.as_mut_ptr() as *mut _, buffer.len(), 0)
};
if receive_result == -1 {
receive_result = get_last_error().abs().neg();
}
#[cfg(feature = "defmt")]
defmt::trace!("Receive result {}", receive_result);
const NRF_EWOULDBLOCK: isize = -(nrfxlib_sys::NRF_EWOULDBLOCK as isize);
const NRF_ENOTCONN: isize = -(nrfxlib_sys::NRF_ENOTCONN as isize);
match receive_result {
0 if !buffer.is_empty() => Poll::Ready(Err(Error::Disconnected)),
NRF_ENOTCONN => Poll::Ready(Err(Error::Disconnected)),
bytes_received @ 0.. => Poll::Ready(Ok(bytes_received as usize)),
NRF_EWOULDBLOCK => Poll::Pending,
error => Poll::Ready(Err(Error::NrfError(error))),
}
})
.await
}
pub async fn receive_from(
&self,
buffer: &mut [u8],
token: &CancellationToken,
) -> Result<(usize, SocketAddr), Error> {
token.bind_to_current_task().await;
core::future::poll_fn(|cx| {
#[cfg(feature = "defmt")]
defmt::trace!("Receiving with socket {}", self.fd);
if token.is_cancelled() {
return Poll::Ready(Err(Error::OperationCancelled));
}
let mut socket_addr_store =
[0u8; core::mem::size_of::<nrfxlib_sys::nrf_sockaddr_in6>()];
let socket_addr_ptr = socket_addr_store.as_mut_ptr() as *mut nrfxlib_sys::nrf_sockaddr;
let mut socket_addr_len = 0u32;
register_socket_waker(cx.waker().clone(), self.fd, SocketDirection::In);
let mut receive_result = unsafe {
nrfxlib_sys::nrf_recvfrom(
self.fd,
buffer.as_mut_ptr() as *mut _,
buffer.len(),
0,
socket_addr_ptr,
&mut socket_addr_len as *mut u32,
)
};
if receive_result == -1 {
receive_result = get_last_error().abs().neg();
}
#[cfg(feature = "defmt")]
defmt::trace!("Receive result {}", receive_result);
const NRF_EWOULDBLOCK: isize = -(nrfxlib_sys::NRF_EWOULDBLOCK as isize);
const NRF_ENOTCONN: isize = -(nrfxlib_sys::NRF_ENOTCONN as isize);
match receive_result {
0 if !buffer.is_empty() => Poll::Ready(Err(Error::Disconnected)),
NRF_ENOTCONN => Poll::Ready(Err(Error::Disconnected)),
bytes_received @ 0.. => Poll::Ready(Ok((bytes_received as usize, {
unsafe { (*socket_addr_ptr).sa_family = self.family as u32 as i32 }
NrfSockAddr::from(socket_addr_ptr as *const _).into()
}))),
NRF_EWOULDBLOCK => Poll::Pending,
error => Poll::Ready(Err(Error::NrfError(error))),
}
})
.await
}
pub async fn send_to(
&self,
buffer: &[u8],
address: SocketAddr,
token: &CancellationToken,
) -> Result<usize, Error> {
token.bind_to_current_task().await;
core::future::poll_fn(|cx| {
#[cfg(feature = "defmt")]
defmt::trace!("Sending with socket {}", self.fd);
if token.is_cancelled() {
return Poll::Ready(Err(Error::OperationCancelled));
}
let addr = NrfSockAddr::from(address);
register_socket_waker(cx.waker().clone(), self.fd, SocketDirection::Out);
let mut send_result = unsafe {
nrfxlib_sys::nrf_sendto(
self.fd,
buffer.as_ptr() as *mut _,
buffer.len(),
0,
addr.as_ptr(),
addr.size() as u32,
)
};
if send_result == -1 {
send_result = get_last_error().abs().neg();
}
#[cfg(feature = "defmt")]
defmt::trace!("Sending result {}", send_result);
const NRF_EWOULDBLOCK: isize = -(nrfxlib_sys::NRF_EWOULDBLOCK as isize);
const NRF_ENOTCONN: isize = -(nrfxlib_sys::NRF_ENOTCONN as isize);
match send_result {
0 if !buffer.is_empty() => Poll::Ready(Err(Error::Disconnected)),
NRF_ENOTCONN => Poll::Ready(Err(Error::Disconnected)),
bytes_received @ 0.. => Poll::Ready(Ok(bytes_received as usize)),
NRF_EWOULDBLOCK => Poll::Pending,
error => Poll::Ready(Err(Error::NrfError(error))),
}
})
.await
}
pub fn set_option<'a>(&'a self, option: SocketOption<'a>) -> Result<(), SocketOptionError> {
let length = option.get_length();
let result = unsafe {
nrfxlib_sys::nrf_setsockopt(
self.fd,
nrfxlib_sys::NRF_SOL_SECURE.try_into().unwrap(),
option.get_name(),
option.get_value(),
length,
)
};
if result < 0 {
Err(result.into())
} else {
Ok(())
}
}
pub async fn deactivate(mut self) -> Result<(), Error> {
self.link.take().unwrap().deactivate().await?;
Ok(())
}
}
impl Drop for Socket {
fn drop(&mut self) {
if !self.split {
let e = unsafe { nrfxlib_sys::nrf_close(self.fd) };
if e == -1 {
panic!("{:?}", Error::NrfError(get_last_error()));
}
}
}
}
impl PartialEq for Socket {
fn eq(&self, other: &Self) -> bool {
self.fd == other.fd
}
}
impl Eq for Socket {}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive, IntoPrimitive)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum SocketFamily {
Unspecified = nrfxlib_sys::NRF_AF_UNSPEC,
Ipv4 = nrfxlib_sys::NRF_AF_INET,
Ipv6 = nrfxlib_sys::NRF_AF_INET6,
Raw = nrfxlib_sys::NRF_AF_PACKET,
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive, IntoPrimitive)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum SocketType {
Stream = nrfxlib_sys::NRF_SOCK_STREAM,
Datagram = nrfxlib_sys::NRF_SOCK_DGRAM,
Raw = nrfxlib_sys::NRF_SOCK_RAW,
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive, IntoPrimitive)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum SocketProtocol {
IP = nrfxlib_sys::NRF_IPPROTO_IP,
Tcp = nrfxlib_sys::NRF_IPPROTO_TCP,
Udp = nrfxlib_sys::NRF_IPPROTO_UDP,
Ipv6 = nrfxlib_sys::NRF_IPPROTO_IPV6,
Raw = nrfxlib_sys::NRF_IPPROTO_RAW,
All = nrfxlib_sys::NRF_IPPROTO_ALL,
Tls1v2 = nrfxlib_sys::NRF_SPROTO_TLS1v2,
DTls1v2 = nrfxlib_sys::NRF_SPROTO_DTLS1v2,
}
#[allow(clippy::enum_variant_names)]
#[derive(Debug)]
pub enum SocketOption<'a> {
TlsHostName(&'a str),
TlsPeerVerify(nrfxlib_sys::nrf_sec_peer_verify_t),
TlsSessionCache(nrfxlib_sys::nrf_sec_session_cache_t),
TlsTagList(&'a [nrfxlib_sys::nrf_sec_tag_t]),
}
impl<'a> SocketOption<'a> {
pub(crate) fn get_name(&self) -> i32 {
match self {
SocketOption::TlsHostName(_) => nrfxlib_sys::NRF_SO_SEC_HOSTNAME as i32,
SocketOption::TlsPeerVerify(_) => nrfxlib_sys::NRF_SO_SEC_PEER_VERIFY as i32,
SocketOption::TlsSessionCache(_) => nrfxlib_sys::NRF_SO_SEC_SESSION_CACHE as i32,
SocketOption::TlsTagList(_) => nrfxlib_sys::NRF_SO_SEC_TAG_LIST as i32,
}
}
pub(crate) fn get_value(&self) -> *const core::ffi::c_void {
match self {
SocketOption::TlsHostName(s) => s.as_ptr() as *const core::ffi::c_void,
SocketOption::TlsPeerVerify(x) => x as *const _ as *const core::ffi::c_void,
SocketOption::TlsSessionCache(x) => x as *const _ as *const core::ffi::c_void,
SocketOption::TlsTagList(x) => x.as_ptr() as *const core::ffi::c_void,
}
}
pub(crate) fn get_length(&self) -> u32 {
match self {
SocketOption::TlsHostName(s) => s.len() as u32,
SocketOption::TlsPeerVerify(x) => core::mem::size_of_val(x) as u32,
SocketOption::TlsSessionCache(x) => core::mem::size_of_val(x) as u32,
SocketOption::TlsTagList(x) => core::mem::size_of_val(*x) as u32,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum SocketOptionError {
InvalidFileDescriptor,
TimeoutTooBig,
InvalidOption,
AlreadyConnected,
UnsupportedOption,
NotASocket,
OutOfMemory,
OutOfResources,
}
impl From<i32> for SocketOptionError {
fn from(errno: i32) -> Self {
match errno.unsigned_abs() {
nrfxlib_sys::NRF_EBADF => SocketOptionError::InvalidFileDescriptor,
nrfxlib_sys::NRF_EINVAL => SocketOptionError::InvalidOption,
nrfxlib_sys::NRF_EISCONN => SocketOptionError::AlreadyConnected,
nrfxlib_sys::NRF_ENOPROTOOPT => SocketOptionError::UnsupportedOption,
nrfxlib_sys::NRF_ENOTSOCK => SocketOptionError::NotASocket,
nrfxlib_sys::NRF_ENOMEM => SocketOptionError::OutOfMemory,
nrfxlib_sys::NRF_ENOBUFS => SocketOptionError::OutOfResources,
_ => panic!("Unknown error code: {}", errno),
}
}
}
#[allow(clippy::declare_interior_mutable_const)]
const ATOMIC_U8_INIT: AtomicU8 = AtomicU8::new(0);
static ACTIVE_SPLIT_SOCKETS: [AtomicU8; nrfxlib_sys::NRF_MODEM_MAX_SOCKET_COUNT as usize] =
[ATOMIC_U8_INIT; nrfxlib_sys::NRF_MODEM_MAX_SOCKET_COUNT as usize];
pub struct SplitSocketHandle {
inner: Option<Socket>,
index: usize,
}
impl SplitSocketHandle {
pub async fn deactivate(mut self) -> Result<(), Error> {
let mut inner = self.inner.take().unwrap();
if ACTIVE_SPLIT_SOCKETS[self.index].fetch_sub(1, Ordering::SeqCst) == 1 {
inner.split = false;
}
inner.deactivate().await?;
Ok(())
}
fn get_new_spot() -> usize {
for (index, count) in ACTIVE_SPLIT_SOCKETS.iter().enumerate() {
if count
.compare_exchange(0, 2, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
return index;
}
}
unreachable!("It should not be possible to have more splits than the maximum socket count");
}
}
impl Deref for SplitSocketHandle {
type Target = Socket;
fn deref(&self) -> &Self::Target {
self.inner.as_ref().unwrap()
}
}
impl Drop for SplitSocketHandle {
fn drop(&mut self) {
if let Some(inner) = self.inner.as_mut() {
if ACTIVE_SPLIT_SOCKETS[self.index].fetch_sub(1, Ordering::SeqCst) == 1 {
inner.split = false;
}
}
}
}