Skip to main content

ref_solver/utils/
validation.rs

1//! Centralized validation and helper functions.
2
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use base64::Engine;
5use sha2::{Digest, Sha512};
6
7use crate::web::format_detection::FileFormat;
8use std::collections::HashSet;
9
10/// Maximum number of contigs allowed in a single file (DOS protection)
11pub const MAX_CONTIGS: usize = 100_000;
12
13/// Security-related constants for input validation
14pub const MAX_FILENAME_LENGTH: usize = 255;
15pub const MIN_FILE_CONTENT_SIZE: usize = 1;
16
17/// Validate that a string is a valid MD5 checksum (32 hex characters).
18///
19/// # Examples
20///
21/// ```
22/// use ref_solver::utils::validation::is_valid_md5;
23///
24/// assert!(is_valid_md5("6aef897c3d6ff0c78aff06ac189178dd"));
25/// assert!(!is_valid_md5("not-an-md5"));
26/// assert!(!is_valid_md5("6aef897c3d6ff0c78aff06ac189178d")); // 31 chars
27/// ```
28#[must_use]
29pub fn is_valid_md5(s: &str) -> bool {
30    s.len() == 32 && s.chars().all(|c| c.is_ascii_hexdigit())
31}
32
33/// Normalize an MD5 string to lowercase.
34/// Returns None if the input is not a valid MD5.
35#[must_use]
36pub fn normalize_md5(s: &str) -> Option<String> {
37    if is_valid_md5(s) {
38        Some(s.to_lowercase())
39    } else {
40        None
41    }
42}
43
44/// Compute the GA4GH sha512t24u digest for a sequence.
45///
46/// Algorithm: SHA-512 the sequence bytes, truncate to the first 24 bytes,
47/// then base64url-encode without padding, producing a 32-character string.
48///
49/// The input sequence should already be uppercased.
50///
51/// # Examples
52///
53/// ```
54/// use ref_solver::utils::validation::compute_sha512t24u;
55///
56/// let digest = compute_sha512t24u(b"ACGT");
57/// assert_eq!(digest.len(), 32);
58/// ```
59#[must_use]
60pub fn compute_sha512t24u(sequence: &[u8]) -> String {
61    let hash = Sha512::digest(sequence);
62    URL_SAFE_NO_PAD.encode(&hash[..24])
63}
64
65/// Validate that a string is a valid sha512t24u digest (32 chars, base64url alphabet).
66///
67/// # Examples
68///
69/// ```
70/// use ref_solver::utils::validation::is_valid_sha512t24u;
71///
72/// assert!(is_valid_sha512t24u("aKF498dAxcJAqme6QYQ7EZ07-fiw8Kw2"));
73/// assert!(!is_valid_sha512t24u("too-short"));
74/// ```
75#[must_use]
76pub fn is_valid_sha512t24u(s: &str) -> bool {
77    s.len() == 32
78        && s.bytes()
79            .all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_')
80}
81
82/// Compute a signature hash from a set of MD5 checksums.
83///
84/// The signature is computed by:
85/// 1. Sorting the MD5s alphabetically
86/// 2. Joining them with commas
87/// 3. Computing MD5 of the concatenated string
88///
89/// This provides a deterministic identifier for a set of contigs.
90#[must_use]
91#[allow(clippy::implicit_hasher)] // Default hasher is fine for this use case
92pub fn compute_signature(md5s: &HashSet<String>) -> String {
93    if md5s.is_empty() {
94        return String::new();
95    }
96
97    let mut sorted: Vec<&str> = md5s.iter().map(std::string::String::as_str).collect();
98    sorted.sort_unstable();
99    let concatenated = sorted.join(",");
100    let digest = md5::compute(concatenated.as_bytes());
101    format!("{digest:x}")
102}
103
104/// Check if adding another contig would exceed the maximum allowed.
105///
106/// Call this with the current count BEFORE adding a new contig.
107/// Returns an error message if adding would exceed the limit, None if safe to add.
108///
109/// # Example
110/// ```ignore
111/// if check_contig_limit(contigs.len()).is_some() {
112///     return Err(...);
113/// }
114/// contigs.push(new_contig); // Safe to add
115/// ```
116#[must_use]
117pub fn check_contig_limit(count: usize) -> Option<String> {
118    if count >= MAX_CONTIGS {
119        Some(format!(
120            "Too many contigs: adding another would exceed maximum of {MAX_CONTIGS}"
121        ))
122    } else {
123        None
124    }
125}
126
127/// Security validation error types
128#[derive(Debug, thiserror::Error)]
129pub enum ValidationError {
130    #[error("Filename too long: exceeds {MAX_FILENAME_LENGTH} characters")]
131    FilenameTooLong,
132    #[error("Invalid filename: contains path traversal or invalid characters")]
133    InvalidFilename,
134    #[error("Empty filename provided")]
135    EmptyFilename,
136    #[error("File content appears malformed or invalid")]
137    InvalidFileContent,
138    #[error("File format validation failed")]
139    FormatValidationFailed,
140}
141
142/// Secure filename validation to prevent directory traversal and other attacks
143///
144/// Validates and sanitizes filenames by:
145/// - Checking length limits
146/// - Preventing directory traversal (../, ..\\)
147/// - Removing potentially dangerous characters
148/// - Ensuring filename is not empty after sanitization
149///
150/// # Errors
151///
152/// Returns `ValidationError::EmptyFilename` if the filename is empty,
153/// `ValidationError::FilenameTooLong` if it exceeds the limit, or
154/// `ValidationError::InvalidFilename` if it contains invalid characters.
155pub fn validate_filename(filename: &str) -> Result<String, ValidationError> {
156    // Check if filename is empty
157    if filename.trim().is_empty() {
158        return Err(ValidationError::EmptyFilename);
159    }
160
161    // Check length limit
162    if filename.len() > MAX_FILENAME_LENGTH {
163        return Err(ValidationError::FilenameTooLong);
164    }
165
166    // Prevent directory traversal attacks
167    if filename.contains("..") || filename.contains('/') || filename.contains('\\') {
168        return Err(ValidationError::InvalidFilename);
169    }
170
171    // Check for null bytes and other dangerous characters
172    if filename.contains('\0') || filename.chars().any(|c| ('\x01'..='\x1F').contains(&c)) {
173        return Err(ValidationError::InvalidFilename);
174    }
175
176    // Sanitize filename by keeping only safe characters
177    let sanitized = filename
178        .chars()
179        .filter(|c| c.is_ascii_alphanumeric() || *c == '.' || *c == '-' || *c == '_' || *c == ' ')
180        .collect::<String>();
181
182    // Ensure sanitized filename is not empty
183    if sanitized.trim().is_empty() {
184        return Err(ValidationError::InvalidFilename);
185    }
186
187    // Prevent hidden files (starting with .) unless it's a known extension
188    if sanitized.starts_with('.') && !has_known_extension(&sanitized) {
189        return Err(ValidationError::InvalidFilename);
190    }
191
192    Ok(sanitized)
193}
194
195/// Check if filename has a known safe extension
196fn has_known_extension(filename: &str) -> bool {
197    let safe_extensions = [
198        ".sam",
199        ".bam",
200        ".cram",
201        ".dict",
202        ".vcf",
203        ".txt",
204        ".tsv",
205        ".csv",
206        ".gz",
207        ".assembly_report.txt",
208    ];
209
210    safe_extensions
211        .iter()
212        .any(|ext| filename.to_lowercase().ends_with(ext))
213}
214
215/// Validate file content using magic numbers for known binary formats
216///
217/// Performs format validation by checking magic numbers (file signatures)
218/// to prevent format confusion attacks and ensure file integrity
219#[must_use]
220pub fn validate_file_format(content: &[u8], expected_format: FileFormat) -> bool {
221    if content.is_empty() {
222        return false;
223    }
224
225    match expected_format {
226        FileFormat::Bam => {
227            // BAM files start with "BAM\x01"
228            content.len() >= 4 && content.starts_with(b"BAM\x01")
229        }
230        FileFormat::Cram => {
231            // CRAM files start with "CRAM"
232            content.len() >= 4 && content.starts_with(b"CRAM")
233        }
234        FileFormat::Vcf => {
235            // VCF files should start with "##fileformat=VCF"
236            let content_str = std::str::from_utf8(content).unwrap_or("");
237            content_str.starts_with("##fileformat=VCF")
238        }
239        FileFormat::Sam => {
240            // SAM files are text-based, check for header indicators
241            let content_str = std::str::from_utf8(content).unwrap_or("");
242            content_str.contains("@SQ")
243                || content_str.contains("@HD")
244                || content_str.contains("SN:")
245                || content_str.contains("LN:")
246        }
247        FileFormat::Dict => {
248            // Picard dictionary files have @HD and @SQ headers
249            let content_str = std::str::from_utf8(content).unwrap_or("");
250            content_str.contains("@HD") && content_str.contains("@SQ")
251        }
252        FileFormat::NcbiReport => {
253            // NCBI assembly reports have specific column headers
254            let content_str = std::str::from_utf8(content).unwrap_or("");
255            content_str.contains("Sequence-Name") || content_str.contains("Sequence-Role")
256        }
257        FileFormat::Tsv => {
258            // TSV files should have tab-separated content
259            let content_str = std::str::from_utf8(content).unwrap_or("");
260            content_str.contains('\t')
261                && (content_str.to_lowercase().contains("length")
262                    || content_str.to_lowercase().contains("sequence")
263                    || content_str.to_lowercase().contains("size"))
264        }
265        FileFormat::Fai => {
266            // FAI files have 5 tab-separated columns per line
267            let content_str = std::str::from_utf8(content).unwrap_or("");
268            content_str.lines().take(5).any(|line| {
269                let fields: Vec<&str> = line.split('\t').collect();
270                fields.len() == 5 && fields[1..].iter().all(|f| f.parse::<u64>().is_ok())
271            })
272        }
273        FileFormat::Fasta => {
274            // FASTA files start with '>' or are gzip compressed (0x1f 0x8b)
275            content.starts_with(b">")
276                || (content.len() >= 2 && content[0] == 0x1f && content[1] == 0x8b)
277        }
278        FileFormat::Auto => {
279            // Auto-detection always passes initial validation
280            true
281        }
282    }
283}
284
285/// Validate that file content is not malicious or malformed
286///
287/// Basic security checks for file content integrity:
288/// - Minimum size requirements
289/// - Binary content detection for text formats
290/// - Basic malformation checks
291///
292/// # Errors
293///
294/// Returns `ValidationError::InvalidFileContent` if the content is too small,
295/// contains unexpected binary data for text formats, or fails UTF-8 validation.
296pub fn validate_file_content(content: &[u8], expected_text: bool) -> Result<(), ValidationError> {
297    // Check minimum content size
298    if content.len() < MIN_FILE_CONTENT_SIZE {
299        return Err(ValidationError::InvalidFileContent);
300    }
301
302    // If we expect text content, validate it's not binary
303    if expected_text {
304        // Check for excessive non-printable characters
305        let non_printable_count = content
306            .iter()
307            .filter(|&&b| b < 9 || (b > 13 && b < 32) || b > 126)
308            .count();
309
310        // Allow up to 5% non-printable characters for text files
311        if content.len() > 100 && non_printable_count > content.len() / 20 {
312            return Err(ValidationError::InvalidFileContent);
313        }
314
315        // Basic UTF-8 validation for text content
316        if std::str::from_utf8(content).is_err() {
317            return Err(ValidationError::InvalidFileContent);
318        }
319    }
320
321    Ok(())
322}
323
324/// Comprehensive input validation combining filename and content checks
325///
326/// Performs complete security validation for file uploads:
327/// - Filename sanitization and security checks
328/// - File format validation via magic numbers
329/// - Content integrity validation
330///
331/// # Errors
332///
333/// Returns a `ValidationError` if filename validation fails, the file format
334/// doesn't match the expected format, or content validation fails.
335pub fn validate_upload(
336    filename: Option<&str>,
337    content: &[u8],
338    expected_format: FileFormat,
339) -> Result<Option<String>, ValidationError> {
340    // Validate filename if provided
341    let validated_filename = if let Some(name) = filename {
342        Some(validate_filename(name)?)
343    } else {
344        None
345    };
346
347    // Validate content integrity
348    let is_text_format = matches!(
349        expected_format,
350        FileFormat::Sam
351            | FileFormat::Dict
352            | FileFormat::Vcf
353            | FileFormat::NcbiReport
354            | FileFormat::Tsv
355            | FileFormat::Auto
356    );
357
358    validate_file_content(content, is_text_format)?;
359
360    // Validate file format - even for auto-detection, check for obvious mismatches
361    if expected_format == FileFormat::Auto {
362        // For auto-detection, at least verify it's not a malformed binary file
363        // Check if it looks like a known binary format but is malformed
364        if content.len() >= 4 {
365            let starts_with_bam = content.starts_with(b"BAM");
366            let starts_with_cram = content.starts_with(b"CRAM");
367
368            // If it looks like it should be BAM/CRAM but isn't valid, reject it
369            if starts_with_bam && !validate_file_format(content, FileFormat::Bam) {
370                return Err(ValidationError::FormatValidationFailed);
371            }
372            if starts_with_cram && !validate_file_format(content, FileFormat::Cram) {
373                return Err(ValidationError::FormatValidationFailed);
374            }
375        }
376    } else if !validate_file_format(content, expected_format) {
377        return Err(ValidationError::FormatValidationFailed);
378    }
379
380    Ok(validated_filename)
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn test_is_valid_md5() {
389        assert!(is_valid_md5("6aef897c3d6ff0c78aff06ac189178dd"));
390        assert!(is_valid_md5("AABBCCDD11223344556677889900AABB")); // uppercase ok
391        assert!(!is_valid_md5("not-an-md5"));
392        assert!(!is_valid_md5("6aef897c3d6ff0c78aff06ac189178d")); // 31 chars
393        assert!(!is_valid_md5("6aef897c3d6ff0c78aff06ac189178ddd")); // 33 chars
394        assert!(!is_valid_md5("")); // empty
395        assert!(!is_valid_md5("6aef897c3d6ff0c78aff06ac189178dg")); // invalid char
396    }
397
398    #[test]
399    fn test_compute_sha512t24u() {
400        // "ACGT" -> known sha512t24u digest
401        let digest = compute_sha512t24u(b"ACGT");
402        assert_eq!(digest.len(), 32);
403        assert!(is_valid_sha512t24u(&digest));
404
405        // Deterministic: same input -> same output
406        assert_eq!(digest, compute_sha512t24u(b"ACGT"));
407
408        // Different input -> different output
409        assert_ne!(digest, compute_sha512t24u(b"TGCA"));
410
411        // Verify against known value (SHA-512 of "ACGT", truncated to 24 bytes, base64url no-pad)
412        assert_eq!(digest, "aKF498dAxcJAqme6QYQ7EZ07-fiw8Kw2");
413    }
414
415    #[test]
416    fn test_is_valid_sha512t24u() {
417        assert!(is_valid_sha512t24u("aKF498dAxcJAqme6QYQ7EZ07-fiw8Kw2"));
418        assert!(!is_valid_sha512t24u("too-short"));
419        assert!(!is_valid_sha512t24u(""));
420        // 33 chars - too long
421        assert!(!is_valid_sha512t24u("aKF498dAxcJAqme6QYQ7EZ07-fiw8Kw2X"));
422        // Invalid character (space)
423        assert!(!is_valid_sha512t24u("aKF498dAxcJAqme6QYQ7EZ07-fiw8Kw "));
424    }
425
426    #[test]
427    fn test_normalize_md5() {
428        assert_eq!(
429            normalize_md5("6AEF897C3D6FF0C78AFF06AC189178DD"),
430            Some("6aef897c3d6ff0c78aff06ac189178dd".to_string())
431        );
432        assert_eq!(normalize_md5("invalid"), None);
433    }
434
435    #[test]
436    fn test_compute_signature() {
437        let mut md5s = HashSet::new();
438        md5s.insert("aaaa".repeat(8)); // fake MD5
439        md5s.insert("bbbb".repeat(8));
440
441        let sig = compute_signature(&md5s);
442        assert_eq!(sig.len(), 32);
443
444        // Same input should give same output
445        let sig2 = compute_signature(&md5s);
446        assert_eq!(sig, sig2);
447
448        // Empty set gives empty string
449        let empty: HashSet<String> = HashSet::new();
450        assert_eq!(compute_signature(&empty), "");
451    }
452
453    #[test]
454    fn test_check_contig_limit() {
455        assert!(check_contig_limit(100).is_none());
456        assert!(check_contig_limit(MAX_CONTIGS - 1).is_none());
457        assert!(check_contig_limit(MAX_CONTIGS).is_some());
458        assert!(check_contig_limit(MAX_CONTIGS + 1).is_some());
459    }
460
461    // Security validation tests
462    #[test]
463    fn test_validate_filename_safe() {
464        assert!(validate_filename("test.sam").is_ok());
465        assert!(validate_filename("my-file.bam").is_ok());
466        assert!(validate_filename("data_file.txt").is_ok());
467        assert!(validate_filename("sample 123.vcf").is_ok());
468    }
469
470    #[test]
471    fn test_validate_filename_dangerous() {
472        // Directory traversal attempts
473        assert!(validate_filename("../etc/passwd").is_err());
474        assert!(validate_filename("..\\windows\\system32").is_err());
475        assert!(validate_filename("test/../../secret").is_err());
476
477        // Null bytes and control characters
478        assert!(validate_filename("test\0.txt").is_err());
479        assert!(validate_filename("test\x01.txt").is_err());
480
481        // Too long filename
482        let long_name = "a".repeat(300);
483        assert!(validate_filename(&long_name).is_err());
484
485        // Empty or whitespace-only
486        assert!(validate_filename("").is_err());
487        assert!(validate_filename("   ").is_err());
488
489        // Hidden files without known extensions
490        assert!(validate_filename(".hidden").is_err());
491    }
492
493    #[test]
494    fn test_validate_filename_sanitization() {
495        // Should remove dangerous characters but keep safe ones
496        let result = validate_filename("test@#$%file.txt").unwrap();
497        assert_eq!(result, "testfile.txt");
498
499        // Should preserve safe characters
500        let result = validate_filename("my-file_123.sam").unwrap();
501        assert_eq!(result, "my-file_123.sam");
502    }
503
504    #[test]
505    fn test_validate_file_format_bam() {
506        let bam_content = b"BAM\x01test_content";
507        assert!(validate_file_format(bam_content, FileFormat::Bam));
508
509        let invalid_bam = b"NOTBAM\x01";
510        assert!(!validate_file_format(invalid_bam, FileFormat::Bam));
511    }
512
513    #[test]
514    fn test_validate_file_format_cram() {
515        let cram_content = b"CRAMtest_content";
516        assert!(validate_file_format(cram_content, FileFormat::Cram));
517
518        let invalid_cram = b"NOTCRAM";
519        assert!(!validate_file_format(invalid_cram, FileFormat::Cram));
520    }
521
522    #[test]
523    fn test_validate_file_format_vcf() {
524        let vcf_content = b"##fileformat=VCFv4.2\n##contig=<ID=chr1>";
525        assert!(validate_file_format(vcf_content, FileFormat::Vcf));
526
527        let invalid_vcf = b"@SQ\tSN:chr1\tLN:123";
528        assert!(!validate_file_format(invalid_vcf, FileFormat::Vcf));
529    }
530
531    #[test]
532    fn test_validate_file_format_sam() {
533        let sam_content = b"@SQ\tSN:chr1\tLN:123456";
534        assert!(validate_file_format(sam_content, FileFormat::Sam));
535
536        let sam_content2 = b"@HD\tVN:1.0\tSO:coordinate";
537        assert!(validate_file_format(sam_content2, FileFormat::Sam));
538    }
539
540    #[test]
541    fn test_validate_file_content_text() {
542        let valid_text = b"@SQ\tSN:chr1\tLN:123456\n@SQ\tSN:chr2\tLN:654321";
543        assert!(validate_file_content(valid_text, true).is_ok());
544
545        // Too much binary data for text format
546        let binary_data = vec![0u8; 1000];
547        assert!(validate_file_content(&binary_data, true).is_err());
548
549        // Empty content
550        assert!(validate_file_content(b"", true).is_err());
551    }
552
553    #[test]
554    fn test_validate_file_content_binary() {
555        let binary_data = vec![0xABu8; 100];
556        assert!(validate_file_content(&binary_data, false).is_ok());
557
558        // Empty content still invalid for binary
559        assert!(validate_file_content(b"", false).is_err());
560    }
561
562    #[test]
563    fn test_validate_upload_complete() {
564        let sam_content = b"@SQ\tSN:chr1\tLN:123456";
565
566        // Valid upload with filename
567        let result = validate_upload(Some("test.sam"), sam_content, FileFormat::Sam);
568        assert!(result.is_ok());
569        assert_eq!(result.unwrap().unwrap(), "test.sam");
570
571        // Valid upload without filename
572        let result = validate_upload(None, sam_content, FileFormat::Sam);
573        assert!(result.is_ok());
574        assert!(result.unwrap().is_none());
575
576        // Invalid filename
577        let result = validate_upload(Some("../etc/passwd"), sam_content, FileFormat::Sam);
578        assert!(result.is_err());
579
580        // Format mismatch
581        let bam_content = b"BAM\x01test";
582        let result = validate_upload(Some("test.sam"), bam_content, FileFormat::Sam);
583        assert!(result.is_err());
584    }
585
586    #[test]
587    fn test_has_known_extension() {
588        assert!(has_known_extension(".sam"));
589        assert!(has_known_extension(".bam"));
590        assert!(has_known_extension(".vcf.gz"));
591        assert!(has_known_extension("test.assembly_report.txt"));
592
593        assert!(!has_known_extension(".exe"));
594        assert!(!has_known_extension(".hidden"));
595        assert!(!has_known_extension(".config"));
596    }
597}