use std::io::{Read, Seek, Write};
use std::path::Path;
use std::time::Instant;
use tokio::sync::mpsc;
use crate::Result;
use crate::customization::Customization;
use crate::helpers::{DirectIoBuffer, Eject, chan_send, check_token, progress};
#[cfg(not(debug_assertions))]
const BUFFER_SIZE: usize = 1 * 1024 * 1024;
#[cfg(debug_assertions)]
const BUFFER_SIZE: usize = 8 * 1024;
fn reader_task(
mut img: impl Read,
buf_rx: std::sync::mpsc::Receiver<Box<DirectIoBuffer<BUFFER_SIZE>>>,
buf_tx: std::sync::mpsc::SyncSender<(Box<DirectIoBuffer<BUFFER_SIZE>>, usize)>,
cancel: Option<tokio_util::sync::CancellationToken>,
) -> Result<()> {
while let Ok(mut buf) = buf_rx.recv() {
let count = read_aligned(&mut img, buf.as_mut_slice())?;
if count == 0 {
break;
}
buf_tx
.send((buf, count))
.map_err(|_| crate::Error::WriterClosed)?;
check_token(cancel.as_ref())?;
}
Ok(())
}
fn writer_task_bmap(
bmap: bb_bmap_parser::Bmap,
mut sd: impl Write + Seek,
mut chan: Option<&mut mpsc::Sender<f32>>,
buf_rx: std::sync::mpsc::Receiver<(Box<DirectIoBuffer<BUFFER_SIZE>>, usize)>,
buf_tx: std::sync::mpsc::SyncSender<Box<DirectIoBuffer<BUFFER_SIZE>>>,
cancel: Option<tokio_util::sync::CancellationToken>,
) -> Result<()> {
let mut pos = 0;
let (mut buf, mut count) = buf_rx.recv().unwrap();
let img_size = bmap.total_mapped_size();
let mut bytes_written = 0u64;
for b in bmap.block_map() {
let end_offset = b.offset() + b.length();
loop {
if pos + (count as u64) > b.offset() && pos < end_offset {
sd.seek(std::io::SeekFrom::Start(pos))?;
sd.write_all(&buf.as_slice()[..count])?;
bytes_written += count as u64;
} else if pos >= end_offset {
break;
}
pos += count as u64;
#[allow(clippy::option_map_or_none)]
chan_send(
chan.as_mut().map_or(None, |p| Some(p)),
progress(bytes_written, img_size),
);
check_token(cancel.as_ref())?;
match buf_rx.recv() {
Ok((x, y)) => {
let _ = buf_tx.send(buf);
buf = x;
count = y;
}
Err(_) => break,
}
}
}
sd.flush().map_err(Into::into)
}
fn writer_task(
img_size: u64,
mut sd: impl Write + Seek,
mut chan: Option<&mut mpsc::Sender<f32>>,
buf_rx: std::sync::mpsc::Receiver<(Box<DirectIoBuffer<BUFFER_SIZE>>, usize)>,
buf_tx: std::sync::mpsc::SyncSender<Box<DirectIoBuffer<BUFFER_SIZE>>>,
cancel: Option<tokio_util::sync::CancellationToken>,
) -> Result<()> {
let mut pos = 0u64;
while let Ok((buf, count)) = buf_rx.recv() {
sd.write_all(&buf.as_slice()[..count])?;
pos += count as u64;
#[allow(clippy::option_map_or_none)]
chan_send(
chan.as_mut().map_or(None, |p| Some(p)),
progress(pos, img_size),
);
let _ = buf_tx.send(buf);
check_token(cancel.as_ref())?;
}
sd.flush().map_err(Into::into)
}
fn read_aligned(mut img: impl Read, buf: &mut [u8]) -> Result<usize> {
const ALIGNMENT: usize = 512;
let mut pos = 0;
while pos != buf.len() {
let count = img.read(&mut buf[pos..])?;
if count == 0 {
if pos % ALIGNMENT != 0 {
let end = pos - pos % ALIGNMENT + ALIGNMENT;
buf[pos..end].fill(0);
pos = end;
}
return Ok(pos);
}
pos += count;
}
Ok(pos)
}
fn write_sd(
img: impl Read + Send,
img_size: u64,
bmap: Option<bb_bmap_parser::Bmap>,
sd: impl Write + Seek,
chan: Option<&mut mpsc::Sender<f32>>,
cancel: Option<tokio_util::sync::CancellationToken>,
) -> Result<()> {
const NUM_BUFFERS: usize = 4;
let (tx1, rx1) = std::sync::mpsc::sync_channel(NUM_BUFFERS);
let (tx2, rx2) = std::sync::mpsc::sync_channel(NUM_BUFFERS);
let global_start = Instant::now();
for _ in 0..NUM_BUFFERS {
tx1.send(Box::new(DirectIoBuffer::new())).unwrap();
}
std::thread::scope(|s| {
let cancle_clone = cancel.clone();
let handle = s.spawn(move || reader_task(img, rx1, tx2, cancle_clone));
match bmap {
Some(x) => writer_task_bmap(x, sd, chan, rx2, tx1, cancel),
None => writer_task(img_size, sd, chan, rx2, tx1, cancel),
}?;
tracing::info!("Total Time taken: {:?}", global_start.elapsed());
handle.join().unwrap()
})
}
pub async fn flash<R: Read + Send + 'static>(
img: impl Future<Output = std::io::Result<(R, u64)>>,
bmap: Option<impl Future<Output = std::io::Result<Box<str>>>>,
dst: Box<Path>,
chan: Option<mpsc::Sender<f32>>,
customizations: Vec<Customization>,
cancel: Option<tokio_util::sync::CancellationToken>,
) -> Result<()> {
tracing::info!("Opening Destination");
let dst_clone = dst.to_path_buf();
let sd = crate::pal::open(&dst_clone).await?;
tracing::info!("Resolving Image");
let bmap = match bmap {
Some(x) => {
Some(bb_bmap_parser::Bmap::from_xml(&x.await?).map_err(|_| crate::Error::InvalidBmap)?)
}
None => None,
};
let (img, img_size) = img.await?;
let cancel_child = cancel.as_ref().map(|x| x.child_token());
let res = tokio::task::spawn_blocking(move || {
flash_internal(img, img_size, bmap, sd, chan, customizations, cancel_child)
})
.await
.unwrap();
let _drop_guard = cancel.map(|x| x.drop_guard());
res
}
fn flash_internal(
img: impl Read + Send,
img_size: u64,
bmap: Option<bb_bmap_parser::Bmap>,
sd: impl Read + Write + Seek + Eject + std::fmt::Debug,
mut chan: Option<mpsc::Sender<f32>>,
customizations: Vec<Customization>,
cancel: Option<tokio_util::sync::CancellationToken>,
) -> Result<()> {
chan_send(chan.as_mut(), 0.0);
let mut sd = crate::helpers::SdCardWrapper::new(sd);
tracing::info!("Writing to SD Card");
write_sd(img, img_size, bmap, &mut sd, chan.as_mut(), cancel.clone())?;
check_token(cancel.as_ref())?;
tracing::info!("Applying customization");
for c in customizations {
let temp = crate::helpers::DeviceWrapper::new(&mut sd).unwrap();
c.customize(temp)?;
}
tracing::info!("Ejecting SD Card");
let _ = sd.eject();
Ok(())
}
#[cfg(test)]
mod tests {
use crate::flashing::{BUFFER_SIZE, read_aligned};
use super::write_sd;
fn test_file(len: usize) -> std::io::Cursor<Box<[u8]>> {
let data: Vec<u8> = (0..len)
.map(|x| x % 255)
.map(|x| u8::try_from(x).unwrap())
.collect();
std::io::Cursor::new(data.into())
}
#[test]
fn sd_write() {
const FILE_LEN: usize = 12 * 1024;
let dummy_file = test_file(FILE_LEN);
let mut sd = std::io::Cursor::new(Vec::<u8>::new());
write_sd(
dummy_file.clone(),
FILE_LEN as u64,
None,
&mut sd,
None,
None,
)
.unwrap();
assert_eq!(sd.get_ref().as_slice(), dummy_file.get_ref().as_ref());
}
#[test]
fn sd_write_bmap() {
const FILE_LEN: usize = 32 * 1024;
const BLOCK_LEN: u64 = BUFFER_SIZE as u64;
const BLOCKS: u64 = (FILE_LEN as u64) / BLOCK_LEN;
const MAPPED_BLOCKS: &[u64] = &[0, 2, BLOCKS - 1];
let dummy_file = test_file(FILE_LEN);
let mut sd = std::io::Cursor::new(vec![0u8; FILE_LEN]);
let mut bmap = bb_bmap_parser::Bmap::builder();
bmap.image_size(FILE_LEN as u64)
.block_size(BLOCK_LEN)
.blocks(BLOCKS)
.mapped_blocks(MAPPED_BLOCKS.len() as u64)
.checksum_type(bb_bmap_parser::HashType::Sha256);
for i in MAPPED_BLOCKS {
bmap.add_block_range(
*i,
*i,
bb_bmap_parser::HashValue::Sha256(Default::default()),
);
}
let bmap = bmap.build().unwrap();
write_sd(
dummy_file.clone(),
FILE_LEN as u64,
Some(bmap.clone()),
&mut sd,
None,
None,
)
.unwrap();
for i in 0..(BLOCKS as usize) {
let start = i * (BLOCK_LEN as usize);
let end = start + (BLOCK_LEN as usize);
if MAPPED_BLOCKS.contains(&(i as u64)) {
assert_eq!(
sd.get_ref().as_slice()[start..end],
dummy_file.get_ref().as_ref()[start..end]
);
} else {
assert_eq!(
&sd.get_ref().as_slice()[start..end],
[0u8; BLOCK_LEN as usize].as_slice()
);
}
}
}
struct UnalignedReader(std::io::Cursor<Box<[u8]>>);
impl UnalignedReader {
const fn as_slice(&self) -> &[u8] {
self.0.get_ref()
}
}
impl std::io::Read for UnalignedReader {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let count = std::cmp::min(self.0.get_ref().len() - self.0.position() as usize, 3);
let count = std::cmp::min(count, buf.len());
self.0.read(&mut buf[..count])
}
}
#[test]
fn aligned_read() {
const FILE_LEN: usize = 12 * 1024;
let mut dummy_file = UnalignedReader(test_file(FILE_LEN));
let mut buf = [0u8; 1024];
let mut pos = 0;
loop {
let count = read_aligned(&mut dummy_file, &mut buf).unwrap();
if count == 0 {
break;
}
assert_eq!(count % 512, 0);
assert_eq!(buf[..count], dummy_file.as_slice()[pos..(pos + count)]);
pos += count;
}
assert_eq!(pos, FILE_LEN);
}
}