wireband-edge 0.4.1

Lightweight Wire.Band client — semantic data middleware for any domain (IoT, AI/ML, DeFi, legal, geospatial, supply chain, and more)
Documentation
//! OTA firmware update pipeline.
//!
//! Downloads a firmware image in streaming chunks, verifies SHA-256,
//! then atomically applies it. Emits [`EDGE_OTA_*`] symbols at each phase
//! so the Wire.Band backend can track fleet update progress.
//!
//! # Update lifecycle
//!
//! ```text
//! EDGE_OTA_START  →  download (chunks)  →  EDGE_OTA_VERIFY  →  EDGE_OTA_APPLY
//! ```

use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};

use sha2::{Digest, Sha256};
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tracing::{debug, error, info, warn};

use crate::client::WireBandClient;
use crate::error::{Result, WireBandError};
use crate::frame;
use crate::symbols::{EDGE_OTA_APPLY, EDGE_OTA_CHUNK, EDGE_OTA_START, EDGE_OTA_VERIFY};

fn unix_ts() -> f64 {
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_default()
        .as_secs_f64()
}

/// Describes a firmware update to apply.
#[derive(Debug, Clone)]
pub struct OtaUpdate {
    /// URL of the firmware image to download.
    pub url: String,
    /// Where to write the final binary on the device.
    pub target_path: PathBuf,
    /// Expected lowercase hex SHA-256 of the downloaded image.
    /// If `None`, hash verification is skipped (not recommended for production).
    pub expected_sha256: Option<String>,
    /// Human-readable version string emitted with OTA events.
    pub version: String,
}

/// Manages firmware OTA updates with hash verification and atomic apply.
pub struct OtaManager {
    http:        reqwest::Client,
    staging_dir: PathBuf,
}

impl OtaManager {
    /// Create a new OTA manager. `staging_dir` is where partial downloads are buffered.
    pub fn new(staging_dir: impl Into<PathBuf>) -> Self {
        Self {
            http:        reqwest::Client::new(),
            staging_dir: staging_dir.into(),
        }
    }

    /// Run the full OTA pipeline: download → verify → apply.
    ///
    /// Emits `EDGE_OTA_START`, `EDGE_OTA_CHUNK` (per chunk), `EDGE_OTA_VERIFY`,
    /// and `EDGE_OTA_APPLY` to the Wire.Band backend.
    pub async fn run(&self, update: OtaUpdate, client: &WireBandClient) -> Result<()> {
        info!(version = %update.version, url = %update.url, "OTA update starting");
        self.emit(client, EDGE_OTA_START, &update.version, "start").await;

        // Download
        let staged = self.download(&update, client).await?;

        // Verify
        if let Some(ref expected) = update.expected_sha256 {
            if !self.verify(&staged, expected).await? {
                error!(version = %update.version, "OTA hash verification failed");
                return Err(WireBandError::Connection(format!(
                    "OTA hash mismatch for {} (expected {})",
                    update.url, expected
                )));
            }
            debug!(version = %update.version, "OTA hash verified");
        } else {
            warn!(version = %update.version, "OTA: skipping hash verification (no expected_sha256)");
        }
        self.emit(client, EDGE_OTA_VERIFY, &update.version, "verified").await;

        // Apply
        self.apply(&staged, &update.target_path).await?;
        self.emit(client, EDGE_OTA_APPLY, &update.version, "applied").await;

        info!(version = %update.version, target = %update.target_path.display(), "OTA complete");
        Ok(())
    }

    // -----------------------------------------------------------------------
    // Internal
    // -----------------------------------------------------------------------

