#![allow(clippy::module_name_repetitions)]
use alloc::vec::Vec;
use embedded_io::{ErrorType, Read, Write};
use core::net::SocketAddr;
use psp::sys;
use core::ffi::c_void;
use crate::traits::io::{EasySocket, Open, OptionType};
use crate::traits::SocketBuffer;
use crate::types::{SocketOptions, SocketRecvFlags, SocketSendFlags};
use super::super::netc;
use super::error::SocketError;
use super::sce::SocketFileDescriptor;
use super::state::{Connected, SocketState, Unbound};
use super::ToSockaddr;
#[repr(C)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TcpSocket<S: SocketState = Unbound, B: SocketBuffer = Vec<u8>> {
pub(super) fd: SocketFileDescriptor,
buffer: B,
send_flags: SocketSendFlags,
recv_flags: SocketRecvFlags,
_marker: core::marker::PhantomData<S>,
}
impl TcpSocket {
pub fn new() -> Result<TcpSocket<Unbound>, SocketError> {
let fd = unsafe { sys::sceNetInetSocket(i32::from(netc::AF_INET), netc::SOCK_STREAM, 0) };
if fd < 0 {
Err(SocketError::new_errno_with_description(
unsafe { sys::sceNetInetGetErrno() },
"failed to create socket",
))
} else {
let fd = SocketFileDescriptor::new(fd);
Ok(TcpSocket {
fd,
buffer: Vec::with_capacity(0),
send_flags: SocketSendFlags::empty(),
recv_flags: SocketRecvFlags::empty(),
_marker: core::marker::PhantomData,
})
}
}
}
impl<S: SocketState> TcpSocket<S> {
#[must_use]
pub fn fd(&self) -> i32 {
*self.fd
}
#[must_use]
pub fn send_flags(&self) -> SocketSendFlags {
self.send_flags
}
pub fn set_send_flags(&mut self, send_flags: SocketSendFlags) {
self.send_flags = send_flags;
}
#[must_use]
pub fn recv_flags(&self) -> SocketRecvFlags {
self.recv_flags
}
pub fn set_recv_flags(&mut self, recv_flags: SocketRecvFlags) {
self.recv_flags = recv_flags;
}
}
impl TcpSocket<Unbound> {
#[must_use]
fn transition(self) -> TcpSocket<Connected> {
TcpSocket {
fd: self.fd,
buffer: Vec::default(),
send_flags: self.send_flags,
recv_flags: self.recv_flags,
_marker: core::marker::PhantomData,
}
}
pub fn connect(self, remote: SocketAddr) -> Result<TcpSocket<Connected>, SocketError> {
match remote {
SocketAddr::V4(v4) => {
let sockaddr = v4.to_sockaddr();
if unsafe {
sys::sceNetInetConnect(
*self.fd,
&sockaddr,
core::mem::size_of::<netc::sockaddr_in>() as u32,
)
} < 0
{
let errno = unsafe { sys::sceNetInetGetErrno() };
Err(SocketError::Errno(errno))
} else {
Ok(self.transition())
}
}
SocketAddr::V6(_) => Err(SocketError::UnsupportedAddressFamily),
}
}
}
impl TcpSocket<Connected> {
pub fn internal_read(&self, buf: &mut [u8]) -> Result<usize, SocketError> {
let result = unsafe {
sys::sceNetInetRecv(
*self.fd,
buf.as_mut_ptr().cast::<c_void>(),
buf.len(),
self.recv_flags.as_i32(),
)
};
if result < 0 {
Err(SocketError::Errno(unsafe { sys::sceNetInetGetErrno() }))
} else {
Ok(result as usize)
}
}
pub fn internal_write(&mut self, buf: &[u8]) -> Result<usize, SocketError> {
self.buffer.append_buffer(buf);
self.send()
}
fn internal_flush(&mut self) -> Result<(), SocketError> {
while !self.buffer.is_empty() {
self.send()?;
}
Ok(())
}
fn send(&mut self) -> Result<usize, SocketError> {
let result = unsafe {
sys::sceNetInetSend(
*self.fd,
self.buffer.as_slice().as_ptr().cast::<c_void>(),
self.buffer.len(),
self.send_flags.as_i32(),
)
};
if result < 0 {
Err(SocketError::Errno(unsafe { sys::sceNetInetGetErrno() }))
} else {
self.buffer.shift_left_buffer(result as usize);
Ok(result as usize)
}
}
}
impl<S: SocketState> ErrorType for TcpSocket<S> {
type Error = SocketError;
}
impl<S: SocketState> OptionType for TcpSocket<S> {
type Options<'a> = SocketOptions;
}
impl Open<'_, '_> for TcpSocket<Unbound> {
type Return = TcpSocket<Connected>;
fn open(self, options: &'_ Self::Options<'_>) -> Result<Self::Return, Self::Error>
where
Self: Sized,
{
let socket = self.connect(options.remote())?;
Ok(socket)
}
}
impl Read for TcpSocket<Connected> {
fn read<'m>(&'m mut self, buf: &'m mut [u8]) -> Result<usize, Self::Error> {
self.internal_read(buf)
}
}
impl Write for TcpSocket<Connected> {
fn write<'m>(&'m mut self, buf: &'m [u8]) -> Result<usize, Self::Error> {
self.internal_write(buf)
}
fn flush(&mut self) -> Result<(), SocketError> {
self.internal_flush()
}
}
impl EasySocket for TcpSocket<Connected> {}