Skip to main content

dfu_core/
asynchronous.rs

1use futures::{io::Cursor, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};
2
3use super::*;
4use core::future::Future;
5use std::convert::TryFrom;
6use std::prelude::v1::*;
7
8/// Trait to implement lower level communication with a USB device.
9pub trait DfuAsyncIo {
10    /// Return type after calling [`Self::read_control`].
11    type Read;
12    /// Return type after calling [`Self::write_control`].
13    type Write;
14    /// Return type after calling [`Self::usb_reset`].
15    type Reset;
16    /// Error type.
17    type Error: From<Error>;
18    /// Dfuse Memory layout type
19    type MemoryLayout: AsRef<memory_layout::mem>;
20
21    /// Read data using control transfer.
22    fn read_control(
23        &self,
24        request_type: u8,
25        request: u8,
26        value: u16,
27        buffer: &mut [u8],
28    ) -> impl Future<Output = Result<Self::Read, Self::Error>> + Send;
29
30    /// Write data using control transfer.
31    fn write_control(
32        &self,
33        request_type: u8,
34        request: u8,
35        value: u16,
36        buffer: &[u8],
37    ) -> impl Future<Output = Result<Self::Write, Self::Error>> + Send;
38
39    /// Triggers a USB reset.
40    fn usb_reset(self) -> impl Future<Output = Result<Self::Reset, Self::Error>> + Send;
41
42    /// Sleep for this duration of time.
43    fn sleep(&self, duration: std::time::Duration) -> impl Future<Output = ()> + Send;
44
45    /// Returns the protocol of the device
46    fn protocol(&self) -> &DfuProtocol<Self::MemoryLayout>;
47
48    /// Returns the functional descriptor of the device.
49    fn functional_descriptor(&self) -> &functional_descriptor::FunctionalDescriptor;
50}
51
52impl UsbReadControl<'_> {
53    /// Execute usb write using io
54    pub async fn execute_async<IO: DfuAsyncIo>(&mut self, io: &IO) -> Result<IO::Read, IO::Error> {
55        io.read_control(self.request_type, self.request, self.value, self.buffer)
56            .await
57    }
58}
59
60impl<D> UsbWriteControl<D>
61where
62    D: AsRef<[u8]>,
63{
64    /// Execute usb write using io
65    pub async fn execute_async<IO: DfuAsyncIo>(&self, io: &IO) -> Result<IO::Write, IO::Error> {
66        io.write_control(
67            self.request_type,
68            self.request,
69            self.value,
70            self.buffer.as_ref(),
71        )
72        .await
73    }
74}
75
76struct Buffer<R: AsyncRead + Unpin> {
77    reader: R,
78    buf: Box<[u8]>,
79    level: usize,
80}
81
82impl<R: AsyncRead + Unpin> Buffer<R> {
83    fn new(size: usize, reader: R) -> Self {
84        Self {
85            reader,
86            buf: vec![0; size].into_boxed_slice(),
87            level: 0,
88        }
89    }
90
91    async fn fill_buf(&mut self) -> Result<&[u8], std::io::Error> {
92        while self.level < self.buf.len() {
93            let dst = &mut self.buf[self.level..];
94            let r = self.reader.read(dst).await?;
95            if r == 0 {
96                break;
97            } else {
98                self.level += r;
99            }
100        }
101        Ok(&self.buf[0..self.level])
102    }
103
104    fn consume(&mut self, amt: usize) {
105        if amt >= self.level {
106            self.level = 0;
107        } else {
108            self.buf.copy_within(amt..self.level, 0);
109            self.level -= amt;
110        }
111    }
112}
113
114/// Generic asynchronous implementation of DFU.
115#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
116pub struct DfuAsync<IO, E>
117where
118    IO: DfuAsyncIo<Read = usize, Write = usize, Reset = (), Error = E>,
119    E: From<std::io::Error> + From<Error>,
120{
121    io: IO,
122    dfu: DfuSansIo,
123    buffer: Vec<u8>,
124    progress: Option<Box<dyn FnMut(usize) + Send>>,
125}
126
127impl<IO, E> DfuAsync<IO, E>
128where
129    IO: DfuAsyncIo<Read = usize, Write = usize, Reset = (), Error = E>,
130    E: From<std::io::Error> + From<Error>,
131{
132    /// Create a new instance of a generic synchronous implementation of DFU.
133    pub fn new(io: IO) -> Self {
134        let transfer_size = io.functional_descriptor().transfer_size as usize;
135        let descriptor = *io.functional_descriptor();
136
137        Self {
138            io,
139            dfu: DfuSansIo::new(descriptor),
140            buffer: vec![0x00; transfer_size],
141            progress: None,
142        }
143    }
144
145    /// Override the address onto which the firmware is downloaded.
146    ///
147    /// This address is only used if the device uses the DfuSe protocol.
148    pub fn override_address(&mut self, address: u32) -> &mut Self {
149        self.dfu.set_address(address);
150        self
151    }
152
153    /// Use this closure to show progress.
154    pub fn with_progress(&mut self, progress: impl FnMut(usize) + Send + 'static) -> &mut Self {
155        self.progress = Some(Box::new(progress));
156        self
157    }
158
159    /// Consume the object and return its [`DfuIo`]
160    pub fn into_inner(self) -> IO {
161        self.io
162    }
163}
164
165impl<IO, E> DfuAsync<IO, E>
166where
167    IO: DfuAsyncIo<Read = usize, Write = usize, Reset = (), Error = E>,
168    E: From<std::io::Error> + From<Error>,
169{
170    /// Download a firmware into the device from a slice.
171    ///
172    /// Returns `Some(Self)` if the device stayed on the bus (manifestation tolerant, no USB reset
173    /// occurred) or `None` if a USB reset was performed.
174    pub async fn download_from_slice(self, slice: &[u8]) -> Result<Option<Self>, IO::Error> {
175        let length = slice.len();
176        let cursor = Cursor::new(slice);
177        self.download(
178            cursor,
179            u32::try_from(length).map_err(|_| Error::OutOfCapabilities)?,
180        )
181        .await
182    }
183
184    /// Download a firmware into the device from a reader.
185    ///
186    /// Returns `Some(Self)` if the device stayed on the bus (manifestation tolerant, no USB reset
187    /// occurred) or `None` if a USB reset was performed.
188    pub async fn download<R: AsyncReadExt + Unpin>(
189        mut self,
190        reader: R,
191        length: u32,
192    ) -> Result<Option<Self>, IO::Error> {
193        let transfer_size = self.io.functional_descriptor().transfer_size as usize;
194        let mut reader = Buffer::new(transfer_size, reader);
195        let buffer = reader.fill_buf().await?;
196        if buffer.is_empty() {
197            return Ok(Some(self));
198        }
199
200        macro_rules! wait_status {
201            ($cmd:expr) => {{
202                let mut cmd = $cmd;
203                loop {
204                    cmd = match cmd.next() {
205                        get_status::Step::Break(cmd) => break cmd,
206                        get_status::Step::Wait(cmd, poll_timeout) => {
207                            self.io
208                                .sleep(std::time::Duration::from_millis(poll_timeout))
209                                .await;
210                            let (cmd, mut control) = cmd.get_status(&mut self.buffer);
211                            let n = control.execute_async(&self.io).await?;
212                            cmd.chain(&self.buffer[..n as usize])??
213                        }
214                    };
215                }
216            }};
217        }
218
219        let cmd = self.dfu.download(self.io.protocol(), length)?;
220        let (cmd, mut control) = cmd.get_status(&mut self.buffer);
221        let n = control.execute_async(&self.io).await?;
222        let (cmd, control) = cmd.chain(&self.buffer[..n])?;
223        if let Some(control) = control {
224            control.execute_async(&self.io).await?;
225        }
226        let (cmd, mut control) = cmd.get_status(&mut self.buffer);
227        let n = control.execute_async(&self.io).await?;
228        let mut download_loop = cmd.chain(&self.buffer[..n])??;
229
230        loop {
231            download_loop = match download_loop.next() {
232                download::Step::Break => break Ok(Some(self)),
233                download::Step::Erase(cmd) => {
234                    let (cmd, control) = cmd.erase()?;
235                    control.execute_async(&self.io).await?;
236                    wait_status!(cmd)
237                }
238                download::Step::SetAddress(cmd) => {
239                    let (cmd, control) = cmd.set_address();
240                    control.execute_async(&self.io).await?;
241                    wait_status!(cmd)
242                }
243                download::Step::DownloadChunk(cmd) => {
244                    let chunk = reader.fill_buf().await?;
245                    let (cmd, control) = cmd.download(chunk)?;
246                    let n = control.execute_async(&self.io).await?;
247                    reader.consume(n);
248                    if let Some(progress) = self.progress.as_mut() {
249                        progress(n);
250                    }
251                    wait_status!(cmd)
252                }
253                download::Step::UsbReset => {
254                    log::trace!("Device reset");
255                    self.io.usb_reset().await?;
256                    break Ok(None);
257                }
258            }
259        }
260    }
261
262    /// Download a firmware into the device.
263    ///
264    /// The length is inferred from the reader. Returns `Some(Self)` if the device stayed on the
265    /// bus (manifestation tolerant, no USB reset occurred) or `None` if a USB reset was performed.
266    pub async fn download_all<R: AsyncReadExt + Unpin + AsyncSeek>(
267        self,
268        mut reader: R,
269    ) -> Result<Option<Self>, IO::Error> {
270        let length = u32::try_from(reader.seek(std::io::SeekFrom::End(0)).await?)
271            .map_err(|_| Error::MaximumTransferSizeExceeded)?;
272        reader.seek(std::io::SeekFrom::Start(0)).await?;
273        self.download(reader, length).await
274    }
275
276    /// Send a Detach request to the device
277    pub async fn detach(&self) -> Result<(), IO::Error> {
278        self.dfu.detach().execute_async(&self.io).await?;
279        Ok(())
280    }
281
282    /// Reset the USB device
283    pub async fn usb_reset(self) -> Result<IO::Reset, IO::Error> {
284        self.io.usb_reset().await
285    }
286
287    /// Returns whether the device will detach if requested
288    pub fn will_detach(&self) -> bool {
289        self.io.functional_descriptor().will_detach
290    }
291
292    /// Returns whether the device is manifestation tolerant
293    pub fn manifestation_tolerant(&self) -> bool {
294        self.io.functional_descriptor().manifestation_tolerant
295    }
296}