payload_dumper 0.7.6

A fast and efficient Android OTA payload dumper written in Rust
use anyhow::{Result, anyhow};
use async_compression::tokio::bufread::{BzDecoder, XzDecoder, ZstdDecoder};
use async_trait::async_trait;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use std::pin::Pin;
use std::sync::Arc;
use tokio::fs::File;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, BufReader};

pub use crate::PartitionUpdate;
use crate::args::Args;
use crate::{InstallOperation, install_operation};


const BUFREADER_SIZE: usize = 64 * 1024;    // 64 KB for BufReader (decompression streams)
const COPY_BUFFER_SIZE: usize = 128 * 1024; // 128 KB for direct copy operations

#[async_trait]
pub trait AsyncPayloadRead: Send + Sync {
    async fn stream_from(&self, offset: u64, length: u64)
    -> Result<Pin<Box<dyn AsyncRead + Send>>>;
}

#[async_trait]
impl<T: AsyncPayloadRead> AsyncPayloadRead for Arc<T> {
    async fn stream_from(
        &self,
        offset: u64,
        length: u64,
    ) -> Result<Pin<Box<dyn AsyncRead + Send>>> {
        (**self).stream_from(offset, length).await
    }
}

#[async_trait]
impl AsyncPayloadRead for Arc<dyn AsyncPayloadRead> {
    async fn stream_from(
        &self,
        offset: u64,
        length: u64,
    ) -> Result<Pin<Box<dyn AsyncRead + Send>>> {
        (**self).stream_from(offset, length).await
    }
}

/// custom copy function with configurable buffer size
async fn copy_with_buffer<R, W>(reader: &mut R, writer: &mut W) -> Result<u64>
where
    R: AsyncRead + Unpin,
    W: AsyncWriteExt + Unpin,
{
    let mut buf = vec![0u8; COPY_BUFFER_SIZE];
    let mut total = 0u64;
    
    loop {
        let n = reader.read(&mut buf).await?;
        if n == 0 {
            break;
        }
        writer.write_all(&buf[..n]).await?;
        total += n as u64;
    }
    
    Ok(total)
}

