1use crate::reader::AeaReader;
2use anyhow::Result;
3use std::{
4 collections::HashMap,
5 io::{Read, Seek, SeekFrom},
6 sync::Arc,
7};
8
9pub struct AeaStream<S>
10where
11 S: Read + Seek + Unpin,
12{
13 reader: AeaReader<S>,
14 virtual_position: u64,
15 end_position: u64,
16 segment_cache: HashMap<(u32, u32), Arc<[u8]>>,
17
18 segment_index_map: Vec<(u32, u32, u64)>,
19 current_scanned_offset: u64,
20 next_unscanned_cluster_index: u32,
21 total_cluster_count: u32,
22}
23
24impl<S> AeaStream<S>
25where
26 S: Read + Seek + Unpin,
27{
28 pub fn new(mut reader: AeaReader<S>) -> std::io::Result<Self> {
29 let end_position = reader
30 .get_decompressed_length()
31 .map_err(std::io::Error::other)?;
32
33 let total_cluster_count = reader.cluster_count().map_err(std::io::Error::other)?;
34
35 Ok(Self {
36 reader,
37 virtual_position: 0,
38 end_position,
39 segment_cache: HashMap::new(),
40 segment_index_map: Vec::new(),
41 current_scanned_offset: 0,
42 next_unscanned_cluster_index: 0,
43 total_cluster_count,
44 })
45 }
46
47 fn ensure_index_up_to(&mut self, required_offset: u64) -> Result<()> {
48 while self.current_scanned_offset <= required_offset
49 && self.next_unscanned_cluster_index < self.total_cluster_count
50 {
51 let header = self
52 .reader
53 .get_cluster_header(self.next_unscanned_cluster_index)?;
54
55 for (segment_index, segment_info) in header.segment_info.iter().enumerate() {
56 self.segment_index_map.push((
57 self.next_unscanned_cluster_index,
58 segment_index as u32,
59 self.current_scanned_offset,
60 ));
61 self.current_scanned_offset += segment_info.decompressed_size as u64;
62 }
63
64 self.next_unscanned_cluster_index += 1;
65 }
66 Ok(())
67 }
68
69 pub fn get_data_at_decompressed_range(&mut self, offset: u64, length: u64) -> Result<Vec<u8>> {
70 if length == 0 || offset >= self.end_position {
71 return Ok(Vec::new());
72 }
73
74 let range_end = (offset + length).min(self.end_position);
75 self.ensure_index_up_to(range_end.saturating_sub(1))?;
76
77 let start_index = self
78 .segment_index_map
79 .partition_point(|&(_, _, segment_offset)| segment_offset <= offset)
80 .saturating_sub(1);
81
82 let end_index = self
83 .segment_index_map
84 .partition_point(|&(_, _, segment_offset)| segment_offset < range_end)
85 .saturating_sub(1);
86
87 let mut result_data = Vec::with_capacity(length as usize);
88 for map_index in start_index..=end_index {
89 let (cluster_index, segment_index, segment_global_start) =
90 self.segment_index_map[map_index];
91
92 let segment_data =
93 if let Some(cached) = self.segment_cache.get(&(cluster_index, segment_index)) {
94 Arc::clone(cached)
95 } else {
96 let data: Vec<u8> = self.reader.get_segment(cluster_index, segment_index)?;
97 let shared_data: Arc<[u8]> = Arc::from(data);
98 self.segment_cache
99 .insert((cluster_index, segment_index), Arc::clone(&shared_data));
100 shared_data
101 };
102
103 let local_start = if map_index == start_index {
104 (offset - segment_global_start) as usize
105 } else {
106 0
107 };
108
109 let local_end = if map_index == end_index {
110 (range_end - segment_global_start) as usize
111 } else {
112 segment_data.len()
113 };
114
115 result_data.extend_from_slice(&segment_data[local_start..local_end]);
116 }
117
118 Ok(result_data)
119 }
120}
121
122impl<S> Read for AeaStream<S>
123where
124 S: Read + Seek + Unpin,
125{
126 fn read(&mut self, buffer: &mut [u8]) -> std::io::Result<usize> {
127 if buffer.is_empty() || self.virtual_position >= self.end_position {
128 return Ok(0);
129 }
130
131 let offset = self.virtual_position;
132 let length = buffer.len() as u64;
133
134 let data = self
135 .get_data_at_decompressed_range(offset, length)
136 .map_err(std::io::Error::other)?;
137
138 let bytes_read = data.len();
139 buffer[..bytes_read].copy_from_slice(&data);
140 self.virtual_position += bytes_read as u64;
141
142 Ok(bytes_read)
143 }
144}
145
146impl<S> Seek for AeaStream<S>
147where
148 S: Read + Seek + Unpin,
149{
150 fn seek(&mut self, position: SeekFrom) -> std::io::Result<u64> {
151 let new_position = match position {
152 SeekFrom::Start(offset) => offset as i64,
153 SeekFrom::Current(offset) => self.virtual_position as i64 + offset,
154 SeekFrom::End(offset) => self.end_position as i64 + offset,
155 };
156
157 if new_position < 0 {
158 return Err(std::io::Error::new(
159 std::io::ErrorKind::InvalidInput,
160 "invalid seek to a negative position",
161 ));
162 }
163
164 self.virtual_position = new_position as u64;
165 Ok(self.virtual_position)
166 }
167}