psp_net/socket/
tcp.rs

1#![allow(clippy::module_name_repetitions)]
2
3use alloc::vec::Vec;
4use embedded_io::{ErrorType, Read, Write};
5
6use core::net::SocketAddr;
7use psp::sys;
8
9use core::ffi::c_void;
10
11use crate::traits::io::{EasySocket, Open, OptionType};
12use crate::traits::SocketBuffer;
13use crate::types::{SocketOptions, SocketRecvFlags, SocketSendFlags};
14
15use super::super::netc;
16
17use super::error::SocketError;
18use super::sce::SocketFileDescriptor;
19use super::state::{Connected, SocketState, Unbound};
20use super::ToSockaddr;
21
22/// A TCP socket
23///
24/// # Safety
25/// This is a wrapper around a raw socket file descriptor.
26///
27/// The socket is closed when the struct is dropped.
28/// Closing via drop is best-effort.
29///
30/// # Notes
31/// The structure implements [`EasySocket`]. This allows you to interact with
32/// the socket using a simplified API. However, you are still free to use it
33/// like a normal Linux socket like you would do in C.
34///
35/// Using it as an easy socket allows you to use it in the following way:
36/// ```no_run
37/// use psp::net::TcpSocket;
38///
39/// let socket = TcpSocket::new().unwrap();
40/// let socket_options = SocketOptions{ remote: addr };
41/// let socket = socket.open(socket_options).unwrap();
42/// socket.write(b"hello world").unwrap();
43/// socket.flush().unwrap();
44/// // no need to call close, as drop will do it
45/// ```
46#[repr(C)]
47#[derive(Debug, Clone, PartialEq, Eq, Hash)]
48pub struct TcpSocket<S: SocketState = Unbound, B: SocketBuffer = Vec<u8>> {
49    /// The socket file descriptor
50    pub(super) fd: SocketFileDescriptor,
51    /// The buffer to store data to send
52    buffer: B,
53    /// flags for send calls
54    send_flags: SocketSendFlags,
55    /// flags for recv calls
56    recv_flags: SocketRecvFlags,
57    /// marker for the socket state
58    _marker: core::marker::PhantomData<S>,
59}
60
61impl TcpSocket {
62    /// Create a TCP socket
63    ///
64    /// # Returns
65    /// A new TCP socket
66    ///
67    /// # Errors
68    /// - [`SocketError::ErrnoWithDescription`] if the socket could not be created
69    pub fn new() -> Result<TcpSocket<Unbound>, SocketError> {
70        let fd = unsafe { sys::sceNetInetSocket(i32::from(netc::AF_INET), netc::SOCK_STREAM, 0) };
71        if fd < 0 {
72            Err(SocketError::new_errno_with_description(
73                unsafe { sys::sceNetInetGetErrno() },
74                "failed to create socket",
75            ))
76        } else {
77            let fd = SocketFileDescriptor::new(fd);
78            Ok(TcpSocket {
79                fd,
80                buffer: Vec::with_capacity(0),
81                send_flags: SocketSendFlags::empty(),
82                recv_flags: SocketRecvFlags::empty(),
83                _marker: core::marker::PhantomData,
84            })
85        }
86    }
87}
88
89impl<S: SocketState> TcpSocket<S> {
90    /// Return the underlying socket's file descriptor
91    #[must_use]
92    pub fn fd(&self) -> i32 {
93        *self.fd
94    }
95
96    /// Flags used when sending data
97    #[must_use]
98    pub fn send_flags(&self) -> SocketSendFlags {
99        self.send_flags
100    }
101
102    /// Set the flags used when sending data
103    pub fn set_send_flags(&mut self, send_flags: SocketSendFlags) {
104        self.send_flags = send_flags;
105    }
106
107    /// Flags used when receiving data
108    #[must_use]
109    pub fn recv_flags(&self) -> SocketRecvFlags {
110        self.recv_flags
111    }
112
113    /// Set the flags used when receiving data
114    pub fn set_recv_flags(&mut self, recv_flags: SocketRecvFlags) {
115        self.recv_flags = recv_flags;
116    }
117}
118
119impl TcpSocket<Unbound> {
120    #[must_use]
121    fn transition(self) -> TcpSocket<Connected> {
122        TcpSocket {
123            fd: self.fd,
124            buffer: Vec::default(),
125            send_flags: self.send_flags,
126            recv_flags: self.recv_flags,
127            _marker: core::marker::PhantomData,
128        }
129    }
130
131    /// Connect to a remote host
132    ///
133    /// # Parameters
134    /// - `remote`: The remote host to connect to
135    ///
136    /// # Returns
137    /// - `Ok(())` if the connection was successful
138    /// - `Err(String)` if the connection was unsuccessful.
139    ///
140    /// # Errors
141    /// - [`SocketError::UnsupportedAddressFamily`] if the address family is not supported (only IPv4 is supported)
142    /// - Any other [`SocketError`] if the connection was unsuccessful
143    pub fn connect(self, remote: SocketAddr) -> Result<TcpSocket<Connected>, SocketError> {
144        match remote {
145            SocketAddr::V4(v4) => {
146                let sockaddr = v4.to_sockaddr();
147
148                if unsafe {
149                    sys::sceNetInetConnect(
150                        *self.fd,
151                        &sockaddr,
152                        core::mem::size_of::<netc::sockaddr_in>() as u32,
153                    )
154                } < 0
155                {
156                    let errno = unsafe { sys::sceNetInetGetErrno() };
157                    Err(SocketError::Errno(errno))
158                } else {
159                    Ok(self.transition())
160                }
161            }
162            SocketAddr::V6(_) => Err(SocketError::UnsupportedAddressFamily),
163        }
164    }
165}
166
167impl TcpSocket<Connected> {
168    /// Read from the socket
169    ///
170    /// # Returns
171    /// - `Ok(usize)` if the read was successful. The number of bytes read
172    /// - `Err(SocketError)` if the read was unsuccessful.
173    ///
174    /// # Errors
175    /// - A [`SocketError`] if the read was unsuccessful
176    ///
177    /// # Notes
178    /// "Low level" read function. Read data from the socket and store it in
179    /// the buffer. This should not be used if you want to use this socket
180    /// [`EasySocket`] style.
181    pub fn internal_read(&self, buf: &mut [u8]) -> Result<usize, SocketError> {
182        let result = unsafe {
183            sys::sceNetInetRecv(
184                *self.fd,
185                buf.as_mut_ptr().cast::<c_void>(),
186                buf.len(),
187                self.recv_flags.as_i32(),
188            )
189        };
190        if result < 0 {
191            Err(SocketError::Errno(unsafe { sys::sceNetInetGetErrno() }))
192        } else {
193            Ok(result as usize)
194        }
195    }
196
197    /// Write to the socket
198    ///
199    /// # Errors
200    /// - A [`SocketError`] if the write was unsuccessful
201    pub fn internal_write(&mut self, buf: &[u8]) -> Result<usize, SocketError> {
202        self.buffer.append_buffer(buf);
203        self.send()
204    }
205
206    fn internal_flush(&mut self) -> Result<(), SocketError> {
207        while !self.buffer.is_empty() {
208            self.send()?;
209        }
210        Ok(())
211    }
212
213    fn send(&mut self) -> Result<usize, SocketError> {
214        let result = unsafe {
215            sys::sceNetInetSend(
216                *self.fd,
217                self.buffer.as_slice().as_ptr().cast::<c_void>(),
218                self.buffer.len(),
219                self.send_flags.as_i32(),
220            )
221        };
222        if result < 0 {
223            Err(SocketError::Errno(unsafe { sys::sceNetInetGetErrno() }))
224        } else {
225            self.buffer.shift_left_buffer(result as usize);
226            Ok(result as usize)
227        }
228    }
229}
230
231impl<S: SocketState> ErrorType for TcpSocket<S> {
232    type Error = SocketError;
233}
234
235impl<S: SocketState> OptionType for TcpSocket<S> {
236    type Options<'a> = SocketOptions;
237}
238
239impl Open<'_, '_> for TcpSocket<Unbound> {
240    type Return = TcpSocket<Connected>;
241    /// Return a TCP socket connected to the remote specified in `options`
242    fn open(self, options: &'_ Self::Options<'_>) -> Result<Self::Return, Self::Error>
243    where
244        Self: Sized,
245    {
246        let socket = self.connect(options.remote())?;
247        Ok(socket)
248    }
249}
250
251impl Read for TcpSocket<Connected> {
252    /// Read from the socket
253    ///
254    /// # Parameters
255    /// - `buf`: The buffer where the read data will be stored
256    ///
257    /// # Returns
258    /// - `Ok(usize)` if the read was successful. The number of bytes read
259    /// - `Err(SocketError)` if the read was unsuccessful.
260    ///
261    /// # Errors
262    /// - [`SocketError::NotConnected`] if the socket is not connected
263    /// - A [`SocketError`] if the read was unsuccessful
264    fn read<'m>(&'m mut self, buf: &'m mut [u8]) -> Result<usize, Self::Error> {
265        self.internal_read(buf)
266    }
267}
268
269impl Write for TcpSocket<Connected> {
270    /// Write to the socket
271    ///
272    /// # Errors
273    /// - [`SocketError::NotConnected`] if the socket is not connected
274    /// - A [`SocketError`] if the write was unsuccessful
275    fn write<'m>(&'m mut self, buf: &'m [u8]) -> Result<usize, Self::Error> {
276        self.internal_write(buf)
277    }
278
279    /// Flush the socket
280    ///
281    /// # Errors
282    /// - A [`SocketError`] if the flush was unsuccessful
283    fn flush(&mut self) -> Result<(), SocketError> {
284        self.internal_flush()
285    }
286}
287
288impl EasySocket for TcpSocket<Connected> {}