1use crate::reader::AeaReader;
2use anyhow::Result;
3use lru::LruCache;
4use std::{
5 io::{Read, Seek, SeekFrom},
6 num::NonZeroUsize,
7 sync::Arc,
8};
9
10pub struct AeaStream<S>
11where
12 S: Read + Seek + Unpin,
13{
14 reader: AeaReader<S>,
15 virtual_position: u64,
16 end_position: u64,
17 segment_cache: LruCache<(u32, u32), Arc<[u8]>>,
18
19 segment_index_map: Vec<(u32, u32, u64)>,
20 current_scanned_offset: u64,
21 next_unscanned_cluster_index: u32,
22 total_cluster_count: u32,
23}
24
25impl<S> AeaStream<S>
26where
27 S: Read + Seek + Unpin,
28{
29 pub fn new(mut reader: AeaReader<S>) -> Result<Self> {
30 let end_position = reader
31 .get_decompressed_length()
32 .map_err(std::io::Error::other)?;
33
34 let total_cluster_count = reader.cluster_count()?;
35 let cache_cap = NonZeroUsize::new(128).unwrap();
36
37 Ok(Self {
38 reader,
39 virtual_position: 0,
40 end_position,
41 segment_cache: LruCache::new(cache_cap),
42 segment_index_map: Vec::new(),
43 current_scanned_offset: 0,
44 next_unscanned_cluster_index: 0,
45 total_cluster_count,
46 })
47 }
48
49 fn ensure_index_up_to(&mut self, required_offset: u64) -> Result<()> {
50 while self.current_scanned_offset <= required_offset
51 && self.next_unscanned_cluster_index < self.total_cluster_count
52 {
53 let header = self
54 .reader
55 .get_cluster_header(self.next_unscanned_cluster_index)?;
56
57 for (segment_index, segment_info) in header.segment_info.iter().enumerate() {
58 self.segment_index_map.push((
59 self.next_unscanned_cluster_index,
60 segment_index as u32,
61 self.current_scanned_offset,
62 ));
63 self.current_scanned_offset += segment_info.decompressed_size as u64;
64 }
65
66 self.next_unscanned_cluster_index += 1;
67 }
68 Ok(())
69 }
70
71 pub fn get_data_at_decompressed_range(&mut self, offset: u64, length: u64) -> Result<Vec<u8>> {
72 if length == 0 || offset >= self.end_position {
73 return Ok(Vec::new());
74 }
75
76 let range_end = (offset + length).min(self.end_position);
77 self.ensure_index_up_to(range_end.saturating_sub(1))?;
78
79 let start_index = self
80 .segment_index_map
81 .partition_point(|&(_, _, segment_offset)| segment_offset <= offset)
82 .saturating_sub(1);
83
84 let end_index = self
85 .segment_index_map
86 .partition_point(|&(_, _, segment_offset)| segment_offset < range_end)
87 .saturating_sub(1);
88
89 let mut result_data = Vec::with_capacity(length as usize);
90 for map_index in start_index..=end_index {
91 let (cluster_index, segment_index, segment_global_start) =
92 self.segment_index_map[map_index];
93
94 let segment_data =
95 if let Some(cached) = self.segment_cache.get(&(cluster_index, segment_index)) {
96 Arc::clone(cached)
97 } else {
98 let data: Vec<u8> = self.reader.get_segment(cluster_index, segment_index)?;
99 let shared_data: Arc<[u8]> = Arc::from(data);
100 self.segment_cache
101 .put((cluster_index, segment_index), Arc::clone(&shared_data));
102 shared_data
103 };
104
105 let local_start = if map_index == start_index {
106 (offset - segment_global_start) as usize
107 } else {
108 0
109 };
110
111 let local_end = if map_index == end_index {
112 (range_end - segment_global_start) as usize
113 } else {
114 segment_data.len()
115 };
116
117 result_data.extend_from_slice(&segment_data[local_start..local_end]);
118 }
119
120 Ok(result_data)
121 }
122}
123
124impl<S> Read for AeaStream<S>
125where
126 S: Read + Seek + Unpin,
127{
128 fn read(&mut self, buffer: &mut [u8]) -> std::io::Result<usize> {
129 if buffer.is_empty() || self.virtual_position >= self.end_position {
130 return Ok(0);
131 }
132
133 let offset = self.virtual_position;
134 let length = buffer.len() as u64;
135
136 let data = self
137 .get_data_at_decompressed_range(offset, length)
138 .map_err(std::io::Error::other)?;
139
140 let bytes_read = data.len();
141 buffer[..bytes_read].copy_from_slice(&data);
142 self.virtual_position += bytes_read as u64;
143
144 Ok(bytes_read)
145 }
146}
147
148impl<S> Seek for AeaStream<S>
149where
150 S: Read + Seek + Unpin,
151{
152 fn seek(&mut self, position: SeekFrom) -> std::io::Result<u64> {
153 let new_position = match position {
154 SeekFrom::Start(offset) => offset as i64,
155 SeekFrom::Current(offset) => self.virtual_position as i64 + offset,
156 SeekFrom::End(offset) => self.end_position as i64 + offset,
157 };
158
159 if new_position < 0 {
160 return Err(std::io::Error::new(
161 std::io::ErrorKind::InvalidInput,
162 "invalid seek to a negative position",
163 ));
164 }
165
166 self.virtual_position = new_position as u64;
167 Ok(self.virtual_position)
168 }
169}