use super::error::{PatchExecutorError, Result};
use crate::api_types::PatchPackageInfo;
use base64;
use flate2::read::GzDecoder;
use reqwest::Client;
use sha2::{Digest, Sha256};
use std::path::{Path, PathBuf};
use tar::Archive;
use tempfile::TempDir;
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tracing::{debug, info, warn};
pub struct PatchProcessor {
temp_dir: TempDir,
http_client: Client,
}
impl PatchProcessor {
pub fn new() -> Result<Self> {
let temp_dir = TempDir::new()
.map_err(|e| PatchExecutorError::custom(format!("Failed to create temp directory: {e}")))?;
let http_client = Client::builder()
.timeout(std::time::Duration::from_secs(300)) .build()
.map_err(|e| PatchExecutorError::custom(format!("Failed to create HTTP client: {e}")))?;
debug!("Creating patch processor, temp directory: {:?}", temp_dir.path());
Ok(Self {
temp_dir,
http_client,
})
}
pub async fn download_patch(&self, patch_info: &PatchPackageInfo) -> Result<PathBuf> {
info!("Starting to download patch package: {}", patch_info.url);
let patch_path = self.temp_dir.path().join("patch.tar.gz");
let response = self
.http_client
.get(&patch_info.url)
.send()
.await
.map_err(|e| PatchExecutorError::download_failed(format!("HTTP request failed: {e}")))?;
if !response.status().is_success() {
return Err(PatchExecutorError::download_failed(format!(
"HTTP status code error: {}",
response.status()
)));
}
let total_size = response.content_length().unwrap_or(0);
debug!("Patch package size: {} bytes", total_size);
let mut file = fs::File::create(&patch_path).await?;
let mut downloaded = 0u64;
let mut stream = response.bytes_stream();
use futures_util::StreamExt;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result
.map_err(|e| PatchExecutorError::download_failed(format!("Failed to download data chunk: {e}")))?;
file.write_all(&chunk).await?;
downloaded += chunk.len() as u64;
if total_size > 0 {
let progress = (downloaded as f64 / total_size as f64) * 100.0;
debug!("Download progress: {:.1}%", progress);
}
}
file.flush().await?;
info!("Patch package download completed: {:?} ({} bytes)", patch_path, downloaded);
Ok(patch_path)
}
pub async fn verify_patch_integrity(
&self,
patch_path: &Path,
patch_info: &PatchPackageInfo,
) -> Result<()> {
info!("Verifying patch integrity: {:?}", patch_path);
if !patch_path.exists() {
return Err(PatchExecutorError::verification_failed("Patch file does not exist"));
}
if let Some(hash) = &patch_info.hash {
self.verify_hash(patch_path, hash).await?;
}
if let Some(signature) = &patch_info.signature {
self.verify_signature(patch_path, signature).await?;
}
info!("Patch integrity verification passed");
Ok(())
}
async fn verify_hash(&self, file_path: &Path, expected_hash: &str) -> Result<()> {
debug!("Verifying file hash: {:?}", file_path);
let expected_hash = if expected_hash.starts_with("sha256:") {
&expected_hash[7..]
} else {
expected_hash
};
let file_content = fs::read(file_path).await?;
let mut hasher = Sha256::new();
hasher.update(&file_content);
let actual_hash = format!("{:x}", hasher.finalize());
if actual_hash != expected_hash {
return Err(PatchExecutorError::hash_mismatch(
expected_hash.to_string(),
actual_hash,
));
}
debug!("Hash verification passed: {}", actual_hash);
Ok(())
}
async fn verify_signature(&self, _file_path: &Path, signature: &str) -> Result<()> {
debug!("Verifying digital signature: {}", signature);
if signature.is_empty() {
warn!("Digital signature is empty, skipping verification");
return Ok(());
}
use base64::{Engine as _, engine::general_purpose};
if general_purpose::STANDARD.decode(signature).is_err() {
return Err(PatchExecutorError::signature_verification_failed(
"Signature is not a valid base64 format",
));
}
debug!("Digital signature verification passed (simplified verification)");
Ok(())
}
pub async fn extract_patch(&self, patch_path: &Path) -> Result<PathBuf> {
info!("Extracting patch package: {:?}", patch_path);
let extract_dir = self.temp_dir.path().join("extracted");
fs::create_dir_all(&extract_dir).await?;
let patch_path_clone = patch_path.to_owned();
let extract_dir_clone = extract_dir.clone();
tokio::task::spawn_blocking(move || {
Self::extract_tar_gz(&patch_path_clone, &extract_dir_clone)
})
.await
.map_err(|e| PatchExecutorError::extraction_failed(format!("Extraction task failed: {e}")))??;
info!("Patch package extracted: {:?}", extract_dir);
Ok(extract_dir)
}
fn extract_tar_gz(archive_path: &Path, extract_to: &Path) -> Result<()> {
let file = std::fs::File::open(archive_path)?;
let decoder = GzDecoder::new(file);
let mut archive = Archive::new(decoder);
for entry_result in archive.entries()? {
let mut entry = entry_result
.map_err(|e| PatchExecutorError::extraction_failed(format!("Failed to read entry: {e}")))?;
let path = entry.path().map_err(|e| {
PatchExecutorError::extraction_failed(format!("Failed to get file path: {e}"))
})?;
let path_buf = path.to_path_buf();
if path_buf.is_absolute()
|| path_buf
.components()
.any(|c| c == std::path::Component::ParentDir)
{
return Err(PatchExecutorError::extraction_failed(format!(
"Unsafe file path: {path_buf:?}"
)));
}
let extract_path = extract_to.join(&path_buf);
if let Some(parent) = extract_path.parent() {
std::fs::create_dir_all(parent)?;
}
entry.unpack(&extract_path).map_err(|e| {
PatchExecutorError::extraction_failed(format!("Failed to unpack file {path_buf:?}: {e}"))
})?;
debug!("Extracting file: {:?} -> {:?}", path_buf, extract_path);
}
Ok(())
}
pub fn temp_dir(&self) -> &Path {
self.temp_dir.path()
}
pub async fn list_extracted_files(&self) -> Result<Vec<PathBuf>> {
let extract_dir = self.temp_dir.path().join("extracted");
if !extract_dir.exists() {
return Ok(Vec::new());
}
let mut files = Vec::new();
let mut read_dir = fs::read_dir(&extract_dir).await?;
while let Some(entry) = read_dir.next_entry().await? {
let path = entry.path();
if path.is_file() {
if let Ok(relative_path) = path.strip_prefix(&extract_dir) {
files.push(relative_path.to_owned());
}
}
}
Ok(files)
}
pub async fn validate_extracted_structure(&self, required_files: &[String]) -> Result<()> {
let extract_dir = self.temp_dir.path().join("extracted");
for required_file in required_files {
let file_path = extract_dir.join(required_file);
if !file_path.exists() {
return Err(PatchExecutorError::verification_failed(format!(
"Required file does not exist: {required_file}"
)));
}
}
debug!("Extracted file structure verification passed");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::fs;
#[tokio::test]
async fn test_patch_processor_creation() {
let processor = PatchProcessor::new();
assert!(processor.is_ok());
}
#[tokio::test]
async fn test_temp_dir_access() {
let processor = PatchProcessor::new().unwrap();
let temp_path = processor.temp_dir();
assert!(temp_path.exists());
assert!(temp_path.is_dir());
}
#[tokio::test]
async fn test_hash_verification() {
let processor = PatchProcessor::new().unwrap();
let test_file = processor.temp_dir().join("test.txt");
let content = b"hello world";
fs::write(&test_file, content).await.unwrap();
let mut hasher = Sha256::new();
hasher.update(content);
let expected_hash = format!("sha256:{:x}", hasher.finalize());
let result = processor.verify_hash(&test_file, &expected_hash).await;
assert!(result.is_ok());
let wrong_hash = "sha256:wronghash";
let result = processor.verify_hash(&test_file, wrong_hash).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_signature_verification() {
let processor = PatchProcessor::new().unwrap();
let test_file = processor.temp_dir().join("test.txt");
fs::write(&test_file, b"test").await.unwrap();
use base64::{Engine as _, engine::general_purpose};
let valid_signature = general_purpose::STANDARD.encode("test signature");
let result = processor
.verify_signature(&test_file, &valid_signature)
.await;
assert!(result.is_ok());
let invalid_signature = "invalid!@#$%";
let result = processor
.verify_signature(&test_file, invalid_signature)
.await;
assert!(result.is_err());
let result = processor.verify_signature(&test_file, "").await;
assert!(result.is_ok()); }
#[tokio::test]
async fn test_tar_gz_extraction() {
let processor = PatchProcessor::new().unwrap();
let tar_path = processor.temp_dir().join("test.tar.gz");
let extract_dir = processor.temp_dir().join("extract_test");
fs::create_dir_all(&extract_dir).await.unwrap();
create_test_tar_gz(&tar_path).unwrap();
let result = PatchProcessor::extract_tar_gz(&tar_path, &extract_dir);
assert!(result.is_ok());
let extracted_file = extract_dir.join("test.txt");
assert!(extracted_file.exists());
}
#[tokio::test]
async fn test_list_extracted_files() {
let processor = PatchProcessor::new().unwrap();
let extract_dir = processor.temp_dir().join("extracted");
fs::create_dir_all(&extract_dir).await.unwrap();
fs::write(extract_dir.join("file1.txt"), "content1")
.await
.unwrap();
fs::write(extract_dir.join("file2.txt"), "content2")
.await
.unwrap();
let files = processor.list_extracted_files().await.unwrap();
assert_eq!(files.len(), 2);
assert!(files.iter().any(|f| f.file_name().unwrap() == "file1.txt"));
assert!(files.iter().any(|f| f.file_name().unwrap() == "file2.txt"));
}
#[tokio::test]
async fn test_validate_extracted_structure() {
let processor = PatchProcessor::new().unwrap();
let extract_dir = processor.temp_dir().join("extracted");
fs::create_dir_all(&extract_dir).await.unwrap();
fs::write(extract_dir.join("required1.txt"), "content")
.await
.unwrap();
fs::write(extract_dir.join("required2.txt"), "content")
.await
.unwrap();
let required_files = vec!["required1.txt".to_string(), "required2.txt".to_string()];
let result = processor
.validate_extracted_structure(&required_files)
.await;
assert!(result.is_ok());
let missing_files = vec!["missing.txt".to_string()];
let result = processor.validate_extracted_structure(&missing_files).await;
assert!(result.is_err());
}
fn create_test_tar_gz(output_path: &Path) -> std::io::Result<()> {
use flate2::Compression;
use flate2::write::GzEncoder;
let tar_gz = std::fs::File::create(output_path)?;
let enc = GzEncoder::new(tar_gz, Compression::default());
let mut tar = tar::Builder::new(enc);
let mut header = tar::Header::new_gnu();
header.set_path("test.txt")?;
header.set_size(12);
header.set_cksum();
tar.append(&header, "hello world\n".as_bytes())?;
tar.finish()?;
Ok(())
}
}