embassy_net_esp_hosted/
lib.rs

1#![no_std]
2#![doc = include_str!("../README.md")]
3#![warn(missing_docs)]
4
5use embassy_futures::select::{select4, Either4};
6use embassy_net_driver_channel as ch;
7use embassy_net_driver_channel::driver::LinkState;
8use embassy_time::{Duration, Instant, Timer};
9use embedded_hal::digital::{InputPin, OutputPin};
10use embedded_hal_async::digital::Wait;
11use embedded_hal_async::spi::SpiDevice;
12
13use crate::ioctl::{PendingIoctl, Shared};
14use crate::proto::{CtrlMsg, CtrlMsgPayload};
15
16mod proto;
17
18// must be first
19mod fmt;
20
21mod control;
22mod ioctl;
23
24pub use control::*;
25
26const MTU: usize = 1514;
27
28macro_rules! impl_bytes {
29    ($t:ident) => {
30        impl $t {
31            pub const SIZE: usize = core::mem::size_of::<Self>();
32
33            #[allow(unused)]
34            pub fn to_bytes(&self) -> [u8; Self::SIZE] {
35                unsafe { core::mem::transmute(*self) }
36            }
37
38            #[allow(unused)]
39            pub fn from_bytes(bytes: &[u8; Self::SIZE]) -> &Self {
40                let alignment = core::mem::align_of::<Self>();
41                assert_eq!(
42                    bytes.as_ptr().align_offset(alignment),
43                    0,
44                    "{} is not aligned",
45                    core::any::type_name::<Self>()
46                );
47                unsafe { core::mem::transmute(bytes) }
48            }
49
50            #[allow(unused)]
51            pub fn from_bytes_mut(bytes: &mut [u8; Self::SIZE]) -> &mut Self {
52                let alignment = core::mem::align_of::<Self>();
53                assert_eq!(
54                    bytes.as_ptr().align_offset(alignment),
55                    0,
56                    "{} is not aligned",
57                    core::any::type_name::<Self>()
58                );
59
60                unsafe { core::mem::transmute(bytes) }
61            }
62        }
63    };
64}
65
66#[repr(C, packed)]
67#[derive(Clone, Copy, Debug, Default)]
68struct PayloadHeader {
69    /// InterfaceType on lower 4 bits, number on higher 4 bits.
70    if_type_and_num: u8,
71
72    /// Flags.
73    ///
74    /// bit 0: more fragments.
75    flags: u8,
76
77    len: u16,
78    offset: u16,
79    checksum: u16,
80    seq_num: u16,
81    reserved2: u8,
82
83    /// Packet type for HCI or PRIV interface, reserved otherwise
84    hci_priv_packet_type: u8,
85}
86impl_bytes!(PayloadHeader);
87
88#[allow(unused)]
89#[repr(u8)]
90enum InterfaceType {
91    Sta = 0,
92    Ap = 1,
93    Serial = 2,
94    Hci = 3,
95    Priv = 4,
96    Test = 5,
97}
98
99const MAX_SPI_BUFFER_SIZE: usize = 1600;
100const HEARTBEAT_MAX_GAP: Duration = Duration::from_secs(20);
101
102/// State for the esp-hosted driver.
103pub struct State {
104    shared: Shared,
105    ch: ch::State<MTU, 4, 4>,
106}
107
108impl State {
109    /// Create a new state.
110    pub fn new() -> Self {
111        Self {
112            shared: Shared::new(),
113            ch: ch::State::new(),
114        }
115    }
116}
117
118/// Type alias for network driver.
119pub type NetDriver<'a> = ch::Device<'a, MTU>;
120
121/// Create a new esp-hosted driver using the provided state, SPI peripheral and pins.
122///
123/// Returns a device handle for interfacing with embassy-net, a control handle for
124/// interacting with the driver, and a runner for communicating with the WiFi device.
125pub async fn new<'a, SPI, IN, OUT>(
126    state: &'a mut State,
127    spi: SPI,
128    handshake: IN,
129    ready: IN,
130    reset: OUT,
131) -> (NetDriver<'a>, Control<'a>, Runner<'a, SPI, IN, OUT>)
132where
133    SPI: SpiDevice,
134    IN: InputPin + Wait,
135    OUT: OutputPin,
136{
137    let (ch_runner, device) = ch::new(&mut state.ch, ch::driver::HardwareAddress::Ethernet([0; 6]));
138    let state_ch = ch_runner.state_runner();
139
140    let runner = Runner {
141        ch: ch_runner,
142        state_ch,
143        shared: &state.shared,
144        next_seq: 1,
145        handshake,
146        ready,
147        reset,
148        spi,
149        heartbeat_deadline: Instant::now() + HEARTBEAT_MAX_GAP,
150    };
151
152    (device, Control::new(state_ch, &state.shared), runner)
153}
154
155/// Runner for communicating with the WiFi device.
156pub struct Runner<'a, SPI, IN, OUT> {
157    ch: ch::Runner<'a, MTU>,
158    state_ch: ch::StateRunner<'a>,
159    shared: &'a Shared,
160
161    next_seq: u16,
162    heartbeat_deadline: Instant,
163
164    spi: SPI,
165    handshake: IN,
166    ready: IN,
167    reset: OUT,
168}
169
170impl<'a, SPI, IN, OUT> Runner<'a, SPI, IN, OUT>
171where
172    SPI: SpiDevice,
173    IN: InputPin + Wait,
174    OUT: OutputPin,
175{
176    /// Run the packet processing.
177    pub async fn run(mut self) -> ! {
178        debug!("resetting...");
179        self.reset.set_low().unwrap();
180        Timer::after_millis(100).await;
181        self.reset.set_high().unwrap();
182        Timer::after_millis(1000).await;
183
184        let mut tx_buf = [0u8; MAX_SPI_BUFFER_SIZE];
185        let mut rx_buf = [0u8; MAX_SPI_BUFFER_SIZE];
186
187        loop {
188            self.handshake.wait_for_high().await.unwrap();
189
190            let ioctl = self.shared.ioctl_wait_pending();
191            let tx = self.ch.tx_buf();
192            let ev = async { self.ready.wait_for_high().await.unwrap() };
193            let hb = Timer::at(self.heartbeat_deadline);
194
195            match select4(ioctl, tx, ev, hb).await {
196                Either4::First(PendingIoctl { buf, req_len }) => {
197                    tx_buf[12..24].copy_from_slice(b"\x01\x08\x00ctrlResp\x02");
198                    tx_buf[24..26].copy_from_slice(&(req_len as u16).to_le_bytes());
199                    tx_buf[26..][..req_len].copy_from_slice(&unsafe { &*buf }[..req_len]);
200
201                    let mut header = PayloadHeader {
202                        if_type_and_num: InterfaceType::Serial as _,
203                        len: (req_len + 14) as _,
204                        offset: PayloadHeader::SIZE as _,
205                        seq_num: self.next_seq,
206                        ..Default::default()
207                    };
208                    self.next_seq = self.next_seq.wrapping_add(1);
209
210                    // Calculate checksum
211                    tx_buf[0..12].copy_from_slice(&header.to_bytes());
212                    header.checksum = checksum(&tx_buf[..26 + req_len]);
213                    tx_buf[0..12].copy_from_slice(&header.to_bytes());
214                }
215                Either4::Second(packet) => {
216                    tx_buf[12..][..packet.len()].copy_from_slice(packet);
217
218                    let mut header = PayloadHeader {
219                        if_type_and_num: InterfaceType::Sta as _,
220                        len: packet.len() as _,
221                        offset: PayloadHeader::SIZE as _,
222                        seq_num: self.next_seq,
223                        ..Default::default()
224                    };
225                    self.next_seq = self.next_seq.wrapping_add(1);
226
227                    // Calculate checksum
228                    tx_buf[0..12].copy_from_slice(&header.to_bytes());
229                    header.checksum = checksum(&tx_buf[..12 + packet.len()]);
230                    tx_buf[0..12].copy_from_slice(&header.to_bytes());
231
232                    self.ch.tx_done();
233                }
234                Either4::Third(()) => {
235                    tx_buf[..PayloadHeader::SIZE].fill(0);
236                }
237                Either4::Fourth(()) => {
238                    panic!("heartbeat from esp32 stopped")
239                }
240            }
241
242            if tx_buf[0] != 0 {
243                trace!("tx: {:02x}", &tx_buf[..40]);
244            }
245
246            self.spi.transfer(&mut rx_buf, &tx_buf).await.unwrap();
247
248            // The esp-hosted firmware deasserts the HANSHAKE pin a few us AFTER ending the SPI transfer
249            // If we check it again too fast, we'll see it's high from the previous transfer, and if we send it
250            // data it will get lost.
251            // Make sure we check it after 100us at minimum.
252            let delay_until = Instant::now() + Duration::from_micros(100);
253            self.handle_rx(&mut rx_buf);
254            Timer::at(delay_until).await;
255        }
256    }
257
258    fn handle_rx(&mut self, buf: &mut [u8]) {
259        trace!("rx: {:02x}", &buf[..40]);
260
261        let buf_len = buf.len();
262        let h = PayloadHeader::from_bytes_mut((&mut buf[..PayloadHeader::SIZE]).try_into().unwrap());
263
264        if h.len == 0 || h.offset as usize != PayloadHeader::SIZE {
265            return;
266        }
267
268        let payload_len = h.len as usize;
269        if buf_len < PayloadHeader::SIZE + payload_len {
270            warn!("rx: len too big");
271            return;
272        }
273
274        let if_type_and_num = h.if_type_and_num;
275        let want_checksum = h.checksum;
276        h.checksum = 0;
277        let got_checksum = checksum(&buf[..PayloadHeader::SIZE + payload_len]);
278        if want_checksum != got_checksum {
279            warn!("rx: bad checksum. Got {:04x}, want {:04x}", got_checksum, want_checksum);
280            return;
281        }
282
283        let payload = &mut buf[PayloadHeader::SIZE..][..payload_len];
284
285        match if_type_and_num & 0x0f {
286            // STA
287            0 => match self.ch.try_rx_buf() {
288                Some(buf) => {
289                    buf[..payload.len()].copy_from_slice(payload);
290                    self.ch.rx_done(payload.len())
291                }
292                None => warn!("failed to push rxd packet to the channel."),
293            },
294            // serial
295            2 => {
296                trace!("serial rx: {:02x}", payload);
297                if payload.len() < 14 {
298                    warn!("serial rx: too short");
299                    return;
300                }
301
302                let is_event = match &payload[..12] {
303                    b"\x01\x08\x00ctrlResp\x02" => false,
304                    b"\x01\x08\x00ctrlEvnt\x02" => true,
305                    _ => {
306                        warn!("serial rx: bad tlv");
307                        return;
308                    }
309                };
310
311                let len = u16::from_le_bytes(payload[12..14].try_into().unwrap()) as usize;
312                if payload.len() < 14 + len {
313                    warn!("serial rx: too short 2");
314                    return;
315                }
316                let data = &payload[14..][..len];
317
318                if is_event {
319                    self.handle_event(data);
320                } else {
321                    self.shared.ioctl_done(data);
322                }
323            }
324            _ => warn!("unknown iftype {}", if_type_and_num),
325        }
326    }
327
328    fn handle_event(&mut self, data: &[u8]) {
329        let Ok(event) = noproto::read::<CtrlMsg>(data) else {
330            warn!("failed to parse event");
331            return;
332        };
333
334        debug!("event: {:?}", &event);
335
336        let Some(payload) = &event.payload else {
337            warn!("event without payload?");
338            return;
339        };
340
341        match payload {
342            CtrlMsgPayload::EventEspInit(_) => self.shared.init_done(),
343            CtrlMsgPayload::EventHeartbeat(_) => self.heartbeat_deadline = Instant::now() + HEARTBEAT_MAX_GAP,
344            CtrlMsgPayload::EventStationDisconnectFromAp(e) => {
345                info!("disconnected, code {}", e.resp);
346                self.state_ch.set_link_state(LinkState::Down);
347            }
348            _ => {}
349        }
350    }
351}
352
353fn checksum(buf: &[u8]) -> u16 {
354    let mut res = 0u16;
355    for &b in buf {
356        res = res.wrapping_add(b as _);
357    }
358    res
359}