1use super::{Error, Instant, Result, RingBuffer, Socket, SocketHandle, SocketMeta};
2use embedded_nal::SocketAddr;
3use fugit::{ExtU32, SecsDurationU32};
4
5pub type SocketBuffer<const N: usize> = RingBuffer<u8, N>;
7
8#[derive(Debug, PartialEq, Eq)]
9pub enum State<const TIMER_HZ: u32> {
10 Created,
12 WaitingForConnect(SocketAddr),
13 Connected(SocketAddr),
15 ShutdownForWrite(Instant<TIMER_HZ>),
17}
18
19#[cfg(feature = "defmt")]
20impl<const TIMER_HZ: u32> defmt::Format for State<TIMER_HZ> {
21 fn format(&self, fmt: defmt::Formatter) {
22 match self {
23 State::Created => defmt::write!(fmt, "State::Created"),
24 State::WaitingForConnect(_) => defmt::write!(fmt, "State::WaitingForConnect"),
25 State::Connected(_) => defmt::write!(fmt, "State::Connected"),
26 State::ShutdownForWrite(_) => defmt::write!(fmt, "State::ShutdownForWrite"),
27 }
28 }
29}
30
31impl<const TIMER_HZ: u32> Default for State<TIMER_HZ> {
32 fn default() -> Self {
33 State::Created
34 }
35}
36
37#[derive(Debug)]
44pub struct TcpSocket<const TIMER_HZ: u32, const L: usize> {
45 pub(crate) meta: SocketMeta,
46 state: State<TIMER_HZ>,
47 check_interval: SecsDurationU32,
48 read_timeout: Option<SecsDurationU32>,
49 available_data: usize,
50 rx_buffer: SocketBuffer<L>,
51 last_check_time: Option<Instant<TIMER_HZ>>,
52}
53
54impl<const TIMER_HZ: u32, const L: usize> TcpSocket<TIMER_HZ, L> {
55 pub fn new(socket_id: u8) -> TcpSocket<TIMER_HZ, L> {
57 TcpSocket {
58 meta: SocketMeta {
59 handle: SocketHandle(socket_id),
60 },
61 state: State::default(),
62 rx_buffer: SocketBuffer::new(),
63 available_data: 0,
64 check_interval: 15.secs(),
65 read_timeout: Some(15.secs()),
66 last_check_time: None,
67 }
68 }
69
70 pub fn handle(&self) -> SocketHandle {
72 self.meta.handle
73 }
74
75 pub fn update_handle(&mut self, handle: SocketHandle) {
76 debug!(
77 "[TCP Socket] [{:?}] Updating handle {:?}",
78 self.handle(),
79 handle
80 );
81 self.meta.update(handle)
82 }
83
84 pub fn endpoint(&self) -> Option<SocketAddr> {
86 match self.state {
87 State::Connected(s) | State::WaitingForConnect(s) => Some(s),
88 _ => None,
89 }
90 }
91
92 pub fn state(&self) -> &State<TIMER_HZ> {
94 &self.state
95 }
96
97 pub fn reset(&mut self) {
98 self.set_state(State::default());
99 self.rx_buffer.clear();
100 self.set_available_data(0);
101 self.last_check_time = None;
102 }
103
104 pub fn should_update_available_data(&mut self, ts: Instant<TIMER_HZ>) -> bool {
105 if !self.is_connected() {
108 return false;
109 }
110
111 let should_update = self
112 .last_check_time
113 .and_then(|last_check_time| ts.checked_duration_since(last_check_time))
114 .map(|dur| dur >= self.check_interval)
115 .unwrap_or(true);
116
117 if should_update {
118 self.last_check_time.replace(ts);
119 }
120
121 should_update
122 }
123
124 pub fn recycle(&self, ts: Instant<TIMER_HZ>) -> bool {
125 if let Some(read_timeout) = self.read_timeout {
126 match self.state {
127 State::Created | State::WaitingForConnect(_) | State::Connected(_) => false,
128 State::ShutdownForWrite(closed_time) => ts
129 .checked_duration_since(closed_time)
130 .map(|dur| dur >= read_timeout)
131 .unwrap_or(false),
132 }
133 } else {
134 false
135 }
136 }
137
138 pub fn closed_by_remote(&mut self, ts: Instant<TIMER_HZ>) {
139 self.set_state(State::ShutdownForWrite(ts));
140 self.set_available_data(0);
141 }
142
143 pub fn set_available_data(&mut self, available_data: usize) {
145 self.available_data = available_data;
146 }
147
148 pub fn get_available_data(&self) -> usize {
150 self.available_data
151 }
152
153 pub fn is_connected(&self) -> bool {
160 matches!(self.state, State::Connected(_))
162 }
163
164 pub fn may_recv(&self) -> bool {
173 match self.state {
174 State::Connected(_) | State::ShutdownForWrite(_) => true,
175 _ if !self.rx_buffer.is_empty() => true,
177 _ => false,
178 }
179 }
180
181 pub fn can_recv(&self) -> bool {
184 if !self.may_recv() {
185 return false;
186 }
187
188 !self.rx_buffer.is_full()
189 }
190
191 fn recv_impl<'b, F, R>(&'b mut self, f: F) -> Result<R>
192 where
193 F: FnOnce(&'b mut SocketBuffer<L>) -> (usize, R),
194 {
195 if !self.may_recv() {
199 return Err(Error::Illegal);
200 }
201
202 let (_size, result) = f(&mut self.rx_buffer);
203 Ok(result)
204 }
205
206 pub fn recv<'b, F, R>(&'b mut self, f: F) -> Result<R>
212 where
213 F: FnOnce(&'b mut [u8]) -> (usize, R),
214 {
215 self.recv_impl(|rx_buffer| rx_buffer.dequeue_many_with(f))
216 }
217
218 pub fn recv_wrapping<'b, F>(&'b mut self, f: F) -> Result<usize>
228 where
229 F: FnOnce(&'b [u8], Option<&'b [u8]>) -> usize,
230 {
231 self.recv_impl(|rx_buffer| {
232 rx_buffer.dequeue_many_with_wrapping(|a, b| {
233 let len = f(a, b);
234 (len, len)
235 })
236 })
237 }
238
239 pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize> {
246 self.recv_impl(|rx_buffer| {
247 let size = rx_buffer.dequeue_slice(data);
248 (size, size)
249 })
250 }
251
252 pub fn peek(&mut self, size: usize) -> Result<&[u8]> {
257 if !self.may_recv() {
259 return Err(Error::Illegal);
260 }
261
262 Ok(self.rx_buffer.get_allocated(0, size))
263 }
264
265 pub fn rx_window(&self) -> usize {
266 self.rx_buffer.window()
267 }
268
269 pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<usize> {
274 let buffer = self.peek(data.len())?;
275 let data = &mut data[..buffer.len()];
276 data.copy_from_slice(buffer);
277 Ok(buffer.len())
278 }
279
280 pub fn rx_enqueue_slice(&mut self, data: &[u8]) -> usize {
281 self.rx_buffer.enqueue_slice(data)
282 }
283
284 pub fn recv_queue(&self) -> usize {
288 self.rx_buffer.len()
289 }
290
291 pub fn set_state(&mut self, state: State<TIMER_HZ>) {
292 debug!(
293 "[TCP Socket] [{:?}] state change: {:?} -> {:?}",
294 self.handle(),
295 self.state,
296 state
297 );
298 self.state = state
299 }
300}
301
302impl<const TIMER_HZ: u32, const L: usize> Into<Socket<TIMER_HZ, L>> for TcpSocket<TIMER_HZ, L> {
303 fn into(self) -> Socket<TIMER_HZ, L> {
304 Socket::Tcp(self)
305 }
306}