Skip to main content

brainwires_network/remote/
attachments.rs

1//! Attachment handling for remote file uploads
2//!
3//! Manages chunked file uploads from the web UI, including:
4//! - Reassembly of chunks
5//! - Decompression (zstd/gzip)
6//! - Checksum verification
7//! - Temporary file storage
8
9use std::collections::HashMap;
10use std::io::{Read, Write};
11use std::path::PathBuf;
12use std::sync::Arc;
13
14use anyhow::{Context, Result, bail};
15use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
16use sha2::{Digest, Sha256};
17use tokio::sync::RwLock;
18
19use super::protocol::CompressionAlgorithm;
20
21/// Default chunk size for attachments (64KB)
22pub const ATTACHMENT_CHUNK_SIZE: usize = 64 * 1024;
23
24/// Compression threshold - files larger than this are compressed (10KB)
25pub const COMPRESSION_THRESHOLD: usize = 10 * 1024;
26
27/// Maximum attachment size (100MB)
28pub const MAX_ATTACHMENT_SIZE: u64 = 100 * 1024 * 1024;
29
30/// State of an in-progress attachment upload
31#[derive(Debug)]
32#[allow(dead_code)]
33struct PendingAttachment {
34    /// Unique attachment ID
35    id: String,
36    /// Original filename
37    filename: String,
38    /// MIME type
39    mime_type: String,
40    /// Expected total size (uncompressed)
41    expected_size: u64,
42    /// Whether data is compressed
43    compressed: bool,
44    /// Compression algorithm (if compressed)
45    compression_algorithm: Option<CompressionAlgorithm>,
46    /// Expected number of chunks
47    chunks_total: u32,
48    /// Received chunks (index -> data)
49    chunks: HashMap<u32, Vec<u8>>,
50    /// Total bytes received so far
51    bytes_received: usize,
52    /// Associated agent ID
53    agent_id: String,
54    /// Command ID for response
55    command_id: String,
56}
57
58/// Manages attachment uploads
59#[derive(Clone)]
60pub struct AttachmentReceiver {
61    /// Pending attachments by ID
62    pending: Arc<RwLock<HashMap<String, PendingAttachment>>>,
63    /// Directory to store received attachments
64    output_dir: PathBuf,
65}
66
67impl AttachmentReceiver {
68    /// Create a new attachment receiver
69    pub fn new(output_dir: PathBuf) -> Self {
70        Self {
71            pending: Arc::new(RwLock::new(HashMap::new())),
72            output_dir,
73        }
74    }
75
76    /// Start receiving a new attachment
77    #[allow(clippy::too_many_arguments)]
78    pub async fn start_upload(
79        &self,
80        command_id: String,
81        agent_id: String,
82        attachment_id: String,
83        filename: String,
84        mime_type: String,
85        size: u64,
86        compressed: bool,
87        compression_algorithm: Option<CompressionAlgorithm>,
88        chunks_total: u32,
89    ) -> Result<()> {
90        // Validate size
91        if size > MAX_ATTACHMENT_SIZE {
92            bail!(
93                "Attachment too large: {} bytes (max: {} bytes)",
94                size,
95                MAX_ATTACHMENT_SIZE
96            );
97        }
98
99        let pending = PendingAttachment {
100            id: attachment_id.clone(),
101            filename,
102            mime_type,
103            expected_size: size,
104            compressed,
105            compression_algorithm,
106            chunks_total,
107            chunks: HashMap::new(),
108            bytes_received: 0,
109            agent_id,
110            command_id,
111        };
112
113        let mut pending_map = self.pending.write().await;
114        pending_map.insert(attachment_id, pending);
115
116        Ok(())
117    }
118
119    /// Receive a chunk of attachment data
120    pub async fn receive_chunk(
121        &self,
122        attachment_id: &str,
123        chunk_index: u32,
124        data: &str,
125        is_final: bool,
126    ) -> Result<bool> {
127        // Decode base64 data
128        let decoded = BASE64
129            .decode(data)
130            .context("Failed to decode base64 chunk data")?;
131
132        let mut pending_map = self.pending.write().await;
133        let pending = pending_map
134            .get_mut(attachment_id)
135            .context("Unknown attachment ID")?;
136
137        // Validate chunk index
138        if chunk_index >= pending.chunks_total {
139            bail!(
140                "Invalid chunk index: {} (expected 0-{})",
141                chunk_index,
142                pending.chunks_total - 1
143            );
144        }
145
146        // Store chunk
147        pending.bytes_received += decoded.len();
148        pending.chunks.insert(chunk_index, decoded);
149
150        // Check if we have all chunks
151        let all_received = pending.chunks.len() == pending.chunks_total as usize;
152
153        if is_final && !all_received {
154            tracing::warn!(
155                "Final chunk received but only have {}/{} chunks",
156                pending.chunks.len(),
157                pending.chunks_total
158            );
159        }
160
161        Ok(all_received)
162    }
163
164    /// Complete the attachment upload, verify checksum, and save to disk
165    pub async fn complete_upload(
166        &self,
167        attachment_id: &str,
168        expected_checksum: &str,
169    ) -> Result<PathBuf> {
170        let pending = {
171            let mut pending_map = self.pending.write().await;
172            pending_map
173                .remove(attachment_id)
174                .context("Unknown attachment ID")?
175        };
176
177        // Reassemble chunks in order
178        let mut assembled = Vec::with_capacity(pending.bytes_received);
179        for i in 0..pending.chunks_total {
180            let chunk = pending
181                .chunks
182                .get(&i)
183                .context(format!("Missing chunk {}", i))?;
184            assembled.extend_from_slice(chunk);
185        }
186
187        // Decompress if needed
188        let data = if pending.compressed {
189            decompress(&assembled, pending.compression_algorithm)?
190        } else {
191            assembled
192        };
193
194        // Verify checksum
195        let mut hasher = Sha256::new();
196        hasher.update(&data);
197        let actual_checksum = format!("{:x}", hasher.finalize());
198
199        if actual_checksum != expected_checksum {
200            bail!(
201                "Checksum mismatch: expected {}, got {}",
202                expected_checksum,
203                actual_checksum
204            );
205        }
206
207        // Ensure output directory exists
208        std::fs::create_dir_all(&self.output_dir)
209            .context("Failed to create attachment output directory")?;
210
211        // Generate unique filename
212        let safe_filename = sanitize_filename(&pending.filename);
213        let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
214        let output_path = self
215            .output_dir
216            .join(format!("{}_{}", timestamp, safe_filename));
217
218        // Write to file
219        let mut file =
220            std::fs::File::create(&output_path).context("Failed to create attachment file")?;
221        file.write_all(&data)
222            .context("Failed to write attachment data")?;
223
224        tracing::info!(
225            "Attachment saved: {} ({} bytes, {})",
226            output_path.display(),
227            data.len(),
228            pending.mime_type
229        );
230
231        Ok(output_path)
232    }
233
234    /// Cancel a pending upload
235    pub async fn cancel_upload(&self, attachment_id: &str) {
236        let mut pending_map = self.pending.write().await;
237        if pending_map.remove(attachment_id).is_some() {
238            tracing::info!("Cancelled attachment upload: {}", attachment_id);
239        }
240    }
241
242    /// Get status of a pending upload
243    pub async fn get_status(&self, attachment_id: &str) -> Option<(u32, u32, usize)> {
244        let pending_map = self.pending.read().await;
245        pending_map
246            .get(attachment_id)
247            .map(|p| (p.chunks.len() as u32, p.chunks_total, p.bytes_received))
248    }
249}
250
251/// Decompress data using the specified algorithm
252fn decompress(data: &[u8], algorithm: Option<CompressionAlgorithm>) -> Result<Vec<u8>> {
253    match algorithm {
254        Some(CompressionAlgorithm::Zstd) => {
255            zstd::decode_all(data).context("Failed to decompress zstd data")
256        }
257        Some(CompressionAlgorithm::Gzip) => {
258            let mut decoder = flate2::read::GzDecoder::new(data);
259            let mut decompressed = Vec::new();
260            decoder
261                .read_to_end(&mut decompressed)
262                .context("Failed to decompress gzip data")?;
263            Ok(decompressed)
264        }
265        None => Ok(data.to_vec()),
266    }
267}
268
269/// Sanitize a filename to prevent path traversal attacks
270fn sanitize_filename(filename: &str) -> String {
271    // Take only the file name (not path)
272    let name = std::path::Path::new(filename)
273        .file_name()
274        .and_then(|n| n.to_str())
275        .unwrap_or("attachment");
276
277    // Replace problematic characters
278    name.chars()
279        .map(|c| match c {
280            '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
281            c if c.is_ascii_control() => '_',
282            c => c,
283        })
284        .collect()
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_sanitize_filename() {
293        assert_eq!(sanitize_filename("file.txt"), "file.txt");
294        assert_eq!(sanitize_filename("path/to/file.txt"), "file.txt");
295        assert_eq!(sanitize_filename("file:name.txt"), "file_name.txt");
296        assert_eq!(sanitize_filename("file<>|name.txt"), "file___name.txt");
297    }
298
299    #[tokio::test]
300    async fn test_attachment_receiver() {
301        let temp_dir = tempfile::tempdir().unwrap();
302        let receiver = AttachmentReceiver::new(temp_dir.path().to_path_buf());
303
304        // Start upload
305        receiver
306            .start_upload(
307                "cmd-1".to_string(),
308                "agent-1".to_string(),
309                "attach-1".to_string(),
310                "test.txt".to_string(),
311                "text/plain".to_string(),
312                13,
313                false,
314                None,
315                1,
316            )
317            .await
318            .unwrap();
319
320        // Send chunk (base64 of "Hello, World!")
321        let data = BASE64.encode(b"Hello, World!");
322        let all_received = receiver
323            .receive_chunk("attach-1", 0, &data, true)
324            .await
325            .unwrap();
326        assert!(all_received);
327
328        // Calculate expected checksum
329        let mut hasher = Sha256::new();
330        hasher.update(b"Hello, World!");
331        let checksum = format!("{:x}", hasher.finalize());
332
333        // Complete upload
334        let path = receiver
335            .complete_upload("attach-1", &checksum)
336            .await
337            .unwrap();
338        assert!(path.exists());
339
340        let content = std::fs::read_to_string(&path).unwrap();
341        assert_eq!(content, "Hello, World!");
342    }
343}