antimatter 2.0.13

antimatter.io Rust library for data control
Documentation
use crate::capsule::classifier::RedactingReader;
use crate::capsule::common::{CapsuleError, CapsuleTag, SpanTag};
use crate::capsule::policy_enforcer::{ColumnPolicyEnforcer, PolicyEnforcer};
use crate::capsule::util_readers::MutexReader;
use crate::capsule::CellIterator;
use std::{
    io::{Error, ErrorKind, Read, Write},
    marker::Send,
    sync::{Arc, Mutex},
};

// END_OF_CELL is a flag output by CellFramer in its cell chunk header
// when the cell ends with the current chunk.
const END_OF_CELL: u8 = 1 << 0;
// END_OF_ROW is a flag output by CellFramer in its cell chunk header
// when the cell and row end with the current chunk.
const END_OF_ROW: u8 = 1 << 1;
// END_OF_FILE is a flag output by CellFramer in its cell chunk header
// when the cell, row, and capsule end with the current chunk.
const END_OF_FILE: u8 = 1 << 2;

pub struct CellFrameWriter<W: Write> {
    output: W,
    // tail_chunk holds the last chunk consumed from the classifier.
    // We delay writing ths out until either more rows are address,
    // or finalise() is called. THis allows us to control when the
    // END_OF_FILE flag is added.
    tail_chunk: Vec<u8>,
    buffer_idx: usize,
}

impl<W: Write> CellFrameWriter<W> {
    pub fn new(writer: W) -> Result<Self, CapsuleError> {
        Ok(Self {
            output: writer,
            tail_chunk: vec![0; 5],
            buffer_idx: 0,
        })
    }

    // write_stream will read the contents form the input stream, chunk by chunk, until a chunk with
    // the EOF flag is consumed. This final chunk will have the EOF flag stripped, and it will be
    // buffered. This is done as the invocation of this function my not mean a true end of file.
    // We need to wait for either another invocation of this function (there is more data), or a
    // call to finalize() is made (we are done, restore the EOF flag and write out).
    pub fn write_stream<R: Read>(&mut self, mut stream: R) -> std::io::Result<u64> {
        // Write out the tail if we have one
        let header_length: usize = 5;
        let mut bytes_consumed: u64 = 0;
        loop {
            if self.tail_chunk.len() > header_length {
                self.output.write_all(&mut self.tail_chunk)?;
                self.tail_chunk.truncate(header_length);
            }

            stream.read_exact(&mut self.tail_chunk[0..header_length])?; // | u32: length | u8: flags |
            bytes_consumed += header_length as u64;
            let chunk_len = u32::from_le_bytes(self.tail_chunk[..4].try_into().map_err(|e| {
                Error::new(ErrorKind::Other, format!("decoding chunk length: {}", e))
            })?);
            self.tail_chunk
                .resize(chunk_len as usize + header_length, 0);
            stream.read_exact(&mut self.tail_chunk[header_length..])?;
            bytes_consumed += chunk_len as u64;

            if self.tail_chunk[4] == (END_OF_CELL | END_OF_ROW | END_OF_FILE) {
                // This is the end of the current stream. Strip the EOF flag and trap this for later
                self.tail_chunk[4] = END_OF_CELL | END_OF_ROW;
                break;
            }
        }
        Ok(bytes_consumed)
    }

    pub fn flush_rows(&mut self, end_of_file: bool) -> std::io::Result<()> {
        if end_of_file {
            self.tail_chunk[4] |= END_OF_FILE;
        }
        self.output.write_all(&self.tail_chunk)?;
        self.tail_chunk.truncate(5);
        self.buffer_idx = 0;
        Ok(())
    }
}

// Allow use to do a passthrough write to the underlying writer
impl<W: Write> Write for CellFrameWriter<W> {
    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
        self.output.write(buf)
    }

    fn flush(&mut self) -> std::io::Result<()> {
        self.output.flush()
    }
}

// InternalCellReader is an implementation of Read which is able to read
// a Cell that was framed up the ClassifyingReader. This is needed since
// the cells are serialized by the output writer with additional context
// data. InternalCellReader will strip the cell frame headers as it
// encounters them, and will return EOF when the end of a cell is
// reached. Once it returns EOF, the end_of_row flag will also be set
// correctly to indicate whether the end of this cell also corresponds
// with the end of a row, which allows CellDecoder to correctly reconstruct
// the original table encoded using ClassifyingReader.
struct InternalCellReader<R> {
    reader: Arc<Mutex<R>>,
    end_of_cell: bool,
    pub end_of_row: bool,
    pub end_of_file: bool,
    chunk_bytes_remaining: usize,
}

