dfu_core/
sync.rs

1use super::*;
2use std::convert::TryFrom;
3use std::io::Cursor;
4use std::prelude::v1::*;
5
6struct Buffer<R: std::io::Read> {
7    reader: R,
8    buf: Box<[u8]>,
9    level: usize,
10}
11
12impl<R: std::io::Read> Buffer<R> {
13    fn new(size: usize, reader: R) -> Self {
14        Self {
15            reader,
16            buf: vec![0; size].into_boxed_slice(),
17            level: 0,
18        }
19    }
20
21    fn fill_buf(&mut self) -> Result<&[u8], std::io::Error> {
22        while self.level < self.buf.len() {
23            let dst = &mut self.buf[self.level..];
24            let r = self.reader.read(dst)?;
25            if r == 0 {
26                break;
27            } else {
28                self.level += r;
29            }
30        }
31        Ok(&self.buf[0..self.level])
32    }
33
34    fn consume(&mut self, amt: usize) {
35        if amt >= self.level {
36            self.level = 0;
37        } else {
38            self.buf.copy_within(amt..self.level, 0);
39            self.level -= amt;
40        }
41    }
42}
43
44/// Generic synchronous implementation of DFU.
45#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
46pub struct DfuSync<IO, E>
47where
48    IO: DfuIo<Read = usize, Write = usize, Reset = (), Error = E>,
49    E: From<std::io::Error> + From<Error>,
50{
51    io: IO,
52    dfu: DfuSansIo,
53    buffer: Vec<u8>,
54    progress: Option<Box<dyn FnMut(usize)>>,
55}
56
57impl<IO, E> DfuSync<IO, E>
58where
59    IO: DfuIo<Read = usize, Write = usize, Reset = (), Error = E>,
60    E: From<std::io::Error> + From<Error>,
61{
62    /// Create a new instance of a generic synchronous implementation of DFU.
63    pub fn new(io: IO) -> Self {
64        let transfer_size = io.functional_descriptor().transfer_size as usize;
65        let descriptor = *io.functional_descriptor();
66
67        Self {
68            io,
69            dfu: DfuSansIo::new(descriptor),
70            buffer: vec![0x00; transfer_size],
71            progress: None,
72        }
73    }
74
75    /// Override the address onto which the firmware is downloaded.
76    ///
77    /// This address is only used if the device uses the DfuSe protocol.
78    pub fn override_address(&mut self, address: u32) -> &mut Self {
79        self.dfu.set_address(address);
80        self
81    }
82
83    /// Use this closure to show progress.
84    pub fn with_progress(&mut self, progress: impl FnMut(usize) + 'static) -> &mut Self {
85        self.progress = Some(Box::new(progress));
86        self
87    }
88
89    /// Consume the object and return its [`DfuIo`]
90    pub fn into_inner(self) -> IO {
91        self.io
92    }
93}
94
95impl<IO, E> DfuSync<IO, E>
96where
97    IO: DfuIo<Read = usize, Write = usize, Reset = (), Error = E>,
98    E: From<std::io::Error> + From<Error>,
99{
100    /// Download a firmware into the device from a slice.
101    pub fn download_from_slice(&mut self, slice: &[u8]) -> Result<(), IO::Error> {
102        let length = slice.len();
103        let cursor = Cursor::new(slice);
104
105        self.download(
106            cursor,
107            u32::try_from(length).map_err(|_| Error::OutOfCapabilities)?,
108        )
109    }
110
111    /// Download a firmware into the device from a reader.
112    pub fn download<R: std::io::Read>(&mut self, reader: R, length: u32) -> Result<(), IO::Error> {
113        let transfer_size = self.io.functional_descriptor().transfer_size as usize;
114        let mut reader = Buffer::new(transfer_size, reader);
115        let buffer = reader.fill_buf()?;
116        if buffer.is_empty() {
117            return Ok(());
118        }
119
120        macro_rules! wait_status {
121            ($cmd:expr) => {{
122                let mut cmd = $cmd;
123                loop {
124                    cmd = match cmd.next() {
125                        get_status::Step::Break(cmd) => break cmd,
126                        get_status::Step::Wait(cmd, poll_timeout) => {
127                            std::thread::sleep(std::time::Duration::from_millis(poll_timeout));
128                            let (cmd, mut control) = cmd.get_status(&mut self.buffer);
129                            let n = control.execute(&self.io)?;
130                            cmd.chain(&self.buffer[..n as usize])??
131                        }
132                    };
133                }
134            }};
135        }
136
137        let cmd = self.dfu.download(self.io.protocol(), length)?;
138        let (cmd, mut control) = cmd.get_status(&mut self.buffer);
139        let n = control.execute(&self.io)?;
140        let (cmd, control) = cmd.chain(&self.buffer[..n])?;
141        if let Some(control) = control {
142            control.execute(&self.io)?;
143        }
144        let (cmd, mut control) = cmd.get_status(&mut self.buffer);
145        let n = control.execute(&self.io)?;
146        let mut download_loop = cmd.chain(&self.buffer[..n])??;
147
148        loop {
149            download_loop = match download_loop.next() {
150                download::Step::Break => break,
151                download::Step::Erase(cmd) => {
152                    let (cmd, control) = cmd.erase()?;
153                    control.execute(&self.io)?;
154                    wait_status!(cmd)
155                }
156                download::Step::SetAddress(cmd) => {
157                    let (cmd, control) = cmd.set_address();
158                    control.execute(&self.io)?;
159                    wait_status!(cmd)
160                }
161                download::Step::DownloadChunk(cmd) => {
162                    let chunk = reader.fill_buf()?;
163                    let (cmd, control) = cmd.download(chunk)?;
164                    let n = control.execute(&self.io)?;
165                    reader.consume(n);
166                    if let Some(progress) = self.progress.as_mut() {
167                        progress(n);
168                    }
169                    wait_status!(cmd)
170                }
171                download::Step::UsbReset => {
172                    log::trace!("Device reset");
173                    self.io.usb_reset()?;
174                    break;
175                }
176            }
177        }
178
179        Ok(())
180    }
181
182    /// Download a firmware into the device.
183    ///
184    /// The length is guest from the reader.
185    pub fn download_all<R: std::io::Read + std::io::Seek>(
186        &mut self,
187        mut reader: R,
188    ) -> Result<(), IO::Error> {
189        let length = u32::try_from(reader.seek(std::io::SeekFrom::End(0))?)
190            .map_err(|_| Error::MaximumTransferSizeExceeded)?;
191        reader.seek(std::io::SeekFrom::Start(0))?;
192        self.download(reader, length)
193    }
194
195    /// Send a Detach request to the device
196    pub fn detach(&self) -> Result<(), IO::Error> {
197        self.dfu.detach().execute(&self.io)?;
198        Ok(())
199    }
200
201    /// Reset the USB device
202    pub fn usb_reset(&self) -> Result<IO::Reset, IO::Error> {
203        self.io.usb_reset()
204    }
205
206    /// Returns whether the device is will detach if requested
207    pub fn will_detach(&self) -> bool {
208        self.io.functional_descriptor().will_detach
209    }
210
211    /// Returns whether the device is manifestation tolerant
212    pub fn manifestation_tolerant(&self) -> bool {
213        self.io.functional_descriptor().manifestation_tolerant
214    }
215}