use crate::error::{LoftyError, Result};
use crate::util::math::F80;
use std::collections::VecDeque;
use std::fs::File;
use std::io::{Cursor, Read, Seek, SeekFrom, Write};
pub(crate) trait SeekStreamLen: Seek {
fn stream_len_hack(&mut self) -> crate::error::Result<u64> {
use std::io::SeekFrom;
let current_pos = self.stream_position()?;
let len = self.seek(SeekFrom::End(0))?;
self.seek(SeekFrom::Start(current_pos))?;
Ok(len)
}
}
impl<T> SeekStreamLen for T where T: Seek {}
pub trait Truncate {
type Error: Into<LoftyError>;
fn truncate(&mut self, new_len: u64) -> std::result::Result<(), Self::Error>;
}
impl Truncate for File {
type Error = std::io::Error;
fn truncate(&mut self, new_len: u64) -> std::result::Result<(), Self::Error> {
self.set_len(new_len)
}
}
impl Truncate for Vec<u8> {
type Error = std::convert::Infallible;
fn truncate(&mut self, new_len: u64) -> std::result::Result<(), Self::Error> {
self.truncate(new_len as usize);
Ok(())
}
}
impl Truncate for VecDeque<u8> {
type Error = std::convert::Infallible;
fn truncate(&mut self, new_len: u64) -> std::result::Result<(), Self::Error> {
self.truncate(new_len as usize);
Ok(())
}
}
impl<T> Truncate for Cursor<T>
where
T: Truncate,
{
type Error = <T as Truncate>::Error;
fn truncate(&mut self, new_len: u64) -> std::result::Result<(), Self::Error> {
self.get_mut().truncate(new_len)
}
}
impl<T> Truncate for Box<T>
where
T: Truncate,
{
type Error = <T as Truncate>::Error;
fn truncate(&mut self, new_len: u64) -> std::result::Result<(), Self::Error> {
self.as_mut().truncate(new_len)
}
}
impl<T> Truncate for &mut T
where
T: Truncate,
{
type Error = <T as Truncate>::Error;
fn truncate(&mut self, new_len: u64) -> std::result::Result<(), Self::Error> {
(**self).truncate(new_len)
}
}
pub trait Length {
type Error: Into<LoftyError>;
fn len(&self) -> std::result::Result<u64, Self::Error>;
}
impl Length for File {
type Error = std::io::Error;
fn len(&self) -> std::result::Result<u64, Self::Error> {
self.metadata().map(|m| m.len())
}
}
impl Length for Vec<u8> {
type Error = std::convert::Infallible;
fn len(&self) -> std::result::Result<u64, Self::Error> {
Ok(self.len() as u64)
}
}
impl Length for VecDeque<u8> {
type Error = std::convert::Infallible;
fn len(&self) -> std::result::Result<u64, Self::Error> {
Ok(self.len() as u64)
}
}
impl<T> Length for Cursor<T>
where
T: Length,
{
type Error = <T as Length>::Error;
fn len(&self) -> std::result::Result<u64, Self::Error> {
Length::len(self.get_ref())
}
}
impl<T> Length for Box<T>
where
T: Length,
{
type Error = <T as Length>::Error;
fn len(&self) -> std::result::Result<u64, Self::Error> {
Length::len(self.as_ref())
}
}
impl<T> Length for &T
where
T: Length,
{
type Error = <T as Length>::Error;
fn len(&self) -> std::result::Result<u64, Self::Error> {
Length::len(*self)
}
}
impl<T> Length for &mut T
where
T: Length,
{
type Error = <T as Length>::Error;
fn len(&self) -> std::result::Result<u64, Self::Error> {
Length::len(*self)
}
}
pub trait FileLike: Read + Write + Seek + Truncate + Length
where
<Self as Truncate>::Error: Into<LoftyError>,
<Self as Length>::Error: Into<LoftyError>,
{
}
impl<T> FileLike for T
where
T: Read + Write + Seek + Truncate + Length,
<T as Truncate>::Error: Into<LoftyError>,
<T as Length>::Error: Into<LoftyError>,
{
}
pub(crate) trait ReadExt: Read {
fn read_f80(&mut self) -> Result<F80>;
}
impl<R> ReadExt for R
where
R: Read,
{
fn read_f80(&mut self) -> Result<F80> {
let mut bytes = [0; 10];
self.read_exact(&mut bytes)?;
Ok(F80::from_be_bytes(bytes))
}
}
#[derive(Copy, Clone, Debug, Default, PartialEq)]
pub(crate) enum RevSearchStart {
#[default]
FromEnd,
FromCurrent,
}
#[derive(Copy, Clone, Debug, Default, PartialEq)]
pub(crate) enum RevSearchEnd {
StreamStart,
#[default]
FromCurrent,
Pos(u64),
}
pub(crate) struct RevPatternSearcher<'a, T> {
start: RevSearchStart,
end: RevSearchEnd,
buffer_size: u64,
pattern: &'a [u8],
reader: &'a mut T,
}
impl<T> RevPatternSearcher<'_, T>
where
T: Read + Seek,
{
pub(crate) fn buffer_size(&mut self, buffer_size: u64) -> &mut Self {
self.buffer_size = buffer_size;
self
}
pub(crate) fn start_pos(&mut self, start: RevSearchStart) -> &mut Self {
self.start = start;
self
}
pub(crate) fn end_pos(&mut self, end: RevSearchEnd) -> &mut Self {
self.end = end;
self
}
pub(crate) fn search(&mut self) -> std::io::Result<bool> {
if self.pattern.is_empty() {
return Ok(true);
}
let original_pos = self.reader.stream_position()?;
let pattern_len = self.pattern.len();
let start_pos = match self.start {
RevSearchStart::FromEnd => self.reader.seek(SeekFrom::End(0))?,
RevSearchStart::FromCurrent => original_pos,
};
let end_pos = match self.end {
RevSearchEnd::StreamStart => 0,
RevSearchEnd::FromCurrent => original_pos,
RevSearchEnd::Pos(p) => p,
};
if start_pos < end_pos
|| (start_pos - end_pos) < pattern_len as u64
|| self.buffer_size < pattern_len as u64
{
self.reader.seek(SeekFrom::Start(original_pos))?;
return Ok(false);
}
let overlap_step = self.buffer_size - ((pattern_len as u64) - 1);
let mut current_pos = start_pos;
let mut buf = vec![0; self.buffer_size as usize];
while current_pos > end_pos {
let window_size = current_pos - end_pos;
let read_size = std::cmp::min(self.buffer_size, window_size);
let read_start = current_pos - read_size;
self.reader.seek(SeekFrom::Start(read_start))?;
let window = &mut buf[..read_size as usize];
self.reader.read_exact(window)?;
if let Some(match_offset) = window
.windows(self.pattern.len())
.enumerate()
.rev()
.find_map(|(idx, window)| {
if window == self.pattern {
Some(idx)
} else {
None
}
}) {
self.reader
.seek(SeekFrom::Start(read_start + match_offset as u64))?;
return Ok(true);
}
current_pos -= std::cmp::min(read_size, overlap_step);
}
Ok(false)
}
}
pub(crate) trait ReadFindExt: Read + Seek + Sized {
fn rfind<'a>(&'a mut self, pattern: &'a [u8]) -> RevPatternSearcher<'a, Self> {
RevPatternSearcher {
start: RevSearchStart::default(),
end: RevSearchEnd::StreamStart,
buffer_size: 1024,
pattern,
reader: self,
}
}
}
impl<T> ReadFindExt for T where T: Read + Seek {}
#[cfg(test)]
mod tests {
use crate::config::{ParseOptions, WriteOptions};
use crate::file::AudioFile;
use crate::io::{ReadFindExt, RevSearchEnd, RevSearchStart};
use crate::mpeg::MpegFile;
use crate::tag::Accessor;
use std::io::{Cursor, Read, Seek, SeekFrom, Write};
use std::iter::repeat_n;
use std::ops::Neg;
const TEST_ASSET: &str = "tests/files/assets/minimal/full_test.mp3";
fn test_asset_contents() -> Vec<u8> {
std::fs::read(TEST_ASSET).unwrap()
}
fn file() -> MpegFile {
let file_contents = test_asset_contents();
let mut reader = Cursor::new(file_contents);
MpegFile::read_from(&mut reader, ParseOptions::new()).unwrap()
}
fn alter_tag(file: &mut MpegFile) {
let tag = file.id3v2_mut().unwrap();
tag.set_artist(String::from("Bar artist"));
}
fn revert_tag(file: &mut MpegFile) {
let tag = file.id3v2_mut().unwrap();
tag.set_artist(String::from("Foo artist"));
}
#[test_log::test]
fn io_save_to_file() {
let mut file = file();
alter_tag(&mut file);
let mut temp_file = tempfile::tempfile().unwrap();
let file_content = std::fs::read(TEST_ASSET).unwrap();
temp_file.write_all(&file_content).unwrap();
temp_file.rewind().unwrap();
file.save_to(&mut temp_file, WriteOptions::new().preferred_padding(0))
.expect("Failed to save to file");
temp_file.rewind().unwrap();
let mut file = MpegFile::read_from(&mut temp_file, ParseOptions::new()).unwrap();
revert_tag(&mut file);
temp_file.rewind().unwrap();
file.save_to(&mut temp_file, WriteOptions::new().preferred_padding(0))
.expect("Failed to save to file");
temp_file.rewind().unwrap();
let mut current_file_contents = Vec::new();
temp_file.read_to_end(&mut current_file_contents).unwrap();
assert_eq!(current_file_contents, test_asset_contents());
}
#[test_log::test]
fn io_save_to_vec() {
let mut file = file();
alter_tag(&mut file);
let file_content = std::fs::read(TEST_ASSET).unwrap();
let mut reader = Cursor::new(file_content);
file.save_to(&mut reader, WriteOptions::new().preferred_padding(0))
.expect("Failed to save to vec");
reader.rewind().unwrap();
let mut file = MpegFile::read_from(&mut reader, ParseOptions::new()).unwrap();
revert_tag(&mut file);
reader.rewind().unwrap();
file.save_to(&mut reader, WriteOptions::new().preferred_padding(0))
.expect("Failed to save to vec");
let current_file_contents = reader.into_inner();
assert_eq!(current_file_contents, test_asset_contents());
}
#[test_log::test]
fn io_save_using_references() {
struct File {
buf: Vec<u8>,
}
let mut f = File {
buf: std::fs::read(TEST_ASSET).unwrap(),
};
let mut file = file();
alter_tag(&mut file);
{
let mut reader = Cursor::new(&mut f.buf);
file.save_to(&mut reader, WriteOptions::new().preferred_padding(0))
.expect("Failed to save to vec");
}
{
let mut reader = Cursor::new(&f.buf[..]);
file = MpegFile::read_from(&mut reader, ParseOptions::new()).unwrap();
revert_tag(&mut file);
}
{
let mut reader = Cursor::new(&mut f.buf);
file.save_to(&mut reader, WriteOptions::new().preferred_padding(0))
.expect("Failed to save to vec");
}
let current_file_contents = f.buf;
assert_eq!(current_file_contents, test_asset_contents());
}
#[test_log::test]
fn rev_search() {
const PAT: &[u8] = b"PATTERN";
let mut data1 = PAT.to_vec();
data1.extend(repeat_n(0, 5000));
let mut stream1 = Cursor::new(data1);
assert!(stream1.rfind(PAT).search().unwrap());
let mut data2 = PAT.to_vec();
data2.extend(repeat_n(0, 1023));
let mut stream2 = Cursor::new(data2);
assert!(stream2.rfind(PAT).search().unwrap());
let mut data3 = PAT.to_vec();
let junk_len = 20;
data3.extend(repeat_n(0, junk_len));
data3.extend(PAT);
data3.extend(repeat_n(0, junk_len));
let last_occurence_offset = data3.len() - (junk_len + PAT.len());
let mut stream3 = Cursor::new(data3);
assert!(stream3.rfind(PAT).search().unwrap());
assert_eq!(stream3.position(), last_occurence_offset as u64);
let mut data4 = PAT.to_vec();
data4.extend(repeat_n(0, junk_len));
data4.extend(PAT);
data4.extend(repeat_n(0, junk_len));
data4.extend(PAT);
let middle_match_offset = PAT.len() + junk_len;
let mut stream4 = Cursor::new(data4);
stream4
.seek(SeekFrom::End(((PAT.len() - 3) as i64).neg()))
.unwrap();
assert!(
stream4
.rfind(PAT)
.start_pos(RevSearchStart::FromCurrent)
.end_pos(RevSearchEnd::StreamStart)
.search()
.unwrap()
);
assert_eq!(stream4.position(), middle_match_offset as u64);
}
}