use std::fs::{self, File};
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use sha2::{Digest, Sha256};
use super::error::BootstrapError;
use super::manifest::{Manifest, ModelPaths};
const HASH_READ_BUF: usize = 64 * 1024;
pub fn verify_sha256_streaming(path: &Path, expected_hex: &str) -> Result<(), BootstrapError> {
let f = File::open(path).map_err(|e| BootstrapError::DiskFull {
path: path.to_path_buf(),
source: e,
})?;
let mut reader = BufReader::with_capacity(HASH_READ_BUF, f);
let mut hasher = Sha256::new();
let mut buf = vec![0u8; HASH_READ_BUF];
loop {
let n = reader
.read(&mut buf)
.map_err(|e| BootstrapError::DiskFull {
path: path.to_path_buf(),
source: e,
})?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
let actual_hex = hex::encode(hasher.finalize());
if !actual_hex.eq_ignore_ascii_case(expected_hex) {
return Err(BootstrapError::Sha256Mismatch {
expected: expected_hex.to_string(),
actual: actual_hex,
});
}
Ok(())
}
pub fn cleanup_partials(target_dir: &Path) -> std::io::Result<()> {
let entries = match fs::read_dir(target_dir) {
Ok(e) => e,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
Err(e) => return Err(e),
};
for ent in entries.flatten() {
let name = ent.file_name();
let name_str = name.to_string_lossy();
if name_str.starts_with(".partial.") {
let _ = fs::remove_file(ent.path());
}
}
Ok(())
}
pub fn assemble_chunks(
target_dir: &Path,
chunk_count: u32,
output_name: &str,
) -> Result<PathBuf, BootstrapError> {
let out_path = target_dir.join(output_name);
let out_file = File::create(&out_path).map_err(|e| BootstrapError::DiskFull {
path: out_path.clone(),
source: e,
})?;
let mut writer = BufWriter::with_capacity(HASH_READ_BUF, out_file);
for idx in 0..chunk_count {
let partial_path = target_dir.join(format!(".partial.{idx}"));
let f = File::open(&partial_path).map_err(|e| BootstrapError::DiskFull {
path: partial_path.clone(),
source: e,
})?;
let mut reader = BufReader::with_capacity(HASH_READ_BUF, f);
std::io::copy(&mut reader, &mut writer).map_err(|e| BootstrapError::DiskFull {
path: partial_path,
source: e,
})?;
}
writer.flush().map_err(|e| BootstrapError::DiskFull {
path: out_path.clone(),
source: e,
})?;
drop(writer);
Ok(out_path)
}
pub fn check_existing(target_dir: &Path, manifest: &Manifest) -> Option<ModelPaths> {
if !target_dir.exists() {
return None;
}
let mut onnx: Option<PathBuf> = None;
let mut tokenizer: Option<PathBuf> = None;
let mut config: Option<PathBuf> = None;
for f in &manifest.files {
let path = target_dir.join(&f.name);
if !path.exists() {
return None;
}
if verify_sha256_streaming(&path, &f.sha256).is_err() {
return None;
}
match f.name.as_str() {
"model_q4f16.onnx" => onnx = Some(path),
"tokenizer.json" => tokenizer = Some(path),
"config.json" => config = Some(path),
_ => {} }
}
Some(ModelPaths {
onnx: onnx?,
tokenizer: tokenizer?,
config: config?,
})
}
pub fn remove_artifact_best_effort(path: &Path) {
let _ = fs::remove_file(path);
}