use std::{
cell::OnceCell,
io::{self, BufRead, Read, Seek, SeekFrom},
ops::Deref,
str::FromStr,
};
use blowfish::{Blowfish, cipher::BlockDecryptMut, cipher::KeyIvInit};
use cbc::cipher::block_padding::NoPadding;
use md5::{Digest, Md5};
use crate::{
audio_file::ReadSeek,
error::{Error, Result},
protocol::media::Cipher,
track::{Track, TrackId},
};
pub struct Decrypt<R>
where
R: ReadSeek,
{
file: R,
file_size: Option<u64>,
key: Key,
buffer: [u8; CBC_BLOCK_SIZE],
buffer_len: usize,
pos: u64,
block: Option<u64>,
}
pub const KEY_LENGTH: usize = 16;
pub type RawKey = [u8; KEY_LENGTH];
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct Key(RawKey);
impl FromStr for Key {
type Err = Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let len = s.len();
if len != KEY_LENGTH {
return Err(Error::out_of_range(format!(
"key length is {len} but should be {KEY_LENGTH}",
)));
}
let bytes = s.as_bytes();
let mut key = [0; KEY_LENGTH];
key.copy_from_slice(bytes);
Ok(Self(key))
}
}
impl Deref for Key {
type Target = RawKey;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
const CBC_BF_IV: &[u8; 8] = b"\x00\x01\x02\x03\x04\x05\x06\x07";
const CBC_BLOCK_SIZE: usize = 2 * 1024;
const CBC_STRIPE_COUNT: usize = 3;
const SUPPORTED_CIPHERS: [Cipher; 1] = [Cipher::BF_CBC_STRIPE];
thread_local! {
static BF_SECRET: OnceCell<Key> = const { OnceCell::new() };
}
pub fn set_bf_secret(secret: Key) -> Result<()> {
BF_SECRET.with(|cell| {
cell.set(secret)
.map_err(|_| Error::unimplemented("decryption key already set"))
})
}
fn bf_secret() -> Result<Key> {
BF_SECRET.with(|cell| {
cell.get()
.copied()
.ok_or_else(|| Error::permission_denied("decryption key not set"))
})
}
impl<R> Decrypt<R>
where
R: ReadSeek,
{
pub fn new(track: &Track, file: R) -> Result<Self>
where
R: ReadSeek,
{
if !track.is_encrypted() {
return Err(Error::invalid_argument(format!("{track} is not encrypted")));
}
if !SUPPORTED_CIPHERS.contains(&track.cipher()) {
return Err(Error::unimplemented(format!(
"unsupported encryption algorithm {}",
track.cipher()
)));
}
let salt = bf_secret()?;
let key = Self::key_for_track_id(track.id(), &salt);
Ok(Self {
file,
file_size: track.file_size(),
key,
buffer: [0; CBC_BLOCK_SIZE],
buffer_len: 0,
pos: 0,
block: None,
})
}
#[must_use]
pub fn key_for_track_id(track_id: TrackId, salt: &Key) -> Key {
let track_hash = format!("{:x}", Md5::digest(track_id.to_string()));
let track_hash = track_hash.as_bytes();
let mut key = RawKey::default();
for i in 0..KEY_LENGTH {
key[i] = track_hash[i] ^ track_hash[i + KEY_LENGTH] ^ salt[i];
}
Key(key)
}
}
impl<R> Seek for Decrypt<R>
where
R: ReadSeek,
{
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
let target = match pos {
SeekFrom::Start(pos) => pos,
SeekFrom::End(pos) => {
let file_size = self.file_size.ok_or_else(|| {
io::Error::new(
io::ErrorKind::Unsupported,
"cannot seek from end with unknown size",
)
})?;
file_size
.checked_add_signed(pos)
.and_then(|pos| pos.checked_sub(1))
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"invalid seek to negative or overflowing position",
)
})?
}
SeekFrom::Current(pos) => {
let current = self
.block
.unwrap_or_default()
.checked_mul(CBC_BLOCK_SIZE as u64)
.and_then(|block| block.checked_add(self.pos))
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"invalid seek to negative or overflowing position",
)
})?;
current.checked_add_signed(pos).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"invalid seek to negative or overflowing position",
)
})?
}
};
if self.file_size.is_some_and(|size| target >= size) {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"seek beyond end of file",
));
}
let block = target.checked_div(CBC_BLOCK_SIZE as u64).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"block calculation would be divide by zero",
)
})?;
let offset = target.checked_rem(CBC_BLOCK_SIZE as u64).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"offset calculation would be divide by zero",
)
})?;
if self.block.is_none_or(|current| current != block) {
self.block = Some(block);
self.file
.seek(SeekFrom::Start(block * CBC_BLOCK_SIZE as u64))?;
if self.file_size.is_some_and(|size| {
let remaining_bytes = size.saturating_sub(block * CBC_BLOCK_SIZE as u64);
remaining_bytes >= CBC_BLOCK_SIZE as u64
}) {
self.file.read_exact(&mut self.buffer)?;
self.buffer_len = CBC_BLOCK_SIZE;
} else {
self.buffer_len = self.file.read(&mut self.buffer)?;
}
let is_encrypted = block % CBC_STRIPE_COUNT as u64 == 0;
let is_full_block = self.buffer_len == CBC_BLOCK_SIZE;
if is_encrypted && is_full_block {
let cipher = cbc::Decryptor::<Blowfish>::new_from_slices(&*self.key, CBC_BF_IV)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
cipher
.decrypt_padded_mut::<NoPadding>(&mut self.buffer[..CBC_BLOCK_SIZE])
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
}
}
self.pos = offset;
Ok(target)
}
}
impl<R> BufRead for Decrypt<R>
where
R: ReadSeek,
{
fn fill_buf(&mut self) -> io::Result<&[u8]> {
if self.pos >= self.buffer_len as u64 {
let _ = self.stream_position()?;
}
let pos = usize::try_from(self.pos).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"buffer position would be out of bounds",
)
})?;
Ok(&self.buffer[pos..self.buffer_len])
}
#[inline]
fn consume(&mut self, amt: usize) {
self.pos = (self.pos.saturating_add(amt as u64)).min(self.buffer_len as u64);
}
}
impl<R> Read for Decrypt<R>
where
R: ReadSeek,
{
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let available = self.fill_buf()?;
let amt = available.len().min(buf.len());
buf[..amt].copy_from_slice(&available[..amt]);
self.consume(amt);
Ok(amt)
}
}