1use futures::{io::Cursor, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};
2
3use super::*;
4use core::future::Future;
5use std::convert::TryFrom;
6use std::prelude::v1::*;
7
8pub trait DfuAsyncIo {
10 type Read;
12 type Write;
14 type Reset;
16 type Error: From<Error>;
18 type MemoryLayout: AsRef<memory_layout::mem>;
20
21 fn read_control(
23 &self,
24 request_type: u8,
25 request: u8,
26 value: u16,
27 buffer: &mut [u8],
28 ) -> impl Future<Output = Result<Self::Read, Self::Error>> + Send;
29
30 fn write_control(
32 &self,
33 request_type: u8,
34 request: u8,
35 value: u16,
36 buffer: &[u8],
37 ) -> impl Future<Output = Result<Self::Write, Self::Error>> + Send;
38
39 fn usb_reset(self) -> impl Future<Output = Result<Self::Reset, Self::Error>> + Send;
41
42 fn sleep(&self, duration: std::time::Duration) -> impl Future<Output = ()> + Send;
44
45 fn protocol(&self) -> &DfuProtocol<Self::MemoryLayout>;
47
48 fn functional_descriptor(&self) -> &functional_descriptor::FunctionalDescriptor;
50}
51
52impl UsbReadControl<'_> {
53 pub async fn execute_async<IO: DfuAsyncIo>(&mut self, io: &IO) -> Result<IO::Read, IO::Error> {
55 io.read_control(self.request_type, self.request, self.value, self.buffer)
56 .await
57 }
58}
59
60impl<D> UsbWriteControl<D>
61where
62 D: AsRef<[u8]>,
63{
64 pub async fn execute_async<IO: DfuAsyncIo>(&self, io: &IO) -> Result<IO::Write, IO::Error> {
66 io.write_control(
67 self.request_type,
68 self.request,
69 self.value,
70 self.buffer.as_ref(),
71 )
72 .await
73 }
74}
75
76struct Buffer<R: AsyncRead + Unpin> {
77 reader: R,
78 buf: Box<[u8]>,
79 level: usize,
80}
81
82impl<R: AsyncRead + Unpin> Buffer<R> {
83 fn new(size: usize, reader: R) -> Self {
84 Self {
85 reader,
86 buf: vec![0; size].into_boxed_slice(),
87 level: 0,
88 }
89 }
90
91 async fn fill_buf(&mut self) -> Result<&[u8], std::io::Error> {
92 while self.level < self.buf.len() {
93 let dst = &mut self.buf[self.level..];
94 let r = self.reader.read(dst).await?;
95 if r == 0 {
96 break;
97 } else {
98 self.level += r;
99 }
100 }
101 Ok(&self.buf[0..self.level])
102 }
103
104 fn consume(&mut self, amt: usize) {
105 if amt >= self.level {
106 self.level = 0;
107 } else {
108 self.buf.copy_within(amt..self.level, 0);
109 self.level -= amt;
110 }
111 }
112}
113
114#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
116pub struct DfuAsync<IO, E>
117where
118 IO: DfuAsyncIo<Read = usize, Write = usize, Reset = (), Error = E>,
119 E: From<std::io::Error> + From<Error>,
120{
121 io: IO,
122 dfu: DfuSansIo,
123 buffer: Vec<u8>,
124 progress: Option<Box<dyn FnMut(usize) + Send>>,
125}
126
127impl<IO, E> DfuAsync<IO, E>
128where
129 IO: DfuAsyncIo<Read = usize, Write = usize, Reset = (), Error = E>,
130 E: From<std::io::Error> + From<Error>,
131{
132 pub fn new(io: IO) -> Self {
134 let transfer_size = io.functional_descriptor().transfer_size as usize;
135 let descriptor = *io.functional_descriptor();
136
137 Self {
138 io,
139 dfu: DfuSansIo::new(descriptor),
140 buffer: vec![0x00; transfer_size],
141 progress: None,
142 }
143 }
144
145 pub fn override_address(&mut self, address: u32) -> &mut Self {
149 self.dfu.set_address(address);
150 self
151 }
152
153 pub fn with_progress(&mut self, progress: impl FnMut(usize) + Send + 'static) -> &mut Self {
155 self.progress = Some(Box::new(progress));
156 self
157 }
158
159 pub fn into_inner(self) -> IO {
161 self.io
162 }
163}
164
165impl<IO, E> DfuAsync<IO, E>
166where
167 IO: DfuAsyncIo<Read = usize, Write = usize, Reset = (), Error = E>,
168 E: From<std::io::Error> + From<Error>,
169{
170 pub async fn download_from_slice(self, slice: &[u8]) -> Result<Option<Self>, IO::Error> {
175 let length = slice.len();
176 let cursor = Cursor::new(slice);
177 self.download(
178 cursor,
179 u32::try_from(length).map_err(|_| Error::OutOfCapabilities)?,
180 )
181 .await
182 }
183
184 pub async fn download<R: AsyncReadExt + Unpin>(
189 mut self,
190 reader: R,
191 length: u32,
192 ) -> Result<Option<Self>, IO::Error> {
193 let transfer_size = self.io.functional_descriptor().transfer_size as usize;
194 let mut reader = Buffer::new(transfer_size, reader);
195 let buffer = reader.fill_buf().await?;
196 if buffer.is_empty() {
197 return Ok(Some(self));
198 }
199
200 macro_rules! wait_status {
201 ($cmd:expr) => {{
202 let mut cmd = $cmd;
203 loop {
204 cmd = match cmd.next() {
205 get_status::Step::Break(cmd) => break cmd,
206 get_status::Step::Wait(cmd, poll_timeout) => {
207 self.io
208 .sleep(std::time::Duration::from_millis(poll_timeout))
209 .await;
210 let (cmd, mut control) = cmd.get_status(&mut self.buffer);
211 let n = control.execute_async(&self.io).await?;
212 cmd.chain(&self.buffer[..n as usize])??
213 }
214 };
215 }
216 }};
217 }
218
219 let cmd = self.dfu.download(self.io.protocol(), length)?;
220 let (cmd, mut control) = cmd.get_status(&mut self.buffer);
221 let n = control.execute_async(&self.io).await?;
222 let (cmd, control) = cmd.chain(&self.buffer[..n])?;
223 if let Some(control) = control {
224 control.execute_async(&self.io).await?;
225 }
226 let (cmd, mut control) = cmd.get_status(&mut self.buffer);
227 let n = control.execute_async(&self.io).await?;
228 let mut download_loop = cmd.chain(&self.buffer[..n])??;
229
230 loop {
231 download_loop = match download_loop.next() {
232 download::Step::Break => break Ok(Some(self)),
233 download::Step::Erase(cmd) => {
234 let (cmd, control) = cmd.erase()?;
235 control.execute_async(&self.io).await?;
236 wait_status!(cmd)
237 }
238 download::Step::SetAddress(cmd) => {
239 let (cmd, control) = cmd.set_address();
240 control.execute_async(&self.io).await?;
241 wait_status!(cmd)
242 }
243 download::Step::DownloadChunk(cmd) => {
244 let chunk = reader.fill_buf().await?;
245 let (cmd, control) = cmd.download(chunk)?;
246 let n = control.execute_async(&self.io).await?;
247 reader.consume(n);
248 if let Some(progress) = self.progress.as_mut() {
249 progress(n);
250 }
251 wait_status!(cmd)
252 }
253 download::Step::UsbReset => {
254 log::trace!("Device reset");
255 self.io.usb_reset().await?;
256 break Ok(None);
257 }
258 }
259 }
260 }
261
262 pub async fn download_all<R: AsyncReadExt + Unpin + AsyncSeek>(
267 self,
268 mut reader: R,
269 ) -> Result<Option<Self>, IO::Error> {
270 let length = u32::try_from(reader.seek(std::io::SeekFrom::End(0)).await?)
271 .map_err(|_| Error::MaximumTransferSizeExceeded)?;
272 reader.seek(std::io::SeekFrom::Start(0)).await?;
273 self.download(reader, length).await
274 }
275
276 pub async fn detach(&self) -> Result<(), IO::Error> {
278 self.dfu.detach().execute_async(&self.io).await?;
279 Ok(())
280 }
281
282 pub async fn usb_reset(self) -> Result<IO::Reset, IO::Error> {
284 self.io.usb_reset().await
285 }
286
287 pub fn will_detach(&self) -> bool {
289 self.io.functional_descriptor().will_detach
290 }
291
292 pub fn manifestation_tolerant(&self) -> bool {
294 self.io.functional_descriptor().manifestation_tolerant
295 }
296}