Skip to main content

aea_tools/
stream.rs

1use crate::reader::AeaReader;
2use anyhow::Result;
3use std::{
4    collections::HashMap,
5    io::{Read, Seek, SeekFrom},
6    sync::Arc,
7};
8
9pub struct AeaStream<S>
10where
11    S: Read + Seek + Unpin,
12{
13    reader: AeaReader<S>,
14    virtual_position: u64,
15    end_position: u64,
16    segment_cache: HashMap<(u32, u32), Arc<[u8]>>,
17
18    segment_index_map: Vec<(u32, u32, u64)>,
19    current_scanned_offset: u64,
20    next_unscanned_cluster_index: u32,
21    total_cluster_count: u32,
22}
23
24impl<S> AeaStream<S>
25where
26    S: Read + Seek + Unpin,
27{
28    pub fn new(mut reader: AeaReader<S>) -> std::io::Result<Self> {
29        let end_position = reader
30            .get_decompressed_length()
31            .map_err(std::io::Error::other)?;
32
33        let total_cluster_count = reader.cluster_count().map_err(std::io::Error::other)?;
34
35        Ok(Self {
36            reader,
37            virtual_position: 0,
38            end_position,
39            segment_cache: HashMap::new(),
40            segment_index_map: Vec::new(),
41            current_scanned_offset: 0,
42            next_unscanned_cluster_index: 0,
43            total_cluster_count,
44        })
45    }
46
47    fn ensure_index_up_to(&mut self, required_offset: u64) -> Result<()> {
48        while self.current_scanned_offset <= required_offset
49            && self.next_unscanned_cluster_index < self.total_cluster_count
50        {
51            let header = self
52                .reader
53                .get_cluster_header(self.next_unscanned_cluster_index)?;
54
55            for (segment_index, segment_info) in header.segment_info.iter().enumerate() {
56                self.segment_index_map.push((
57                    self.next_unscanned_cluster_index,
58                    segment_index as u32,
59                    self.current_scanned_offset,
60                ));
61                self.current_scanned_offset += segment_info.decompressed_size as u64;
62            }
63
64            self.next_unscanned_cluster_index += 1;
65        }
66        Ok(())
67    }
68
69    pub fn get_data_at_decompressed_range(&mut self, offset: u64, length: u64) -> Result<Vec<u8>> {
70        if length == 0 || offset >= self.end_position {
71            return Ok(Vec::new());
72        }
73
74        let range_end = (offset + length).min(self.end_position);
75        self.ensure_index_up_to(range_end.saturating_sub(1))?;
76
77        let start_index = self
78            .segment_index_map
79            .partition_point(|&(_, _, segment_offset)| segment_offset <= offset)
80            .saturating_sub(1);
81
82        let end_index = self
83            .segment_index_map
84            .partition_point(|&(_, _, segment_offset)| segment_offset < range_end)
85            .saturating_sub(1);
86
87        let mut result_data = Vec::with_capacity(length as usize);
88        for map_index in start_index..=end_index {
89            let (cluster_index, segment_index, segment_global_start) =
90                self.segment_index_map[map_index];
91
92            let segment_data =
93                if let Some(cached) = self.segment_cache.get(&(cluster_index, segment_index)) {
94                    Arc::clone(cached)
95                } else {
96                    let data: Vec<u8> = self.reader.get_segment(cluster_index, segment_index)?;
97                    let shared_data: Arc<[u8]> = Arc::from(data);
98                    self.segment_cache
99                        .insert((cluster_index, segment_index), Arc::clone(&shared_data));
100                    shared_data
101                };
102
103            let local_start = if map_index == start_index {
104                (offset - segment_global_start) as usize
105            } else {
106                0
107            };
108
109            let local_end = if map_index == end_index {
110                (range_end - segment_global_start) as usize
111            } else {
112                segment_data.len()
113            };
114
115            result_data.extend_from_slice(&segment_data[local_start..local_end]);
116        }
117
118        Ok(result_data)
119    }
120}
121
122impl<S> Read for AeaStream<S>
123where
124    S: Read + Seek + Unpin,
125{
126    fn read(&mut self, buffer: &mut [u8]) -> std::io::Result<usize> {
127        if buffer.is_empty() || self.virtual_position >= self.end_position {
128            return Ok(0);
129        }
130
131        let offset = self.virtual_position;
132        let length = buffer.len() as u64;
133
134        let data = self
135            .get_data_at_decompressed_range(offset, length)
136            .map_err(std::io::Error::other)?;
137
138        let bytes_read = data.len();
139        buffer[..bytes_read].copy_from_slice(&data);
140        self.virtual_position += bytes_read as u64;
141
142        Ok(bytes_read)
143    }
144}
145
146impl<S> Seek for AeaStream<S>
147where
148    S: Read + Seek + Unpin,
149{
150    fn seek(&mut self, position: SeekFrom) -> std::io::Result<u64> {
151        let new_position = match position {
152            SeekFrom::Start(offset) => offset as i64,
153            SeekFrom::Current(offset) => self.virtual_position as i64 + offset,
154            SeekFrom::End(offset) => self.end_position as i64 + offset,
155        };
156
157        if new_position < 0 {
158            return Err(std::io::Error::new(
159                std::io::ErrorKind::InvalidInput,
160                "invalid seek to a negative position",
161            ));
162        }
163
164        self.virtual_position = new_position as u64;
165        Ok(self.virtual_position)
166    }
167}