1use super::*;
2use std::convert::TryFrom;
3use std::io::Cursor;
4use std::prelude::v1::*;
5
6struct Buffer<R: std::io::Read> {
7 reader: R,
8 buf: Box<[u8]>,
9 level: usize,
10}
11
12impl<R: std::io::Read> Buffer<R> {
13 fn new(size: usize, reader: R) -> Self {
14 Self {
15 reader,
16 buf: vec![0; size].into_boxed_slice(),
17 level: 0,
18 }
19 }
20
21 fn fill_buf(&mut self) -> Result<&[u8], std::io::Error> {
22 while self.level < self.buf.len() {
23 let dst = &mut self.buf[self.level..];
24 let r = self.reader.read(dst)?;
25 if r == 0 {
26 break;
27 } else {
28 self.level += r;
29 }
30 }
31 Ok(&self.buf[0..self.level])
32 }
33
34 fn consume(&mut self, amt: usize) {
35 if amt >= self.level {
36 self.level = 0;
37 } else {
38 self.buf.copy_within(amt..self.level, 0);
39 self.level -= amt;
40 }
41 }
42}
43
44#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
46pub struct DfuSync<IO, E>
47where
48 IO: DfuIo<Read = usize, Write = usize, Reset = (), Error = E>,
49 E: From<std::io::Error> + From<Error>,
50{
51 io: IO,
52 dfu: DfuSansIo,
53 buffer: Vec<u8>,
54 progress: Option<Box<dyn FnMut(usize)>>,
55}
56
57impl<IO, E> DfuSync<IO, E>
58where
59 IO: DfuIo<Read = usize, Write = usize, Reset = (), Error = E>,
60 E: From<std::io::Error> + From<Error>,
61{
62 pub fn new(io: IO) -> Self {
64 let transfer_size = io.functional_descriptor().transfer_size as usize;
65 let descriptor = *io.functional_descriptor();
66
67 Self {
68 io,
69 dfu: DfuSansIo::new(descriptor),
70 buffer: vec![0x00; transfer_size],
71 progress: None,
72 }
73 }
74
75 pub fn override_address(&mut self, address: u32) -> &mut Self {
79 self.dfu.set_address(address);
80 self
81 }
82
83 pub fn with_progress(&mut self, progress: impl FnMut(usize) + 'static) -> &mut Self {
85 self.progress = Some(Box::new(progress));
86 self
87 }
88
89 pub fn into_inner(self) -> IO {
91 self.io
92 }
93}
94
95impl<IO, E> DfuSync<IO, E>
96where
97 IO: DfuIo<Read = usize, Write = usize, Reset = (), Error = E>,
98 E: From<std::io::Error> + From<Error>,
99{
100 pub fn download_from_slice(&mut self, slice: &[u8]) -> Result<(), IO::Error> {
102 let length = slice.len();
103 let cursor = Cursor::new(slice);
104
105 self.download(
106 cursor,
107 u32::try_from(length).map_err(|_| Error::OutOfCapabilities)?,
108 )
109 }
110
111 pub fn download<R: std::io::Read>(&mut self, reader: R, length: u32) -> Result<(), IO::Error> {
113 let transfer_size = self.io.functional_descriptor().transfer_size as usize;
114 let mut reader = Buffer::new(transfer_size, reader);
115 let buffer = reader.fill_buf()?;
116 if buffer.is_empty() {
117 return Ok(());
118 }
119
120 macro_rules! wait_status {
121 ($cmd:expr) => {{
122 let mut cmd = $cmd;
123 loop {
124 cmd = match cmd.next() {
125 get_status::Step::Break(cmd) => break cmd,
126 get_status::Step::Wait(cmd, poll_timeout) => {
127 std::thread::sleep(std::time::Duration::from_millis(poll_timeout));
128 let (cmd, mut control) = cmd.get_status(&mut self.buffer);
129 let n = control.execute(&self.io)?;
130 cmd.chain(&self.buffer[..n as usize])??
131 }
132 };
133 }
134 }};
135 }
136
137 let cmd = self.dfu.download(self.io.protocol(), length)?;
138 let (cmd, mut control) = cmd.get_status(&mut self.buffer);
139 let n = control.execute(&self.io)?;
140 let (cmd, control) = cmd.chain(&self.buffer[..n])?;
141 if let Some(control) = control {
142 control.execute(&self.io)?;
143 }
144 let (cmd, mut control) = cmd.get_status(&mut self.buffer);
145 let n = control.execute(&self.io)?;
146 let mut download_loop = cmd.chain(&self.buffer[..n])??;
147
148 loop {
149 download_loop = match download_loop.next() {
150 download::Step::Break => break,
151 download::Step::Erase(cmd) => {
152 let (cmd, control) = cmd.erase()?;
153 control.execute(&self.io)?;
154 wait_status!(cmd)
155 }
156 download::Step::SetAddress(cmd) => {
157 let (cmd, control) = cmd.set_address();
158 control.execute(&self.io)?;
159 wait_status!(cmd)
160 }
161 download::Step::DownloadChunk(cmd) => {
162 let chunk = reader.fill_buf()?;
163 let (cmd, control) = cmd.download(chunk)?;
164 let n = control.execute(&self.io)?;
165 reader.consume(n);
166 if let Some(progress) = self.progress.as_mut() {
167 progress(n);
168 }
169 wait_status!(cmd)
170 }
171 download::Step::UsbReset => {
172 log::trace!("Device reset");
173 self.io.usb_reset()?;
174 break;
175 }
176 }
177 }
178
179 Ok(())
180 }
181
182 pub fn download_all<R: std::io::Read + std::io::Seek>(
186 &mut self,
187 mut reader: R,
188 ) -> Result<(), IO::Error> {
189 let length = u32::try_from(reader.seek(std::io::SeekFrom::End(0))?)
190 .map_err(|_| Error::MaximumTransferSizeExceeded)?;
191 reader.seek(std::io::SeekFrom::Start(0))?;
192 self.download(reader, length)
193 }
194
195 pub fn detach(&self) -> Result<(), IO::Error> {
197 self.dfu.detach().execute(&self.io)?;
198 Ok(())
199 }
200
201 pub fn usb_reset(&self) -> Result<IO::Reset, IO::Error> {
203 self.io.usb_reset()
204 }
205
206 pub fn will_detach(&self) -> bool {
208 self.io.functional_descriptor().will_detach
209 }
210
211 pub fn manifestation_tolerant(&self) -> bool {
213 self.io.functional_descriptor().manifestation_tolerant
214 }
215}