barbacane_lib/
hot_reload.rs1use futures_util::StreamExt;
7use sha2::{Digest, Sha256};
8use std::path::{Path, PathBuf};
9use tokio::io::AsyncWriteExt;
10use uuid::Uuid;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum HotReloadResult {
15 Success { artifact_id: Uuid },
17 Failed { artifact_id: Uuid, error: String },
19}
20
21pub static HOT_RELOAD_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
23
24pub async fn download_artifact(
35 http_client: &reqwest::Client,
36 download_url: &str,
37 expected_sha256: &str,
38 artifact_dir: &Path,
39) -> Result<PathBuf, String> {
40 let temp_filename = format!("artifact-{}.bca.tmp", Uuid::new_v4());
42 let temp_path = artifact_dir.join(&temp_filename);
43
44 tracing::info!(
45 download_url = %download_url,
46 temp_path = %temp_path.display(),
47 "Downloading artifact"
48 );
49
50 let response = http_client
52 .get(download_url)
53 .send()
54 .await
55 .map_err(|e| format!("download request failed: {}", e))?;
56
57 if !response.status().is_success() {
58 return Err(format!(
59 "download failed with status: {}",
60 response.status()
61 ));
62 }
63
64 let mut file = tokio::fs::File::create(&temp_path)
66 .await
67 .map_err(|e| format!("failed to create temp file: {}", e))?;
68
69 let mut hasher = Sha256::new();
71 let mut stream = response.bytes_stream();
72 let mut total_bytes = 0u64;
73
74 while let Some(chunk_result) = stream.next().await {
75 let chunk = chunk_result.map_err(|e| format!("download stream error: {}", e))?;
76 hasher.update(&chunk);
77 file.write_all(&chunk)
78 .await
79 .map_err(|e| format!("write error: {}", e))?;
80 total_bytes += chunk.len() as u64;
81 }
82
83 file.flush()
84 .await
85 .map_err(|e| format!("flush error: {}", e))?;
86 drop(file);
87
88 let computed_sha256 = hex::encode(hasher.finalize());
90 if computed_sha256 != expected_sha256 {
91 let _ = tokio::fs::remove_file(&temp_path).await;
92 return Err(format!(
93 "checksum mismatch: expected {}, got {}",
94 expected_sha256, computed_sha256
95 ));
96 }
97
98 let final_filename = format!("artifact-{}.bca", Uuid::new_v4());
100 let final_path = artifact_dir.join(&final_filename);
101 tokio::fs::rename(&temp_path, &final_path)
102 .await
103 .map_err(|e| format!("rename failed: {}", e))?;
104
105 tracing::info!(
106 final_path = %final_path.display(),
107 size_bytes = total_bytes,
108 sha256 = %computed_sha256,
109 "Artifact downloaded and verified"
110 );
111
112 Ok(final_path)
113}
114
115pub fn compute_sha256(data: &[u8]) -> String {
117 let mut hasher = Sha256::new();
118 hasher.update(data);
119 hex::encode(hasher.finalize())
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use std::net::SocketAddr;
126 use tokio::io::AsyncReadExt;
127 use tokio::net::TcpListener;
128
129 async fn start_test_server(
131 content: Vec<u8>,
132 status: u16,
133 ) -> (SocketAddr, tokio::task::JoinHandle<()>) {
134 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
135 let addr = listener.local_addr().unwrap();
136
137 let handle = tokio::spawn(async move {
138 if let Ok((mut socket, _)) = listener.accept().await {
139 let mut buf = [0u8; 1024];
141 let _ = socket.read(&mut buf).await;
142
143 let response = format!(
145 "HTTP/1.1 {} OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
146 status,
147 content.len()
148 );
149 let _ = socket.write_all(response.as_bytes()).await;
150 let _ = socket.write_all(&content).await;
151 }
152 });
153
154 (addr, handle)
155 }
156
157 #[tokio::test]
158 async fn test_download_artifact_success() {
159 let content = b"test artifact content for download verification";
161 let expected_sha256 = compute_sha256(content);
162
163 let (addr, server_handle) = start_test_server(content.to_vec(), 200).await;
165 let url = format!("http://{}/artifact.bca", addr);
166
167 let temp_dir = tempfile::tempdir().unwrap();
169
170 let client = reqwest::Client::new();
172 let result = download_artifact(&client, &url, &expected_sha256, temp_dir.path()).await;
173
174 assert!(result.is_ok(), "download should succeed: {:?}", result);
176 let artifact_path = result.unwrap();
177 assert!(artifact_path.exists(), "artifact file should exist");
178
179 let downloaded_content = tokio::fs::read(&artifact_path).await.unwrap();
181 assert_eq!(downloaded_content, content);
182
183 server_handle.abort();
185 }
186
187 #[tokio::test]
188 async fn test_download_artifact_checksum_mismatch() {
189 let content = b"test artifact content";
191 let wrong_sha256 = "0000000000000000000000000000000000000000000000000000000000000000";
192
193 let (addr, server_handle) = start_test_server(content.to_vec(), 200).await;
195 let url = format!("http://{}/artifact.bca", addr);
196
197 let temp_dir = tempfile::tempdir().unwrap();
199
200 let client = reqwest::Client::new();
202 let result = download_artifact(&client, &url, wrong_sha256, temp_dir.path()).await;
203
204 assert!(
206 result.is_err(),
207 "download should fail with checksum mismatch"
208 );
209 let error = result.unwrap_err();
210 assert!(
211 error.contains("checksum mismatch"),
212 "error should mention checksum: {}",
213 error
214 );
215
216 let files: Vec<_> = std::fs::read_dir(temp_dir.path()).unwrap().collect();
218 assert!(files.is_empty(), "temp file should be cleaned up");
219
220 server_handle.abort();
222 }
223
224 #[tokio::test]
225 async fn test_download_artifact_http_error() {
226 let (addr, server_handle) = start_test_server(vec![], 404).await;
228 let url = format!("http://{}/artifact.bca", addr);
229
230 let temp_dir = tempfile::tempdir().unwrap();
232
233 let client = reqwest::Client::new();
235 let result = download_artifact(&client, &url, "dummy", temp_dir.path()).await;
236
237 assert!(result.is_err(), "download should fail with HTTP error");
239 let error = result.unwrap_err();
240 assert!(
241 error.contains("404") || error.contains("status"),
242 "error should mention status: {}",
243 error
244 );
245
246 server_handle.abort();
248 }
249
250 #[tokio::test]
251 async fn test_download_artifact_connection_refused() {
252 let url = "http://127.0.0.1:1/artifact.bca";
254
255 let temp_dir = tempfile::tempdir().unwrap();
257
258 let client = reqwest::Client::new();
260 let result = download_artifact(&client, url, "dummy", temp_dir.path()).await;
261
262 assert!(
264 result.is_err(),
265 "download should fail with connection error"
266 );
267 let error = result.unwrap_err();
268 assert!(
269 error.contains("download request failed"),
270 "error should mention request failed: {}",
271 error
272 );
273 }
274
275 #[tokio::test]
276 async fn test_hot_reload_lock_prevents_concurrent_reloads() {
277 let guard = HOT_RELOAD_LOCK.try_lock();
279 assert!(guard.is_ok(), "first lock should succeed");
280
281 let second_guard = HOT_RELOAD_LOCK.try_lock();
283 assert!(second_guard.is_err(), "second lock should fail");
284
285 drop(guard);
287
288 let third_guard = HOT_RELOAD_LOCK.try_lock();
290 assert!(third_guard.is_ok(), "third lock should succeed after drop");
291 }
292
293 #[test]
294 fn test_compute_sha256() {
295 let data = b"hello world";
296 let hash = compute_sha256(data);
297 assert_eq!(
299 hash,
300 "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
301 );
302 }
303
304 #[test]
305 fn test_compute_sha256_empty() {
306 let data = b"";
307 let hash = compute_sha256(data);
308 assert_eq!(
310 hash,
311 "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
312 );
313 }
314
315 #[test]
316 fn test_hot_reload_result_equality() {
317 let id = Uuid::new_v4();
318
319 let success1 = HotReloadResult::Success { artifact_id: id };
320 let success2 = HotReloadResult::Success { artifact_id: id };
321 assert_eq!(success1, success2);
322
323 let failed1 = HotReloadResult::Failed {
324 artifact_id: id,
325 error: "test error".to_string(),
326 };
327 let failed2 = HotReloadResult::Failed {
328 artifact_id: id,
329 error: "test error".to_string(),
330 };
331 assert_eq!(failed1, failed2);
332
333 assert_ne!(success1, failed1);
334 }
335}