Skip to main content

ort_openrouter_cli/net/
socket.rs

1//! ort: Open Router CLI
2//! https://github.com/grahamking/ort
3//!
4//! MIT License
5//! Copyright (c) 2025 Graham King
6
7use core::ffi::{c_int, c_void};
8use core::mem::size_of;
9use core::net::{Ipv4Addr, SocketAddrV4};
10
11use crate::{ErrorKind, OrtResult, Read, Write, ort_error, syscall, utils};
12
13pub struct TcpSocket {
14    fd: i32,
15}
16
17impl TcpSocket {
18    pub fn new() -> OrtResult<Self> {
19        let fd = syscall::socket(syscall::AF_INET, syscall::SOCK_STREAM | syscall::SOCK_CLOEXEC, 0);
20        if fd == -1 {
21            return Err(ort_error(ErrorKind::SocketCreateFailed, ""));
22        }
23        set_tcp_fastopen(fd);
24        Ok(TcpSocket { fd })
25    }
26
27    pub fn connect(&self, addr: &SocketAddrV4) -> OrtResult<()> {
28        let c_addr = socket_addr_v4_to_c(addr);
29        let len = size_of::<syscall::sockaddr_in>() as syscall::socklen_t;
30        let res = syscall::connect(self.fd, &c_addr as *const _ as *const syscall::sockaddr, len);
31        if res == -1 {
32            return Err(ort_error(ErrorKind::SocketConnectFailed, ""));
33        }
34        Ok(())
35    }
36}
37
38impl super::AsFd for TcpSocket {
39    fn as_fd(&self) -> i32 {
40        self.fd
41    }
42}
43
44impl Read for TcpSocket {
45    fn read(&mut self, buf: &mut [u8]) -> OrtResult<usize> {
46        let bytes_read = syscall::read(self.fd, buf.as_mut_ptr() as *mut c_void, buf.len());
47        if bytes_read < 0 {
48            if bytes_read == syscall::EAGAIN {
49                return Err(ort_error(ErrorKind::WouldBlock, ""));
50            }
51            // see /usr/include/asm-generic/errno.h to translate the codes
52            let err_code = utils::num_to_string(-bytes_read);
53            utils::print_string(c"socket read err: ", &err_code);
54            Err(ort_error(ErrorKind::SocketReadFailed, "syscall read error"))
55        } else {
56            Ok(bytes_read as usize)
57        }
58    }
59}
60
61impl Write for TcpSocket {
62    fn write(&mut self, buf: &[u8]) -> OrtResult<usize> {
63        let bytes_written = syscall::write(self.fd, buf.as_ptr() as *const c_void, buf.len());
64        if bytes_written < 0 {
65            // see /usr/include/asm-generic/errno.h to translate the codes
66            let err_code = utils::num_to_string(-bytes_written);
67            utils::print_string(c"socket write err: ", &err_code);
68            Err(ort_error(
69                ErrorKind::SocketWriteFailed,
70                "syscall write error",
71            ))
72        } else {
73            Ok(bytes_written as usize)
74        }
75    }
76
77    fn flush(&mut self) -> OrtResult<()> {
78        Ok(())
79    }
80}
81
82fn set_tcp_fastopen(fd: i32) {
83    let optval: c_int = 1; // Enable
84    syscall::setsockopt(
85        fd,
86        syscall::IPPROTO_TCP,
87        syscall::TCP_FASTOPEN_CONNECT,
88        &optval as *const _ as *const core::ffi::c_void,
89        size_of::<i32>() as u32,
90    );
91}
92
93fn socket_addr_v4_to_c(addr: &SocketAddrV4) -> syscall::sockaddr_in {
94    syscall::sockaddr_in {
95        sin_family: syscall::AF_INET as syscall::sa_family_t,
96        sin_port: addr.port().to_be(),
97        sin_addr: ip_v4_addr_to_c(addr.ip()),
98        ..unsafe { core::mem::zeroed() }
99    }
100}
101fn ip_v4_addr_to_c(addr: &Ipv4Addr) -> syscall::in_addr {
102    // `s_addr` is stored as BE on all machines and the array is in BE order.
103    // So the native endian conversion method is used so that it's never swapped.
104    syscall::in_addr {
105        s_addr: u32::from_ne_bytes(addr.octets()),
106    }
107}