Skip to main content

bb_flasher_sd/
flashing.rs

1use std::io::{Read, Seek, Write};
2use std::path::Path;
3use std::time::Instant;
4
5use tokio::sync::mpsc;
6
7use crate::Result;
8use crate::customization::Customization;
9use crate::helpers::{DirectIoBuffer, Eject, chan_send, check_token, progress};
10
11// Stack overflow occurs during debug since box moves data from stack to heap in debug builds
12#[cfg(not(debug_assertions))]
13const BUFFER_SIZE: usize = 1 * 1024 * 1024;
14#[cfg(debug_assertions)]
15const BUFFER_SIZE: usize = 8 * 1024;
16
17fn reader_task(
18    mut img: impl Read,
19    buf_rx: std::sync::mpsc::Receiver<Box<DirectIoBuffer<BUFFER_SIZE>>>,
20    buf_tx: std::sync::mpsc::SyncSender<(Box<DirectIoBuffer<BUFFER_SIZE>>, usize)>,
21    cancel: Option<tokio_util::sync::CancellationToken>,
22) -> Result<()> {
23    while let Ok(mut buf) = buf_rx.recv() {
24        let count = read_aligned(&mut img, buf.as_mut_slice())?;
25        if count == 0 {
26            break;
27        }
28
29        buf_tx
30            .send((buf, count))
31            .map_err(|_| crate::Error::WriterClosed)?;
32        check_token(cancel.as_ref())?;
33    }
34
35    Ok(())
36}
37
38/// While writing, a few assumptions should hold:
39/// - All writes should be in buffers multiple of block size (4K).
40/// - All writes should be aligned to block size (4K).
41///
42/// Thus, we will be writing some data that is not strictly present in the bmap.
43fn writer_task_bmap(
44    bmap: bb_bmap_parser::Bmap,
45    mut sd: impl Write + Seek,
46    mut chan: Option<&mut mpsc::Sender<f32>>,
47    buf_rx: std::sync::mpsc::Receiver<(Box<DirectIoBuffer<BUFFER_SIZE>>, usize)>,
48    buf_tx: std::sync::mpsc::SyncSender<Box<DirectIoBuffer<BUFFER_SIZE>>>,
49    cancel: Option<tokio_util::sync::CancellationToken>,
50) -> Result<()> {
51    let mut pos = 0;
52    let (mut buf, mut count) = buf_rx.recv().unwrap();
53    let img_size = bmap.total_mapped_size();
54    let mut bytes_written = 0u64;
55
56    for b in bmap.block_map() {
57        let end_offset = b.offset() + b.length();
58
59        loop {
60            // Write any buffer that lies even partially in the bmap range.
61            if pos + (count as u64) > b.offset() && pos < end_offset {
62                sd.seek(std::io::SeekFrom::Start(pos))?;
63                sd.write_all(&buf.as_slice()[..count])?;
64                bytes_written += count as u64;
65            } else if pos >= end_offset {
66                break;
67            }
68
69            pos += count as u64;
70            // Clippy warning is simply wrong here
71            #[allow(clippy::option_map_or_none)]
72            chan_send(
73                chan.as_mut().map_or(None, |p| Some(p)),
74                progress(bytes_written, img_size),
75            );
76            check_token(cancel.as_ref())?;
77
78            match buf_rx.recv() {
79                Ok((x, y)) => {
80                    let _ = buf_tx.send(buf);
81                    buf = x;
82                    count = y;
83                }
84                Err(_) => break,
85            }
86        }
87    }
88
89    sd.flush().map_err(Into::into)
90}
91
92fn writer_task(
93    img_size: u64,
94    mut sd: impl Write + Seek,
95    mut chan: Option<&mut mpsc::Sender<f32>>,
96    buf_rx: std::sync::mpsc::Receiver<(Box<DirectIoBuffer<BUFFER_SIZE>>, usize)>,
97    buf_tx: std::sync::mpsc::SyncSender<Box<DirectIoBuffer<BUFFER_SIZE>>>,
98    cancel: Option<tokio_util::sync::CancellationToken>,
99) -> Result<()> {
100    let mut pos = 0u64;
101
102    while let Ok((buf, count)) = buf_rx.recv() {
103        sd.write_all(&buf.as_slice()[..count])?;
104
105        pos += count as u64;
106        // Clippy warning is simply wrong here
107        #[allow(clippy::option_map_or_none)]
108        chan_send(
109            chan.as_mut().map_or(None, |p| Some(p)),
110            progress(pos, img_size),
111        );
112
113        let _ = buf_tx.send(buf);
114        check_token(cancel.as_ref())?;
115    }
116
117    sd.flush().map_err(Into::into)
118}
119
120/// A lot of reads from compressed files are not aligned. Since reading even from compressed files
121/// is significantly faster than writing to SD Card, better to do multiple reads.
122fn read_aligned(mut img: impl Read, buf: &mut [u8]) -> Result<usize> {
123    const ALIGNMENT: usize = 512;
124
125    let mut pos = 0;
126
127    while pos != buf.len() {
128        let count = img.read(&mut buf[pos..])?;
129        if count == 0 {
130            if pos % ALIGNMENT != 0 {
131                let end = pos - pos % ALIGNMENT + ALIGNMENT;
132                buf[pos..end].fill(0);
133                pos = end;
134            }
135            return Ok(pos);
136        }
137        pos += count;
138    }
139
140    Ok(pos)
141}
142
143fn write_sd(
144    img: impl Read + Send,
145    img_size: u64,
146    bmap: Option<bb_bmap_parser::Bmap>,
147    sd: impl Write + Seek,
148    chan: Option<&mut mpsc::Sender<f32>>,
149    cancel: Option<tokio_util::sync::CancellationToken>,
150) -> Result<()> {
151    const NUM_BUFFERS: usize = 4;
152
153    let (tx1, rx1) = std::sync::mpsc::sync_channel(NUM_BUFFERS);
154    let (tx2, rx2) = std::sync::mpsc::sync_channel(NUM_BUFFERS);
155    let global_start = Instant::now();
156
157    // Starting buffers
158    for _ in 0..NUM_BUFFERS {
159        tx1.send(Box::new(DirectIoBuffer::new())).unwrap();
160    }
161
162    std::thread::scope(|s| {
163        let cancle_clone = cancel.clone();
164        let handle = s.spawn(move || reader_task(img, rx1, tx2, cancle_clone));
165
166        match bmap {
167            Some(x) => writer_task_bmap(x, sd, chan, rx2, tx1, cancel),
168            None => writer_task(img_size, sd, chan, rx2, tx1, cancel),
169        }?;
170        tracing::info!("Total Time taken: {:?}", global_start.elapsed());
171
172        handle.join().unwrap()
173    })
174}
175
176/// Flash OS image to SD card.
177///
178/// # Customization
179///
180/// Support post flashing customization. Currently only sysconf is supported, which is used by
181/// [BeagleBoard.org].
182///
183/// # Image
184///
185/// Using a resolver function for image and image size. This is to allow downloading the image, or
186/// some kind of lazy loading after SD card permissions have be acquired. This is useful in GUIs
187/// since the user would expect a password prompt at the start of flashing.
188///
189/// Many users might switch task after starting the flashing process, which would make it
190/// frustrating if the prompt occured after downloading.
191///
192/// # Progress
193///
194/// Progress lies between 0 and 1.
195///
196/// # Aborting
197///
198/// The process can be aborted by dropping all strong references to the [`Arc`] that owns the
199/// [`Weak`] passed as `cancel`.
200///
201/// [`Arc`]: std::sync::Arc
202/// [`Weak`]: std::sync::Weak
203/// [BeagleBoard.org]: https://www.beagleboard.org/
204pub async fn flash<R: Read + Send + 'static>(
205    img: impl Future<Output = std::io::Result<(R, u64)>>,
206    bmap: Option<impl Future<Output = std::io::Result<Box<str>>>>,
207    dst: Box<Path>,
208    chan: Option<mpsc::Sender<f32>>,
209    customizations: Vec<Customization>,
210    cancel: Option<tokio_util::sync::CancellationToken>,
211) -> Result<()> {
212    tracing::info!("Opening Destination");
213    let dst_clone = dst.to_path_buf();
214    let sd = crate::pal::open(&dst_clone).await?;
215
216    tracing::info!("Resolving Image");
217    let bmap = match bmap {
218        Some(x) => {
219            Some(bb_bmap_parser::Bmap::from_xml(&x.await?).map_err(|_| crate::Error::InvalidBmap)?)
220        }
221        None => None,
222    };
223    let (img, img_size) = img.await?;
224
225    let cancel_child = cancel.as_ref().map(|x| x.child_token());
226    let res = tokio::task::spawn_blocking(move || {
227        flash_internal(img, img_size, bmap, sd, chan, customizations, cancel_child)
228    })
229    .await
230    .unwrap();
231
232    // Cancel all tasks on drop
233    let _drop_guard = cancel.map(|x| x.drop_guard());
234
235    res
236}
237
238fn flash_internal(
239    img: impl Read + Send,
240    img_size: u64,
241    bmap: Option<bb_bmap_parser::Bmap>,
242    sd: impl Read + Write + Seek + Eject + std::fmt::Debug,
243    mut chan: Option<mpsc::Sender<f32>>,
244    customizations: Vec<Customization>,
245    cancel: Option<tokio_util::sync::CancellationToken>,
246) -> Result<()> {
247    chan_send(chan.as_mut(), 0.0);
248
249    let mut sd = crate::helpers::SdCardWrapper::new(sd);
250
251    tracing::info!("Writing to SD Card");
252    write_sd(img, img_size, bmap, &mut sd, chan.as_mut(), cancel.clone())?;
253
254    check_token(cancel.as_ref())?;
255
256    tracing::info!("Applying customization");
257    for c in customizations {
258        let temp = crate::helpers::DeviceWrapper::new(&mut sd).unwrap();
259        c.customize(temp)?;
260    }
261
262    tracing::info!("Ejecting SD Card");
263    let _ = sd.eject();
264
265    Ok(())
266}
267
268#[cfg(test)]
269mod tests {
270    use crate::flashing::{BUFFER_SIZE, read_aligned};
271
272    use super::write_sd;
273
274    fn test_file(len: usize) -> std::io::Cursor<Box<[u8]>> {
275        let data: Vec<u8> = (0..len)
276            .map(|x| x % 255)
277            .map(|x| u8::try_from(x).unwrap())
278            .collect();
279        std::io::Cursor::new(data.into())
280    }
281
282    #[test]
283    fn sd_write() {
284        const FILE_LEN: usize = 12 * 1024;
285
286        let dummy_file = test_file(FILE_LEN);
287        let mut sd = std::io::Cursor::new(Vec::<u8>::new());
288
289        write_sd(
290            dummy_file.clone(),
291            FILE_LEN as u64,
292            None,
293            &mut sd,
294            None,
295            None,
296        )
297        .unwrap();
298
299        assert_eq!(sd.get_ref().as_slice(), dummy_file.get_ref().as_ref());
300    }
301
302    #[test]
303    fn sd_write_bmap() {
304        const FILE_LEN: usize = 32 * 1024;
305        const BLOCK_LEN: u64 = BUFFER_SIZE as u64;
306        const BLOCKS: u64 = (FILE_LEN as u64) / BLOCK_LEN;
307        const MAPPED_BLOCKS: &[u64] = &[0, 2, BLOCKS - 1];
308
309        let dummy_file = test_file(FILE_LEN);
310        let mut sd = std::io::Cursor::new(vec![0u8; FILE_LEN]);
311
312        let mut bmap = bb_bmap_parser::Bmap::builder();
313        bmap.image_size(FILE_LEN as u64)
314            .block_size(BLOCK_LEN)
315            .blocks(BLOCKS)
316            .mapped_blocks(MAPPED_BLOCKS.len() as u64)
317            .checksum_type(bb_bmap_parser::HashType::Sha256);
318
319        for i in MAPPED_BLOCKS {
320            bmap.add_block_range(
321                *i,
322                *i,
323                bb_bmap_parser::HashValue::Sha256(Default::default()),
324            );
325        }
326
327        let bmap = bmap.build().unwrap();
328
329        write_sd(
330            dummy_file.clone(),
331            FILE_LEN as u64,
332            Some(bmap.clone()),
333            &mut sd,
334            None,
335            None,
336        )
337        .unwrap();
338
339        for i in 0..(BLOCKS as usize) {
340            let start = i * (BLOCK_LEN as usize);
341            let end = start + (BLOCK_LEN as usize);
342            if MAPPED_BLOCKS.contains(&(i as u64)) {
343                assert_eq!(
344                    sd.get_ref().as_slice()[start..end],
345                    dummy_file.get_ref().as_ref()[start..end]
346                );
347            } else {
348                assert_eq!(
349                    &sd.get_ref().as_slice()[start..end],
350                    [0u8; BLOCK_LEN as usize].as_slice()
351                );
352            }
353        }
354    }
355
356    struct UnalignedReader(std::io::Cursor<Box<[u8]>>);
357
358    impl UnalignedReader {
359        const fn as_slice(&self) -> &[u8] {
360            self.0.get_ref()
361        }
362    }
363
364    impl std::io::Read for UnalignedReader {
365        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
366            let count = std::cmp::min(self.0.get_ref().len() - self.0.position() as usize, 3);
367            let count = std::cmp::min(count, buf.len());
368            self.0.read(&mut buf[..count])
369        }
370    }
371
372    #[test]
373    fn aligned_read() {
374        const FILE_LEN: usize = 12 * 1024;
375
376        let mut dummy_file = UnalignedReader(test_file(FILE_LEN));
377        let mut buf = [0u8; 1024];
378        let mut pos = 0;
379
380        loop {
381            let count = read_aligned(&mut dummy_file, &mut buf).unwrap();
382            if count == 0 {
383                break;
384            }
385
386            assert_eq!(count % 512, 0);
387            assert_eq!(buf[..count], dummy_file.as_slice()[pos..(pos + count)]);
388            pos += count;
389        }
390
391        assert_eq!(pos, FILE_LEN);
392    }
393}