Skip to main content

aea_tools/
reader.rs

1use crate::{
2    cluster_header::ClusterHeader,
3    crypto::{
4        aes_aead_decrypt, derive_cluster_header_encryption_key, derive_cluster_key,
5        derive_main_key, derive_segment_key,
6    },
7    dictionary::AeaDictionary,
8    prologue::AeaPrologue,
9    root_header::RootHeader,
10};
11use anyhow::Result;
12use lzfse_rust::LzfseRingDecoder;
13use sha2::{Digest, Sha256};
14use std::{
15    collections::HashMap,
16    io::{Read, Seek},
17};
18
19pub struct AeaReader<S>
20where
21    S: Read + Seek + Unpin,
22{
23    stream: S,
24    start_pos: u64,
25    dictionary: AeaDictionary,
26    runtime_data: RuntimeData,
27}
28
29impl<S> AeaReader<S>
30where
31    S: Read + Seek + Unpin,
32{
33    pub fn new(external_key: &[u8], mut stream: S) -> Result<Self> {
34        let start_pos = stream.stream_position()?;
35
36        Ok(Self {
37            stream,
38            start_pos,
39            dictionary: AeaDictionary::default(),
40            runtime_data: RuntimeData::new(external_key),
41        })
42    }
43
44    fn ensure_prologue_loaded(&mut self) -> Result<()> {
45        if self.runtime_data.prologue.is_some() {
46            return Ok(());
47        }
48        self.stream.seek(std::io::SeekFrom::Start(self.start_pos))?;
49        let prologue = AeaPrologue::decode(&mut self.stream)?;
50        self.dictionary.prologue_range = Some((self.start_pos, prologue.length() as u64));
51        self.runtime_data.prologue = Some(prologue);
52        Ok(())
53    }
54
55    fn prologue(&self) -> Result<&AeaPrologue> {
56        self.runtime_data
57            .prologue
58            .as_ref()
59            .ok_or_else(|| anyhow::anyhow!("Prologue not loaded"))
60    }
61
62    fn prologue_mut(&mut self) -> Result<&mut AeaPrologue> {
63        self.runtime_data
64            .prologue
65            .as_mut()
66            .ok_or_else(|| anyhow::anyhow!("Prologue not loaded"))
67    }
68
69    pub fn get_prologue(&mut self) -> Result<&AeaPrologue> {
70        self.ensure_prologue_loaded()?;
71        self.prologue()
72    }
73
74    pub fn get_main_key(&mut self) -> Result<[u8; 32]> {
75        if let Some(amk) = self.runtime_data.amk {
76            return Ok(amk);
77        }
78
79        self.ensure_prologue_loaded()?;
80        let prologue = self.prologue()?;
81        let amk = derive_main_key(
82            &prologue.salt,
83            &self.runtime_data.external_key,
84            &prologue.profile_id,
85        )?;
86
87        self.runtime_data.amk = Some(amk);
88        Ok(amk)
89    }
90
91    pub fn get_root_header(&mut self) -> Result<&RootHeader> {
92        let amk = self.get_main_key()?;
93        self.ensure_prologue_loaded()?;
94        let prologue = self.prologue_mut()?;
95        let root_header = prologue.get_decrypted_root_header(&amk)?;
96
97        Ok(root_header)
98    }
99
100    fn ensure_cluster_header_loaded(&mut self, cluster_index: u32) -> Result<()> {
101        if self
102            .runtime_data
103            .cluster_headers
104            .contains_key(&cluster_index)
105        {
106            return Ok(());
107        }
108
109        let segments_per_cluster = self.get_root_header()?.segments_per_cluster;
110        if let Some((offset, _length, chek, hmac)) =
111            self.dictionary.cluster_map.get(&cluster_index).cloned()
112        {
113            self.stream.seek(std::io::SeekFrom::Start(offset))?;
114            let cluster_header =
115                ClusterHeader::decode(&mut self.stream, &chek, &hmac, segments_per_cluster)?;
116            self.runtime_data
117                .cluster_headers
118                .insert(cluster_index, cluster_header);
119            return Ok(());
120        }
121
122        let amk = self.get_main_key()?;
123        let ck = derive_cluster_key(&amk, cluster_index)?;
124        self.runtime_data.ck.insert(cluster_index, ck);
125        let chek = derive_cluster_header_encryption_key(&ck);
126
127        if cluster_index == 0 {
128            self.ensure_prologue_loaded()?;
129            let prologue_range = self
130                .dictionary
131                .prologue_range
132                .ok_or_else(|| anyhow::anyhow!("Prologue range not found in dictionary"))?;
133
134            let offset = prologue_range.0 + prologue_range.1;
135            self.stream.seek(std::io::SeekFrom::Start(offset))?;
136
137            let first_cluster_hmac = self.prologue()?.first_cluster_hmac;
138            let cluster_header = ClusterHeader::decode(
139                &mut self.stream,
140                &chek,
141                &first_cluster_hmac,
142                segments_per_cluster,
143            )?;
144
145            self.dictionary.cluster_map.insert(
146                cluster_index,
147                (
148                    offset,
149                    cluster_header.encoded_len() as u64,
150                    chek,
151                    first_cluster_hmac,
152                ),
153            );
154            self.runtime_data
155                .cluster_headers
156                .insert(cluster_index, cluster_header);
157
158            return Ok(());
159        }
160
161        let previous_cluster_header = self.get_cluster_header(cluster_index - 1)?;
162        let hmac = previous_cluster_header.next_cluster_hmac;
163
164        let segment_info = &previous_cluster_header.segment_info;
165        let segment_offset = segment_info
166            .iter()
167            .map(|info| info.compressed_size as u64)
168            .sum::<u64>();
169        let header_offset = self
170            .dictionary
171            .cluster_map
172            .get(&(cluster_index - 1))
173            .map(|(offset, length, _, _)| *offset + *length)
174            .ok_or_else(|| anyhow::anyhow!("Previous cluster header not found in dictionary"))?;
175
176        let offset = header_offset + segment_offset;
177        self.stream.seek(std::io::SeekFrom::Start(offset))?;
178
179        let cluster_header =
180            ClusterHeader::decode(&mut self.stream, &chek, &hmac, segments_per_cluster)?;
181        self.dictionary.cluster_map.insert(
182            cluster_index,
183            (offset, cluster_header.encoded_len() as u64, chek, hmac),
184        );
185        self.runtime_data
186            .cluster_headers
187            .insert(cluster_index, cluster_header);
188
189        Ok(())
190    }
191
192    pub fn get_cluster_header(&mut self, cluster_index: u32) -> Result<&ClusterHeader> {
193        self.ensure_cluster_header_loaded(cluster_index)?;
194        self.runtime_data
195            .cluster_headers
196            .get(&cluster_index)
197            .ok_or_else(|| anyhow::anyhow!("Cluster header not found"))
198    }
199
200    pub fn get_segment(&mut self, cluster_index: u32, segment_index: u32) -> Result<Vec<u8>> {
201        if let Some((offset, length, key, hmac)) = self
202            .dictionary
203            .segment_map
204            .get(&(cluster_index, segment_index))
205            .cloned()
206        {
207            self.stream.seek(std::io::SeekFrom::Start(offset))?;
208            let mut encrypted_segment_data = vec![0u8; length as usize];
209            self.stream.read_exact(&mut encrypted_segment_data)?;
210            let segment_data = aes_aead_decrypt(&key, &encrypted_segment_data, &[], &hmac)?;
211
212            let decoder = &mut self.runtime_data.lzfse_decoder;
213            let decompressed = if segment_data.starts_with(b"bvx2") {
214                let mut out = Vec::new();
215                decoder.decode_bytes(&segment_data, &mut out)?;
216                out
217            } else {
218                segment_data
219            };
220
221            return Ok(decompressed);
222        }
223
224        let (segment_offset, segment_info, segment_hmac) = {
225            let cluster_header = self.get_cluster_header(cluster_index)?;
226            let offset = cluster_header
227                .segment_info
228                .iter()
229                .take(segment_index as usize)
230                .map(|info| info.compressed_size as u64)
231                .sum::<u64>();
232            let segment_info = cluster_header
233                .segment_info
234                .get(segment_index as usize)
235                .ok_or_else(|| {
236                    anyhow::anyhow!(
237                        "Segment index {} out of bounds for cluster {}",
238                        segment_index,
239                        cluster_index
240                    )
241                })?
242                .clone();
243            let segment_hmac = *cluster_header
244                .segment_hmacs
245                .get(segment_index as usize)
246                .ok_or_else(|| {
247                    anyhow::anyhow!(
248                        "Segment index {} out of bounds for cluster {}",
249                        segment_index,
250                        cluster_index
251                    )
252                })?;
253
254            let segment_offset = self
255                .dictionary
256                .cluster_map
257                .get(&cluster_index)
258                .map(|(offset, length, _, _)| *offset + *length)
259                .ok_or_else(|| anyhow::anyhow!("Cluster header not found in dictionary"))?
260                + offset;
261
262            (segment_offset, segment_info, segment_hmac)
263        };
264
265        let ck = self.runtime_data.ck.get(&cluster_index).ok_or_else(|| {
266            anyhow::anyhow!("Cluster key not found for cluster {}", cluster_index)
267        })?;
268        let sk = derive_segment_key(ck, segment_index);
269
270        self.dictionary.segment_map.insert(
271            (cluster_index, segment_index),
272            (
273                segment_offset,
274                segment_info.compressed_size as u64,
275                sk,
276                segment_hmac,
277            ),
278        );
279
280        self.stream.seek(std::io::SeekFrom::Start(segment_offset))?;
281        let mut encrypted_segment_data = vec![0u8; segment_info.compressed_size as usize];
282        self.stream.read_exact(&mut encrypted_segment_data)?;
283        let segment_data = aes_aead_decrypt(&sk, &encrypted_segment_data, &[], &segment_hmac)?;
284
285        let decoder = &mut self.runtime_data.lzfse_decoder;
286        let decompressed = if segment_data.starts_with(b"bvx2") {
287            let mut out = Vec::new();
288            decoder.decode_bytes(&segment_data, &mut out)?;
289            out
290        } else {
291            segment_data
292        };
293
294        let expected_size = segment_info.decompressed_size as usize;
295        let actual_size = decompressed.len();
296        if expected_size != actual_size {
297            return Err(anyhow::anyhow!(
298                "Size mismatch: expected {}, got {}",
299                expected_size,
300                actual_size
301            ));
302        }
303        if actual_size == 0 {
304            return Ok(Vec::new());
305        }
306        let expected_checksum = segment_info.checksum;
307        let actual_checksum = Sha256::digest(&decompressed);
308        if expected_checksum != actual_checksum.as_slice() {
309            return Err(anyhow::anyhow!(
310                "Checksum mismatch: expected {:x?}, got {:x?}",
311                expected_checksum,
312                actual_checksum
313            ));
314        }
315
316        Ok(decompressed)
317    }
318
319    pub fn cluster_count(&mut self) -> Result<u32> {
320        let root_header = self.get_root_header()?;
321        let container_size = u64::from_le_bytes(root_header.container_size);
322        let segment_size = u32::from_le_bytes(root_header.segment_size) as u64;
323        let segments_per_cluster = u32::from_le_bytes(root_header.segments_per_cluster) as u64;
324        let cluster_size = segment_size * segments_per_cluster;
325        let cluster_count = container_size.div_ceil(cluster_size) as u32;
326
327        Ok(cluster_count)
328    }
329
330    fn ensure_all_cluster_headers_loaded(&mut self) -> Result<()> {
331        let cluster_count = self.cluster_count()?;
332        for cluster_index in 0..cluster_count {
333            self.ensure_cluster_header_loaded(cluster_index)?;
334        }
335
336        Ok(())
337    }
338
339    pub fn get_all_cluster_headers(&mut self) -> Result<Vec<&ClusterHeader>> {
340        self.ensure_all_cluster_headers_loaded()?;
341        let cluster_headers = self
342            .runtime_data
343            .cluster_headers
344            .values()
345            .collect::<Vec<_>>();
346
347        Ok(cluster_headers)
348    }
349
350    pub fn get_all_segments_from_cluster(&mut self, cluster_index: u32) -> Result<Vec<Vec<u8>>> {
351        let cluster_header = self.get_cluster_header(cluster_index)?;
352        let segment_count = cluster_header.segment_info.len() as u32;
353        let mut segments = Vec::with_capacity(segment_count as usize);
354        for segment_index in 0..segment_count {
355            let segment = self.get_segment(cluster_index, segment_index)?;
356            segments.push(segment);
357        }
358
359        Ok(segments)
360    }
361
362    pub fn get_decompressed_length(&mut self) -> Result<u64> {
363        let raw_size = self.get_root_header()?.raw_size;
364        let total_length = u64::from_le_bytes(raw_size);
365        Ok(total_length)
366    }
367}
368
369struct RuntimeData {
370    pub external_key: Vec<u8>,
371    pub prologue: Option<AeaPrologue>,
372    pub amk: Option<[u8; 32]>,
373    pub ck: HashMap<u32, [u8; 32]>,
374    pub cluster_headers: HashMap<u32, ClusterHeader>,
375    pub lzfse_decoder: LzfseRingDecoder,
376}
377
378impl RuntimeData {
379    fn new(external_key: &[u8]) -> Self {
380        Self {
381            external_key: external_key.to_vec(),
382            prologue: None,
383            amk: None,
384            ck: HashMap::new(),
385            cluster_headers: HashMap::new(),
386            lzfse_decoder: LzfseRingDecoder::default(),
387        }
388    }
389}