impl<R: Read> InternalCellReader<R> {
    fn new(input: Arc<Mutex<R>>) -> Self {
        Self {
            reader: input,
            end_of_cell: false,
            end_of_row: false,
            end_of_file: false,
            chunk_bytes_remaining: 0,
        }
    }
}

// The implementation of Read relies on the chunk_bytes_remaining in
// order to keep track of where the next frame header is expected. If
// a read doesn't consume an entire chunk, chunk_bytes_remaining is set
// to indicate that more bytes should be read directly from self.reader
// on the next call to read.
impl<R: Read> Read for InternalCellReader<R> {
    fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
        let mut bytes_read: usize = 0;
        if self.chunk_bytes_remaining > 0 {
            // finish reading the next chunk if we haven't already. don't
            // bother going on to the next chunk until this one has been
            // consumed entirely.
            //
            // TODO(performance): continue on to populate buf from the
            // rest of self.reader when there is space available.
            let to_copy = std::cmp::min(buf.len(), self.chunk_bytes_remaining);
            bytes_read = self.reader.lock().unwrap().read(&mut buf[..to_copy])?;
            self.chunk_bytes_remaining -= bytes_read;
            return Ok(bytes_read);
        } else if self.end_of_cell {
            return Ok(0);
        }

        // read header
        let mut header_buf = [0u8; 5];
        let mut header_bytes_read: usize = 0;
        while header_bytes_read < 5 {
            match self
                .reader
                .lock()
                .unwrap()
                .read(&mut header_buf[header_bytes_read..])
            {
                Ok(0) => return Ok(bytes_read),
                Ok(n) => header_bytes_read += n,
                Err(e) => return Err(e),
            }
        }

        let chunk_len = u32::from_le_bytes(header_buf[..4].try_into().map_err(|e| {
            std::io::Error::new(ErrorKind::Other, format!("decoding chunk length: {}", e))
        })?);
        let flags = header_buf[4];

        if flags & END_OF_CELL > 0 {
            self.end_of_cell = true;
        }
        if flags & END_OF_ROW > 0 {
            self.end_of_row = true;
        }
        if flags & END_OF_FILE > 0 {
            self.end_of_file = true;
        }

        // need to read buf.len() - bytes_read more bytes
        let to_read: usize = std::cmp::min(buf.len() - bytes_read, chunk_len as usize);

        // read chunk and update chunk_bytes_remaining
        self.reader
            .lock()
            .unwrap()
            .read_exact(&mut buf[bytes_read..bytes_read + to_read])
            .map_err(|e| std::io::Error::other(format!("partial read of next chunk: {}", e)))?;
        self.chunk_bytes_remaining = chunk_len as usize - to_read;
        Ok(bytes_read + to_read)
    }
}

// CellDecoder is responsible for decoding the cells in a row.
pub struct CellDecoder<R, P>
where
    R: Read + Send + 'static,
    P: PolicyEnforcer + 'static,
{
    reader: Arc<Mutex<R>>,
    redacting_reader:
        Option<Arc<Mutex<RedactingReader<InternalCellReader<R>, ColumnPolicyEnforcer<P>>>>>,
    enforcer: Option<Arc<Mutex<P>>>,
    redact_tags: Vec<CapsuleTag>,
    cell_reader: Option<Arc<Mutex<InternalCellReader<R>>>>,
    column: usize,
    pub end_of_file: bool,
    pub filtered_records: usize,
    pub allowed_records: usize,
    is_deny_record: bool,
    span_tags: Vec<Vec<SpanTag>>,
}

impl<R: Read + Send + 'static, P: PolicyEnforcer + 'static> CellDecoder<R, P> {
    pub fn new(
        reader: Arc<Mutex<R>>,
        enforcer: Option<Arc<Mutex<P>>>,
        redact_tags: Vec<CapsuleTag>,
    ) -> Result<Self, CapsuleError> {
        Ok(Self {
            reader,
            enforcer,
            redact_tags,
            redacting_reader: None,
            cell_reader: None,
            column: 0,
            end_of_file: false,
            filtered_records: 0,
            allowed_records: 0,
            is_deny_record: false,
            span_tags: Vec::new(),
        })
    }

