gem_index_filter/
filter.rs

1use std::collections::HashSet;
2use std::io::{BufRead, BufReader, Read, Write};
3use sha2::{Sha256, Sha512, Digest};
4
5/// Filtering mode for gem selection
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum FilterMode<'a> {
8    /// Pass through all gems (no filtering)
9    Passthrough,
10    /// Include only gems in the allowlist
11    Allow(&'a HashSet<&'a str>),
12    /// Exclude gems in the blocklist
13    Block(&'a HashSet<&'a str>),
14}
15
16/// Version output mode
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum VersionOutput {
19    /// Preserve original version information
20    Preserve,
21    /// Strip versions, replacing with '0'
22    Strip,
23}
24
25/// Supported digest algorithms for checksum computation
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum DigestAlgorithm {
28    /// SHA-256 checksum
29    Sha256,
30    /// SHA-512 checksum
31    Sha512,
32}
33
34/// Internal enum for holding active digest state
35enum DigestState {
36    Sha256(Sha256),
37    Sha512(Sha512),
38}
39
40/// Writer wrapper that computes digest of data as it's written
41/// This enables streaming checksum computation with zero buffering
42pub struct DigestWriter<'a, W: Write> {
43    inner: &'a mut W,
44    state: DigestState,
45}
46
47impl<'a, W: Write> DigestWriter<'a, W> {
48    /// Create a new DigestWriter with the specified algorithm
49    pub fn new(inner: &'a mut W, algorithm: DigestAlgorithm) -> Self {
50        let state = match algorithm {
51            DigestAlgorithm::Sha256 => DigestState::Sha256(Sha256::new()),
52            DigestAlgorithm::Sha512 => DigestState::Sha512(Sha512::new()),
53        };
54        DigestWriter { inner, state }
55    }
56
57    /// Finalize the digest and return the hex-encoded checksum
58    pub fn finalize(self) -> String {
59        match self.state {
60            DigestState::Sha256(hasher) => hex::encode(hasher.finalize()),
61            DigestState::Sha512(hasher) => hex::encode(hasher.finalize()),
62        }
63    }
64}
65
66impl<'a, W: Write> Write for DigestWriter<'a, W> {
67    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
68        // Update digest with the data
69        match &mut self.state {
70            DigestState::Sha256(hasher) => hasher.update(buf),
71            DigestState::Sha512(hasher) => hasher.update(buf),
72        }
73        // Write to underlying writer
74        self.inner.write(buf)
75    }
76
77    fn flush(&mut self) -> std::io::Result<()> {
78        self.inner.flush()
79    }
80}
81/// Stream and filter versions file by first word (gem name) with zero memory retention
82///
83/// This function:
84/// - Reads input line by line
85/// - Passes through metadata until "---" separator
86/// - Applies filtering based on mode (Allow/Block/Passthrough)
87/// - Immediately writes matching lines to output
88/// - Optionally strips version information, replacing with "0"
89/// - Optionally computes a checksum of the filtered output
90/// - Ignores everything after the first word until newline
91/// - Retains only the current line buffer in memory
92///
93/// Returns:
94/// - `Ok(None)` if no digest algorithm was specified
95/// - `Ok(Some(hex_string))` if digest was computed
96pub fn filter_versions_streaming<R: Read, W: Write>(
97    input: R,
98    output: &mut W,
99    mode: FilterMode,
100version_output: VersionOutput,
101    digest_algorithm: Option<DigestAlgorithm>,
102) -> std::io::Result<Option<String>> {
103    let mut reader = BufReader::new(input);
104
105// Wrap output in DigestWriter if checksum is requested
106    match digest_algorithm {
107        Some(algorithm) => {
108            // Wrap output writer to compute digest as data streams through
109            let mut digest_writer = DigestWriter::new(output, algorithm);
110
111            // Pass through metadata until separator "---"
112            pass_through_metadata(&mut reader, &mut digest_writer)?;
113
114            // Branch to specialized filter function based on mode
115            // This hoists the mode check outside the hot loop for performance
116            match mode {
117                FilterMode::Passthrough => process_passthrough(&mut reader, &mut digest_writer, version_output)?,
118                FilterMode::Allow(allowlist) => process_filtered(&mut reader, &mut digest_writer, allowlist, true, version_output)?,
119                FilterMode::Block(blocklist) => process_filtered(&mut reader, &mut digest_writer, blocklist, false, version_output)?,
120            }
121
122            // Finalize digest and return hex string
123            Ok(Some(digest_writer.finalize()))
124        }
125        None => {
126            // No digest requested, use output directly
127            // Pass through metadata until separator "---"
128            pass_through_metadata(&mut reader, output)?;
129
130            // Branch to specialized filter function based on mode
131            match mode {
132                FilterMode::Passthrough => process_passthrough(&mut reader, output, version_output)?,
133                FilterMode::Allow(allowlist) => process_filtered(&mut reader, output, allowlist, true, version_output)?,
134                FilterMode::Block(blocklist) => process_filtered(&mut reader, output, blocklist, false, version_output)?,
135            }
136
137            Ok(None)
138        }
139    }
140}
141
142/// Pass through metadata lines until the "---" separator
143fn pass_through_metadata<R: Read, W: Write>(
144    reader: &mut BufReader<R>,
145    output: &mut W,
146) -> std::io::Result<()> {
147    let mut line = String::new();
148
149    loop {
150        line.clear();
151        let n = reader.read_line(&mut line)?;
152        if n == 0 {
153            return Err(std::io::Error::new(
154                std::io::ErrorKind::InvalidData,
155                "No separator found in versions file",
156            ));
157        }
158
159        output.write_all(line.as_bytes())?;
160
161        if line.trim() == "---" {
162            break;
163        }
164    }
165
166    Ok(())
167}
168
169/// Process all gems without filtering
170fn process_passthrough<R: Read, W: Write>(
171    reader: &mut BufReader<R>,
172    output: &mut W,
173    version_output: VersionOutput,
174) -> std::io::Result<()> {
175    let mut line = String::new();
176
177    loop {
178        line.clear();
179        let n = reader.read_line(&mut line)?;
180        if n == 0 {
181            break; // EOF
182        }
183
184        let trimmed = line.trim();
185        if trimmed.is_empty() {
186            continue;
187        }
188
189        match version_output {
190            VersionOutput::Strip => write_gem_line_stripped(trimmed, output)?,
191            VersionOutput::Preserve => output.write_all(line.as_bytes())?,
192        }
193    }
194
195    Ok(())
196}
197
198/// Process gems with filtering based on gemlist membership
199///
200/// When `include_on_match` is true (Allow mode): includes gems where gemlist.contains(gemname) == true
201/// When `include_on_match` is false (Block mode): includes gems where gemlist.contains(gemname) == false
202fn process_filtered<R: Read, W: Write>(
203    reader: &mut BufReader<R>,
204    output: &mut W,
205    gemlist: &HashSet<&str>,
206    include_on_match: bool,
207    version_output: VersionOutput,
208) -> std::io::Result<()> {
209    let mut line = String::new();
210
211    loop {
212        line.clear();
213        let n = reader.read_line(&mut line)?;
214        if n == 0 {
215            break; // EOF
216        }
217
218        let trimmed = line.trim();
219        if trimmed.is_empty() {
220            continue;
221        }
222
223        // Extract first word (gem name) and check gemlist
224        if let Some(gem_name) = extract_gem_name(trimmed) {
225            let is_in_list = gemlist.contains(gem_name);
226            if is_in_list == include_on_match {
227                write_gem_line(trimmed, &line, output, version_output)?;
228            }
229        }
230    }
231
232    Ok(())
233}
234
235/// Extract gem name (first word) from a gem line
236#[inline]
237fn extract_gem_name(line: &str) -> Option<&str> {
238    line.find(' ').map(|space_pos| &line[..space_pos])
239}
240
241/// Write a gem line to output, optionally stripping version information
242#[inline]
243fn write_gem_line<W: Write>(
244    trimmed: &str,
245    original_line: &str,
246    output: &mut W,
247    version_output: VersionOutput,
248) -> std::io::Result<()> {
249    match version_output {
250        VersionOutput::Strip => write_gem_line_stripped(trimmed, output),
251        VersionOutput::Preserve => output.write_all(original_line.as_bytes()),
252    }
253}
254
255/// Helper function to write a gem line with stripped version info
256#[inline]
257fn write_gem_line_stripped<W: Write>(trimmed: &str, output: &mut W) -> std::io::Result<()> {
258    // Parse and reconstruct line: gemname versions md5 [extra...] -> gemname 0 md5 [extra...]
259    let parts: Vec<&str> = trimmed.split_whitespace().collect();
260    if parts.len() >= 3 {
261        // Write: gemname 0 md5 [any additional fields]
262        write!(output, "{} 0", parts[0])?;
263        for part in &parts[2..] {
264            write!(output, " {}", part)?;
265        }
266        writeln!(output)
267    } else {
268        // Fallback for malformed lines - write as-is with newline
269        writeln!(output, "{}", trimmed)
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_streaming_filter() {
279        let input = r#"created_at: 2024-04-01T00:00:05Z
280---
281rails 7.0.0 abc123
282activerecord 7.0.0 def456
283sinatra 3.0.0 ghi789
284rails 7.0.1 xyz999
285"#;
286
287        let mut allowlist = HashSet::new();
288        allowlist.insert("rails");
289        allowlist.insert("sinatra");
290
291        let mut output = Vec::new();
292let digest = filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Allow(&allowlist), VersionOutput::Preserve, None).unwrap();
293        assert!(digest.is_none());
294
295        let result = String::from_utf8(output).unwrap();
296
297        // Should contain metadata
298        assert!(result.contains("created_at: 2024-04-01T00:00:05Z"));
299        assert!(result.contains("---"));
300
301        // Should contain allowlisted gems
302        assert!(result.contains("rails 7.0.0 abc123"));
303        assert!(result.contains("sinatra 3.0.0 ghi789"));
304        assert!(result.contains("rails 7.0.1 xyz999"));
305
306        // Should NOT contain filtered gem
307        assert!(!result.contains("activerecord"));
308    }
309
310    #[test]
311    fn test_streaming_preserves_exact_format() {
312        let input = r#"created_at: 2024-04-01T00:00:05Z
313---
314rails 7.0.0 abc123
315"#;
316
317        let mut allowlist = HashSet::new();
318        allowlist.insert("rails");
319
320        let mut output = Vec::new();
321filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Allow(&allowlist), VersionOutput::Preserve, None).unwrap();
322
323        let result = String::from_utf8(output).unwrap();
324        assert_eq!(result, input); // Should be identical for all-included case
325    }
326
327    #[test]
328    fn test_streaming_empty_allowlist() {
329        let input = r#"created_at: 2024-04-01T00:00:05Z
330---
331rails 7.0.0 abc123
332sinatra 3.0.0 ghi789
333"#;
334
335        let allowlist = HashSet::new();
336
337        let mut output = Vec::new();
338filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Allow(&allowlist), VersionOutput::Preserve, None).unwrap();
339
340        let result = String::from_utf8(output).unwrap();
341
342        // Should only contain metadata
343        assert!(result.contains("created_at"));
344        assert!(result.contains("---"));
345        assert!(!result.contains("rails"));
346        assert!(!result.contains("sinatra"));
347    }
348
349    #[test]
350    fn test_passthrough_mode() {
351        let input = r#"created_at: 2024-04-01T00:00:05Z
352---
353rails 7.0.0 abc123
354activerecord 7.0.0 def456
355sinatra 3.0.0 ghi789
356"#;
357
358        let mut output = Vec::new();
359filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Passthrough, VersionOutput::Preserve, None).unwrap();
360
361        let result = String::from_utf8(output).unwrap();
362
363        // Should contain metadata
364        assert!(result.contains("created_at: 2024-04-01T00:00:05Z"));
365        assert!(result.contains("---"));
366
367        // Should contain all gems
368        assert!(result.contains("rails 7.0.0 abc123"));
369        assert!(result.contains("activerecord 7.0.0 def456"));
370        assert!(result.contains("sinatra 3.0.0 ghi789"));
371    }
372
373    #[test]
374    fn test_block_mode() {
375        let input = r#"created_at: 2024-04-01T00:00:05Z
376---
377rails 7.0.0 abc123
378activerecord 7.0.0 def456
379sinatra 3.0.0 ghi789
380puma 5.0.0 xyz999
381"#;
382
383        let mut blocklist = HashSet::new();
384        blocklist.insert("activerecord");
385        blocklist.insert("puma");
386
387        let mut output = Vec::new();
388filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Block(&blocklist), VersionOutput::Preserve, None).unwrap();
389
390        let result = String::from_utf8(output).unwrap();
391
392        // Should contain metadata
393        assert!(result.contains("created_at: 2024-04-01T00:00:05Z"));
394        assert!(result.contains("---"));
395
396        // Should contain non-blocked gems
397        assert!(result.contains("rails 7.0.0 abc123"));
398        assert!(result.contains("sinatra 3.0.0 ghi789"));
399
400        // Should NOT contain blocked gems
401        assert!(!result.contains("activerecord"));
402        assert!(!result.contains("puma"));
403    }
404
405    #[test]
406    fn test_block_mode_with_strip_versions() {
407        let input = r#"created_at: 2024-04-01T00:00:05Z
408---
409rails 7.0.0,7.0.1 abc123
410activerecord 7.0.0 def456
411sinatra 3.0.0 ghi789
412"#;
413
414        let mut blocklist = HashSet::new();
415        blocklist.insert("activerecord");
416
417        let mut output = Vec::new();
418filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Block(&blocklist), VersionOutput::Strip, None).unwrap();
419
420        let result = String::from_utf8(output).unwrap();
421
422        // Should contain stripped versions for non-blocked gems
423        assert!(result.contains("rails 0 abc123"));
424        assert!(result.contains("sinatra 0 ghi789"));
425
426        // Should NOT contain blocked gem
427        assert!(!result.contains("activerecord"));
428    }
429
430    #[test]
431    fn test_strip_versions_preserves_extra_fields() {
432        let input = r#"created_at: 2024-04-01T00:00:05Z
433---
434rails 7.0.0 abc123 extra1 extra2
435sinatra 3.0.0 def456
436puma 5.0.0 ghi789 extra_field
437"#;
438
439        let mut allowlist = HashSet::new();
440        allowlist.insert("rails");
441        allowlist.insert("puma");
442
443        let mut output = Vec::new();
444filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Allow(&allowlist), VersionOutput::Strip, None).unwrap();
445
446        let result = String::from_utf8(output).unwrap();
447
448        // Should preserve extra fields after md5 hash
449        assert!(result.contains("rails 0 abc123 extra1 extra2"));
450        assert!(result.contains("puma 0 ghi789 extra_field"));
451
452        // Should NOT contain filtered gem
453        assert!(!result.contains("sinatra"));
454    }
455
456    #[test]
457    fn test_strip_versions() {
458        let input = r#"created_at: 2024-04-01T00:00:05Z
459---
460rails 7.0.0,7.0.1,7.0.2 abc123def456
461activerecord 7.0.0 def456
462sinatra 3.0.0,3.0.1 123456789abc
463rails 7.0.3,7.0.4 updated999888
464"#;
465
466        let mut allowlist = HashSet::new();
467        allowlist.insert("rails");
468        allowlist.insert("sinatra");
469
470        let mut output = Vec::new();
471filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Allow(&allowlist), VersionOutput::Strip, None).unwrap();
472
473        let result = String::from_utf8(output).unwrap();
474
475        // Should contain metadata
476        assert!(result.contains("created_at: 2024-04-01T00:00:05Z"));
477        assert!(result.contains("---"));
478
479        // Should contain stripped versions (0 instead of version list)
480        assert!(result.contains("rails 0 abc123def456"));
481        assert!(result.contains("rails 0 updated999888"));
482        assert!(result.contains("sinatra 0 123456789abc"));
483
484        // Should NOT contain original version strings
485        assert!(!result.contains("7.0.0,7.0.1,7.0.2"));
486        assert!(!result.contains("7.0.3,7.0.4"));
487        assert!(!result.contains("3.0.0,3.0.1"));
488
489        // Should NOT contain filtered gem
490        assert!(!result.contains("activerecord"));
491    }
492
493    #[test]
494    fn test_strip_versions_preserves_order() {
495        let input = r#"created_at: 2024-04-01T00:00:05Z
496---
497zebra 1.0.0 aaa111
498apple 1.0.0 bbb222
499mango 1.0.0 ccc333
500banana 1.0.0 ddd444
501"#;
502
503        let mut allowlist = HashSet::new();
504        allowlist.insert("banana");
505        allowlist.insert("zebra");
506        allowlist.insert("mango");
507
508        let mut output = Vec::new();
509filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Allow(&allowlist), VersionOutput::Strip, None).unwrap();
510
511        let result = String::from_utf8(output).unwrap();
512
513        let lines: Vec<&str> = result.lines().collect();
514        let gem_lines: Vec<&str> = lines.iter().skip(2).copied().collect();
515
516        // Verify original order is preserved with stripped versions
517        assert_eq!(gem_lines.len(), 3);
518        assert_eq!(gem_lines[0], "zebra 0 aaa111");
519        assert_eq!(gem_lines[1], "mango 0 ccc333");
520        assert_eq!(gem_lines[2], "banana 0 ddd444");
521    }
522
523    #[test]
524    fn test_digest_sha256() {
525        let input = r#"created_at: 2024-04-01T00:00:05Z
526---
527rails 7.0.0 abc123
528sinatra 3.0.0 ghi789
529"#;
530
531        let mut allowlist = HashSet::new();
532        allowlist.insert("rails");
533
534        let mut output = Vec::new();
535        let digest = filter_versions_streaming(
536            input.as_bytes(),
537            &mut output,
538            FilterMode::Allow(&allowlist),
539            VersionOutput::Preserve,
540            Some(DigestAlgorithm::Sha256)
541        ).unwrap();
542
543        // Should return a digest
544        assert!(digest.is_some());
545        let digest_hex = digest.unwrap();
546
547        // SHA-256 produces 64 hex characters (32 bytes)
548        assert_eq!(digest_hex.len(), 64);
549
550        // Verify digest is hex
551        assert!(digest_hex.chars().all(|c| c.is_ascii_hexdigit()));
552
553        // Verify output content is correct
554        let result = String::from_utf8(output).unwrap();
555        assert!(result.contains("rails 7.0.0 abc123"));
556        assert!(!result.contains("sinatra"));
557    }
558
559
560    #[test]
561    fn test_digest_sha512() {
562        let input = r#"created_at: 2024-04-01T00:00:05Z
563---
564rails 7.0.0 abc123
565"#;
566
567        let mut output = Vec::new();
568        let digest = filter_versions_streaming(
569            input.as_bytes(),
570            &mut output,
571            FilterMode::Passthrough,
572            VersionOutput::Preserve,
573            Some(DigestAlgorithm::Sha512)
574        ).unwrap();
575
576        // Should return a digest
577        assert!(digest.is_some());
578        let digest_hex = digest.unwrap();
579
580        // SHA-512 produces 128 hex characters (64 bytes)
581        assert_eq!(digest_hex.len(), 128);
582
583        // Verify digest is hex
584        assert!(digest_hex.chars().all(|c| c.is_ascii_hexdigit()));
585    }
586
587    #[test]
588    fn test_digest_with_strip_versions() {
589        let input = r#"created_at: 2024-04-01T00:00:05Z
590---
591rails 7.0.0,7.0.1,7.0.2 abc123
592sinatra 3.0.0 def456
593"#;
594
595        let mut allowlist = HashSet::new();
596        allowlist.insert("rails");
597
598        let mut output = Vec::new();
599        let digest = filter_versions_streaming(
600            input.as_bytes(),
601            &mut output,
602            FilterMode::Allow(&allowlist),
603            VersionOutput::Strip,
604            Some(DigestAlgorithm::Sha256)
605        ).unwrap();
606
607        assert!(digest.is_some());
608        let result = String::from_utf8(output).unwrap();
609
610        // Verify stripped output
611        assert!(result.contains("rails 0 abc123"));
612
613        // Digest should be different from non-stripped version
614        let mut output2 = Vec::new();
615        let digest2 = filter_versions_streaming(
616            input.as_bytes(),
617            &mut output2,
618            FilterMode::Allow(&allowlist),
619            VersionOutput::Preserve,
620            Some(DigestAlgorithm::Sha256)
621        ).unwrap();
622
623        assert_ne!(digest.unwrap(), digest2.unwrap());
624    }
625
626    #[test]
627    fn test_digest_consistency() {
628        let input = r#"created_at: 2024-04-01T00:00:05Z
629---
630rails 7.0.0 abc123
631"#;
632
633        // Run twice with same input, should get same digest
634        let mut output1 = Vec::new();
635        let digest1 = filter_versions_streaming(
636            input.as_bytes(),
637            &mut output1,
638            FilterMode::Passthrough,
639            VersionOutput::Preserve,
640            Some(DigestAlgorithm::Sha256)
641        ).unwrap();
642
643        let mut output2 = Vec::new();
644        let digest2 = filter_versions_streaming(
645            input.as_bytes(),
646            &mut output2,
647            FilterMode::Passthrough,
648            VersionOutput::Preserve,
649            Some(DigestAlgorithm::Sha256)
650        ).unwrap();
651
652        assert_eq!(digest1, digest2);
653        assert_eq!(output1, output2);
654    }
655}