Skip to main content

wasefire_protocol_usb/
device.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use alloc::boxed::Box;
16use alloc::collections::VecDeque;
17use alloc::vec::Vec;
18use core::marker::PhantomData;
19
20use usb_device::class_prelude::{
21    ControlIn, ControlOut, InterfaceNumber, StringIndex, UsbBus, UsbBusAllocator, UsbClass,
22};
23use usb_device::descriptor::{BosWriter, DescriptorWriter};
24use usb_device::endpoint::{EndpointAddress, EndpointIn, EndpointOut};
25use usb_device::{LangID, UsbError};
26use wasefire_board_api::Error;
27use wasefire_board_api::platform::protocol::{Api, Event};
28use wasefire_error::Code;
29use wasefire_logger as log;
30
31use crate::common::{Decoder, Encoder};
32
33pub struct Impl<'a, B: UsbBus, T: HasRpc<'a, B>> {
34    _never: !,
35    _phantom: PhantomData<(&'a (), B, T)>,
36}
37
38pub trait HasRpc<'a, B: UsbBus> {
39    fn with_rpc<R>(f: impl FnOnce(&mut Rpc<'a, B>) -> R) -> R;
40    fn vendor(request: &[u8]) -> Result<Box<[u8]>, Error>;
41}
42
43impl<'a, B: UsbBus, T: HasRpc<'a, B>> Api for Impl<'a, B, T> {
44    fn read() -> Result<Option<Box<[u8]>>, Error> {
45        T::with_rpc(|x| x.read())
46    }
47
48    fn write(response: &[u8]) -> Result<(), Error> {
49        T::with_rpc(|x| x.write(response))
50    }
51
52    fn enable() -> Result<(), Error> {
53        T::with_rpc(|x| x.enable())
54    }
55
56    fn vendor(request: &[u8]) -> Result<Box<[u8]>, Error> {
57        T::vendor(request)
58    }
59}
60
61pub struct Rpc<'a, B: UsbBus> {
62    interface: InterfaceNumber,
63    read_ep: EndpointOut<'a, B>,
64    write_ep: EndpointIn<'a, B>,
65    state: State,
66}
67
68impl<'a, B: UsbBus> Rpc<'a, B> {
69    pub fn new(usb_bus: &'a UsbBusAllocator<B>) -> Self {
70        let interface = usb_bus.interface();
71        let read_ep = usb_bus.bulk(MAX_PACKET_SIZE);
72        let write_ep = usb_bus.bulk(MAX_PACKET_SIZE);
73        Rpc { interface, read_ep, write_ep, state: State::Disabled }
74    }
75
76    pub fn read(&mut self) -> Result<Option<Box<[u8]>>, Error> {
77        let result = self.state.read()?;
78        match &result {
79            #[cfg(not(feature = "defmt"))]
80            Some(result) => log::debug!("Reading {:02x?}", result),
81            #[cfg(feature = "defmt")]
82            Some(result) => log::debug!("Reading {=[u8]:02x}", result),
83            None => log::debug!("Reading (no message)"),
84        }
85        Ok(result)
86    }
87
88    pub fn write(&mut self, response: &[u8]) -> Result<(), Error> {
89        #[cfg(not(feature = "defmt"))]
90        log::debug!("Writing {:02x?}", response);
91        #[cfg(feature = "defmt")]
92        log::debug!("Writing {=[u8]:02x}", response);
93        self.state.write(response, &self.write_ep)
94    }
95
96    pub fn enable(&mut self) -> Result<(), Error> {
97        match self.state {
98            State::Disabled => {
99                self.state = WaitRequest;
100                Ok(())
101            }
102            _ => Err(Error::user(Code::InvalidState)),
103        }
104    }
105
106    pub fn tick(&mut self, push: impl FnOnce(Event)) {
107        if self.state.notify() {
108            push(Event);
109        }
110    }
111}
112
113const MAX_PACKET_SIZE: u16 = 64;
114
115enum State {
116    Disabled,
117    WaitRequest,
118    ReceiveRequest { decoder: Decoder },
119    RequestReady { notified: bool, request: Vec<u8> },
120    WaitResponse,
121    SendResponse { packets: VecDeque<[u8; 64]> },
122}
123use State::*;
124
125impl State {
126    fn read(&mut self) -> Result<Option<Box<[u8]>>, Error> {
127        match self {
128            RequestReady { request, .. } => {
129                let request = core::mem::take(request);
130                log::debug!("Received a message of {} bytes.", request.len());
131                *self = WaitResponse;
132                Ok(Some(request.into_boxed_slice()))
133            }
134            WaitRequest | ReceiveRequest { .. } | SendResponse { .. } => Ok(None),
135            WaitResponse | Disabled => Err(Error::user(Code::InvalidState)),
136        }
137    }
138
139    fn write<B: UsbBus>(&mut self, response: &[u8], ep: &EndpointIn<B>) -> Result<(), Error> {
140        if !matches!(self, WaitResponse) {
141            return Err(Error::user(Code::InvalidState));
142        }
143        let packets: VecDeque<_> = Encoder::new(response).collect();
144        log::debug!("Sending a message of {} bytes in {} packets.", response.len(), packets.len());
145        *self = SendResponse { packets };
146        self.send(ep);
147        Ok(())
148    }
149
150    fn receive<B: UsbBus>(&mut self, ep: &EndpointOut<B>) {
151        let decoder = match self {
152            ReceiveRequest { decoder } => decoder,
153            Disabled => {
154                log::error!("Not receiving data while disabled.");
155                return;
156            }
157            _ => {
158                *self = ReceiveRequest { decoder: Decoder::default() };
159                match self {
160                    ReceiveRequest { decoder } => decoder,
161                    _ => unreachable!(),
162                }
163            }
164        };
165        let mut packet = [0; MAX_PACKET_SIZE as usize];
166        let len = ep.read(&mut packet).unwrap();
167        if len != MAX_PACKET_SIZE as usize {
168            log::warn!("Received a packet of {} bytes instead of 64.", len);
169            *self = WaitRequest;
170            return;
171        }
172        match core::mem::take(decoder).push(&packet) {
173            None => {
174                log::warn!("Received invalid packet 0x{:02x}", packet[0]);
175                *self = WaitRequest;
176            }
177            Some(Ok(request)) => {
178                log::trace!("Received a message of {} bytes.", request.len());
179                *self = RequestReady { notified: false, request };
180            }
181            Some(Err(x)) => {
182                log::trace!("Received a packet.");
183                *decoder = x;
184            }
185        }
186    }
187
188    fn send<B: UsbBus>(&mut self, ep: &EndpointIn<B>) {
189        let packets = match self {
190            Disabled => {
191                log::error!("Not sending data while disabled.");
192                return;
193            }
194            SendResponse { packets } => packets,
195            _ => return,
196        };
197        let packet = match packets.pop_front() {
198            Some(x) => x,
199            None => {
200                log::warn!("Invalid state: SendResponse with no packets.");
201                *self = WaitRequest;
202                return;
203            }
204        };
205        let len = match ep.write(&packet) {
206            Err(UsbError::WouldBlock) => {
207                log::warn!("Failed to send packet, retrying later.");
208                packets.push_front(packet);
209                return;
210            }
211            x => x.unwrap(),
212        };
213        if len != MAX_PACKET_SIZE as usize {
214            log::warn!("Sent a packet of {} bytes instead of 64.", len);
215            *self = WaitRequest;
216            return;
217        }
218        let remaining = packets.len();
219        if packets.is_empty() {
220            *self = WaitRequest;
221        }
222        log::trace!("Sent the next packet ({} remaining).", remaining);
223    }
224
225    fn notify(&mut self) -> bool {
226        match self {
227            RequestReady { notified, .. } => !core::mem::replace(notified, true),
228            _ => false,
229        }
230    }
231}
232
233impl<B: UsbBus> UsbClass<B> for Rpc<'_, B> {
234    fn get_configuration_descriptors(
235        &self, writer: &mut DescriptorWriter,
236    ) -> usb_device::Result<()> {
237        writer.iad(self.interface, 1, 0xff, 0x58, 0x01, None)?;
238        writer.interface(self.interface, 0xff, 0x58, 0x01)?;
239        writer.endpoint(&self.write_ep)?;
240        writer.endpoint(&self.read_ep)?;
241        Ok(())
242    }
243
244    fn get_bos_descriptors(&self, writer: &mut BosWriter) -> usb_device::Result<()> {
245        // Advertise WebUSB.
246        let mut data = Vec::with_capacity(24);
247        data.push(0); // bReserved
248        // PlatformCapabilityUUID
249        data.extend_from_slice(b"\x38\xb6\x08\x34\xa9\x09\xa0\x47\x8b\xfd\xa0\x76\x88\x15\xb6\x65");
250        data.extend_from_slice(&[0x00, 0x01]); // bcdVersion
251        data.push(WEBUSB_VENDOR_CODE); // bVendorCode
252        data.push(WEBUSB_URL_DESC.is_some() as u8); // iLandingPage
253        // bDevCapabilityType = PLATFORM
254        writer.capability(0x05, &data)
255    }
256
257    fn get_string(&self, _: StringIndex, _id: LangID) -> Option<&str> {
258        // We don't have strings.
259        None
260    }
261
262    fn reset(&mut self) {
263        self.state = match self.state {
264            State::Disabled => State::Disabled,
265            _ => State::WaitRequest,
266        };
267    }
268
269    fn poll(&mut self) {
270        // We probably don't need to do anything here.
271    }
272
273    fn control_out(&mut self, _: ControlOut<B>) {
274        // We probably don't need to do anything here.
275    }
276
277    fn control_in(&mut self, xfer: ControlIn<B>) {
278        let req = xfer.request();
279        if req.request_type != usb_device::control::RequestType::Vendor
280            || req.recipient != usb_device::control::Recipient::Device
281            || req.request != WEBUSB_VENDOR_CODE
282        {
283            return; // Only handle WebUSB requests.
284        }
285        // Stall on invalid requests.
286        let Some(descriptor) = WEBUSB_URL_DESC else { return xfer.reject().unwrap() };
287        const GET_URL: u16 = 2;
288        if req.index != GET_URL || req.value != 1 {
289            return xfer.reject().unwrap();
290        }
291        xfer.accept_with_static(descriptor).unwrap();
292    }
293
294    fn endpoint_setup(&mut self, _: EndpointAddress) {
295        // We probably don't need to do anything here.
296    }
297
298    fn endpoint_out(&mut self, addr: EndpointAddress) {
299        if self.read_ep.address() != addr {
300            return;
301        }
302        self.state.receive(&self.read_ep);
303    }
304
305    fn endpoint_in_complete(&mut self, addr: EndpointAddress) {
306        if self.write_ep.address() != addr {
307            return;
308        }
309        self.state.send(&self.write_ep);
310    }
311}
312
313const WEBUSB_VENDOR_CODE: u8 = 1;
314
315const WEBUSB_URL_DESC: Option<&[u8]> = {
316    const SPLIT: (Option<u8>, &[u8]) = split_webusb_url(option_env!("WASEFIRE_WEBUSB_URL"));
317    match SPLIT.0 {
318        None => None,
319        Some(scheme) => Some(&make_webusb_url::<{ 3 + SPLIT.1.len() }>(scheme, SPLIT.1)),
320    }
321};
322
323const fn make_webusb_url<const LEN: usize>(scheme: u8, data: &[u8]) -> [u8; LEN] {
324    assert!(LEN < 256);
325    let mut result = [0; LEN];
326    result[0] = LEN as u8; // bLength
327    result[1] = 3; // bDescriptorType = WEBUSB_URL
328    result[2] = scheme; // bScheme
329    let mut i = 0;
330    while 3 + i < LEN {
331        result[3 + i] = data[i];
332        i += 1;
333    }
334    result
335}
336
337const fn split_webusb_url(url: Option<&'static str>) -> (Option<u8>, &'static [u8]) {
338    let Some(url) = url else { return (None, &[]) };
339    let url = url.as_bytes();
340    let (scheme, data) = if let Some(data) = strip_prefix(url, b"http://") {
341        (0, data)
342    } else if let Some(data) = strip_prefix(url, b"https://") {
343        (1, data)
344    } else {
345        (255, url)
346    };
347    (Some(scheme), data)
348}
349
350const fn strip_prefix(data: &'static [u8], prefix: &'static [u8]) -> Option<&'static [u8]> {
351    if data.len() < prefix.len() {
352        return None;
353    }
354    let mut i = 0;
355    while i < prefix.len() {
356        if data[i] != prefix[i] {
357            return None;
358        }
359        i += 1;
360    }
361    let ptr = unsafe { data.as_ptr().add(prefix.len()) };
362    let len = data.len() - prefix.len();
363    Some(unsafe { core::slice::from_raw_parts(ptr, len) })
364}