Skip to main content

fast_pull/cache/
merge.rs

1use crate::{ProgressEntry, Pusher};
2use bytes::{Bytes, BytesMut};
3use std::collections::{BTreeMap, btree_map::Entry};
4
5/// 优先选择大块调用 push,并且会把大块合并成一个 Bytes
6#[derive(Debug)]
7pub struct CacheMergePusher<P> {
8    inner: P,
9    cache: BTreeMap<u64, Bytes>,
10    cache_size: usize,
11    high_watermark: usize,
12    low_watermark: usize,
13}
14
15impl<P: Pusher> CacheMergePusher<P> {
16    pub const fn new(inner: P, high_watermark: usize, low_watermark: usize) -> Self {
17        Self {
18            inner,
19            cache: BTreeMap::new(),
20            cache_size: 0,
21            high_watermark,
22            low_watermark,
23        }
24    }
25
26    fn evict_until(&mut self, target_size: usize) -> Result<(), P::Error> {
27        if self.cache_size <= target_size {
28            return Ok(());
29        }
30
31        let mut runs: Vec<(u64, usize)> = Vec::with_capacity(self.cache.len());
32        let mut curr_start = None;
33        let mut curr_len = 0;
34        let mut expected_next = 0;
35
36        for (&start, bytes) in &self.cache {
37            let len = bytes.len();
38            if let Some(c_start) = curr_start {
39                if start == expected_next {
40                    curr_len += len;
41                    expected_next += len as u64;
42                } else {
43                    runs.push((c_start, curr_len));
44                    curr_start = Some(start);
45                    curr_len = len;
46                    expected_next = start + len as u64;
47                }
48            } else {
49                curr_start = Some(start);
50                curr_len = len;
51                expected_next = start + len as u64;
52            }
53        }
54        if let Some(c_start) = curr_start {
55            runs.push((c_start, curr_len));
56        }
57        runs.sort_unstable_by_key(|&(_, len)| std::cmp::Reverse(len));
58
59        let mut curr_buf = BytesMut::with_capacity(self.cache_size);
60        let mut err = None;
61        for (start, total_len) in runs {
62            let need_push = err.is_none() && self.cache_size > target_size;
63            let first_bytes = match self.cache.entry(start) {
64                Entry::Occupied(entry) => {
65                    let is_merged = entry.get().len() == total_len;
66                    if !need_push && is_merged {
67                        continue;
68                    }
69                    entry.remove()
70                }
71                Entry::Vacant(_) => unreachable!(),
72            };
73            let chunk = if first_bytes.len() == total_len {
74                first_bytes
75            } else {
76                curr_buf.extend_from_slice(&first_bytes);
77                let mut curr_key = start + first_bytes.len() as u64;
78                let end = start + total_len as u64;
79                while curr_key < end {
80                    let bytes = self.cache.remove(&curr_key).unwrap();
81                    curr_buf.extend_from_slice(&bytes);
82                    curr_key += bytes.len() as u64;
83                }
84                curr_buf.split().freeze()
85            };
86            if need_push {
87                let end = start + total_len as u64;
88                let range = start..end;
89                self.cache_size -= total_len;
90                if let Err((e, ret_bytes)) = self.inner.push(&range, chunk) {
91                    err = Some(e);
92                    if !ret_bytes.is_empty() {
93                        self.cache_size += ret_bytes.len();
94                        let retry_start = start + (total_len - ret_bytes.len()) as u64;
95                        self.cache.insert(retry_start, ret_bytes);
96                    }
97                }
98            } else {
99                self.cache.insert(start, chunk);
100            }
101        }
102        err.map_or(Ok(()), Err)
103    }
104}
105
106impl<P: Pusher> Pusher for CacheMergePusher<P> {
107    type Error = P::Error;
108
109    fn push(&mut self, range: &ProgressEntry, bytes: Bytes) -> Result<(), (Self::Error, Bytes)> {
110        if bytes.is_empty() {
111            return Ok(());
112        }
113
114        self.cache_size += bytes.len();
115        if let Some(old_bytes) = self.cache.insert(range.start, bytes) {
116            self.cache_size -= old_bytes.len();
117        }
118
119        if self.cache_size >= self.high_watermark
120            && let Err(e) = self.evict_until(self.low_watermark)
121        {
122            return Err((e, Bytes::new()));
123        }
124
125        Ok(())
126    }
127
128    fn flush(&mut self) -> Result<(), Self::Error> {
129        self.evict_until(0)?;
130        self.inner.flush()
131    }
132}