use crate::decrypt::{DecryptKeys, decrypt_sectors};
use crate::error::Result;
use super::SectorSource;
pub struct DecryptingSectorSource<S: SectorSource> {
inner: S,
keys: DecryptKeys,
unit_key_idx: usize,
}
impl<S: SectorSource> DecryptingSectorSource<S> {
pub fn new(inner: S, keys: DecryptKeys) -> Self {
Self {
inner,
keys,
unit_key_idx: 0,
}
}
pub fn with_unit_key_idx(mut self, idx: usize) -> Self {
self.unit_key_idx = idx;
self
}
pub fn set_keys(&mut self, keys: DecryptKeys) {
self.keys = keys;
}
pub fn inner(&self) -> &S {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}
pub fn into_inner(self) -> S {
self.inner
}
}
impl<S: SectorSource> SectorSource for DecryptingSectorSource<S> {
fn capacity_sectors(&self) -> u32 {
self.inner.capacity_sectors()
}
fn read_sectors(
&mut self,
lba: u32,
count: u16,
buf: &mut [u8],
recovery: bool,
) -> Result<usize> {
let n = self.inner.read_sectors(lba, count, buf, recovery)?;
decrypt_sectors(&mut buf[..n], &self.keys, self.unit_key_idx)?;
Ok(n)
}
fn set_speed(&mut self, kbs: u16) {
self.inner.set_speed(kbs)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
struct PatternedSource {
capacity: u32,
}
impl PatternedSource {
fn fill(lba: u32, count: u16, buf: &mut [u8]) {
let bytes = count as usize * 2048;
for (i, slot) in buf[..bytes].iter_mut().enumerate() {
let abs = lba as u64 * 2048 + i as u64;
*slot = ((abs.wrapping_mul(2654435761) >> 16) & 0xff) as u8;
}
}
}
impl SectorSource for PatternedSource {
fn capacity_sectors(&self) -> u32 {
self.capacity
}
fn read_sectors(
&mut self,
lba: u32,
count: u16,
buf: &mut [u8],
_recovery: bool,
) -> Result<usize> {
Self::fill(lba, count, buf);
Ok(count as usize * 2048)
}
}
#[test]
fn passthrough_with_no_keys() {
let src = PatternedSource { capacity: 16 };
let mut wrapped = DecryptingSectorSource::new(src, DecryptKeys::None);
assert_eq!(wrapped.capacity_sectors(), 16);
let mut got = vec![0u8; 4 * 2048];
let n = wrapped.read_sectors(3, 4, &mut got, false).unwrap();
assert_eq!(n, 4 * 2048);
let mut expected = vec![0u8; 4 * 2048];
PatternedSource::fill(3, 4, &mut expected);
assert_eq!(got, expected);
}
#[test]
fn passthrough_set_speed_delegates() {
struct SpeedRecorder {
last: Option<u16>,
}
impl SectorSource for SpeedRecorder {
fn capacity_sectors(&self) -> u32 {
0
}
fn read_sectors(
&mut self,
_lba: u32,
_count: u16,
_buf: &mut [u8],
_recovery: bool,
) -> Result<usize> {
Ok(0)
}
fn set_speed(&mut self, kbs: u16) {
self.last = Some(kbs);
}
}
let mut wrapped =
DecryptingSectorSource::new(SpeedRecorder { last: None }, DecryptKeys::None);
wrapped.set_speed(7200);
assert_eq!(wrapped.inner().last, Some(7200));
}
}