1use std::collections::HashSet;
2use std::io::{BufRead, BufReader, Read, Write};
3use sha2::{Sha256, Sha512, Digest};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum FilterMode<'a> {
8 Passthrough,
10 Allow(&'a HashSet<&'a str>),
12 Block(&'a HashSet<&'a str>),
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum VersionOutput {
19 Preserve,
21 Strip,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum DigestAlgorithm {
28 Sha256,
30 Sha512,
32}
33
34enum DigestState {
36 Sha256(Sha256),
37 Sha512(Sha512),
38}
39
40pub struct DigestWriter<'a, W: Write> {
43 inner: &'a mut W,
44 state: DigestState,
45}
46
47impl<'a, W: Write> DigestWriter<'a, W> {
48 pub fn new(inner: &'a mut W, algorithm: DigestAlgorithm) -> Self {
50 let state = match algorithm {
51 DigestAlgorithm::Sha256 => DigestState::Sha256(Sha256::new()),
52 DigestAlgorithm::Sha512 => DigestState::Sha512(Sha512::new()),
53 };
54 DigestWriter { inner, state }
55 }
56
57 pub fn finalize(self) -> String {
59 match self.state {
60 DigestState::Sha256(hasher) => hex::encode(hasher.finalize()),
61 DigestState::Sha512(hasher) => hex::encode(hasher.finalize()),
62 }
63 }
64}
65
66impl<'a, W: Write> Write for DigestWriter<'a, W> {
67 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
68 match &mut self.state {
70 DigestState::Sha256(hasher) => hasher.update(buf),
71 DigestState::Sha512(hasher) => hasher.update(buf),
72 }
73 self.inner.write(buf)
75 }
76
77 fn flush(&mut self) -> std::io::Result<()> {
78 self.inner.flush()
79 }
80}
81pub fn filter_versions_streaming<R: Read, W: Write>(
97 input: R,
98 output: &mut W,
99 mode: FilterMode,
100version_output: VersionOutput,
101 digest_algorithm: Option<DigestAlgorithm>,
102) -> std::io::Result<Option<String>> {
103 let mut reader = BufReader::new(input);
104
105match digest_algorithm {
107 Some(algorithm) => {
108 let mut digest_writer = DigestWriter::new(output, algorithm);
110
111 pass_through_metadata(&mut reader, &mut digest_writer)?;
113
114 match mode {
117 FilterMode::Passthrough => process_passthrough(&mut reader, &mut digest_writer, version_output)?,
118 FilterMode::Allow(allowlist) => process_filtered(&mut reader, &mut digest_writer, allowlist, true, version_output)?,
119 FilterMode::Block(blocklist) => process_filtered(&mut reader, &mut digest_writer, blocklist, false, version_output)?,
120 }
121
122 Ok(Some(digest_writer.finalize()))
124 }
125 None => {
126 pass_through_metadata(&mut reader, output)?;
129
130 match mode {
132 FilterMode::Passthrough => process_passthrough(&mut reader, output, version_output)?,
133 FilterMode::Allow(allowlist) => process_filtered(&mut reader, output, allowlist, true, version_output)?,
134 FilterMode::Block(blocklist) => process_filtered(&mut reader, output, blocklist, false, version_output)?,
135 }
136
137 Ok(None)
138 }
139 }
140}
141
142fn pass_through_metadata<R: Read, W: Write>(
144 reader: &mut BufReader<R>,
145 output: &mut W,
146) -> std::io::Result<()> {
147 let mut line = String::new();
148
149 loop {
150 line.clear();
151 let n = reader.read_line(&mut line)?;
152 if n == 0 {
153 return Err(std::io::Error::new(
154 std::io::ErrorKind::InvalidData,
155 "No separator found in versions file",
156 ));
157 }
158
159 output.write_all(line.as_bytes())?;
160
161 if line.trim() == "---" {
162 break;
163 }
164 }
165
166 Ok(())
167}
168
169fn process_passthrough<R: Read, W: Write>(
171 reader: &mut BufReader<R>,
172 output: &mut W,
173 version_output: VersionOutput,
174) -> std::io::Result<()> {
175 let mut line = String::new();
176
177 loop {
178 line.clear();
179 let n = reader.read_line(&mut line)?;
180 if n == 0 {
181 break; }
183
184 let trimmed = line.trim();
185 if trimmed.is_empty() {
186 continue;
187 }
188
189 match version_output {
190 VersionOutput::Strip => write_gem_line_stripped(trimmed, output)?,
191 VersionOutput::Preserve => output.write_all(line.as_bytes())?,
192 }
193 }
194
195 Ok(())
196}
197
198fn process_filtered<R: Read, W: Write>(
203 reader: &mut BufReader<R>,
204 output: &mut W,
205 gemlist: &HashSet<&str>,
206 include_on_match: bool,
207 version_output: VersionOutput,
208) -> std::io::Result<()> {
209 let mut line = String::new();
210
211 loop {
212 line.clear();
213 let n = reader.read_line(&mut line)?;
214 if n == 0 {
215 break; }
217
218 let trimmed = line.trim();
219 if trimmed.is_empty() {
220 continue;
221 }
222
223 if let Some(gem_name) = extract_gem_name(trimmed) {
225 let is_in_list = gemlist.contains(gem_name);
226 if is_in_list == include_on_match {
227 write_gem_line(trimmed, &line, output, version_output)?;
228 }
229 }
230 }
231
232 Ok(())
233}
234
235#[inline]
237fn extract_gem_name(line: &str) -> Option<&str> {
238 line.find(' ').map(|space_pos| &line[..space_pos])
239}
240
241#[inline]
243fn write_gem_line<W: Write>(
244 trimmed: &str,
245 original_line: &str,
246 output: &mut W,
247 version_output: VersionOutput,
248) -> std::io::Result<()> {
249 match version_output {
250 VersionOutput::Strip => write_gem_line_stripped(trimmed, output),
251 VersionOutput::Preserve => output.write_all(original_line.as_bytes()),
252 }
253}
254
255#[inline]
257fn write_gem_line_stripped<W: Write>(trimmed: &str, output: &mut W) -> std::io::Result<()> {
258 let parts: Vec<&str> = trimmed.split_whitespace().collect();
260 if parts.len() >= 3 {
261 write!(output, "{} 0", parts[0])?;
263 for part in &parts[2..] {
264 write!(output, " {}", part)?;
265 }
266 writeln!(output)
267 } else {
268 writeln!(output, "{}", trimmed)
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_streaming_filter() {
279 let input = r#"created_at: 2024-04-01T00:00:05Z
280---
281rails 7.0.0 abc123
282activerecord 7.0.0 def456
283sinatra 3.0.0 ghi789
284rails 7.0.1 xyz999
285"#;
286
287 let mut allowlist = HashSet::new();
288 allowlist.insert("rails");
289 allowlist.insert("sinatra");
290
291 let mut output = Vec::new();
292let digest = filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Allow(&allowlist), VersionOutput::Preserve, None).unwrap();
293 assert!(digest.is_none());
294
295 let result = String::from_utf8(output).unwrap();
296
297 assert!(result.contains("created_at: 2024-04-01T00:00:05Z"));
299 assert!(result.contains("---"));
300
301 assert!(result.contains("rails 7.0.0 abc123"));
303 assert!(result.contains("sinatra 3.0.0 ghi789"));
304 assert!(result.contains("rails 7.0.1 xyz999"));
305
306 assert!(!result.contains("activerecord"));
308 }
309
310 #[test]
311 fn test_streaming_preserves_exact_format() {
312 let input = r#"created_at: 2024-04-01T00:00:05Z
313---
314rails 7.0.0 abc123
315"#;
316
317 let mut allowlist = HashSet::new();
318 allowlist.insert("rails");
319
320 let mut output = Vec::new();
321filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Allow(&allowlist), VersionOutput::Preserve, None).unwrap();
322
323 let result = String::from_utf8(output).unwrap();
324 assert_eq!(result, input); }
326
327 #[test]
328 fn test_streaming_empty_allowlist() {
329 let input = r#"created_at: 2024-04-01T00:00:05Z
330---
331rails 7.0.0 abc123
332sinatra 3.0.0 ghi789
333"#;
334
335 let allowlist = HashSet::new();
336
337 let mut output = Vec::new();
338filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Allow(&allowlist), VersionOutput::Preserve, None).unwrap();
339
340 let result = String::from_utf8(output).unwrap();
341
342 assert!(result.contains("created_at"));
344 assert!(result.contains("---"));
345 assert!(!result.contains("rails"));
346 assert!(!result.contains("sinatra"));
347 }
348
349 #[test]
350 fn test_passthrough_mode() {
351 let input = r#"created_at: 2024-04-01T00:00:05Z
352---
353rails 7.0.0 abc123
354activerecord 7.0.0 def456
355sinatra 3.0.0 ghi789
356"#;
357
358 let mut output = Vec::new();
359filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Passthrough, VersionOutput::Preserve, None).unwrap();
360
361 let result = String::from_utf8(output).unwrap();
362
363 assert!(result.contains("created_at: 2024-04-01T00:00:05Z"));
365 assert!(result.contains("---"));
366
367 assert!(result.contains("rails 7.0.0 abc123"));
369 assert!(result.contains("activerecord 7.0.0 def456"));
370 assert!(result.contains("sinatra 3.0.0 ghi789"));
371 }
372
373 #[test]
374 fn test_block_mode() {
375 let input = r#"created_at: 2024-04-01T00:00:05Z
376---
377rails 7.0.0 abc123
378activerecord 7.0.0 def456
379sinatra 3.0.0 ghi789
380puma 5.0.0 xyz999
381"#;
382
383 let mut blocklist = HashSet::new();
384 blocklist.insert("activerecord");
385 blocklist.insert("puma");
386
387 let mut output = Vec::new();
388filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Block(&blocklist), VersionOutput::Preserve, None).unwrap();
389
390 let result = String::from_utf8(output).unwrap();
391
392 assert!(result.contains("created_at: 2024-04-01T00:00:05Z"));
394 assert!(result.contains("---"));
395
396 assert!(result.contains("rails 7.0.0 abc123"));
398 assert!(result.contains("sinatra 3.0.0 ghi789"));
399
400 assert!(!result.contains("activerecord"));
402 assert!(!result.contains("puma"));
403 }
404
405 #[test]
406 fn test_block_mode_with_strip_versions() {
407 let input = r#"created_at: 2024-04-01T00:00:05Z
408---
409rails 7.0.0,7.0.1 abc123
410activerecord 7.0.0 def456
411sinatra 3.0.0 ghi789
412"#;
413
414 let mut blocklist = HashSet::new();
415 blocklist.insert("activerecord");
416
417 let mut output = Vec::new();
418filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Block(&blocklist), VersionOutput::Strip, None).unwrap();
419
420 let result = String::from_utf8(output).unwrap();
421
422 assert!(result.contains("rails 0 abc123"));
424 assert!(result.contains("sinatra 0 ghi789"));
425
426 assert!(!result.contains("activerecord"));
428 }
429
430 #[test]
431 fn test_strip_versions_preserves_extra_fields() {
432 let input = r#"created_at: 2024-04-01T00:00:05Z
433---
434rails 7.0.0 abc123 extra1 extra2
435sinatra 3.0.0 def456
436puma 5.0.0 ghi789 extra_field
437"#;
438
439 let mut allowlist = HashSet::new();
440 allowlist.insert("rails");
441 allowlist.insert("puma");
442
443 let mut output = Vec::new();
444filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Allow(&allowlist), VersionOutput::Strip, None).unwrap();
445
446 let result = String::from_utf8(output).unwrap();
447
448 assert!(result.contains("rails 0 abc123 extra1 extra2"));
450 assert!(result.contains("puma 0 ghi789 extra_field"));
451
452 assert!(!result.contains("sinatra"));
454 }
455
456 #[test]
457 fn test_strip_versions() {
458 let input = r#"created_at: 2024-04-01T00:00:05Z
459---
460rails 7.0.0,7.0.1,7.0.2 abc123def456
461activerecord 7.0.0 def456
462sinatra 3.0.0,3.0.1 123456789abc
463rails 7.0.3,7.0.4 updated999888
464"#;
465
466 let mut allowlist = HashSet::new();
467 allowlist.insert("rails");
468 allowlist.insert("sinatra");
469
470 let mut output = Vec::new();
471filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Allow(&allowlist), VersionOutput::Strip, None).unwrap();
472
473 let result = String::from_utf8(output).unwrap();
474
475 assert!(result.contains("created_at: 2024-04-01T00:00:05Z"));
477 assert!(result.contains("---"));
478
479 assert!(result.contains("rails 0 abc123def456"));
481 assert!(result.contains("rails 0 updated999888"));
482 assert!(result.contains("sinatra 0 123456789abc"));
483
484 assert!(!result.contains("7.0.0,7.0.1,7.0.2"));
486 assert!(!result.contains("7.0.3,7.0.4"));
487 assert!(!result.contains("3.0.0,3.0.1"));
488
489 assert!(!result.contains("activerecord"));
491 }
492
493 #[test]
494 fn test_strip_versions_preserves_order() {
495 let input = r#"created_at: 2024-04-01T00:00:05Z
496---
497zebra 1.0.0 aaa111
498apple 1.0.0 bbb222
499mango 1.0.0 ccc333
500banana 1.0.0 ddd444
501"#;
502
503 let mut allowlist = HashSet::new();
504 allowlist.insert("banana");
505 allowlist.insert("zebra");
506 allowlist.insert("mango");
507
508 let mut output = Vec::new();
509filter_versions_streaming(input.as_bytes(), &mut output, FilterMode::Allow(&allowlist), VersionOutput::Strip, None).unwrap();
510
511 let result = String::from_utf8(output).unwrap();
512
513 let lines: Vec<&str> = result.lines().collect();
514 let gem_lines: Vec<&str> = lines.iter().skip(2).copied().collect();
515
516 assert_eq!(gem_lines.len(), 3);
518 assert_eq!(gem_lines[0], "zebra 0 aaa111");
519 assert_eq!(gem_lines[1], "mango 0 ccc333");
520 assert_eq!(gem_lines[2], "banana 0 ddd444");
521 }
522
523 #[test]
524 fn test_digest_sha256() {
525 let input = r#"created_at: 2024-04-01T00:00:05Z
526---
527rails 7.0.0 abc123
528sinatra 3.0.0 ghi789
529"#;
530
531 let mut allowlist = HashSet::new();
532 allowlist.insert("rails");
533
534 let mut output = Vec::new();
535 let digest = filter_versions_streaming(
536 input.as_bytes(),
537 &mut output,
538 FilterMode::Allow(&allowlist),
539 VersionOutput::Preserve,
540 Some(DigestAlgorithm::Sha256)
541 ).unwrap();
542
543 assert!(digest.is_some());
545 let digest_hex = digest.unwrap();
546
547 assert_eq!(digest_hex.len(), 64);
549
550 assert!(digest_hex.chars().all(|c| c.is_ascii_hexdigit()));
552
553 let result = String::from_utf8(output).unwrap();
555 assert!(result.contains("rails 7.0.0 abc123"));
556 assert!(!result.contains("sinatra"));
557 }
558
559
560 #[test]
561 fn test_digest_sha512() {
562 let input = r#"created_at: 2024-04-01T00:00:05Z
563---
564rails 7.0.0 abc123
565"#;
566
567 let mut output = Vec::new();
568 let digest = filter_versions_streaming(
569 input.as_bytes(),
570 &mut output,
571 FilterMode::Passthrough,
572 VersionOutput::Preserve,
573 Some(DigestAlgorithm::Sha512)
574 ).unwrap();
575
576 assert!(digest.is_some());
578 let digest_hex = digest.unwrap();
579
580 assert_eq!(digest_hex.len(), 128);
582
583 assert!(digest_hex.chars().all(|c| c.is_ascii_hexdigit()));
585 }
586
587 #[test]
588 fn test_digest_with_strip_versions() {
589 let input = r#"created_at: 2024-04-01T00:00:05Z
590---
591rails 7.0.0,7.0.1,7.0.2 abc123
592sinatra 3.0.0 def456
593"#;
594
595 let mut allowlist = HashSet::new();
596 allowlist.insert("rails");
597
598 let mut output = Vec::new();
599 let digest = filter_versions_streaming(
600 input.as_bytes(),
601 &mut output,
602 FilterMode::Allow(&allowlist),
603 VersionOutput::Strip,
604 Some(DigestAlgorithm::Sha256)
605 ).unwrap();
606
607 assert!(digest.is_some());
608 let result = String::from_utf8(output).unwrap();
609
610 assert!(result.contains("rails 0 abc123"));
612
613 let mut output2 = Vec::new();
615 let digest2 = filter_versions_streaming(
616 input.as_bytes(),
617 &mut output2,
618 FilterMode::Allow(&allowlist),
619 VersionOutput::Preserve,
620 Some(DigestAlgorithm::Sha256)
621 ).unwrap();
622
623 assert_ne!(digest.unwrap(), digest2.unwrap());
624 }
625
626 #[test]
627 fn test_digest_consistency() {
628 let input = r#"created_at: 2024-04-01T00:00:05Z
629---
630rails 7.0.0 abc123
631"#;
632
633 let mut output1 = Vec::new();
635 let digest1 = filter_versions_streaming(
636 input.as_bytes(),
637 &mut output1,
638 FilterMode::Passthrough,
639 VersionOutput::Preserve,
640 Some(DigestAlgorithm::Sha256)
641 ).unwrap();
642
643 let mut output2 = Vec::new();
644 let digest2 = filter_versions_streaming(
645 input.as_bytes(),
646 &mut output2,
647 FilterMode::Passthrough,
648 VersionOutput::Preserve,
649 Some(DigestAlgorithm::Sha256)
650 ).unwrap();
651
652 assert_eq!(digest1, digest2);
653 assert_eq!(output1, output2);
654 }
655}