use core::ffi::{c_int, c_void};
use core::mem::size_of;
use core::net::{Ipv4Addr, SocketAddrV4};
use crate::{ErrorKind, OrtResult, Read, Write, ort_error, syscall, utils};
pub struct TcpSocket {
fd: i32,
}
impl TcpSocket {
pub fn new() -> OrtResult<Self> {
let fd = syscall::socket(syscall::AF_INET, syscall::SOCK_STREAM | syscall::SOCK_CLOEXEC, 0);
if fd == -1 {
return Err(ort_error(ErrorKind::SocketCreateFailed, ""));
}
set_tcp_fastopen(fd);
Ok(TcpSocket { fd })
}
pub fn connect(&self, addr: &SocketAddrV4) -> OrtResult<()> {
let c_addr = socket_addr_v4_to_c(addr);
let len = size_of::<syscall::sockaddr_in>() as syscall::socklen_t;
let res = syscall::connect(self.fd, &c_addr as *const _ as *const syscall::sockaddr, len);
if res == -1 {
return Err(ort_error(ErrorKind::SocketConnectFailed, ""));
}
Ok(())
}
}
impl super::AsFd for TcpSocket {
fn as_fd(&self) -> i32 {
self.fd
}
}
impl Read for TcpSocket {
fn read(&mut self, buf: &mut [u8]) -> OrtResult<usize> {
let bytes_read = syscall::read(self.fd, buf.as_mut_ptr() as *mut c_void, buf.len());
if bytes_read < 0 {
if bytes_read == syscall::EAGAIN {
return Err(ort_error(ErrorKind::WouldBlock, ""));
}
let err_code = utils::num_to_string(-bytes_read);
utils::print_string(c"socket read err: ", &err_code);
Err(ort_error(ErrorKind::SocketReadFailed, "syscall read error"))
} else {
Ok(bytes_read as usize)
}
}
}
impl Write for TcpSocket {
fn write(&mut self, buf: &[u8]) -> OrtResult<usize> {
let bytes_written = syscall::write(self.fd, buf.as_ptr() as *const c_void, buf.len());
if bytes_written < 0 {
let err_code = utils::num_to_string(-bytes_written);
utils::print_string(c"socket write err: ", &err_code);
Err(ort_error(
ErrorKind::SocketWriteFailed,
"syscall write error",
))
} else {
Ok(bytes_written as usize)
}
}
fn flush(&mut self) -> OrtResult<()> {
Ok(())
}
}
fn set_tcp_fastopen(fd: i32) {
let optval: c_int = 1; syscall::setsockopt(
fd,
syscall::IPPROTO_TCP,
syscall::TCP_FASTOPEN_CONNECT,
&optval as *const _ as *const core::ffi::c_void,
size_of::<i32>() as u32,
);
}
fn socket_addr_v4_to_c(addr: &SocketAddrV4) -> syscall::sockaddr_in {
syscall::sockaddr_in {
sin_family: syscall::AF_INET as syscall::sa_family_t,
sin_port: addr.port().to_be(),
sin_addr: ip_v4_addr_to_c(addr.ip()),
..unsafe { core::mem::zeroed() }
}
}
fn ip_v4_addr_to_c(addr: &Ipv4Addr) -> syscall::in_addr {
syscall::in_addr {
s_addr: u32::from_ne_bytes(addr.octets()),
}
}