brainwires_network/remote/
attachments.rs1use std::collections::HashMap;
10use std::io::{Read, Write};
11use std::path::PathBuf;
12use std::sync::Arc;
13
14use anyhow::{Context, Result, bail};
15use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
16use sha2::{Digest, Sha256};
17use tokio::sync::RwLock;
18
19use super::protocol::CompressionAlgorithm;
20
21pub const ATTACHMENT_CHUNK_SIZE: usize = 64 * 1024;
23
24pub const COMPRESSION_THRESHOLD: usize = 10 * 1024;
26
27pub const MAX_ATTACHMENT_SIZE: u64 = 100 * 1024 * 1024;
29
30#[derive(Debug)]
32#[allow(dead_code)]
33struct PendingAttachment {
34 id: String,
36 filename: String,
38 mime_type: String,
40 expected_size: u64,
42 compressed: bool,
44 compression_algorithm: Option<CompressionAlgorithm>,
46 chunks_total: u32,
48 chunks: HashMap<u32, Vec<u8>>,
50 bytes_received: usize,
52 agent_id: String,
54 command_id: String,
56}
57
58#[derive(Clone)]
60pub struct AttachmentReceiver {
61 pending: Arc<RwLock<HashMap<String, PendingAttachment>>>,
63 output_dir: PathBuf,
65}
66
67impl AttachmentReceiver {
68 pub fn new(output_dir: PathBuf) -> Self {
70 Self {
71 pending: Arc::new(RwLock::new(HashMap::new())),
72 output_dir,
73 }
74 }
75
76 #[allow(clippy::too_many_arguments)]
78 pub async fn start_upload(
79 &self,
80 command_id: String,
81 agent_id: String,
82 attachment_id: String,
83 filename: String,
84 mime_type: String,
85 size: u64,
86 compressed: bool,
87 compression_algorithm: Option<CompressionAlgorithm>,
88 chunks_total: u32,
89 ) -> Result<()> {
90 if size > MAX_ATTACHMENT_SIZE {
92 bail!(
93 "Attachment too large: {} bytes (max: {} bytes)",
94 size,
95 MAX_ATTACHMENT_SIZE
96 );
97 }
98
99 let pending = PendingAttachment {
100 id: attachment_id.clone(),
101 filename,
102 mime_type,
103 expected_size: size,
104 compressed,
105 compression_algorithm,
106 chunks_total,
107 chunks: HashMap::new(),
108 bytes_received: 0,
109 agent_id,
110 command_id,
111 };
112
113 let mut pending_map = self.pending.write().await;
114 pending_map.insert(attachment_id, pending);
115
116 Ok(())
117 }
118
119 pub async fn receive_chunk(
121 &self,
122 attachment_id: &str,
123 chunk_index: u32,
124 data: &str,
125 is_final: bool,
126 ) -> Result<bool> {
127 let decoded = BASE64
129 .decode(data)
130 .context("Failed to decode base64 chunk data")?;
131
132 let mut pending_map = self.pending.write().await;
133 let pending = pending_map
134 .get_mut(attachment_id)
135 .context("Unknown attachment ID")?;
136
137 if chunk_index >= pending.chunks_total {
139 bail!(
140 "Invalid chunk index: {} (expected 0-{})",
141 chunk_index,
142 pending.chunks_total - 1
143 );
144 }
145
146 pending.bytes_received += decoded.len();
148 pending.chunks.insert(chunk_index, decoded);
149
150 let all_received = pending.chunks.len() == pending.chunks_total as usize;
152
153 if is_final && !all_received {
154 tracing::warn!(
155 "Final chunk received but only have {}/{} chunks",
156 pending.chunks.len(),
157 pending.chunks_total
158 );
159 }
160
161 Ok(all_received)
162 }
163
164 pub async fn complete_upload(
166 &self,
167 attachment_id: &str,
168 expected_checksum: &str,
169 ) -> Result<PathBuf> {
170 let pending = {
171 let mut pending_map = self.pending.write().await;
172 pending_map
173 .remove(attachment_id)
174 .context("Unknown attachment ID")?
175 };
176
177 let mut assembled = Vec::with_capacity(pending.bytes_received);
179 for i in 0..pending.chunks_total {
180 let chunk = pending
181 .chunks
182 .get(&i)
183 .context(format!("Missing chunk {}", i))?;
184 assembled.extend_from_slice(chunk);
185 }
186
187 let data = if pending.compressed {
189 decompress(&assembled, pending.compression_algorithm)?
190 } else {
191 assembled
192 };
193
194 let mut hasher = Sha256::new();
196 hasher.update(&data);
197 let actual_checksum = format!("{:x}", hasher.finalize());
198
199 if actual_checksum != expected_checksum {
200 bail!(
201 "Checksum mismatch: expected {}, got {}",
202 expected_checksum,
203 actual_checksum
204 );
205 }
206
207 std::fs::create_dir_all(&self.output_dir)
209 .context("Failed to create attachment output directory")?;
210
211 let safe_filename = sanitize_filename(&pending.filename);
213 let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
214 let output_path = self
215 .output_dir
216 .join(format!("{}_{}", timestamp, safe_filename));
217
218 let mut file =
220 std::fs::File::create(&output_path).context("Failed to create attachment file")?;
221 file.write_all(&data)
222 .context("Failed to write attachment data")?;
223
224 tracing::info!(
225 "Attachment saved: {} ({} bytes, {})",
226 output_path.display(),
227 data.len(),
228 pending.mime_type
229 );
230
231 Ok(output_path)
232 }
233
234 pub async fn cancel_upload(&self, attachment_id: &str) {
236 let mut pending_map = self.pending.write().await;
237 if pending_map.remove(attachment_id).is_some() {
238 tracing::info!("Cancelled attachment upload: {}", attachment_id);
239 }
240 }
241
242 pub async fn get_status(&self, attachment_id: &str) -> Option<(u32, u32, usize)> {
244 let pending_map = self.pending.read().await;
245 pending_map
246 .get(attachment_id)
247 .map(|p| (p.chunks.len() as u32, p.chunks_total, p.bytes_received))
248 }
249}
250
251fn decompress(data: &[u8], algorithm: Option<CompressionAlgorithm>) -> Result<Vec<u8>> {
253 match algorithm {
254 Some(CompressionAlgorithm::Zstd) => {
255 zstd::decode_all(data).context("Failed to decompress zstd data")
256 }
257 Some(CompressionAlgorithm::Gzip) => {
258 let mut decoder = flate2::read::GzDecoder::new(data);
259 let mut decompressed = Vec::new();
260 decoder
261 .read_to_end(&mut decompressed)
262 .context("Failed to decompress gzip data")?;
263 Ok(decompressed)
264 }
265 None => Ok(data.to_vec()),
266 }
267}
268
269fn sanitize_filename(filename: &str) -> String {
271 let name = std::path::Path::new(filename)
273 .file_name()
274 .and_then(|n| n.to_str())
275 .unwrap_or("attachment");
276
277 name.chars()
279 .map(|c| match c {
280 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
281 c if c.is_ascii_control() => '_',
282 c => c,
283 })
284 .collect()
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_sanitize_filename() {
293 assert_eq!(sanitize_filename("file.txt"), "file.txt");
294 assert_eq!(sanitize_filename("path/to/file.txt"), "file.txt");
295 assert_eq!(sanitize_filename("file:name.txt"), "file_name.txt");
296 assert_eq!(sanitize_filename("file<>|name.txt"), "file___name.txt");
297 }
298
299 #[tokio::test]
300 async fn test_attachment_receiver() {
301 let temp_dir = tempfile::tempdir().unwrap();
302 let receiver = AttachmentReceiver::new(temp_dir.path().to_path_buf());
303
304 receiver
306 .start_upload(
307 "cmd-1".to_string(),
308 "agent-1".to_string(),
309 "attach-1".to_string(),
310 "test.txt".to_string(),
311 "text/plain".to_string(),
312 13,
313 false,
314 None,
315 1,
316 )
317 .await
318 .unwrap();
319
320 let data = BASE64.encode(b"Hello, World!");
322 let all_received = receiver
323 .receive_chunk("attach-1", 0, &data, true)
324 .await
325 .unwrap();
326 assert!(all_received);
327
328 let mut hasher = Sha256::new();
330 hasher.update(b"Hello, World!");
331 let checksum = format!("{:x}", hasher.finalize());
332
333 let path = receiver
335 .complete_upload("attach-1", &checksum)
336 .await
337 .unwrap();
338 assert!(path.exists());
339
340 let content = std::fs::read_to_string(&path).unwrap();
341 assert_eq!(content, "Hello, World!");
342 }
343}