payload_dumper 0.8.0

A fast and efficient Android OTA payload dumper
// SPDX-License-Identifier: Apache-2.0
// Copyright (c) 2025 rhythmcache
// https://github.com/rhythmcache/payload-dumper-rust
//
// This file is part of payload-dumper-rust. It implements components used for
// extracting and processing Android OTA payloads.

use ahash::AHashMap as HashMap;
use anyhow::{Result, anyhow};
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use tempfile::TempDir;
use tokio::fs::File;
use tokio::io::AsyncWriteExt;
use tokio::sync::Semaphore;

use crate::args::Args;
use crate::http::HttpReader;
use crate::payload::payload_dumper::{AsyncPayloadRead, PayloadReader, dump_partition};
use crate::readers::local_reader::LocalAsyncPayloadReader;
use crate::utils::format_elapsed_time;
use crate::verify::verify_partitions_hash;
use crate::{DeltaArchiveManifest, PartitionUpdate};
use std::path::Path;

/// Information about the data range needed for a partition
#[derive(Debug, Clone)]
pub struct PartitionDataRange {
    pub min_offset: u64,
    pub total_bytes: u64,
}

/// Calculate the min/max data offsets for all operations in a partition
pub fn calculate_partition_range(
    partition: &PartitionUpdate,
    data_offset: u64,
) -> Option<PartitionDataRange> {
    let mut min_offset = u64::MAX;
    let mut max_offset = 0u64;
    let mut ops_with_data = 0;

    for op in &partition.operations {
        // only consider operations that actually read from payload data
        if let (Some(offset), Some(length)) = (op.data_offset, op.data_length)
            && length > 0 {
                let abs_offset = data_offset + offset;
                let end_offset = abs_offset + length;

                min_offset = min_offset.min(abs_offset);
                max_offset = max_offset.max(end_offset);
                ops_with_data += 1;
            }
    }

    if ops_with_data == 0 || min_offset == u64::MAX {
        return None;
    }

    Some(PartitionDataRange {
        min_offset,
        total_bytes: max_offset - min_offset,
    })
}

async fn download_partition_data_to_path(
    http_reader: &HttpReader,
    range: &PartitionDataRange,
    temp_dir_path: &Path,
    partition_name: &str,
    progress_bar: &ProgressBar,
) -> Result<PathBuf> {
    progress_bar.set_message(format!(
        "Downloading {} ({:.2} MB)",
        partition_name,
        range.total_bytes as f64 / 1024.0 / 1024.0
    ));

    let temp_path = temp_dir_path.join(format!("{}.prefetch", partition_name));
    let mut file = File::create(&temp_path).await?;

    const BUFFER_SIZE: usize = 256 * 1024; // 256 KB buffer for reading
    let mut buffer = vec![0u8; BUFFER_SIZE];
    let mut downloaded = 0u64;
    let total = range.total_bytes;
    let mut current_offset = range.min_offset;

    while downloaded < total {
        let remaining = total - downloaded;
        let chunk_size = remaining.min(BUFFER_SIZE as u64) as usize;

        // Read chunk from HTTP
        http_reader
            .read_at(current_offset, &mut buffer[..chunk_size])
            .await?;

        // Write to file
        file.write_all(&buffer[..chunk_size]).await?;

        downloaded += chunk_size as u64;
        current_offset += chunk_size as u64;

        let percent = (downloaded as f64 / total as f64 * 100.0) as u64;
        progress_bar.set_position(percent);
    }

    file.flush().await?;
    drop(file);

    progress_bar.finish_with_message(format!("✓ Downloaded {}", partition_name));

    Ok(temp_path)
}

/// Wrapper reader that translates offsets from absolute to relative
struct OffsetTranslatingReader {
    inner: LocalAsyncPayloadReader,
    base_offset: u64,
}

impl OffsetTranslatingReader {
    async fn new(path: PathBuf, base_offset: u64) -> Result<Self> {
        let inner = LocalAsyncPayloadReader::new(path).await?;
        Ok(Self { inner, base_offset })
    }
}

#[async_trait::async_trait]
impl AsyncPayloadRead for OffsetTranslatingReader {
    async fn open_reader(&self) -> Result<Box<dyn PayloadReader>> {
        // open the inner reader
        let inner_reader = self.inner.open_reader().await?;

        // wrap it with offset translation
        Ok(Box::new(OffsetTranslatingPayloadReader {
            inner: inner_reader,
            base_offset: self.base_offset,
        }))
    }
}

struct OffsetTranslatingPayloadReader {
    inner: Box<dyn PayloadReader>,
    base_offset: u64,
}

#[async_trait::async_trait]
impl PayloadReader for OffsetTranslatingPayloadReader {
    async fn read_range(
        &mut self,
        offset: u64,
        length: u64,
    ) -> Result<std::pin::Pin<Box<dyn tokio::io::AsyncRead + Send + '_>>> {
        // translate absolute offset to relative offset in temp file
        if offset < self.base_offset {
            return Err(anyhow!(
                "Offset {} is before base offset {}",
                offset,
                self.base_offset
            ));
        }

        let relative_offset = offset - self.base_offset;
        self.inner.read_range(relative_offset, length).await
    }
}

