use crate::capsule::classifier::RedactingReader;
use crate::capsule::common::{CapsuleError, CapsuleTag, SpanTag};
use crate::capsule::policy_enforcer::{ColumnPolicyEnforcer, PolicyEnforcer};
use crate::capsule::util_readers::MutexReader;
use crate::capsule::CellIterator;
use std::{
io::{Error, ErrorKind, Read, Write},
marker::Send,
sync::{Arc, Mutex},
};
const END_OF_CELL: u8 = 1 << 0;
const END_OF_ROW: u8 = 1 << 1;
const END_OF_FILE: u8 = 1 << 2;
pub struct CellFrameWriter<W: Write> {
output: W,
tail_chunk: Vec<u8>,
buffer_idx: usize,
}
impl<W: Write> CellFrameWriter<W> {
pub fn new(writer: W) -> Result<Self, CapsuleError> {
Ok(Self {
output: writer,
tail_chunk: vec![0; 5],
buffer_idx: 0,
})
}
pub fn write_stream<R: Read>(&mut self, mut stream: R) -> std::io::Result<u64> {
let header_length: usize = 5;
let mut bytes_consumed: u64 = 0;
loop {
if self.tail_chunk.len() > header_length {
self.output.write_all(&mut self.tail_chunk)?;
self.tail_chunk.truncate(header_length);
}
stream.read_exact(&mut self.tail_chunk[0..header_length])?; bytes_consumed += header_length as u64;
let chunk_len = u32::from_le_bytes(self.tail_chunk[..4].try_into().map_err(|e| {
Error::new(ErrorKind::Other, format!("decoding chunk length: {}", e))
})?);
self.tail_chunk
.resize(chunk_len as usize + header_length, 0);
stream.read_exact(&mut self.tail_chunk[header_length..])?;
bytes_consumed += chunk_len as u64;
if self.tail_chunk[4] == (END_OF_CELL | END_OF_ROW | END_OF_FILE) {
self.tail_chunk[4] = END_OF_CELL | END_OF_ROW;
break;
}
}
Ok(bytes_consumed)
}
pub fn flush_rows(&mut self, end_of_file: bool) -> std::io::Result<()> {
if end_of_file {
self.tail_chunk[4] |= END_OF_FILE;
}
self.output.write_all(&self.tail_chunk)?;
self.tail_chunk.truncate(5);
self.buffer_idx = 0;
Ok(())
}
}
impl<W: Write> Write for CellFrameWriter<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.output.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.output.flush()
}
}
struct InternalCellReader<R> {
reader: Arc<Mutex<R>>,
end_of_cell: bool,
pub end_of_row: bool,
pub end_of_file: bool,
chunk_bytes_remaining: usize,
}
impl<R: Read> InternalCellReader<R> {
fn new(input: Arc<Mutex<R>>) -> Self {
Self {
reader: input,
end_of_cell: false,
end_of_row: false,
end_of_file: false,
chunk_bytes_remaining: 0,
}
}
}
impl<R: Read> Read for InternalCellReader<R> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
let mut bytes_read: usize = 0;
if self.chunk_bytes_remaining > 0 {
let to_copy = std::cmp::min(buf.len(), self.chunk_bytes_remaining);
bytes_read = self.reader.lock().unwrap().read(&mut buf[..to_copy])?;
self.chunk_bytes_remaining -= bytes_read;
return Ok(bytes_read);
} else if self.end_of_cell {
return Ok(0);
}
let mut header_buf = [0u8; 5];
let mut header_bytes_read: usize = 0;
while header_bytes_read < 5 {
match self
.reader
.lock()
.unwrap()
.read(&mut header_buf[header_bytes_read..])
{
Ok(0) => return Ok(bytes_read),
Ok(n) => header_bytes_read += n,
Err(e) => return Err(e),
}
}
let chunk_len = u32::from_le_bytes(header_buf[..4].try_into().map_err(|e| {
std::io::Error::new(ErrorKind::Other, format!("decoding chunk length: {}", e))
})?);
let flags = header_buf[4];
if flags & END_OF_CELL > 0 {
self.end_of_cell = true;
}
if flags & END_OF_ROW > 0 {
self.end_of_row = true;
}
if flags & END_OF_FILE > 0 {
self.end_of_file = true;
}
let to_read: usize = std::cmp::min(buf.len() - bytes_read, chunk_len as usize);
self.reader
.lock()
.unwrap()
.read_exact(&mut buf[bytes_read..bytes_read + to_read])
.map_err(|e| std::io::Error::other(format!("partial read of next chunk: {}", e)))?;
self.chunk_bytes_remaining = chunk_len as usize - to_read;
Ok(bytes_read + to_read)
}
}
pub struct CellDecoder<R, P>
where
R: Read + Send + 'static,
P: PolicyEnforcer + 'static,
{
reader: Arc<Mutex<R>>,
redacting_reader:
Option<Arc<Mutex<RedactingReader<InternalCellReader<R>, ColumnPolicyEnforcer<P>>>>>,
enforcer: Option<Arc<Mutex<P>>>,
redact_tags: Vec<CapsuleTag>,
cell_reader: Option<Arc<Mutex<InternalCellReader<R>>>>,
column: usize,
pub end_of_file: bool,
pub filtered_records: usize,
pub allowed_records: usize,
is_deny_record: bool,
span_tags: Vec<Vec<SpanTag>>,
}
impl<R: Read + Send + 'static, P: PolicyEnforcer + 'static> CellDecoder<R, P> {
pub fn new(
reader: Arc<Mutex<R>>,
enforcer: Option<Arc<Mutex<P>>>,
redact_tags: Vec<CapsuleTag>,
) -> Result<Self, CapsuleError> {
Ok(Self {
reader,
enforcer,
redact_tags,
redacting_reader: None,
cell_reader: None,
column: 0,
end_of_file: false,
filtered_records: 0,
allowed_records: 0,
is_deny_record: false,
span_tags: Vec::new(),
})
}
fn consume_row(&mut self) -> Result<(), CapsuleError> {
if self.cell_reader.is_some() {
let mut cell_reader = self.cell_reader.as_mut().unwrap().lock().unwrap();
while !cell_reader.end_of_row {
if !cell_reader.end_of_cell {
let mut remainder: Vec<u8> = Vec::new();
cell_reader.read_to_end(&mut remainder).map_err(|e| {
CapsuleError::Generic(format!("reading to end of cell: {}", e))
})?;
}
if cell_reader.end_of_file {
self.end_of_file = true;
return Ok(());
}
*cell_reader = InternalCellReader::new(self.reader.clone());
}
if cell_reader.end_of_file {
self.end_of_file = true;
}
}
Ok(())
}
}
impl<R: Read + Send + 'static, P: PolicyEnforcer + 'static> CellIterator for CellDecoder<R, P> {
fn span_tags(&self) -> Vec<Vec<SpanTag>> {
self.span_tags.clone()
}
fn is_deny_record(&self) -> bool {
if self.redacting_reader.is_some() {
return self
.redacting_reader
.as_ref()
.unwrap()
.lock()
.unwrap()
.is_deny_record;
}
false
}
fn next_cell(&mut self) -> Result<Box<dyn Read + Send + 'static>, CapsuleError> {
if self.cell_reader.is_some() {
let mut cell_reader = self.cell_reader.as_mut().unwrap().lock().unwrap();
if self.redacting_reader.is_some() {
let mut redacting_reader = self.redacting_reader.as_mut().unwrap().lock().unwrap();
if redacting_reader.is_deny_capsule {
self.is_deny_record = true;
return Err(CapsuleError::CapsuleAccessDeniedByPolicy);
} else if redacting_reader.is_deny_record {
self.is_deny_record = true;
self.filtered_records += 1;
drop(redacting_reader);
drop(cell_reader);
self.consume_row()?;
return Err(CapsuleError::RowAccessDeniedByPolicy);
} else {
self.allowed_records += 1;
self.span_tags
.push(std::mem::take(&mut redacting_reader.span_tags));
}
}
self.column += 1;
if !cell_reader.end_of_cell {
let mut remainder: Vec<u8> = Vec::new();
cell_reader
.read_to_end(&mut remainder)
.map_err(|e| CapsuleError::Generic(format!("reading to end of cell: {}", e)))?;
}
if cell_reader.end_of_file {
self.end_of_file = true;
return Err(CapsuleError::EndOfCapsule);
}
if cell_reader.end_of_row {
return Err(CapsuleError::EndOfRow);
}
}
self.cell_reader = Some(Arc::new(Mutex::new(InternalCellReader::new(
self.reader.clone(),
))));
if self.enforcer.is_some() {
self.redacting_reader = Some(Arc::new(Mutex::new(RedactingReader::<
InternalCellReader<R>,
ColumnPolicyEnforcer<P>,
>::new(
self.cell_reader.clone().unwrap().clone(),
ColumnPolicyEnforcer {
enforcer: self.enforcer.clone().unwrap().clone(),
column: self.column,
},
self.redact_tags.clone(),
))));
Ok(Box::new(MutexReader {
reader: self.redacting_reader.clone().unwrap().clone(),
}))
} else {
Ok(Box::new(MutexReader {
reader: self.cell_reader.clone().unwrap().clone(),
}))
}
}
fn cleanup(&mut self) -> Result<(), CapsuleError> {
if self.cell_reader.is_some() {
let cell_reader = self.cell_reader.as_mut().unwrap().lock().unwrap();
if self.redacting_reader.is_some() {
let redacting_reader = self.redacting_reader.as_mut().unwrap().lock().unwrap();
if redacting_reader.is_deny_capsule {
self.is_deny_record = true;
self.end_of_file = true;
} else if redacting_reader.is_deny_record {
self.is_deny_record = true;
self.filtered_records += 1;
drop(redacting_reader);
drop(cell_reader);
self.consume_row()?;
}
}
}
Ok(())
}
}