1#![no_std]
17
18use core::{
19 convert::TryFrom,
20 net::{IpAddr, Ipv4Addr, SocketAddr},
21};
22pub use embedded_nal;
23use nanorand::{Rng, SeedableRng};
24pub use smoltcp;
25
26use embedded_nal::{TcpClientStack, UdpClientStack, UdpFullStack};
27use embedded_time::duration::Milliseconds;
28use smoltcp::{
29 iface::SocketHandle,
30 socket::dhcpv4,
31 wire::{IpAddress, IpCidr, IpEndpoint, Ipv4Address, Ipv4Cidr},
32};
33
34use heapless::Vec;
35use nanorand::wyrand::WyRand;
36
37#[cfg(feature = "shared-stack")]
38pub mod shared;
39
40const TCP_PORT_DYNAMIC_RANGE_START: u16 = 49152;
42
43#[derive(Debug, Copy, Clone)]
44pub enum SmoltcpError {
45 RouteTableFull,
46}
47
48#[derive(Debug, Copy, Clone)]
49pub enum NetworkError {
50 NoSocket,
51 DnsStart(smoltcp::socket::dns::StartQueryError),
52 DnsFailure,
53 UdpConnectionFailure(smoltcp::socket::udp::BindError),
54 TcpConnectionFailure(smoltcp::socket::tcp::ConnectError),
55 TcpReadFailure(smoltcp::socket::tcp::RecvError),
56 TcpWriteFailure(smoltcp::socket::tcp::SendError),
57 UdpReadFailure(smoltcp::socket::udp::RecvError),
58 UdpWriteFailure(smoltcp::socket::udp::SendError),
59 Unsupported,
60 NotConnected,
61 NoAddress,
62}
63
64impl embedded_nal::TcpError for NetworkError {
65 fn kind(&self) -> embedded_nal::TcpErrorKind {
66 match self {
67 NetworkError::TcpReadFailure(_) => embedded_nal::TcpErrorKind::PipeClosed,
68 NetworkError::TcpWriteFailure(_) => embedded_nal::TcpErrorKind::PipeClosed,
69 _ => embedded_nal::TcpErrorKind::Other,
70 }
71 }
72}
73
74impl From<smoltcp::iface::RouteTableFull> for SmoltcpError {
75 fn from(_: smoltcp::iface::RouteTableFull) -> SmoltcpError {
76 SmoltcpError::RouteTableFull
77 }
78}
79
80#[derive(Debug)]
82pub enum Error {
83 Network(SmoltcpError),
84 Time(embedded_time::TimeError),
85}
86
87impl From<embedded_time::TimeError> for Error {
88 fn from(e: embedded_time::TimeError) -> Self {
89 Error::Time(e)
90 }
91}
92
93impl From<embedded_time::clock::Error> for Error {
94 fn from(e: embedded_time::clock::Error) -> Self {
95 Error::Time(e.into())
96 }
97}
98
99impl From<embedded_time::ConversionError> for Error {
100 fn from(e: embedded_time::ConversionError) -> Self {
101 Error::Time(e.into())
102 }
103}
104
105#[derive(Debug)]
106pub struct UdpSocket {
107 handle: SocketHandle,
108 destination: Option<IpEndpoint>,
109}
110
111pub struct NetworkStack<'a, Device, Clock>
113where
114 Device: smoltcp::phy::Device,
115 Clock: embedded_time::Clock,
116 u32: From<Clock::T>,
117{
118 network_interface: smoltcp::iface::Interface,
119 device: Device,
120 sockets: smoltcp::iface::SocketSet<'a>,
121 dhcp_handle: Option<SocketHandle>,
122 dns_handle: Option<SocketHandle>,
123 dns_lookups: heapless::LinearMap<heapless::String<255>, smoltcp::socket::dns::QueryHandle, 2>,
124 unused_tcp_handles: Vec<SocketHandle, 16>,
125 unused_udp_handles: Vec<SocketHandle, 16>,
126 clock: Clock,
127 last_poll: Option<embedded_time::Instant<Clock>>,
128 stack_time: smoltcp::time::Instant,
129 rand: WyRand,
130}
131
132impl<'a, Device, Clock> NetworkStack<'a, Device, Clock>
133where
134 Device: smoltcp::phy::Device,
135 Clock: embedded_time::Clock,
136 u32: From<Clock::T>,
137{
138 pub fn new(
155 stack: smoltcp::iface::Interface,
156 device: Device,
157 sockets: smoltcp::iface::SocketSet<'a>,
158 clock: Clock,
159 ) -> Self {
160 let mut unused_tcp_handles: Vec<SocketHandle, 16> = Vec::new();
161 let mut unused_udp_handles: Vec<SocketHandle, 16> = Vec::new();
162 let mut dhcp_handle: Option<SocketHandle> = None;
163 let mut dns_handle: Option<SocketHandle> = None;
164
165 for (handle, socket) in sockets.iter() {
166 match socket {
167 smoltcp::socket::Socket::Tcp(_) => {
168 unused_tcp_handles.push(handle).ok();
169 }
170 smoltcp::socket::Socket::Udp(_) => {
171 unused_udp_handles.push(handle).ok();
172 }
173 smoltcp::socket::Socket::Dhcpv4(_) => {
174 dhcp_handle.replace(handle);
175 }
176 smoltcp::socket::Socket::Dns(_) => {
177 dns_handle.replace(handle);
178 }
179
180 #[allow(unreachable_patterns)]
184 _ => {}
185 }
186 }
187
188 NetworkStack {
189 network_interface: stack,
190 sockets,
191 device,
192 dhcp_handle,
193 dns_handle,
194 unused_tcp_handles,
195 unused_udp_handles,
196 last_poll: None,
197 dns_lookups: heapless::LinearMap::new(),
198 clock,
199 stack_time: smoltcp::time::Instant::from_secs(0),
200 rand: WyRand::new_seed(0),
201 }
202 }
203
204 pub fn seed_random_port(&mut self, seed: &[u8]) {
209 let mut s = [0; 8];
210 let n = seed.len().min(s.len());
211 s[..n].copy_from_slice(&seed[..n]);
212 self.rand.reseed(s);
213 }
214
215 pub fn poll(&mut self) -> Result<bool, Error> {
220 let now = self.clock.try_now()?;
221
222 if self.last_poll.is_none() {
226 self.last_poll.replace(now);
227 }
228
229 let elapsed_system_time = now - *self.last_poll.as_ref().unwrap();
231
232 let elapsed_ms: Milliseconds<u32> = Milliseconds::try_from(elapsed_system_time)?;
233
234 if elapsed_ms.0 > 0 {
235 self.stack_time += smoltcp::time::Duration::from_millis(elapsed_ms.0.into());
236
237 self.last_poll.replace(self.last_poll.unwrap() + elapsed_ms);
245 }
246
247 let updated =
248 self.network_interface
249 .poll(self.stack_time, &mut self.device, &mut self.sockets);
250
251 if let Some(handle) = self.dhcp_handle {
253 let mut close_sockets = false;
254 let mut dns_server = None;
255
256 if let Some(event) = self.sockets.get_mut::<dhcpv4::Socket>(handle).poll() {
257 match event {
258 dhcpv4::Event::Configured(config) => {
259 if config.address.address().is_unicast()
260 && self.network_interface.ipv4_addr().unwrap()
261 != config.address.address()
262 {
263 close_sockets = true;
264 Self::set_ipv4_addr(&mut self.network_interface, config.address);
265 }
266
267 if let Some(server) = config
268 .dns_servers
269 .iter()
270 .next()
271 .map(|ipv4| smoltcp::wire::IpAddress::Ipv4(*ipv4))
272 {
273 dns_server.replace(server);
274 }
275
276 if let Some(route) = config.router {
277 self.network_interface
280 .routes_mut()
281 .add_default_ipv4_route(route)
282 .map_err(|e| Error::Network(e.into()))?;
283 } else {
284 self.network_interface
285 .routes_mut()
286 .remove_default_ipv4_route();
287 }
288 }
289 dhcpv4::Event::Deconfigured => {
290 self.network_interface
291 .routes_mut()
292 .remove_default_ipv4_route();
293 Self::set_ipv4_addr(
294 &mut self.network_interface,
295 Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0),
296 );
297 }
298 }
299 }
300
301 if close_sockets {
302 self.close_sockets();
303 }
304
305 if let Some((server, handle)) = dns_server.zip(self.dns_handle) {
306 let dns = self.sockets.get_mut::<smoltcp::socket::dns::Socket>(handle);
307
308 for (_query, handle) in self.dns_lookups.iter() {
310 dns.cancel_query(*handle);
311 }
312 self.dns_lookups.clear();
313
314 dns.update_servers(&[server]);
315 }
316 }
317
318 Ok(updated)
319 }
320
321 pub fn close_sockets(&mut self) {
323 for (_handle, socket) in self.sockets.iter_mut() {
325 match socket {
326 smoltcp::socket::Socket::Udp(sock) => {
327 sock.close();
328 }
329 smoltcp::socket::Socket::Tcp(sock) => {
330 sock.abort();
331 }
332
333 _ => {}
334 }
335 }
336 }
337
338 fn set_ipv4_addr(interface: &mut smoltcp::iface::Interface, address: Ipv4Cidr) {
339 interface.update_ip_addrs(|addrs| {
340 match addrs
342 .iter_mut()
343 .find(|cidr| matches!(cidr.address(), IpAddress::Ipv4(_)))
344 {
345 Some(addr) => *addr = IpCidr::Ipv4(address),
346 None => addrs.push(IpCidr::Ipv4(address)).unwrap(),
347 }
348 });
349 }
350
351 pub fn handle_link_reset(&mut self) {
353 self.close_sockets();
355
356 if let Some(handle) = self.dhcp_handle {
358 self.sockets.get_mut::<dhcpv4::Socket>(handle).reset();
359
360 self.network_interface.update_ip_addrs(|addrs| {
361 if let Some(addr) = addrs.iter_mut().next() {
362 *addr = IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0));
363 };
364 });
365 }
366 }
367
368 pub fn interface(&self) -> &smoltcp::iface::Interface {
370 &self.network_interface
371 }
372
373 pub fn interface_mut(&mut self) -> &mut smoltcp::iface::Interface {
380 &mut self.network_interface
381 }
382
383 fn is_port_in_use(&mut self, port: u16) -> bool {
388 for (_handle, socket) in self.sockets.iter_mut() {
389 match socket {
390 smoltcp::socket::Socket::Tcp(sock) => {
391 if sock
392 .local_endpoint()
393 .map(|endpoint| endpoint.port == port)
394 .unwrap_or(false)
395 {
396 return true;
397 }
398 }
399 smoltcp::socket::Socket::Udp(sock) => {
400 let endpoint = sock.endpoint();
401 if endpoint.is_specified() && endpoint.port == port {
402 return true;
403 }
404 }
405 _ => {}
406 }
407 }
408
409 false
410 }
411
412 fn get_ephemeral_port(&mut self) -> u16 {
414 loop {
415 let random_offset = {
418 let random_data = self.rand.rand();
419 u16::from_be_bytes([random_data[0], random_data[1]])
420 };
421
422 let port = TCP_PORT_DYNAMIC_RANGE_START
423 + random_offset % (u16::MAX - TCP_PORT_DYNAMIC_RANGE_START);
424 if !self.is_port_in_use(port) {
425 return port;
426 }
427 }
428 }
429}
430
431impl<Device, Clock> TcpClientStack for NetworkStack<'_, Device, Clock>
432where
433 Device: smoltcp::phy::Device,
434 Clock: embedded_time::Clock,
435 u32: From<Clock::T>,
436{
437 type Error = NetworkError;
438 type TcpSocket = SocketHandle;
439
440 fn socket(&mut self) -> Result<SocketHandle, NetworkError> {
441 match self.unused_tcp_handles.pop() {
442 Some(handle) => {
443 let internal_socket: &mut smoltcp::socket::tcp::Socket =
445 self.sockets.get_mut(handle);
446 internal_socket.abort();
447
448 Ok(handle)
449 }
450 None => Err(NetworkError::NoSocket),
451 }
452 }
453
454 fn connect(
455 &mut self,
456 socket: &mut SocketHandle,
457 remote: SocketAddr,
458 ) -> embedded_nal::nb::Result<(), NetworkError> {
459 let dest_addr = match remote.ip() {
460 IpAddr::V4(addr) => {
461 let octets = addr.octets();
462 smoltcp::wire::Ipv4Address::new(octets[0], octets[1], octets[2], octets[3])
463 }
464
465 _ => return Err(embedded_nal::nb::Error::Other(NetworkError::Unsupported)),
467 };
468
469 let local_port = self.get_ephemeral_port();
470 let internal_socket = self
471 .sockets
472 .get_mut::<smoltcp::socket::tcp::Socket>(*socket);
473
474 if !internal_socket.is_open() {
475 let context = self.network_interface.context();
476 internal_socket
477 .connect(context, (dest_addr, remote.port()), local_port)
478 .map_err(|e| {
479 embedded_nal::nb::Error::Other(NetworkError::TcpConnectionFailure(e))
480 })?;
481 }
482
483 if internal_socket.state() == smoltcp::socket::tcp::State::Established {
484 Ok(())
485 } else {
486 Err(embedded_nal::nb::Error::WouldBlock)
487 }
488 }
489
490 fn send(
491 &mut self,
492 socket: &mut SocketHandle,
493 buffer: &[u8],
494 ) -> embedded_nal::nb::Result<usize, NetworkError> {
495 let socket: &mut smoltcp::socket::tcp::Socket = self.sockets.get_mut(*socket);
496 socket
497 .send_slice(buffer)
498 .map_err(|e| embedded_nal::nb::Error::Other(NetworkError::TcpWriteFailure(e)))
499 }
500
501 fn receive(
502 &mut self,
503 socket: &mut SocketHandle,
504 buffer: &mut [u8],
505 ) -> embedded_nal::nb::Result<usize, NetworkError> {
506 let socket: &mut smoltcp::socket::tcp::Socket = self.sockets.get_mut(*socket);
507 socket
508 .recv_slice(buffer)
509 .map_err(|e| embedded_nal::nb::Error::Other(NetworkError::TcpReadFailure(e)))
510 }
511
512 fn close(&mut self, socket: SocketHandle) -> Result<(), NetworkError> {
513 let internal_socket: &mut smoltcp::socket::tcp::Socket = self.sockets.get_mut(socket);
514
515 internal_socket.close();
516 self.unused_tcp_handles.push(socket).unwrap();
517 Ok(())
518 }
519}
520
521impl<Device, Clock> UdpClientStack for NetworkStack<'_, Device, Clock>
522where
523 Device: smoltcp::phy::Device,
524 Clock: embedded_time::Clock,
525 u32: From<Clock::T>,
526{
527 type Error = NetworkError;
528 type UdpSocket = UdpSocket;
529
530 fn socket(&mut self) -> Result<UdpSocket, NetworkError> {
531 let handle = self
532 .unused_udp_handles
533 .pop()
534 .ok_or(NetworkError::NoSocket)?;
535
536 let internal_socket: &mut smoltcp::socket::udp::Socket = self.sockets.get_mut(handle);
538 internal_socket.close();
539
540 Ok(UdpSocket {
541 handle,
542 destination: None,
543 })
544 }
545
546 fn connect(&mut self, socket: &mut UdpSocket, remote: SocketAddr) -> Result<(), NetworkError> {
547 match remote {
549 SocketAddr::V4(addr) => {
550 let octets = addr.ip().octets();
551 socket.destination.replace(IpEndpoint::new(
552 IpAddress::v4(octets[0], octets[1], octets[2], octets[3]),
553 addr.port(),
554 ));
555 }
556
557 _ => return Err(NetworkError::Unsupported),
559 }
560
561 let local_port = self.get_ephemeral_port();
563
564 let Some(cidr) = self
565 .network_interface
566 .ip_addrs()
567 .iter()
568 .find(|item| matches!(item, smoltcp::wire::IpCidr::Ipv4(_)))
569 else {
570 return Err(NetworkError::NoAddress);
571 };
572
573 let local_endpoint = IpEndpoint::new(cidr.address(), local_port);
574
575 let internal_socket: &mut smoltcp::socket::udp::Socket =
576 self.sockets.get_mut(socket.handle);
577 internal_socket
578 .bind(local_endpoint)
579 .map_err(NetworkError::UdpConnectionFailure)?;
580
581 Ok(())
582 }
583
584 fn send(
585 &mut self,
586 socket: &mut UdpSocket,
587 buffer: &[u8],
588 ) -> embedded_nal::nb::Result<(), NetworkError> {
589 let internal_socket: &mut smoltcp::socket::udp::Socket =
590 self.sockets.get_mut(socket.handle);
591 let destination = socket.destination.ok_or(NetworkError::NotConnected)?;
592 internal_socket
593 .send_slice(buffer, destination)
594 .map_err(|e| embedded_nal::nb::Error::Other(NetworkError::UdpWriteFailure(e)))
595 }
596
597 fn receive(
598 &mut self,
599 socket: &mut UdpSocket,
600 buffer: &mut [u8],
601 ) -> embedded_nal::nb::Result<(usize, SocketAddr), NetworkError> {
602 let internal_socket: &mut smoltcp::socket::udp::Socket =
603 self.sockets.get_mut(socket.handle);
604 let (size, source) = internal_socket
605 .recv_slice(buffer)
606 .map_err(|e| embedded_nal::nb::Error::Other(NetworkError::UdpReadFailure(e)))?;
607
608 let source = {
609 let octets = source.endpoint.addr.as_bytes();
610
611 SocketAddr::new(
612 IpAddr::V4(Ipv4Addr::new(octets[0], octets[1], octets[2], octets[3])),
613 source.endpoint.port,
614 )
615 };
616
617 Ok((size, source))
618 }
619
620 fn close(&mut self, socket: UdpSocket) -> Result<(), NetworkError> {
621 let internal_socket: &mut smoltcp::socket::udp::Socket =
622 self.sockets.get_mut(socket.handle);
623
624 internal_socket.close();
625
626 self.unused_udp_handles.push(socket.handle).unwrap();
628
629 Ok(())
630 }
631}
632
633impl<Device, Clock> UdpFullStack for NetworkStack<'_, Device, Clock>
634where
635 Device: smoltcp::phy::Device,
636 Clock: embedded_time::Clock,
637 u32: From<Clock::T>,
638{
639 fn bind(&mut self, socket: &mut UdpSocket, local_port: u16) -> Result<(), NetworkError> {
641 let Some(cidr) = self
642 .network_interface
643 .ip_addrs()
644 .iter()
645 .find(|item| matches!(item, smoltcp::wire::IpCidr::Ipv4(_)))
646 else {
647 return Err(NetworkError::NoAddress);
648 };
649
650 let local_endpoint = IpEndpoint::new(cidr.address(), local_port);
651
652 let internal_socket: &mut smoltcp::socket::udp::Socket =
653 self.sockets.get_mut(socket.handle);
654 internal_socket
655 .bind(local_endpoint)
656 .map_err(NetworkError::UdpConnectionFailure)?;
657
658 Ok(())
659 }
660
661 fn send_to(
663 &mut self,
664 socket: &mut Self::UdpSocket,
665 remote: SocketAddr,
666 buffer: &[u8],
667 ) -> embedded_nal::nb::Result<(), NetworkError> {
668 let destination = match remote {
669 SocketAddr::V4(addr) => {
670 let octets = addr.ip().octets();
671 IpEndpoint::new(
672 IpAddress::v4(octets[0], octets[1], octets[2], octets[3]),
673 addr.port(),
674 )
675 }
676 _ => return Err(embedded_nal::nb::Error::Other(NetworkError::Unsupported)),
678 };
679
680 let internal_socket: &mut smoltcp::socket::udp::Socket =
681 self.sockets.get_mut(socket.handle);
682 internal_socket
683 .send_slice(buffer, destination)
684 .map_err(|e| embedded_nal::nb::Error::Other(NetworkError::UdpWriteFailure(e)))
685 }
686}
687
688impl<Device, Clock> embedded_nal::Dns for NetworkStack<'_, Device, Clock>
689where
690 Device: smoltcp::phy::Device,
691 Clock: embedded_time::Clock,
692 u32: From<Clock::T>,
693{
694 type Error = NetworkError;
695 fn get_host_by_name(
696 &mut self,
697 hostname: &str,
698 _addr_type: embedded_nal::AddrType,
699 ) -> embedded_nal::nb::Result<IpAddr, Self::Error> {
700 let handle = self.dns_handle.ok_or(NetworkError::Unsupported)?;
701 let dns_socket: &mut smoltcp::socket::dns::Socket = self.sockets.get_mut(handle);
702 let context = self.network_interface.context();
703 let key = heapless::String::from(hostname);
704
705 if let Some(handle) = self.dns_lookups.get(&key) {
706 match dns_socket.get_query_result(*handle) {
707 Ok(addrs) => {
708 self.dns_lookups.remove(&key);
709 let addr = addrs.iter().next().ok_or(NetworkError::DnsFailure)?;
710 let smoltcp::wire::IpAddress::Ipv4(addr) = addr else {
711 panic!("Unexpected address return type");
712 };
713 return Ok(IpAddr::V4(addr.0.into()));
714 }
715 Err(smoltcp::socket::dns::GetQueryResultError::Pending) => {}
716 Err(smoltcp::socket::dns::GetQueryResultError::Failed) => {
717 self.dns_lookups.remove(&key);
718 return Err(embedded_nal::nb::Error::Other(NetworkError::DnsFailure));
719 }
720 }
721 } else {
722 let dns_query = dns_socket
724 .start_query(context, hostname, smoltcp::wire::DnsQueryType::A)
725 .map_err(NetworkError::DnsStart)?;
726 if self.dns_lookups.insert(key, dns_query).is_err() {
727 dns_socket.cancel_query(dns_query);
728 return Err(embedded_nal::nb::Error::Other(NetworkError::Unsupported));
729 }
730 }
731
732 Err(embedded_nal::nb::Error::WouldBlock)
733 }
734
735 fn get_host_by_address(
736 &mut self,
737 _addr: IpAddr,
738 _: &mut [u8],
739 ) -> Result<usize, embedded_nal::nb::Error<Self::Error>> {
740 unimplemented!()
741 }
742}