use crate::aacs;
use crate::css;
use rayon::prelude::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
const PARALLEL_MIN_UNITS: usize = 8;
pub const MAX_THREADS: usize = 64;
static DECRYPT_THREADS: AtomicUsize = AtomicUsize::new(0);
static DECRYPT_POOL: RwLock<Option<Arc<rayon::ThreadPool>>> = RwLock::new(None);
pub fn set_decrypt_threads(n: usize) {
let clamped = n.min(MAX_THREADS);
DECRYPT_THREADS.store(clamped, Ordering::Relaxed);
if let Ok(mut guard) = DECRYPT_POOL.write() {
*guard = None;
}
}
fn decrypt_pool() -> Arc<rayon::ThreadPool> {
if let Ok(guard) = DECRYPT_POOL.read() {
if let Some(pool) = guard.as_ref() {
return Arc::clone(pool);
}
}
let mut guard = DECRYPT_POOL.write().expect("DECRYPT_POOL RwLock poisoned");
if let Some(pool) = guard.as_ref() {
return Arc::clone(pool);
}
let n = decrypt_threads();
let pool = Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(n)
.thread_name(|i| format!("freemkv-decrypt-{i}"))
.build()
.expect("rayon decrypt pool build failed"),
);
*guard = Some(Arc::clone(&pool));
pool
}
pub fn decrypt_threads() -> usize {
let explicit = DECRYPT_THREADS.load(Ordering::Relaxed);
if explicit > 0 {
return explicit;
}
let env = std::env::var("FREEMKV_THREADS")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(0);
if env > 0 {
return env.min(MAX_THREADS);
}
let cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(2);
cores.clamp(1, MAX_THREADS)
}
#[derive(Clone)]
pub enum DecryptKeys {
None,
Aacs {
unit_keys: Vec<(u32, [u8; 16])>,
read_data_key: Option<[u8; 16]>,
},
Css { title_key: [u8; 5] },
}
impl DecryptKeys {
pub fn is_encrypted(&self) -> bool {
!matches!(self, DecryptKeys::None)
}
}
pub fn decrypt_sectors(
buf: &mut [u8],
keys: &DecryptKeys,
unit_key_idx: usize,
) -> Result<(), crate::error::Error> {
match keys {
DecryptKeys::None => {}
DecryptKeys::Aacs {
unit_keys,
read_data_key,
} => {
let uk = match unit_keys.get(unit_key_idx) {
Some((_, k)) => *k,
None => {
return Err(crate::error::Error::DecryptFailed);
}
};
let rdk: Option<[u8; 16]> = *read_data_key;
let unit_len = aacs::ALIGNED_UNIT_LEN;
let nthreads = decrypt_threads();
let chunks: Vec<&mut [u8]> = buf.chunks_mut(unit_len).collect();
let nunits = chunks.len();
let decrypt_one = |chunk: &mut [u8]| {
if chunk.len() == unit_len && aacs::is_unit_encrypted(chunk) {
let original: Vec<u8> = chunk.to_vec();
if !aacs::decrypt_unit_full(chunk, &uk, rdk.as_ref()) {
chunk.copy_from_slice(&original);
}
}
};
if nthreads <= 1 || nunits < PARALLEL_MIN_UNITS {
for chunk in chunks {
decrypt_one(chunk);
}
} else {
decrypt_pool().install(|| {
chunks.into_par_iter().for_each(|chunk| {
decrypt_one(chunk);
});
});
}
}
DecryptKeys::Css { title_key } => {
for chunk in buf.chunks_mut(2048) {
css::lfsr::descramble_sector(title_key, chunk);
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn nav_file_unit_survives_decrypt_attempt() {
let mut unit = vec![0u8; aacs::ALIGNED_UNIT_LEN];
unit[0] = b'M';
unit[1] = b'P';
unit[2] = b'L';
unit[3] = b'S';
for (i, b) in unit.iter_mut().enumerate().skip(4) {
*b = (i as u8).wrapping_mul(31);
}
let snapshot = unit.clone();
let keys = DecryptKeys::Aacs {
unit_keys: vec![(0, [0xAB; 16])],
read_data_key: None,
};
decrypt_sectors(&mut unit, &keys, 0).unwrap();
assert_eq!(
unit, snapshot,
"non-m2ts unit must be restored after failed decrypt"
);
}
}