Skip to main content

barbacane_lib/
hot_reload.rs

1//! Hot-reload functionality for zero-downtime artifact updates.
2//!
3//! This module handles downloading, verifying, and applying new artifacts
4//! while the gateway continues serving requests.
5
6use futures_util::StreamExt;
7use sha2::{Digest, Sha256};
8use std::path::{Path, PathBuf};
9use tokio::io::AsyncWriteExt;
10use uuid::Uuid;
11
12/// Result of a hot-reload attempt.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum HotReloadResult {
15    /// Hot-reload completed successfully.
16    Success { artifact_id: Uuid },
17    /// Hot-reload failed with an error.
18    Failed { artifact_id: Uuid, error: String },
19}
20
21/// Lock to prevent concurrent hot-reloads.
22pub static HOT_RELOAD_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
23
24/// Download an artifact from the control plane and verify its checksum.
25///
26/// # Arguments
27/// * `http_client` - The HTTP client to use for downloading
28/// * `download_url` - URL to download the artifact from
29/// * `expected_sha256` - Expected SHA256 hash of the artifact (hex-encoded)
30/// * `artifact_dir` - Directory to store the downloaded artifact
31///
32/// # Returns
33/// The path to the downloaded and verified artifact, or an error message.
34pub async fn download_artifact(
35    http_client: &reqwest::Client,
36    download_url: &str,
37    expected_sha256: &str,
38    artifact_dir: &Path,
39) -> Result<PathBuf, String> {
40    // Create temp file path
41    let temp_filename = format!("artifact-{}.bca.tmp", Uuid::new_v4());
42    let temp_path = artifact_dir.join(&temp_filename);
43
44    tracing::info!(
45        download_url = %download_url,
46        temp_path = %temp_path.display(),
47        "Downloading artifact"
48    );
49
50    // Start download
51    let response = http_client
52        .get(download_url)
53        .send()
54        .await
55        .map_err(|e| format!("download request failed: {}", e))?;
56
57    if !response.status().is_success() {
58        return Err(format!(
59            "download failed with status: {}",
60            response.status()
61        ));
62    }
63
64    // Create temp file
65    let mut file = tokio::fs::File::create(&temp_path)
66        .await
67        .map_err(|e| format!("failed to create temp file: {}", e))?;
68
69    // Stream to file while computing SHA256
70    let mut hasher = Sha256::new();
71    let mut stream = response.bytes_stream();
72    let mut total_bytes = 0u64;
73
74    while let Some(chunk_result) = stream.next().await {
75        let chunk = chunk_result.map_err(|e| format!("download stream error: {}", e))?;
76        hasher.update(&chunk);
77        file.write_all(&chunk)
78            .await
79            .map_err(|e| format!("write error: {}", e))?;
80        total_bytes += chunk.len() as u64;
81    }
82
83    file.flush()
84        .await
85        .map_err(|e| format!("flush error: {}", e))?;
86    drop(file);
87
88    // Verify checksum
89    let computed_sha256 = hex::encode(hasher.finalize());
90    if computed_sha256 != expected_sha256 {
91        let _ = tokio::fs::remove_file(&temp_path).await;
92        return Err(format!(
93            "checksum mismatch: expected {}, got {}",
94            expected_sha256, computed_sha256
95        ));
96    }
97
98    // Atomic rename to final path
99    let final_filename = format!("artifact-{}.bca", Uuid::new_v4());
100    let final_path = artifact_dir.join(&final_filename);
101    tokio::fs::rename(&temp_path, &final_path)
102        .await
103        .map_err(|e| format!("rename failed: {}", e))?;
104
105    tracing::info!(
106        final_path = %final_path.display(),
107        size_bytes = total_bytes,
108        sha256 = %computed_sha256,
109        "Artifact downloaded and verified"
110    );
111
112    Ok(final_path)
113}
114
115/// Compute SHA256 hash of data (hex-encoded).
116pub fn compute_sha256(data: &[u8]) -> String {
117    let mut hasher = Sha256::new();
118    hasher.update(data);
119    hex::encode(hasher.finalize())
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use std::net::SocketAddr;
126    use tokio::io::AsyncReadExt;
127    use tokio::net::TcpListener;
128
129    /// Simple HTTP server for testing downloads.
130    async fn start_test_server(
131        content: Vec<u8>,
132        status: u16,
133    ) -> (SocketAddr, tokio::task::JoinHandle<()>) {
134        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
135        let addr = listener.local_addr().unwrap();
136
137        let handle = tokio::spawn(async move {
138            if let Ok((mut socket, _)) = listener.accept().await {
139                // Read the request (we don't care about the content)
140                let mut buf = [0u8; 1024];
141                let _ = socket.read(&mut buf).await;
142
143                // Send response
144                let response = format!(
145                    "HTTP/1.1 {} OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
146                    status,
147                    content.len()
148                );
149                let _ = socket.write_all(response.as_bytes()).await;
150                let _ = socket.write_all(&content).await;
151            }
152        });
153
154        (addr, handle)
155    }
156
157    #[tokio::test]
158    async fn test_download_artifact_success() {
159        // Create test content
160        let content = b"test artifact content for download verification";
161        let expected_sha256 = compute_sha256(content);
162
163        // Start test server
164        let (addr, server_handle) = start_test_server(content.to_vec(), 200).await;
165        let url = format!("http://{}/artifact.bca", addr);
166
167        // Create temp directory
168        let temp_dir = tempfile::tempdir().unwrap();
169
170        // Download artifact
171        let client = reqwest::Client::new();
172        let result = download_artifact(&client, &url, &expected_sha256, temp_dir.path()).await;
173
174        // Verify success
175        assert!(result.is_ok(), "download should succeed: {:?}", result);
176        let artifact_path = result.unwrap();
177        assert!(artifact_path.exists(), "artifact file should exist");
178
179        // Verify content
180        let downloaded_content = tokio::fs::read(&artifact_path).await.unwrap();
181        assert_eq!(downloaded_content, content);
182
183        // Cleanup
184        server_handle.abort();
185    }
186
187    #[tokio::test]
188    async fn test_download_artifact_checksum_mismatch() {
189        // Create test content
190        let content = b"test artifact content";
191        let wrong_sha256 = "0000000000000000000000000000000000000000000000000000000000000000";
192
193        // Start test server
194        let (addr, server_handle) = start_test_server(content.to_vec(), 200).await;
195        let url = format!("http://{}/artifact.bca", addr);
196
197        // Create temp directory
198        let temp_dir = tempfile::tempdir().unwrap();
199
200        // Download artifact
201        let client = reqwest::Client::new();
202        let result = download_artifact(&client, &url, wrong_sha256, temp_dir.path()).await;
203
204        // Verify failure
205        assert!(
206            result.is_err(),
207            "download should fail with checksum mismatch"
208        );
209        let error = result.unwrap_err();
210        assert!(
211            error.contains("checksum mismatch"),
212            "error should mention checksum: {}",
213            error
214        );
215
216        // Verify temp file was cleaned up
217        let files: Vec<_> = std::fs::read_dir(temp_dir.path()).unwrap().collect();
218        assert!(files.is_empty(), "temp file should be cleaned up");
219
220        // Cleanup
221        server_handle.abort();
222    }
223
224    #[tokio::test]
225    async fn test_download_artifact_http_error() {
226        // Start test server that returns 404
227        let (addr, server_handle) = start_test_server(vec![], 404).await;
228        let url = format!("http://{}/artifact.bca", addr);
229
230        // Create temp directory
231        let temp_dir = tempfile::tempdir().unwrap();
232
233        // Download artifact
234        let client = reqwest::Client::new();
235        let result = download_artifact(&client, &url, "dummy", temp_dir.path()).await;
236
237        // Verify failure
238        assert!(result.is_err(), "download should fail with HTTP error");
239        let error = result.unwrap_err();
240        assert!(
241            error.contains("404") || error.contains("status"),
242            "error should mention status: {}",
243            error
244        );
245
246        // Cleanup
247        server_handle.abort();
248    }
249
250    #[tokio::test]
251    async fn test_download_artifact_connection_refused() {
252        // Use a URL that will fail to connect
253        let url = "http://127.0.0.1:1/artifact.bca";
254
255        // Create temp directory
256        let temp_dir = tempfile::tempdir().unwrap();
257
258        // Download artifact
259        let client = reqwest::Client::new();
260        let result = download_artifact(&client, url, "dummy", temp_dir.path()).await;
261
262        // Verify failure
263        assert!(
264            result.is_err(),
265            "download should fail with connection error"
266        );
267        let error = result.unwrap_err();
268        assert!(
269            error.contains("download request failed"),
270            "error should mention request failed: {}",
271            error
272        );
273    }
274
275    #[tokio::test]
276    async fn test_hot_reload_lock_prevents_concurrent_reloads() {
277        // Acquire the lock
278        let guard = HOT_RELOAD_LOCK.try_lock();
279        assert!(guard.is_ok(), "first lock should succeed");
280
281        // Try to acquire again
282        let second_guard = HOT_RELOAD_LOCK.try_lock();
283        assert!(second_guard.is_err(), "second lock should fail");
284
285        // Drop first lock
286        drop(guard);
287
288        // Now should succeed
289        let third_guard = HOT_RELOAD_LOCK.try_lock();
290        assert!(third_guard.is_ok(), "third lock should succeed after drop");
291    }
292
293    #[test]
294    fn test_compute_sha256() {
295        let data = b"hello world";
296        let hash = compute_sha256(data);
297        // Known SHA256 of "hello world"
298        assert_eq!(
299            hash,
300            "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
301        );
302    }
303
304    #[test]
305    fn test_compute_sha256_empty() {
306        let data = b"";
307        let hash = compute_sha256(data);
308        // Known SHA256 of empty string
309        assert_eq!(
310            hash,
311            "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
312        );
313    }
314
315    #[test]
316    fn test_hot_reload_result_equality() {
317        let id = Uuid::new_v4();
318
319        let success1 = HotReloadResult::Success { artifact_id: id };
320        let success2 = HotReloadResult::Success { artifact_id: id };
321        assert_eq!(success1, success2);
322
323        let failed1 = HotReloadResult::Failed {
324            artifact_id: id,
325            error: "test error".to_string(),
326        };
327        let failed2 = HotReloadResult::Failed {
328            artifact_id: id,
329            error: "test error".to_string(),
330        };
331        assert_eq!(failed1, failed2);
332
333        assert_ne!(success1, failed1);
334    }
335}