1use std::io::{Read, Seek, Write};
2use std::path::Path;
3use std::time::Instant;
4
5use tokio::sync::mpsc;
6
7use crate::Result;
8use crate::customization::Customization;
9use crate::helpers::{DirectIoBuffer, Eject, chan_send, check_token, progress};
10
11#[cfg(not(debug_assertions))]
13const BUFFER_SIZE: usize = 1 * 1024 * 1024;
14#[cfg(debug_assertions)]
15const BUFFER_SIZE: usize = 8 * 1024;
16
17fn reader_task(
18 mut img: impl Read,
19 buf_rx: std::sync::mpsc::Receiver<Box<DirectIoBuffer<BUFFER_SIZE>>>,
20 buf_tx: std::sync::mpsc::SyncSender<(Box<DirectIoBuffer<BUFFER_SIZE>>, usize)>,
21 cancel: Option<tokio_util::sync::CancellationToken>,
22) -> Result<()> {
23 while let Ok(mut buf) = buf_rx.recv() {
24 let count = read_aligned(&mut img, buf.as_mut_slice())?;
25 if count == 0 {
26 break;
27 }
28
29 buf_tx
30 .send((buf, count))
31 .map_err(|_| crate::Error::WriterClosed)?;
32 check_token(cancel.as_ref())?;
33 }
34
35 Ok(())
36}
37
38fn writer_task_bmap(
44 bmap: bb_bmap_parser::Bmap,
45 mut sd: impl Write + Seek,
46 mut chan: Option<&mut mpsc::Sender<f32>>,
47 buf_rx: std::sync::mpsc::Receiver<(Box<DirectIoBuffer<BUFFER_SIZE>>, usize)>,
48 buf_tx: std::sync::mpsc::SyncSender<Box<DirectIoBuffer<BUFFER_SIZE>>>,
49 cancel: Option<tokio_util::sync::CancellationToken>,
50) -> Result<()> {
51 let mut pos = 0;
52 let (mut buf, mut count) = buf_rx.recv().unwrap();
53 let img_size = bmap.total_mapped_size();
54 let mut bytes_written = 0u64;
55
56 for b in bmap.block_map() {
57 let end_offset = b.offset() + b.length();
58
59 loop {
60 if pos + (count as u64) > b.offset() && pos < end_offset {
62 sd.seek(std::io::SeekFrom::Start(pos))?;
63 sd.write_all(&buf.as_slice()[..count])?;
64 bytes_written += count as u64;
65 } else if pos >= end_offset {
66 break;
67 }
68
69 pos += count as u64;
70 #[allow(clippy::option_map_or_none)]
72 chan_send(
73 chan.as_mut().map_or(None, |p| Some(p)),
74 progress(bytes_written, img_size),
75 );
76 check_token(cancel.as_ref())?;
77
78 match buf_rx.recv() {
79 Ok((x, y)) => {
80 let _ = buf_tx.send(buf);
81 buf = x;
82 count = y;
83 }
84 Err(_) => break,
85 }
86 }
87 }
88
89 sd.flush().map_err(Into::into)
90}
91
92fn writer_task(
93 img_size: u64,
94 mut sd: impl Write + Seek,
95 mut chan: Option<&mut mpsc::Sender<f32>>,
96 buf_rx: std::sync::mpsc::Receiver<(Box<DirectIoBuffer<BUFFER_SIZE>>, usize)>,
97 buf_tx: std::sync::mpsc::SyncSender<Box<DirectIoBuffer<BUFFER_SIZE>>>,
98 cancel: Option<tokio_util::sync::CancellationToken>,
99) -> Result<()> {
100 let mut pos = 0u64;
101
102 while let Ok((buf, count)) = buf_rx.recv() {
103 sd.write_all(&buf.as_slice()[..count])?;
104
105 pos += count as u64;
106 #[allow(clippy::option_map_or_none)]
108 chan_send(
109 chan.as_mut().map_or(None, |p| Some(p)),
110 progress(pos, img_size),
111 );
112
113 let _ = buf_tx.send(buf);
114 check_token(cancel.as_ref())?;
115 }
116
117 sd.flush().map_err(Into::into)
118}
119
120fn read_aligned(mut img: impl Read, buf: &mut [u8]) -> Result<usize> {
123 const ALIGNMENT: usize = 512;
124
125 let mut pos = 0;
126
127 while pos != buf.len() {
128 let count = img.read(&mut buf[pos..])?;
129 if count == 0 {
130 if pos % ALIGNMENT != 0 {
131 let end = pos - pos % ALIGNMENT + ALIGNMENT;
132 buf[pos..end].fill(0);
133 pos = end;
134 }
135 return Ok(pos);
136 }
137 pos += count;
138 }
139
140 Ok(pos)
141}
142
143fn write_sd(
144 img: impl Read + Send,
145 img_size: u64,
146 bmap: Option<bb_bmap_parser::Bmap>,
147 sd: impl Write + Seek,
148 chan: Option<&mut mpsc::Sender<f32>>,
149 cancel: Option<tokio_util::sync::CancellationToken>,
150) -> Result<()> {
151 const NUM_BUFFERS: usize = 4;
152
153 let (tx1, rx1) = std::sync::mpsc::sync_channel(NUM_BUFFERS);
154 let (tx2, rx2) = std::sync::mpsc::sync_channel(NUM_BUFFERS);
155 let global_start = Instant::now();
156
157 for _ in 0..NUM_BUFFERS {
159 tx1.send(Box::new(DirectIoBuffer::new())).unwrap();
160 }
161
162 std::thread::scope(|s| {
163 let cancle_clone = cancel.clone();
164 let handle = s.spawn(move || reader_task(img, rx1, tx2, cancle_clone));
165
166 match bmap {
167 Some(x) => writer_task_bmap(x, sd, chan, rx2, tx1, cancel),
168 None => writer_task(img_size, sd, chan, rx2, tx1, cancel),
169 }?;
170 tracing::info!("Total Time taken: {:?}", global_start.elapsed());
171
172 handle.join().unwrap()
173 })
174}
175
176pub async fn flash<R: Read + Send + 'static>(
205 img: impl Future<Output = std::io::Result<(R, u64)>>,
206 bmap: Option<impl Future<Output = std::io::Result<Box<str>>>>,
207 dst: Box<Path>,
208 chan: Option<mpsc::Sender<f32>>,
209 customizations: Vec<Customization>,
210 cancel: Option<tokio_util::sync::CancellationToken>,
211) -> Result<()> {
212 tracing::info!("Opening Destination");
213 let dst_clone = dst.to_path_buf();
214 let sd = crate::pal::open(&dst_clone).await?;
215
216 tracing::info!("Resolving Image");
217 let bmap = match bmap {
218 Some(x) => {
219 Some(bb_bmap_parser::Bmap::from_xml(&x.await?).map_err(|_| crate::Error::InvalidBmap)?)
220 }
221 None => None,
222 };
223 let (img, img_size) = img.await?;
224
225 let cancel_child = cancel.as_ref().map(|x| x.child_token());
226 let res = tokio::task::spawn_blocking(move || {
227 flash_internal(img, img_size, bmap, sd, chan, customizations, cancel_child)
228 })
229 .await
230 .unwrap();
231
232 let _drop_guard = cancel.map(|x| x.drop_guard());
234
235 res
236}
237
238fn flash_internal(
239 img: impl Read + Send,
240 img_size: u64,
241 bmap: Option<bb_bmap_parser::Bmap>,
242 sd: impl Read + Write + Seek + Eject + std::fmt::Debug,
243 mut chan: Option<mpsc::Sender<f32>>,
244 customizations: Vec<Customization>,
245 cancel: Option<tokio_util::sync::CancellationToken>,
246) -> Result<()> {
247 chan_send(chan.as_mut(), 0.0);
248
249 let mut sd = crate::helpers::SdCardWrapper::new(sd);
250
251 tracing::info!("Writing to SD Card");
252 write_sd(img, img_size, bmap, &mut sd, chan.as_mut(), cancel.clone())?;
253
254 check_token(cancel.as_ref())?;
255
256 tracing::info!("Applying customization");
257 for c in customizations {
258 let temp = crate::helpers::DeviceWrapper::new(&mut sd).unwrap();
259 c.customize(temp)?;
260 }
261
262 tracing::info!("Ejecting SD Card");
263 let _ = sd.eject();
264
265 Ok(())
266}
267
268#[cfg(test)]
269mod tests {
270 use crate::flashing::{BUFFER_SIZE, read_aligned};
271
272 use super::write_sd;
273
274 fn test_file(len: usize) -> std::io::Cursor<Box<[u8]>> {
275 let data: Vec<u8> = (0..len)
276 .map(|x| x % 255)
277 .map(|x| u8::try_from(x).unwrap())
278 .collect();
279 std::io::Cursor::new(data.into())
280 }
281
282 #[test]
283 fn sd_write() {
284 const FILE_LEN: usize = 12 * 1024;
285
286 let dummy_file = test_file(FILE_LEN);
287 let mut sd = std::io::Cursor::new(Vec::<u8>::new());
288
289 write_sd(
290 dummy_file.clone(),
291 FILE_LEN as u64,
292 None,
293 &mut sd,
294 None,
295 None,
296 )
297 .unwrap();
298
299 assert_eq!(sd.get_ref().as_slice(), dummy_file.get_ref().as_ref());
300 }
301
302 #[test]
303 fn sd_write_bmap() {
304 const FILE_LEN: usize = 32 * 1024;
305 const BLOCK_LEN: u64 = BUFFER_SIZE as u64;
306 const BLOCKS: u64 = (FILE_LEN as u64) / BLOCK_LEN;
307 const MAPPED_BLOCKS: &[u64] = &[0, 2, BLOCKS - 1];
308
309 let dummy_file = test_file(FILE_LEN);
310 let mut sd = std::io::Cursor::new(vec![0u8; FILE_LEN]);
311
312 let mut bmap = bb_bmap_parser::Bmap::builder();
313 bmap.image_size(FILE_LEN as u64)
314 .block_size(BLOCK_LEN)
315 .blocks(BLOCKS)
316 .mapped_blocks(MAPPED_BLOCKS.len() as u64)
317 .checksum_type(bb_bmap_parser::HashType::Sha256);
318
319 for i in MAPPED_BLOCKS {
320 bmap.add_block_range(
321 *i,
322 *i,
323 bb_bmap_parser::HashValue::Sha256(Default::default()),
324 );
325 }
326
327 let bmap = bmap.build().unwrap();
328
329 write_sd(
330 dummy_file.clone(),
331 FILE_LEN as u64,
332 Some(bmap.clone()),
333 &mut sd,
334 None,
335 None,
336 )
337 .unwrap();
338
339 for i in 0..(BLOCKS as usize) {
340 let start = i * (BLOCK_LEN as usize);
341 let end = start + (BLOCK_LEN as usize);
342 if MAPPED_BLOCKS.contains(&(i as u64)) {
343 assert_eq!(
344 sd.get_ref().as_slice()[start..end],
345 dummy_file.get_ref().as_ref()[start..end]
346 );
347 } else {
348 assert_eq!(
349 &sd.get_ref().as_slice()[start..end],
350 [0u8; BLOCK_LEN as usize].as_slice()
351 );
352 }
353 }
354 }
355
356 struct UnalignedReader(std::io::Cursor<Box<[u8]>>);
357
358 impl UnalignedReader {
359 const fn as_slice(&self) -> &[u8] {
360 self.0.get_ref()
361 }
362 }
363
364 impl std::io::Read for UnalignedReader {
365 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
366 let count = std::cmp::min(self.0.get_ref().len() - self.0.position() as usize, 3);
367 let count = std::cmp::min(count, buf.len());
368 self.0.read(&mut buf[..count])
369 }
370 }
371
372 #[test]
373 fn aligned_read() {
374 const FILE_LEN: usize = 12 * 1024;
375
376 let mut dummy_file = UnalignedReader(test_file(FILE_LEN));
377 let mut buf = [0u8; 1024];
378 let mut pos = 0;
379
380 loop {
381 let count = read_aligned(&mut dummy_file, &mut buf).unwrap();
382 if count == 0 {
383 break;
384 }
385
386 assert_eq!(count % 512, 0);
387 assert_eq!(buf[..count], dummy_file.as_slice()[pos..(pos + count)]);
388 pos += count;
389 }
390
391 assert_eq!(pos, FILE_LEN);
392 }
393}