use std::{
io::{self, Cursor, Read, Seek, SeekFrom},
ops::Deref,
str::FromStr,
};
use blowfish::{cipher::BlockDecryptMut, cipher::KeyIvInit, Blowfish};
use cbc::cipher::block_padding::NoPadding;
use md5::{Digest, Md5};
use stream_download::{storage::temp::TempStorageProvider, StreamDownload};
use crate::{
error::{Error, Result},
protocol::media::Cipher,
track::{Track, TrackId},
};
pub struct Decrypt {
download: StreamDownload<TempStorageProvider>,
file_size: Option<u64>,
cipher: Cipher,
key: Key,
buffer: Cursor<Vec<u8>>,
block: Option<u64>,
}
pub const KEY_LENGTH: usize = 16;
pub type RawKey = [u8; KEY_LENGTH];
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct Key(RawKey);
impl FromStr for Key {
type Err = Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let len = s.len();
if len != KEY_LENGTH {
return Err(Error::out_of_range(format!(
"key length is {len} but should be {KEY_LENGTH}",
)));
}
let bytes = s.as_bytes();
let mut key = [0; KEY_LENGTH];
key.copy_from_slice(bytes);
Ok(Self(key))
}
}
impl Deref for Key {
type Target = RawKey;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Decrypt {
const CBC_BF_IV: &[u8; 8] = b"\x00\x01\x02\x03\x04\x05\x06\x07";
const CBC_BLOCK_SIZE: usize = 2 * 1024;
const CBC_STRIPE_COUNT: usize = 3;
const SUPPORTED_CIPHERS: [Cipher; 2] = [Cipher::NONE, Cipher::BF_CBC_STRIPE];
pub fn new(
track: &Track,
download: StreamDownload<TempStorageProvider>,
salt: &Key,
) -> Result<Self> {
if !Self::SUPPORTED_CIPHERS.contains(&track.cipher()) {
return Err(Error::unimplemented("unsupported encryption algorithm"));
}
let key = Self::key_for_track_id(track.id(), salt);
Ok(Self {
download,
file_size: track.file_size(),
cipher: track.cipher(),
key,
buffer: Cursor::new(Vec::new()),
block: None,
})
}
#[must_use]
pub fn key_for_track_id(track_id: TrackId, salt: &Key) -> Key {
let track_hash = format!("{:x}", Md5::digest(track_id.to_string()));
let track_hash = track_hash.as_bytes();
let mut key = RawKey::default();
for i in 0..KEY_LENGTH {
key[i] = track_hash[i] ^ track_hash[i + KEY_LENGTH] ^ salt[i];
}
Key(key)
}
#[must_use]
fn bytes_on_buffer(&self) -> u64 {
let len = self.buffer.get_ref().len() as u64;
len.saturating_sub(self.buffer.position())
}
}
impl Seek for Decrypt {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
if self.cipher == Cipher::NONE {
return self.download.seek(pos);
}
let target = match pos {
SeekFrom::Start(pos) => pos,
SeekFrom::End(pos) => {
let file_size = self.file_size.ok_or(io::Error::new(
io::ErrorKind::Unsupported,
"cannot seek from the end of a stream with unknown size",
))?;
file_size
.checked_add_signed(pos)
.and_then(|pos| pos.checked_sub(1))
.ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid seek to a negative or overflowing position",
))?
}
SeekFrom::Current(pos) => {
let current = self.block.map_or(0, |block| {
block * Self::CBC_BLOCK_SIZE as u64 + self.buffer.position()
});
current.checked_add_signed(pos).ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid seek to a negative or overflowing position",
))?
}
};
if self.file_size.is_some_and(|file_size| target >= file_size) {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"seek to a position beyond the end of the file",
));
}
let block = target
.checked_div(Self::CBC_BLOCK_SIZE as u64)
.ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"block calculation would be divide by zero",
))?;
let offset = target
.checked_rem(Self::CBC_BLOCK_SIZE as u64)
.ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"offset calculation would be divide by zero",
))?;
if self.block.is_none_or(|current| current != block) {
self.block = Some(block);
self.download
.seek(SeekFrom::Start(block * Self::CBC_BLOCK_SIZE as u64))?;
let mut buffer = [0; Self::CBC_BLOCK_SIZE];
let length = self.download.read(&mut buffer)?;
let is_encrypted = block % Self::CBC_STRIPE_COUNT as u64 == 0;
let is_full_block = length == Self::CBC_BLOCK_SIZE;
if is_encrypted && is_full_block {
let cipher =
cbc::Decryptor::<Blowfish>::new_from_slices(&*self.key, Self::CBC_BF_IV)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
cipher
.decrypt_padded_mut::<NoPadding>(&mut buffer)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
}
let mut buffer = buffer.to_vec();
buffer.truncate(length);
self.buffer = Cursor::new(buffer);
}
self.buffer.set_position(offset);
Ok(target)
}
}
impl Read for Decrypt {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.cipher == Cipher::NONE {
return self.download.read(buf);
}
let mut bytes_on_buffer = self.bytes_on_buffer();
let bytes_wanted = buf.len();
let mut bytes_read = 0;
while bytes_read < bytes_wanted {
if bytes_on_buffer == 0 {
let _ = self.stream_position()?;
bytes_on_buffer = self.bytes_on_buffer();
}
if bytes_on_buffer == 0 {
break;
}
let bytes_to_read = usize::min(
bytes_on_buffer.try_into().unwrap_or(usize::MAX),
bytes_wanted.saturating_sub(bytes_read),
);
let bytes_read_from_buffer = self
.buffer
.read(&mut buf[bytes_read..bytes_read + bytes_to_read])?;
bytes_on_buffer -= bytes_read_from_buffer as u64;
bytes_read += bytes_read_from_buffer;
}
Ok(bytes_read)
}
}