use crate::capsule::{CapsuleError, CellIterator, SpanTag};
use serde::Serialize;
use std::{
io::{Read, Write},
sync::{Arc, Mutex},
};
const CAPSULE_DELIMITER: u8 = 0xFF;
const CAPSULE_END_OF_FRAME: u8 = 0x00;
const OBS_WINDOW_SIZE: usize = 8 * 1024;
pub trait Discard {
fn skip_frame(&mut self) -> std::io::Result<usize>;
}
pub struct EOFCallbackReader<R: Read, F: Fn(usize) -> Result<(), std::io::Error>> {
input: R,
callback: F,
total_bytes_read: usize,
}
impl<R: Read, F: Fn(usize) -> Result<(), std::io::Error>> EOFCallbackReader<R, F> {
pub fn new(input: R, callback: F) -> Self {
Self {
input,
callback,
total_bytes_read: 0,
}
}
}
impl<R: Read, F: Fn(usize) -> Result<(), std::io::Error>> Read for EOFCallbackReader<R, F> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self.input.read(buf) {
Ok(0) => {
if let Err(e) = (self.callback)(self.total_bytes_read) {
Err(e)
} else {
Ok(0)
}
}
Ok(n) => {
self.total_bytes_read += n;
Ok(n)
}
Err(e) => Err(e),
}
}
}
pub struct EOFCallbackWriter<W: Write, F: FnMut(usize) -> Result<(), std::io::Error>> {
output: W,
callback: F,
total_bytes_written: usize,
}
impl<W: Write, F: FnMut(usize) -> Result<(), std::io::Error>> EOFCallbackWriter<W, F> {
pub fn new(output: W, callback: F) -> Self {
Self {
output,
callback,
total_bytes_written: 0,
}
}
}
impl<W: Write, F: FnMut(usize) -> Result<(), std::io::Error>> Write for EOFCallbackWriter<W, F> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.output.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
(self.callback)(self.total_bytes_written)?;
self.output.flush()
}
}
pub struct LazyEvaluatingReader<R: Read, S: Serialize, F: Fn() -> Result<S, std::io::Error>> {
input: R,
get_content: F,
content: Vec<u8>,
content_offset: usize,
content_computed: bool,
}
impl<R: Read, S: Serialize, F: Fn() -> Result<S, std::io::Error>> LazyEvaluatingReader<R, S, F> {
pub fn new(input: R, get_content: F) -> Self {
Self {
input,
get_content,
content: Vec::new(),
content_offset: 0,
content_computed: false,
}
}
}
impl<R: Read, S: Serialize, F: Fn() -> Result<S, std::io::Error>> Read
for LazyEvaluatingReader<R, S, F>
{
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
match self.input.read(buf) {
Ok(0) => {}
Ok(n) => return Ok(n),
Err(e) => return Err(e),
}
if !self.content_computed {
self.content_computed = true;
let content = (self.get_content)()?;
ciborium::into_writer(&content, &mut self.content).map_err(|e| {
std::io::Error::other(format!("serializing returned content: {}", e))
})?;
}
let to_copy = std::cmp::min(self.content.len() - self.content_offset, buf.len());
if to_copy > 0 {
buf[..to_copy]
.copy_from_slice(&self.content[self.content_offset..self.content_offset + to_copy]);
self.content_offset += to_copy;
}
Ok(to_copy)
}
}
pub struct OBSEscapeReader<R: Read> {
input: R,
delimiter: u8,
end_of_frame: u8,
window: Vec<u8>,
window_size: usize,
window_idx: usize,
window_tail: bool,
escaping: bool,
}
impl<R: Read> OBSEscapeReader<R> {
pub fn new(input: R) -> Self {
Self {
input,
delimiter: CAPSULE_DELIMITER,
end_of_frame: CAPSULE_END_OF_FRAME,
window: vec![0; 4096],
window_size: 0,
window_idx: 0,
window_tail: false,
escaping: false,
}
}
}
impl<R: Read> Read for OBSEscapeReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.window_size == self.window_idx {
self.window_size = self.input.read(&mut self.window)?;
self.window_idx = 0;
}
if self.window_size == 0 && !self.window_tail {
self.window_tail = true;
self.window[0] = self.delimiter;
self.window[1] = self.end_of_frame;
self.window_size = 2;
}
let mut bytes_written = 0;
let mut iter = self.window[self.window_idx..self.window_size]
.iter()
.peekable();
while bytes_written < buf.len() {
if self.escaping {
buf[bytes_written] = self.delimiter;
bytes_written += 1;
self.escaping = false;
continue;
}
if let Some(&item) = iter.next() {
if item == self.delimiter && !self.window_tail {
self.escaping = true
}
buf[bytes_written] = item;
bytes_written += 1;
self.window_idx += 1;
} else {
break;
}
}
Ok(bytes_written)
}
}
pub struct OBSReader<R: Read> {
input: R,
delimiter: u8,
end_of_frame: u8,
window: Vec<u8>,
window_size: usize,
window_idx: usize,
escaping: bool,
bytes_read: usize,
bytes_written: usize,
}
impl<R: Read> OBSReader<R> {
pub fn new(input: R) -> Self {
Self {
input,
delimiter: CAPSULE_DELIMITER,
end_of_frame: CAPSULE_END_OF_FRAME,
window: vec![0; OBS_WINDOW_SIZE],
window_size: 0,
window_idx: 0,
escaping: false,
bytes_read: 0,
bytes_written: 0,
}
}
fn filtered_read(&mut self, buf: &mut [u8]) -> std::io::Result<(usize, bool)> {
if self.window_size == self.window_idx {
self.window_size = self.input.read(&mut self.window)?;
self.bytes_read += self.window_size;
self.window_idx = 0;
}
let mut bytes_written = 0;
let mut end_of_frame = false;
let mut iter = self.window[self.window_idx..self.window_size].iter();
while bytes_written < buf.len() {
if let Some(&item) = iter.next() {
if self.escaping && item == self.end_of_frame {
self.window_idx += 1;
end_of_frame = true;
break;
}
if !self.escaping && item == self.delimiter {
self.escaping = true
} else {
buf[bytes_written] = item;
bytes_written += 1;
self.escaping = false;
}
self.window_idx += 1;
} else {
break;
}
}
if !end_of_frame
&& self.window_size - self.window_idx >= 2
&& self.window[self.window_idx] == self.delimiter
&& self.window[self.window_idx + 1] == self.end_of_frame
{
self.window_idx += 2;
end_of_frame = true;
}
Ok((bytes_written, end_of_frame))
}
}
impl<R: Read> Read for OBSReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let (n, _) = self.filtered_read(buf)?;
self.bytes_written += n;
Ok(n)
}
}
impl<R: Read> Discard for OBSReader<R> {
fn skip_frame(&mut self) -> std::io::Result<usize> {
let mut temp_buffer: Vec<u8> = vec![0; 1024];
let mut bytes_consumed = 0;
loop {
let (n, eof) = self.filtered_read(&mut temp_buffer)?;
bytes_consumed += n;
self.bytes_read += bytes_consumed;
self.bytes_written += bytes_consumed;
if eof {
return Ok(bytes_consumed);
}
}
}
}
pub struct OBSEscapeWriter<W: Write> {
output: W,
delimiter: u8,
end_of_frame: u8,
window: Vec<u8>,
window_idx: usize,
bytes_written: usize,
escaping: bool,
}
impl<W: Write> OBSEscapeWriter<W> {
pub fn new(output: W) -> Self {
Self {
output,
delimiter: CAPSULE_DELIMITER,
end_of_frame: CAPSULE_END_OF_FRAME,
window: vec![0; OBS_WINDOW_SIZE],
window_idx: 0,
bytes_written: 0,
escaping: false,
}
}
}
impl<W: Write> Write for OBSEscapeWriter<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut iter = buf.iter().peekable();
let mut bytes_read: usize = 0;
loop {
if self.window_idx == self.window.len() {
self.output.write_all(&self.window)?;
self.bytes_written += self.window.len();
self.window_idx = 0;
}
if self.escaping {
self.window[self.window_idx] = self.delimiter;
self.window_idx += 1;
self.escaping = false;
continue;
}
if let Some(&item) = iter.next() {
if item == self.delimiter {
self.escaping = true
}
self.window[self.window_idx] = item;
self.window_idx += 1;
bytes_read += 1
} else {
break;
}
}
if self.window_idx != 0 {
self.output.write_all(&self.window[..self.window_idx])?;
self.bytes_written += self.window_idx;
self.window_idx = 0;
}
Ok(bytes_read)
}
fn flush(&mut self) -> std::io::Result<()> {
self.window_idx = 0;
if self.escaping {
self.window[self.window_idx] = self.delimiter;
self.window_idx += 1;
}
self.window[self.window_idx] = self.delimiter;
self.window[self.window_idx + 1] = self.end_of_frame;
self.window_idx += 2;
self.output.write_all(&self.window[..self.window_idx])?;
self.bytes_written += self.window_idx;
self.window_idx = 0;
self.output.flush()
}
}
pub struct MutexReader<R> {
pub reader: Arc<Mutex<R>>,
}
impl<R: Read + Send> Read for MutexReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.reader.lock().unwrap().read(buf)
}
}
impl<R: Read + Send + Discard> Discard for MutexReader<R> {
fn skip_frame(&mut self) -> std::io::Result<usize> {
self.reader.lock().unwrap().skip_frame()
}
}
pub struct MutexCellIterator<I: CellIterator> {
pub it: Arc<Mutex<I>>,
}
impl<I: CellIterator> CellIterator for MutexCellIterator<I> {
fn next_cell(&mut self) -> Result<Box<dyn Read + Send + 'static>, CapsuleError> {
self.it.lock().unwrap().next_cell()
}
fn is_deny_record(&self) -> bool {
self.it.lock().unwrap().is_deny_record()
}
fn span_tags(&self) -> Vec<Vec<SpanTag>> {
self.it.lock().unwrap().span_tags()
}
fn cleanup(&mut self) -> Result<(), CapsuleError> {
self.it.lock().unwrap().cleanup()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_escaping_reader() {
let input_data = [1, 2, 3, 255, 255, 1, 4, 5, 1];
let mut reader = OBSEscapeReader::new(std::io::Cursor::new(input_data));
let mut output = Vec::new();
let _ = reader.read_to_end(&mut output).expect("Read failed");
assert_eq!(
output,
vec![1, 2, 3, 255, 255, 255, 255, 1, 4, 5, 1, 255, 0]
);
let mut result = Vec::new();
let mut reader_unescape = OBSReader::new(std::io::Cursor::new(output));
let _ = reader_unescape
.read_to_end(&mut result)
.expect("Read failed");
assert_eq!(result, vec![1, 2, 3, 255, 255, 1, 4, 5, 1]);
}
#[test]
fn test_escaping_writer() {
let mut output = Vec::new();
let mut writer = OBSEscapeWriter::new(&mut output);
let input_data = [1, 2, 3, 255, 255, 1, 4, 5, 1];
writer
.write_all(input_data.to_vec().as_slice())
.expect("failed to write data");
writer.flush().expect("failed to flush writer");
assert_eq!(
output,
vec![1, 2, 3, 255, 255, 255, 255, 1, 4, 5, 1, 255, 0]
);
let mut result = Vec::new();
let mut reader_unescape = OBSReader::new(std::io::Cursor::new(output));
let _ = reader_unescape
.read_to_end(&mut result)
.expect("Read failed");
assert_eq!(result, vec![1, 2, 3, 255, 255, 1, 4, 5, 1]);
}
}