async fn process_operation_streaming<P: AsyncPayloadRead>(
    operation_index: usize,
    op: &InstallOperation,
    data_offset: u64,
    block_size: u64,
    payload_reader: &P,
    out_file: &mut File,
) -> Result<()> {
    let offset = data_offset + op.data_offset.unwrap_or(0);
    let length = op.data_length.unwrap_or(0);

    match op.r#type() {
        install_operation::Type::Replace => {
            let mut stream = payload_reader.stream_from(offset, length).await?;
            out_file
                .seek(std::io::SeekFrom::Start(
                    op.dst_extents[0].start_block.unwrap_or(0) * block_size,
                ))
                .await?;
            copy_with_buffer(&mut stream, out_file).await?;
        }
        install_operation::Type::ReplaceXz => {
            let stream = payload_reader.stream_from(offset, length).await?;
            let mut decoder = XzDecoder::new(BufReader::with_capacity(BUFREADER_SIZE, stream));
            out_file
                .seek(std::io::SeekFrom::Start(
                    op.dst_extents[0].start_block.unwrap_or(0) * block_size,
                ))
                .await?;
            match copy_with_buffer(&mut decoder, out_file).await {
                Ok(_) => {}
                Err(e) => {
                    println!(
                        "  Warning: Skipping operation {} due to XZ decompression error: {}",
                        operation_index, e
                    );
                    return Ok(());
                }
            }
        }
        install_operation::Type::ReplaceBz => {
            let stream = payload_reader.stream_from(offset, length).await?;
            let mut decoder = BzDecoder::new(BufReader::with_capacity(BUFREADER_SIZE, stream));
            out_file
                .seek(std::io::SeekFrom::Start(
                    op.dst_extents[0].start_block.unwrap_or(0) * block_size,
                ))
                .await?;
            match copy_with_buffer(&mut decoder, out_file).await {
                Ok(_) => {}
                Err(e) => {
                    println!(
                        "  Warning: Skipping operation {} due to BZ2 decompression error: {}",
                        operation_index, e
                    );
                    return Ok(());
                }
            }
        }
        install_operation::Type::Zstd => {
            let stream = payload_reader.stream_from(offset, length).await?;
            let mut decoder = ZstdDecoder::new(BufReader::with_capacity(BUFREADER_SIZE, stream));

            if op.dst_extents.len() == 1 {
                out_file
                    .seek(std::io::SeekFrom::Start(
                        op.dst_extents[0].start_block.unwrap_or(0) * block_size,
                    ))
                    .await?;
                match copy_with_buffer(&mut decoder, out_file).await {
                    Ok(_) => {}
                    Err(e) => {
                        println!(
                            "  Warning: Skipping operation {} due to Zstd decompression error: {}",
                            operation_index, e
                        );
                        return Ok(());
                    }
                }
            } else {
                let mut decompressed = Vec::new();
                match decoder.read_to_end(&mut decompressed).await {
                    Ok(_) => {
                        let mut pos = 0;
                        for ext in &op.dst_extents {
                            let ext_size = (ext.num_blocks.unwrap_or(0) * block_size) as usize;
                            let end_pos = pos + ext_size;
                            if end_pos <= decompressed.len() {
                                out_file
                                    .seek(std::io::SeekFrom::Start(
                                        ext.start_block.unwrap_or(0) * block_size,
                                    ))
                                    .await?;
                                out_file.write_all(&decompressed[pos..end_pos]).await?;
                                pos = end_pos;
                            } else {
                                println!(
                                    "  Warning: Skipping extent in operation {} due to insufficient data.",
                                    operation_index
                                );
                                break;
                            }
                        }
                    }
                    Err(e) => {
                        println!(
                            "  Warning: Skipping operation {} due to Zstd error: {}",
                            operation_index, e
                        );
                        return Ok(());
                    }
                }
            }
        }
        install_operation::Type::Zero => {
            let zeros = vec![0u8; block_size as usize];
            for ext in &op.dst_extents {
                out_file
                    .seek(std::io::SeekFrom::Start(
                        ext.start_block.unwrap_or(0) * block_size,
                    ))
                    .await?;
                for _ in 0..ext.num_blocks.unwrap_or(0) {
                    out_file.write_all(&zeros).await?;
                }
            }
        }
        install_operation::Type::SourceCopy
        | install_operation::Type::SourceBsdiff
        | install_operation::Type::BrotliBsdiff
        | install_operation::Type::Lz4diffBsdiff => {
            return Err(anyhow!(
                "Operation {} is a differential OTA operation which is not supported",
                operation_index
            ));
        }
        _ => {
            println!(
                "  Warning: Skipping operation {} due to unknown operation type",
                operation_index
            );
            return Ok(());
        }
    }
    Ok(())
}

pub async fn dump_partition<P: AsyncPayloadRead>(
    partition: &PartitionUpdate,
    data_offset: u64,
    block_size: u64,
    args: &Args,
    payload_reader: &P,
    multi_progress: Option<&MultiProgress>,
) -> Result<()> {
    let partition_name = &partition.partition_name;
    let total_ops = partition.operations.len() as u64;

    let progress_bar = if let Some(mp) = multi_progress {
        let pb = mp.add(ProgressBar::new(100));
        pb.set_style(
            ProgressStyle::default_bar()
                .template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/white}] {percent}% - {msg}")
                .unwrap()
                .progress_chars("▰▱ "),
        );
        pb.enable_steady_tick(tokio::time::Duration::from_secs(1));
        pb.set_message(format!("Processing {} ({} ops)", partition_name, total_ops));
        Some(pb)
    } else {
        None
    };

    let out_dir = &args.out;
    if args.out.to_string_lossy() != "-" {
        tokio::fs::create_dir_all(out_dir).await?;
    }

    let out_path = out_dir.join(format!("{}.img", partition_name));
    let mut out_file = File::create(&out_path).await?;

    if let Some(info) = &partition.new_partition_info {
        if let Some(size) = info.size {
            out_file.set_len(size).await?;
        } else {
            return Err(anyhow!("Partition size is missing"));
        }
    }

    for (i, op) in partition.operations.iter().enumerate() {
        process_operation_streaming(
            i,
            op,
            data_offset,
            block_size,
            payload_reader,
            &mut out_file,
        )
        .await?;

        if let Some(pb) = &progress_bar {
            let percentage = ((i + 1) as f64 / total_ops as f64 * 100.0) as u64;
            pb.set_position(percentage);
        }
    }

    out_file.flush().await?;

    if let Some(pb) = progress_bar {
        pb.finish_with_message(format!(
            "✓ Completed {} ({} ops)",
            partition_name, total_ops
        ));
    }

    Ok(())
}