use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
use sha2::{Digest, Sha256};
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tracing::{debug, error, info, warn};
use crate::client::WireBandClient;
use crate::error::{Result, WireBandError};
use crate::frame;
use crate::symbols::{EDGE_OTA_APPLY, EDGE_OTA_CHUNK, EDGE_OTA_START, EDGE_OTA_VERIFY};
fn unix_ts() -> f64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64()
}
#[derive(Debug, Clone)]
pub struct OtaUpdate {
pub url: String,
pub target_path: PathBuf,
pub expected_sha256: Option<String>,
pub version: String,
}
pub struct OtaManager {
http: reqwest::Client,
staging_dir: PathBuf,
}
impl OtaManager {
pub fn new(staging_dir: impl Into<PathBuf>) -> Self {
Self {
http: reqwest::Client::new(),
staging_dir: staging_dir.into(),
}
}
pub async fn run(&self, update: OtaUpdate, client: &WireBandClient) -> Result<()> {
info!(version = %update.version, url = %update.url, "OTA update starting");
self.emit(client, EDGE_OTA_START, &update.version, "start").await;
let staged = self.download(&update, client).await?;
if let Some(ref expected) = update.expected_sha256 {
if !self.verify(&staged, expected).await? {
error!(version = %update.version, "OTA hash verification failed");
return Err(WireBandError::Connection(format!(
"OTA hash mismatch for {} (expected {})",
update.url, expected
)));
}
debug!(version = %update.version, "OTA hash verified");
} else {
warn!(version = %update.version, "OTA: skipping hash verification (no expected_sha256)");
}
self.emit(client, EDGE_OTA_VERIFY, &update.version, "verified").await;
self.apply(&staged, &update.target_path).await?;
self.emit(client, EDGE_OTA_APPLY, &update.version, "applied").await;
info!(version = %update.version, target = %update.target_path.display(), "OTA complete");
Ok(())
}
async fn download(&self, update: &OtaUpdate, client: &WireBandClient) -> Result<PathBuf> {
fs::create_dir_all(&self.staging_dir).await
.map_err(|e| WireBandError::Connection(e.to_string()))?;
let staging_path = self.staging_dir.join("ota_staged.bin");
let mut file = fs::File::create(&staging_path).await
.map_err(|e| WireBandError::Connection(e.to_string()))?;
let mut resp = self.http.get(&update.url).send().await?;
let total = resp.content_length().unwrap_or(0);
let mut downloaded: u64 = 0;
let mut chunk_count: u64 = 0;
while let Some(chunk) = resp.chunk().await? {
file.write_all(&chunk).await
.map_err(|e| WireBandError::Connection(e.to_string()))?;
downloaded += chunk.len() as u64;
chunk_count += 1;
if chunk_count % 64 == 0 {
let data = serde_json::json!({
"version": update.version,
"downloaded": downloaded,
"total": total,
"chunks": chunk_count,
});
let topic = format!("ota/chunk/{}", update.version);
let encoded = frame::encode(EDGE_OTA_CHUNK, &topic, &data);
client.buffer_event(topic, EDGE_OTA_CHUNK, encoded, unix_ts()).await;
}
}
file.flush().await
.map_err(|e| WireBandError::Connection(e.to_string()))?;
debug!(bytes = downloaded, chunks = chunk_count, "OTA download complete");
Ok(staging_path)
}
async fn verify(&self, path: &Path, expected: &str) -> Result<bool> {
let data = fs::read(path).await
.map_err(|e| WireBandError::Connection(e.to_string()))?;
let mut hasher = Sha256::new();
hasher.update(&data);
let hash = format!("{:x}", hasher.finalize());
Ok(hash == expected.to_ascii_lowercase())
}
async fn apply(&self, staged: &Path, target: &Path) -> Result<()> {
if let Some(parent) = target.parent() {
fs::create_dir_all(parent).await
.map_err(|e| WireBandError::Connection(e.to_string()))?;
}
match fs::rename(staged, target).await {
Ok(()) => Ok(()),
Err(_) => {
fs::copy(staged, target).await
.map_err(|e| WireBandError::Connection(e.to_string()))?;
let _ = fs::remove_file(staged).await;
Ok(())
}
}
}
async fn emit(&self, client: &WireBandClient, symbol: u16, version: &str, phase: &str) {
let topic = format!("ota/{phase}");
let data = serde_json::json!({ "version": version, "phase": phase });
let encoded = frame::encode(symbol, &topic, &data);
client.buffer_event(topic, symbol, encoded, unix_ts()).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn verify_correct_hash() {
let dir: TempDir = tempfile::tempdir().unwrap();
let mgr = OtaManager::new(dir.path());
let content = b"hello world";
let path = dir.path().join("test.bin");
fs::write(&path, content).await.unwrap();
let mut h = Sha256::new();
h.update(content);
let expected = format!("{:x}", h.finalize());
assert!(mgr.verify(&path, &expected).await.unwrap());
}
#[tokio::test]
async fn verify_wrong_hash() {
let dir: TempDir = tempfile::tempdir().unwrap();
let mgr = OtaManager::new(dir.path());
let path = dir.path().join("test.bin");
fs::write(&path, b"hello world").await.unwrap();
assert!(!mgr.verify(&path, "deadbeef").await.unwrap());
}
#[tokio::test]
async fn apply_renames_file() {
let dir: TempDir = tempfile::tempdir().unwrap();
let mgr = OtaManager::new(dir.path());
let staged = dir.path().join("staged.bin");
let target = dir.path().join("firmware.bin");
fs::write(&staged, b"firmware").await.unwrap();
mgr.apply(&staged, &target).await.unwrap();
assert!(target.exists());
assert!(!staged.exists());
}
}