    async fn download(&self, update: &OtaUpdate, client: &WireBandClient) -> Result<PathBuf> {
        fs::create_dir_all(&self.staging_dir).await
            .map_err(|e| WireBandError::Connection(e.to_string()))?;

        let staging_path = self.staging_dir.join("ota_staged.bin");
        let mut file = fs::File::create(&staging_path).await
            .map_err(|e| WireBandError::Connection(e.to_string()))?;

        let mut resp = self.http.get(&update.url).send().await?;
        let total = resp.content_length().unwrap_or(0);
        let mut downloaded: u64 = 0;
        let mut chunk_count: u64 = 0;

        while let Some(chunk) = resp.chunk().await? {
            file.write_all(&chunk).await
                .map_err(|e| WireBandError::Connection(e.to_string()))?;

            downloaded  += chunk.len() as u64;
            chunk_count += 1;

            // Emit a chunk event every 64 chunks to avoid flooding the backend
            if chunk_count % 64 == 0 {
                let data = serde_json::json!({
                    "version":    update.version,
                    "downloaded": downloaded,
                    "total":      total,
                    "chunks":     chunk_count,
                });
                let topic   = format!("ota/chunk/{}", update.version);
                let encoded = frame::encode(EDGE_OTA_CHUNK, &topic, &data);
                client.buffer_event(topic, EDGE_OTA_CHUNK, encoded, unix_ts()).await;
            }
        }

        file.flush().await
            .map_err(|e| WireBandError::Connection(e.to_string()))?;

        debug!(bytes = downloaded, chunks = chunk_count, "OTA download complete");
        Ok(staging_path)
    }

    async fn verify(&self, path: &Path, expected: &str) -> Result<bool> {
        let data = fs::read(path).await
            .map_err(|e| WireBandError::Connection(e.to_string()))?;
        let mut hasher = Sha256::new();
        hasher.update(&data);
        let hash = format!("{:x}", hasher.finalize());
        Ok(hash == expected.to_ascii_lowercase())
    }

    /// Atomic rename: staged → target. Works within the same filesystem.
    /// Falls back to copy+delete when crossing filesystem boundaries.
    async fn apply(&self, staged: &Path, target: &Path) -> Result<()> {
        if let Some(parent) = target.parent() {
            fs::create_dir_all(parent).await
                .map_err(|e| WireBandError::Connection(e.to_string()))?;
        }

        // Try atomic rename first
        match fs::rename(staged, target).await {
            Ok(()) => Ok(()),
            Err(_) => {
                // Cross-filesystem: copy then delete
                fs::copy(staged, target).await
                    .map_err(|e| WireBandError::Connection(e.to_string()))?;
                let _ = fs::remove_file(staged).await;
                Ok(())
            }
        }
    }

    async fn emit(&self, client: &WireBandClient, symbol: u16, version: &str, phase: &str) {
        let topic   = format!("ota/{phase}");
        let data    = serde_json::json!({ "version": version, "phase": phase });
        let encoded = frame::encode(symbol, &topic, &data);
        client.buffer_event(topic, symbol, encoded, unix_ts()).await;
    }
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

#[cfg(test)]
mod tests {
    use super::*;
    use tempfile::TempDir;

    #[tokio::test]
    async fn verify_correct_hash() {
        let dir: TempDir = tempfile::tempdir().unwrap();
        let mgr = OtaManager::new(dir.path());
        let content = b"hello world";
        let path = dir.path().join("test.bin");
        fs::write(&path, content).await.unwrap();

        let mut h = Sha256::new();
        h.update(content);
        let expected = format!("{:x}", h.finalize());

        assert!(mgr.verify(&path, &expected).await.unwrap());
    }

    #[tokio::test]
    async fn verify_wrong_hash() {
        let dir: TempDir = tempfile::tempdir().unwrap();
        let mgr = OtaManager::new(dir.path());
        let path = dir.path().join("test.bin");
        fs::write(&path, b"hello world").await.unwrap();
        assert!(!mgr.verify(&path, "deadbeef").await.unwrap());
    }

    #[tokio::test]
    async fn apply_renames_file() {
        let dir: TempDir = tempfile::tempdir().unwrap();
        let mgr = OtaManager::new(dir.path());
        let staged = dir.path().join("staged.bin");
        let target = dir.path().join("firmware.bin");
        fs::write(&staged, b"firmware").await.unwrap();
        mgr.apply(&staged, &target).await.unwrap();
        assert!(target.exists());
        assert!(!staged.exists());
    }
}