1use crate::{
2 cluster_header::ClusterHeader,
3 crypto::{
4 aes_aead_decrypt, derive_cluster_header_encryption_key, derive_cluster_key,
5 derive_main_key, derive_segment_key,
6 },
7 dictionary::AeaDictionary,
8 prologue::AeaPrologue,
9 root_header::RootHeader,
10};
11use anyhow::Result;
12use lzfse_rust::LzfseRingDecoder;
13use sha2::{Digest, Sha256};
14use std::{
15 collections::HashMap,
16 io::{Read, Seek},
17};
18
19pub struct AeaReader<S>
20where
21 S: Read + Seek + Unpin,
22{
23 stream: S,
24 start_pos: u64,
25 dictionary: AeaDictionary,
26 runtime_data: RuntimeData,
27}
28
29impl<S> AeaReader<S>
30where
31 S: Read + Seek + Unpin,
32{
33 pub fn new(external_key: &[u8], mut stream: S) -> Result<Self> {
34 let start_pos = stream.stream_position()?;
35
36 Ok(Self {
37 stream,
38 start_pos,
39 dictionary: AeaDictionary::default(),
40 runtime_data: RuntimeData::new(external_key),
41 })
42 }
43
44 fn ensure_prologue_loaded(&mut self) -> Result<()> {
45 if self.runtime_data.prologue.is_some() {
46 return Ok(());
47 }
48 self.stream.seek(std::io::SeekFrom::Start(self.start_pos))?;
49 let prologue = AeaPrologue::decode(&mut self.stream)?;
50 self.dictionary.prologue_range = Some((self.start_pos, prologue.length() as u64));
51 self.runtime_data.prologue = Some(prologue);
52 Ok(())
53 }
54
55 fn prologue(&self) -> Result<&AeaPrologue> {
56 self.runtime_data
57 .prologue
58 .as_ref()
59 .ok_or_else(|| anyhow::anyhow!("Prologue not loaded"))
60 }
61
62 fn prologue_mut(&mut self) -> Result<&mut AeaPrologue> {
63 self.runtime_data
64 .prologue
65 .as_mut()
66 .ok_or_else(|| anyhow::anyhow!("Prologue not loaded"))
67 }
68
69 pub fn get_prologue(&mut self) -> Result<&AeaPrologue> {
70 self.ensure_prologue_loaded()?;
71 self.prologue()
72 }
73
74 pub fn get_main_key(&mut self) -> Result<[u8; 32]> {
75 if let Some(amk) = self.runtime_data.amk {
76 return Ok(amk);
77 }
78
79 self.ensure_prologue_loaded()?;
80 let prologue = self.prologue()?;
81 let amk = derive_main_key(
82 &prologue.salt,
83 &self.runtime_data.external_key,
84 &prologue.profile_id,
85 )?;
86
87 self.runtime_data.amk = Some(amk);
88 Ok(amk)
89 }
90
91 pub fn get_root_header(&mut self) -> Result<&RootHeader> {
92 let amk = self.get_main_key()?;
93 self.ensure_prologue_loaded()?;
94 let prologue = self.prologue_mut()?;
95 let root_header = prologue.get_decrypted_root_header(&amk)?;
96
97 Ok(root_header)
98 }
99
100 fn ensure_cluster_header_loaded(&mut self, cluster_index: u32) -> Result<()> {
101 if self
102 .runtime_data
103 .cluster_headers
104 .contains_key(&cluster_index)
105 {
106 return Ok(());
107 }
108
109 let segments_per_cluster = self.get_root_header()?.segments_per_cluster;
110 if let Some((offset, _length, chek, hmac)) =
111 self.dictionary.cluster_map.get(&cluster_index).cloned()
112 {
113 self.stream.seek(std::io::SeekFrom::Start(offset))?;
114 let cluster_header =
115 ClusterHeader::decode(&mut self.stream, &chek, &hmac, segments_per_cluster)?;
116 self.runtime_data
117 .cluster_headers
118 .insert(cluster_index, cluster_header);
119 return Ok(());
120 }
121
122 let amk = self.get_main_key()?;
123 let ck = derive_cluster_key(&amk, cluster_index)?;
124 self.runtime_data.ck.insert(cluster_index, ck);
125 let chek = derive_cluster_header_encryption_key(&ck);
126
127 if cluster_index == 0 {
128 self.ensure_prologue_loaded()?;
129 let prologue_range = self
130 .dictionary
131 .prologue_range
132 .ok_or_else(|| anyhow::anyhow!("Prologue range not found in dictionary"))?;
133
134 let offset = prologue_range.0 + prologue_range.1;
135 self.stream.seek(std::io::SeekFrom::Start(offset))?;
136
137 let first_cluster_hmac = self.prologue()?.first_cluster_hmac;
138 let cluster_header = ClusterHeader::decode(
139 &mut self.stream,
140 &chek,
141 &first_cluster_hmac,
142 segments_per_cluster,
143 )?;
144
145 self.dictionary.cluster_map.insert(
146 cluster_index,
147 (
148 offset,
149 cluster_header.encoded_len() as u64,
150 chek,
151 first_cluster_hmac,
152 ),
153 );
154 self.runtime_data
155 .cluster_headers
156 .insert(cluster_index, cluster_header);
157
158 return Ok(());
159 }
160
161 let previous_cluster_header = self.get_cluster_header(cluster_index - 1)?;
162 let hmac = previous_cluster_header.next_cluster_hmac;
163
164 let segment_info = &previous_cluster_header.segment_info;
165 let segment_offset = segment_info
166 .iter()
167 .map(|info| info.compressed_size as u64)
168 .sum::<u64>();
169 let header_offset = self
170 .dictionary
171 .cluster_map
172 .get(&(cluster_index - 1))
173 .map(|(offset, length, _, _)| *offset + *length)
174 .ok_or_else(|| anyhow::anyhow!("Previous cluster header not found in dictionary"))?;
175
176 let offset = header_offset + segment_offset;
177 self.stream.seek(std::io::SeekFrom::Start(offset))?;
178
179 let cluster_header =
180 ClusterHeader::decode(&mut self.stream, &chek, &hmac, segments_per_cluster)?;
181 self.dictionary.cluster_map.insert(
182 cluster_index,
183 (offset, cluster_header.encoded_len() as u64, chek, hmac),
184 );
185 self.runtime_data
186 .cluster_headers
187 .insert(cluster_index, cluster_header);
188
189 Ok(())
190 }
191
192 pub fn get_cluster_header(&mut self, cluster_index: u32) -> Result<&ClusterHeader> {
193 self.ensure_cluster_header_loaded(cluster_index)?;
194 self.runtime_data
195 .cluster_headers
196 .get(&cluster_index)
197 .ok_or_else(|| anyhow::anyhow!("Cluster header not found"))
198 }
199
200 pub fn get_segment(&mut self, cluster_index: u32, segment_index: u32) -> Result<Vec<u8>> {
201 if let Some((offset, length, key, hmac)) = self
202 .dictionary
203 .segment_map
204 .get(&(cluster_index, segment_index))
205 .cloned()
206 {
207 self.stream.seek(std::io::SeekFrom::Start(offset))?;
208 let mut encrypted_segment_data = vec![0u8; length as usize];
209 self.stream.read_exact(&mut encrypted_segment_data)?;
210 let segment_data = aes_aead_decrypt(&key, &encrypted_segment_data, &[], &hmac)?;
211
212 let decoder = &mut self.runtime_data.lzfse_decoder;
213 let decompressed = if segment_data.starts_with(b"bvx2") {
214 let mut out = Vec::new();
215 decoder.decode_bytes(&segment_data, &mut out)?;
216 out
217 } else {
218 segment_data
219 };
220
221 return Ok(decompressed);
222 }
223
224 let (segment_offset, segment_info, segment_hmac) = {
225 let cluster_header = self.get_cluster_header(cluster_index)?;
226 let offset = cluster_header
227 .segment_info
228 .iter()
229 .take(segment_index as usize)
230 .map(|info| info.compressed_size as u64)
231 .sum::<u64>();
232 let segment_info = cluster_header
233 .segment_info
234 .get(segment_index as usize)
235 .ok_or_else(|| {
236 anyhow::anyhow!(
237 "Segment index {} out of bounds for cluster {}",
238 segment_index,
239 cluster_index
240 )
241 })?
242 .clone();
243 let segment_hmac = *cluster_header
244 .segment_hmacs
245 .get(segment_index as usize)
246 .ok_or_else(|| {
247 anyhow::anyhow!(
248 "Segment index {} out of bounds for cluster {}",
249 segment_index,
250 cluster_index
251 )
252 })?;
253
254 let segment_offset = self
255 .dictionary
256 .cluster_map
257 .get(&cluster_index)
258 .map(|(offset, length, _, _)| *offset + *length)
259 .ok_or_else(|| anyhow::anyhow!("Cluster header not found in dictionary"))?
260 + offset;
261
262 (segment_offset, segment_info, segment_hmac)
263 };
264
265 let ck = self.runtime_data.ck.get(&cluster_index).ok_or_else(|| {
266 anyhow::anyhow!("Cluster key not found for cluster {}", cluster_index)
267 })?;
268 let sk = derive_segment_key(ck, segment_index);
269
270 self.dictionary.segment_map.insert(
271 (cluster_index, segment_index),
272 (
273 segment_offset,
274 segment_info.compressed_size as u64,
275 sk,
276 segment_hmac,
277 ),
278 );
279
280 self.stream.seek(std::io::SeekFrom::Start(segment_offset))?;
281 let mut encrypted_segment_data = vec![0u8; segment_info.compressed_size as usize];
282 self.stream.read_exact(&mut encrypted_segment_data)?;
283 let segment_data = aes_aead_decrypt(&sk, &encrypted_segment_data, &[], &segment_hmac)?;
284
285 let decoder = &mut self.runtime_data.lzfse_decoder;
286 let decompressed = if segment_data.starts_with(b"bvx2") {
287 let mut out = Vec::new();
288 decoder.decode_bytes(&segment_data, &mut out)?;
289 out
290 } else {
291 segment_data
292 };
293
294 let expected_size = segment_info.decompressed_size as usize;
295 let actual_size = decompressed.len();
296 if expected_size != actual_size {
297 return Err(anyhow::anyhow!(
298 "Size mismatch: expected {}, got {}",
299 expected_size,
300 actual_size
301 ));
302 }
303 if actual_size == 0 {
304 return Ok(Vec::new());
305 }
306 let expected_checksum = segment_info.checksum;
307 let actual_checksum = Sha256::digest(&decompressed);
308 if expected_checksum != actual_checksum.as_slice() {
309 return Err(anyhow::anyhow!(
310 "Checksum mismatch: expected {:x?}, got {:x?}",
311 expected_checksum,
312 actual_checksum
313 ));
314 }
315
316 Ok(decompressed)
317 }
318
319 pub fn cluster_count(&mut self) -> Result<u32> {
320 let root_header = self.get_root_header()?;
321 let container_size = u64::from_le_bytes(root_header.container_size);
322 let segment_size = u32::from_le_bytes(root_header.segment_size) as u64;
323 let segments_per_cluster = u32::from_le_bytes(root_header.segments_per_cluster) as u64;
324 let cluster_size = segment_size * segments_per_cluster;
325 let cluster_count = container_size.div_ceil(cluster_size) as u32;
326
327 Ok(cluster_count)
328 }
329
330 fn ensure_all_cluster_headers_loaded(&mut self) -> Result<()> {
331 let cluster_count = self.cluster_count()?;
332 for cluster_index in 0..cluster_count {
333 self.ensure_cluster_header_loaded(cluster_index)?;
334 }
335
336 Ok(())
337 }
338
339 pub fn get_all_cluster_headers(&mut self) -> Result<Vec<&ClusterHeader>> {
340 self.ensure_all_cluster_headers_loaded()?;
341 let cluster_headers = self
342 .runtime_data
343 .cluster_headers
344 .values()
345 .collect::<Vec<_>>();
346
347 Ok(cluster_headers)
348 }
349
350 pub fn get_all_segments_from_cluster(&mut self, cluster_index: u32) -> Result<Vec<Vec<u8>>> {
351 let cluster_header = self.get_cluster_header(cluster_index)?;
352 let segment_count = cluster_header.segment_info.len() as u32;
353 let mut segments = Vec::with_capacity(segment_count as usize);
354 for segment_index in 0..segment_count {
355 let segment = self.get_segment(cluster_index, segment_index)?;
356 segments.push(segment);
357 }
358
359 Ok(segments)
360 }
361
362 pub fn get_decompressed_length(&mut self) -> Result<u64> {
363 let raw_size = self.get_root_header()?.raw_size;
364 let total_length = u64::from_le_bytes(raw_size);
365 Ok(total_length)
366 }
367}
368
369struct RuntimeData {
370 pub external_key: Vec<u8>,
371 pub prologue: Option<AeaPrologue>,
372 pub amk: Option<[u8; 32]>,
373 pub ck: HashMap<u32, [u8; 32]>,
374 pub cluster_headers: HashMap<u32, ClusterHeader>,
375 pub lzfse_decoder: LzfseRingDecoder,
376}
377
378impl RuntimeData {
379 fn new(external_key: &[u8]) -> Self {
380 Self {
381 external_key: external_key.to_vec(),
382 prologue: None,
383 amk: None,
384 ck: HashMap::new(),
385 cluster_headers: HashMap::new(),
386 lzfse_decoder: LzfseRingDecoder::default(),
387 }
388 }
389}