1use crate::{ProgressEntry, Pusher};
2use bytes::{Bytes, BytesMut};
3use std::collections::{BTreeMap, btree_map::Entry};
4
5#[derive(Debug)]
7pub struct CacheMergePusher<P> {
8 inner: P,
9 cache: BTreeMap<u64, Bytes>,
10 cache_size: usize,
11 high_watermark: usize,
12 low_watermark: usize,
13}
14
15impl<P: Pusher> CacheMergePusher<P> {
16 pub const fn new(inner: P, high_watermark: usize, low_watermark: usize) -> Self {
17 Self {
18 inner,
19 cache: BTreeMap::new(),
20 cache_size: 0,
21 high_watermark,
22 low_watermark,
23 }
24 }
25
26 fn evict_until(&mut self, target_size: usize) -> Result<(), P::Error> {
27 if self.cache_size <= target_size {
28 return Ok(());
29 }
30
31 let mut runs: Vec<(u64, usize)> = Vec::with_capacity(self.cache.len());
32 let mut curr_start = None;
33 let mut curr_len = 0;
34 let mut expected_next = 0;
35
36 for (&start, bytes) in &self.cache {
37 let len = bytes.len();
38 if let Some(c_start) = curr_start {
39 if start == expected_next {
40 curr_len += len;
41 expected_next += len as u64;
42 } else {
43 runs.push((c_start, curr_len));
44 curr_start = Some(start);
45 curr_len = len;
46 expected_next = start + len as u64;
47 }
48 } else {
49 curr_start = Some(start);
50 curr_len = len;
51 expected_next = start + len as u64;
52 }
53 }
54 if let Some(c_start) = curr_start {
55 runs.push((c_start, curr_len));
56 }
57 runs.sort_unstable_by_key(|&(_, len)| std::cmp::Reverse(len));
58
59 let mut curr_buf = BytesMut::with_capacity(self.cache_size);
60 let mut err = None;
61 for (start, total_len) in runs {
62 let need_push = err.is_none() && self.cache_size > target_size;
63 let first_bytes = match self.cache.entry(start) {
64 Entry::Occupied(entry) => {
65 let is_merged = entry.get().len() == total_len;
66 if !need_push && is_merged {
67 continue;
68 }
69 entry.remove()
70 }
71 Entry::Vacant(_) => unreachable!(),
72 };
73 let chunk = if first_bytes.len() == total_len {
74 first_bytes
75 } else {
76 curr_buf.extend_from_slice(&first_bytes);
77 let mut curr_key = start + first_bytes.len() as u64;
78 let end = start + total_len as u64;
79 while curr_key < end {
80 let bytes = self.cache.remove(&curr_key).unwrap();
81 curr_buf.extend_from_slice(&bytes);
82 curr_key += bytes.len() as u64;
83 }
84 curr_buf.split().freeze()
85 };
86 if need_push {
87 let end = start + total_len as u64;
88 let range = start..end;
89 self.cache_size -= total_len;
90 if let Err((e, ret_bytes)) = self.inner.push(&range, chunk) {
91 err = Some(e);
92 if !ret_bytes.is_empty() {
93 self.cache_size += ret_bytes.len();
94 let retry_start = start + (total_len - ret_bytes.len()) as u64;
95 self.cache.insert(retry_start, ret_bytes);
96 }
97 }
98 } else {
99 self.cache.insert(start, chunk);
100 }
101 }
102 err.map_or(Ok(()), Err)
103 }
104}
105
106impl<P: Pusher> Pusher for CacheMergePusher<P> {
107 type Error = P::Error;
108
109 fn push(&mut self, range: &ProgressEntry, bytes: Bytes) -> Result<(), (Self::Error, Bytes)> {
110 if bytes.is_empty() {
111 return Ok(());
112 }
113
114 self.cache_size += bytes.len();
115 if let Some(old_bytes) = self.cache.insert(range.start, bytes) {
116 self.cache_size -= old_bytes.len();
117 }
118
119 if self.cache_size >= self.high_watermark
120 && let Err(e) = self.evict_until(self.low_watermark)
121 {
122 return Err((e, Bytes::new()));
123 }
124
125 Ok(())
126 }
127
128 fn flush(&mut self) -> Result<(), Self::Error> {
129 self.evict_until(0)?;
130 self.inner.flush()
131 }
132}