    fn consume_row(&mut self) -> Result<(), CapsuleError> {
        if self.cell_reader.is_some() {
            let mut cell_reader = self.cell_reader.as_mut().unwrap().lock().unwrap();
            while !cell_reader.end_of_row {
                if !cell_reader.end_of_cell {
                    let mut remainder: Vec<u8> = Vec::new();
                    cell_reader.read_to_end(&mut remainder).map_err(|e| {
                        CapsuleError::Generic(format!("reading to end of cell: {}", e))
                    })?;
                }
                if cell_reader.end_of_file {
                    self.end_of_file = true;
                    return Ok(());
                }
                *cell_reader = InternalCellReader::new(self.reader.clone());
            }

            if cell_reader.end_of_file {
                self.end_of_file = true;
            }
        }
        Ok(())
    }
}

impl<R: Read + Send + 'static, P: PolicyEnforcer + 'static> CellIterator for CellDecoder<R, P> {
    fn span_tags(&self) -> Vec<Vec<SpanTag>> {
        self.span_tags.clone()
    }

    fn is_deny_record(&self) -> bool {
        if self.redacting_reader.is_some() {
            return self
                .redacting_reader
                .as_ref()
                .unwrap()
                .lock()
                .unwrap()
                .is_deny_record;
        }
        false
    }

    fn next_cell(&mut self) -> Result<Box<dyn Read + Send + 'static>, CapsuleError> {
        // clean up and ensure we are at the beginning of the next cell
        if self.cell_reader.is_some() {
            let mut cell_reader = self.cell_reader.as_mut().unwrap().lock().unwrap();

            // collect deny capsule, deny record, and span tags
            if self.redacting_reader.is_some() {
                let mut redacting_reader = self.redacting_reader.as_mut().unwrap().lock().unwrap();
                if redacting_reader.is_deny_capsule {
                    self.is_deny_record = true;
                    return Err(CapsuleError::CapsuleAccessDeniedByPolicy);
                } else if redacting_reader.is_deny_record {
                    self.is_deny_record = true;
                    self.filtered_records += 1;
                    // advance reader to end of this row
                    drop(redacting_reader);
                    drop(cell_reader);
                    self.consume_row()?;
                    return Err(CapsuleError::RowAccessDeniedByPolicy);
                } else {
                    self.allowed_records += 1;
                    self.span_tags
                        .push(std::mem::take(&mut redacting_reader.span_tags));
                }
            }

            // advance to next column
            self.column += 1;

            // read to end of the last cell if the caller didn't
            if !cell_reader.end_of_cell {
                let mut remainder: Vec<u8> = Vec::new();
                cell_reader
                    .read_to_end(&mut remainder)
                    .map_err(|e| CapsuleError::Generic(format!("reading to end of cell: {}", e)))?;
            }
            if cell_reader.end_of_file {
                self.end_of_file = true;
                return Err(CapsuleError::EndOfCapsule);
            }
            if cell_reader.end_of_row {
                return Err(CapsuleError::EndOfRow);
            }
        }

        // create a new internal cell reader for this cell
        self.cell_reader = Some(Arc::new(Mutex::new(InternalCellReader::new(
            self.reader.clone(),
        ))));

        // redaction, if enabled
        if self.enforcer.is_some() {
            self.redacting_reader = Some(Arc::new(Mutex::new(RedactingReader::<
                InternalCellReader<R>,
                ColumnPolicyEnforcer<P>,
            >::new(
                self.cell_reader.clone().unwrap().clone(),
                ColumnPolicyEnforcer {
                    enforcer: self.enforcer.clone().unwrap().clone(),
                    column: self.column,
                },
                self.redact_tags.clone(),
            ))));

            Ok(Box::new(MutexReader {
                reader: self.redacting_reader.clone().unwrap().clone(),
            }))
        } else {
            Ok(Box::new(MutexReader {
                reader: self.cell_reader.clone().unwrap().clone(),
            }))
        }
    }

    fn cleanup(&mut self) -> Result<(), CapsuleError> {
        // This function is called at the end of a row to clean up in case
        // the read did not end up getting called due to an error or other
        // early exit.
        if self.cell_reader.is_some() {
            let cell_reader = self.cell_reader.as_mut().unwrap().lock().unwrap();

            // collect deny capsule, deny record, and span tags
            if self.redacting_reader.is_some() {
                let redacting_reader = self.redacting_reader.as_mut().unwrap().lock().unwrap();
                if redacting_reader.is_deny_capsule {
                    self.is_deny_record = true;
                    self.end_of_file = true;
                } else if redacting_reader.is_deny_record {
                    self.is_deny_record = true;
                    self.filtered_records += 1;
                    // advance reader to end of this row
                    drop(redacting_reader);
                    drop(cell_reader);
                    self.consume_row()?;
                }
            }
        }
        Ok(())
    }
}