Skip to main content

aea_tools/
stream.rs

1use crate::reader::AeaReader;
2use anyhow::Result;
3use lru::LruCache;
4use std::{
5    io::{Read, Seek, SeekFrom},
6    num::NonZeroUsize,
7    sync::Arc,
8};
9
10pub struct AeaStream<S>
11where
12    S: Read + Seek + Unpin,
13{
14    reader: AeaReader<S>,
15    virtual_position: u64,
16    end_position: u64,
17    segment_cache: LruCache<(u32, u32), Arc<[u8]>>,
18
19    segment_index_map: Vec<(u32, u32, u64)>,
20    current_scanned_offset: u64,
21    next_unscanned_cluster_index: u32,
22    total_cluster_count: u32,
23}
24
25impl<S> AeaStream<S>
26where
27    S: Read + Seek + Unpin,
28{
29    pub fn new(mut reader: AeaReader<S>) -> Result<Self> {
30        let end_position = reader
31            .get_decompressed_length()
32            .map_err(std::io::Error::other)?;
33
34        let total_cluster_count = reader.cluster_count()?;
35        let cache_cap = NonZeroUsize::new(128).unwrap();
36
37        Ok(Self {
38            reader,
39            virtual_position: 0,
40            end_position,
41            segment_cache: LruCache::new(cache_cap),
42            segment_index_map: Vec::new(),
43            current_scanned_offset: 0,
44            next_unscanned_cluster_index: 0,
45            total_cluster_count,
46        })
47    }
48
49    fn ensure_index_up_to(&mut self, required_offset: u64) -> Result<()> {
50        while self.current_scanned_offset <= required_offset
51            && self.next_unscanned_cluster_index < self.total_cluster_count
52        {
53            let header = self
54                .reader
55                .get_cluster_header(self.next_unscanned_cluster_index)?;
56
57            for (segment_index, segment_info) in header.segment_info.iter().enumerate() {
58                self.segment_index_map.push((
59                    self.next_unscanned_cluster_index,
60                    segment_index as u32,
61                    self.current_scanned_offset,
62                ));
63                self.current_scanned_offset += segment_info.decompressed_size as u64;
64            }
65
66            self.next_unscanned_cluster_index += 1;
67        }
68        Ok(())
69    }
70
71    pub fn get_data_at_decompressed_range(&mut self, offset: u64, length: u64) -> Result<Vec<u8>> {
72        if length == 0 || offset >= self.end_position {
73            return Ok(Vec::new());
74        }
75
76        let range_end = (offset + length).min(self.end_position);
77        self.ensure_index_up_to(range_end.saturating_sub(1))?;
78
79        let start_index = self
80            .segment_index_map
81            .partition_point(|&(_, _, segment_offset)| segment_offset <= offset)
82            .saturating_sub(1);
83
84        let end_index = self
85            .segment_index_map
86            .partition_point(|&(_, _, segment_offset)| segment_offset < range_end)
87            .saturating_sub(1);
88
89        let mut result_data = Vec::with_capacity(length as usize);
90        for map_index in start_index..=end_index {
91            let (cluster_index, segment_index, segment_global_start) =
92                self.segment_index_map[map_index];
93
94            let segment_data =
95                if let Some(cached) = self.segment_cache.get(&(cluster_index, segment_index)) {
96                    Arc::clone(cached)
97                } else {
98                    let data: Vec<u8> = self.reader.get_segment(cluster_index, segment_index)?;
99                    let shared_data: Arc<[u8]> = Arc::from(data);
100                    self.segment_cache
101                        .put((cluster_index, segment_index), Arc::clone(&shared_data));
102                    shared_data
103                };
104
105            let local_start = if map_index == start_index {
106                (offset - segment_global_start) as usize
107            } else {
108                0
109            };
110
111            let local_end = if map_index == end_index {
112                (range_end - segment_global_start) as usize
113            } else {
114                segment_data.len()
115            };
116
117            result_data.extend_from_slice(&segment_data[local_start..local_end]);
118        }
119
120        Ok(result_data)
121    }
122}
123
124impl<S> Read for AeaStream<S>
125where
126    S: Read + Seek + Unpin,
127{
128    fn read(&mut self, buffer: &mut [u8]) -> std::io::Result<usize> {
129        if buffer.is_empty() || self.virtual_position >= self.end_position {
130            return Ok(0);
131        }
132
133        let offset = self.virtual_position;
134        let length = buffer.len() as u64;
135
136        let data = self
137            .get_data_at_decompressed_range(offset, length)
138            .map_err(std::io::Error::other)?;
139
140        let bytes_read = data.len();
141        buffer[..bytes_read].copy_from_slice(&data);
142        self.virtual_position += bytes_read as u64;
143
144        Ok(bytes_read)
145    }
146}
147
148impl<S> Seek for AeaStream<S>
149where
150    S: Read + Seek + Unpin,
151{
152    fn seek(&mut self, position: SeekFrom) -> std::io::Result<u64> {
153        let new_position = match position {
154            SeekFrom::Start(offset) => offset as i64,
155            SeekFrom::Current(offset) => self.virtual_position as i64 + offset,
156            SeekFrom::End(offset) => self.end_position as i64 + offset,
157        };
158
159        if new_position < 0 {
160            return Err(std::io::Error::new(
161                std::io::ErrorKind::InvalidInput,
162                "invalid seek to a negative position",
163            ));
164        }
165
166        self.virtual_position = new_position as u64;
167        Ok(self.virtual_position)
168    }
169}