pub async fn prefetch_and_extract(
    url: String,
    manifest: DeltaArchiveManifest,
    data_offset: u64,
    args: Arc<Args>,
    partitions_to_extract: Vec<PartitionUpdate>,
    multi_progress: Arc<MultiProgress>,
    //file_type: RemoteFileType,
) -> Result<()> {
    let start_time = Instant::now();

    let main_pb = multi_progress.add(ProgressBar::new_spinner());
    main_pb.set_style(
        ProgressStyle::default_spinner()
            .template("{spinner:.blue} {msg}")
            .unwrap(),
    );
    main_pb.enable_steady_tick(tokio::time::Duration::from_millis(300));

    main_pb.set_message("Initializing prefetch mode...");

    // Create temporary directory
    let temp_dir = TempDir::new()?;
    main_pb.println(format!(
        "- Created temporary directory: {:?}",
        temp_dir.path()
    ));

    // Calculate ranges for all partitions
    let mut partition_info: HashMap<String, PartitionDataRange> = HashMap::new();
    let mut total_download_size = 0u64;

    for partition in &partitions_to_extract {
        if let Some(range) = calculate_partition_range(partition, data_offset) {
            total_download_size += range.total_bytes;
            partition_info.insert(partition.partition_name.clone(), range);
        }
    }

    main_pb.println(format!(
        "- Total data to download: {:.2} MB across {} partitions",
        total_download_size as f64 / 1024.0 / 1024.0,
        partition_info.len()
    ));

    let thread_count = if args.no_parallel {
        1
    } else if let Some(threads) = args.threads {
        threads
    } else {
        num_cpus::get()
    };

    // Get block size for extraction
    let block_size = manifest.block_size.unwrap_or(4096) as u64;

    // Create a single HttpReader that will be shared
    let http_reader = Arc::new(HttpReader::new(url.clone(), args.user_agent.as_deref()).await?);

    // Download and extract partitions as soon as each download completes
    main_pb.set_message("Downloading and extracting partitions...");

    let download_semaphore = Arc::new(Semaphore::new(thread_count));
    let extract_semaphore = Arc::new(Semaphore::new(thread_count));
    let mut combined_tasks = Vec::new();

    for partition in &partitions_to_extract {
        let partition_name = partition.partition_name.clone();

        if let Some(range) = partition_info.get(&partition_name) {
            let range = range.clone();
            let temp_dir_path = temp_dir.path().to_path_buf();
            let partition = partition.clone();
            let args = Arc::clone(&args);
            let multi_progress = Arc::clone(&multi_progress);
            let http_reader = Arc::clone(&http_reader);

            let download_pb = multi_progress.add(ProgressBar::new(100));
            download_pb.set_style(
                ProgressStyle::default_bar()
                    .template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/white}] {percent}% - {msg}")
                    .unwrap()
                    .progress_chars("▰▱ "),
            );
            download_pb.enable_steady_tick(tokio::time::Duration::from_secs(1));

            let download_semaphore = Arc::clone(&download_semaphore);
            let extract_semaphore = Arc::clone(&extract_semaphore);

            // Spawn combined download + extract task
            let task = tokio::spawn(async move {
                // Download
                let temp_path = {
                    let _permit = download_semaphore.acquire().await.unwrap();

                    let temp_path = download_partition_data_to_path(
                        &http_reader,
                        &range,
                        &temp_dir_path,
                        &partition_name,
                        &download_pb,
                    )
                    .await
                    .map_err(|e| (partition_name.clone(), e))?;

                    // Release download permit before extraction
                    drop(_permit);
                    temp_path
                };

                // Extract immediately after download completes
                let _permit = extract_semaphore.acquire().await.unwrap();

                let reader = OffsetTranslatingReader::new(temp_path, range.min_offset)
                    .await
                    .map(|r| Arc::new(r) as Arc<dyn AsyncPayloadRead>)
                    .map_err(|e| (partition_name.clone(), e))?;

                dump_partition(
                    &partition,
                    data_offset,
                    block_size,
                    &args,
                    &reader,
                    Some(&multi_progress),
                )
                .await
                .map_err(|e| (partition_name, e))
            });

            combined_tasks.push(task);
        }
    }

    // Wait for all download+extract tasks to complete
    let results = futures::future::join_all(combined_tasks).await;

    let mut failed_partitions = Vec::new();
    for result in results {
        match result {
            Ok(Ok(())) => {}
            Ok(Err((partition_name, error))) => {
                eprintln!("Failed to process partition {}: {}", partition_name, error);
                failed_partitions.push(partition_name);
            }
            Err(e) => {
                eprintln!("Task panicked: {}", e);
            }
        }
    }

    main_pb.println("✓ All partitions downloaded and extracted");

    if !args.no_verify {
        main_pb.println("- Verifying partition hashes...");

        let partitions_to_verify: Vec<&PartitionUpdate> = partitions_to_extract
            .iter()
            .filter(|p| !failed_partitions.contains(&p.partition_name))
            .collect();

        match verify_partitions_hash(&partitions_to_verify, &args, &multi_progress).await {
            Ok(failed_verifications) => {
                if !failed_verifications.is_empty() {
                    eprintln!(
                        "Hash verification failed for {} partitions.",
                        failed_verifications.len()
                    );
                    failed_partitions.extend(failed_verifications);
                }
            }
            Err(e) => {
                eprintln!("Error during hash verification: {}", e);
            }
        }
    } else {
        main_pb.println("- Skipping hash verification");
    }

    let elapsed_time = format_elapsed_time(start_time.elapsed());

    if failed_partitions.is_empty() {
        main_pb.finish_with_message(format!(
            "All partitions extracted successfully! (in {})",
            elapsed_time
        ));
    } else {
        main_pb.finish_with_message(format!(
            "Completed with {} failed partitions. (in {})",
            failed_partitions.len(),
            elapsed_time
        ));
        println!(
            "\nExtraction completed with {} failed partitions in {}. Output directory: {:?}",
            failed_partitions.len(),
            elapsed_time,
            args.out
        );
    }

    Ok(())
}