use crate::capsule::common::{
CapsuleError, CapsuleTag, Column, PolicyDecision, RowReader, SpanTag,
};
use crate::capsule::framer::CellDecoder;
use crate::capsule::policy_enforcer::PolicyEnforcer;
use crate::capsule::{CellIterator, RowIterator};
use crate::session::hook_processor::HookProcessor;
use crate::session::policy_engine::PolicyEngine;
use crate::session::{DataHookInvoker, RUNTIME};
use serde_tuple::{Deserialize_tuple, Serialize_tuple};
use std::collections::{HashMap, VecDeque};
use std::io::{Error, Read, Seek, Write};
use std::marker::{Send, Sync};
use std::sync::{Arc, Mutex, RwLock};
use std::{mem, ops::DerefMut};
use tokio::task::JoinHandle;
const CELL_CLASSIFIER_CHUNK_SIZE: usize = 16 * 1024;
const CELL_CLASSIFIER_OVERLAP: usize = 256;
const READ_AHEAD_SIZE: usize = CELL_CLASSIFIER_CHUNK_SIZE * 4;
const END_OF_CELL: u8 = 1 << 0;
const END_OF_ROW: u8 = 1 << 1;
const END_OF_FILE: u8 = 1 << 2;
#[derive(Deserialize_tuple, Serialize_tuple, Debug)]
pub struct ClassifyingReaderHeader {
pub length: usize,
pub tags: Vec<SpanTag>,
}
#[derive(Clone, Debug)]
pub struct SpanPolicyDecision {
pub start: usize,
pub end: usize,
pub decision: PolicyDecision,
}
trait Resolver {
fn resolve(&mut self) -> Result<Vec<CellMeta>, Error>;
fn is_finished(&self) -> bool;
}
struct ClassificationResolver {
handle: JoinHandle<Result<Vec<CellMeta>, Error>>,
}
impl Resolver for ClassificationResolver {
fn resolve(&mut self) -> Result<Vec<CellMeta>, Error> {
RUNTIME.block_on(&mut self.handle).map_err(|e| {
Error::new(
std::io::ErrorKind::Other,
format!("failed to join classification result: {}", e),
)
})?
}
fn is_finished(&self) -> bool {
self.handle.is_finished()
}
}
struct SkipClassificationResolver {
result: Vec<CellMeta>,
}
impl Resolver for SkipClassificationResolver {
fn resolve(&mut self) -> Result<Vec<CellMeta>, Error> {
Ok(mem::take(&mut self.result))
}
fn is_finished(&self) -> bool {
true
}
}
pub trait EnforcePolicy {
fn enforce(
&mut self,
span_tags: &[SpanTag],
redact_tags: &[CapsuleTag],
data: &[u8],
) -> Result<(PolicyDecision, Vec<SpanPolicyDecision>), CapsuleError>;
}
pub struct RedactingReader<R: Read, E: EnforcePolicy> {
input: Arc<Mutex<R>>,
enforcer: E,
redact_tags: Vec<CapsuleTag>,
current_header: ClassifyingReaderHeader,
current_chunk: Vec<u8>,
current_chunk_offset: usize,
current_chunk_len: usize,
current_chunk_decisions: Vec<SpanPolicyDecision>,
current_decision_index: usize,
redact_token: Vec<u8>,
redact_token_offset: usize,
previously_redacted: bool,
pub is_deny_record: bool,
pub is_deny_capsule: bool,
pub span_tags: Vec<SpanTag>,
capsule_byte_offset: usize,
}
impl<R: Read, E: EnforcePolicy> RedactingReader<R, E> {
pub fn new(input: Arc<Mutex<R>>, enforcer: E, redact_tags: Vec<CapsuleTag>) -> Self {
Self {
input,
enforcer,
redact_tags,
current_header: ClassifyingReaderHeader {
length: 0,
tags: Vec::new(),
},
current_chunk: vec![0u8; CELL_CLASSIFIER_CHUNK_SIZE],
current_chunk_offset: 0,
current_chunk_len: 0,
current_chunk_decisions: Vec::new(),
current_decision_index: 0,
redact_token: "{redacted}".as_bytes().to_vec(),
redact_token_offset: 0,
previously_redacted: false,
is_deny_record: false,
is_deny_capsule: false,
span_tags: Vec::new(),
capsule_byte_offset: 0,
}
}
}
impl<R: Read, E: EnforcePolicy> Read for RedactingReader<R, E> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
let mut bytes_read: usize = 0;
if self.redact_token_offset > 0 {
let to_copy = std::cmp::min(
self.redact_token.len() - self.redact_token_offset,
buf.len(),
);
buf[..to_copy].copy_from_slice(
&self.redact_token[self.redact_token_offset..self.redact_token_offset + to_copy],
);
self.redact_token_offset =
(self.redact_token_offset + to_copy) % self.redact_token.len();
bytes_read += to_copy;
}
if self.current_chunk_len == 0 {
let mut next_byte = [0u8; 1];
if self.input.lock().unwrap().read(&mut next_byte)? == 0 {
return Ok(bytes_read);
}
self.current_header = ciborium::from_reader(
std::io::Cursor::new(next_byte).chain(self.input.lock().unwrap().deref_mut()),
)
.map_err(|e| {
Error::new(
std::io::ErrorKind::Other,
format!("reading header from input: {}", e),
)
})?;
for tag in &self.current_header.tags {
let mut cloned = tag.clone();
cloned.start += self.capsule_byte_offset;
cloned.end += self.capsule_byte_offset;
let mut tag_subsumed = false;
for t in self.span_tags.iter_mut() {
if t.tag == cloned.tag && t.start <= cloned.start && (t.end + 1) >= cloned.start
{
if t.end < cloned.end {
t.end = cloned.end;
}
tag_subsumed = true;
break;
}
}
if !tag_subsumed {
self.span_tags.push(cloned);
}
}
self.capsule_byte_offset += self.current_header.length;
self.current_chunk_len = self.current_header.length;
self.current_chunk_offset = 0;
self.input
.lock()
.unwrap()
.read_exact(&mut self.current_chunk[..self.current_chunk_len])
.map_err(|e| std::io::Error::other(format!("reading next chunk: {}", e)))?;
let (decision, span_decisions) = self
.enforcer
.enforce(
&self.current_header.tags,
&self.redact_tags,
&self.current_chunk[..self.current_chunk_len],
)
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("determining policy decision: {}", e),
)
})?;
if decision == PolicyDecision::DenyRecord {
self.is_deny_record = true;
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"record access denied by policy",
));
} else if decision == PolicyDecision::DenyCapsule {
self.is_deny_capsule = true;
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"capsule access denied by policy",
));
}
self.current_chunk_decisions = span_decisions;
}
while self.current_decision_index < self.current_chunk_decisions.len()
&& bytes_read < buf.len()
{
let next_decision = &self.current_chunk_decisions[self.current_decision_index];
let mut to_copy = std::cmp::min(
buf.len() - bytes_read,
next_decision.start - self.current_chunk_offset,
);
if to_copy > 0 {
self.previously_redacted = false;
buf[bytes_read..bytes_read + to_copy].copy_from_slice(
&self.current_chunk
[self.current_chunk_offset..self.current_chunk_offset + to_copy],
);
bytes_read += to_copy;
self.current_chunk_len -= to_copy;
self.current_chunk_offset += to_copy;
}
if bytes_read >= buf.len() {
return Ok(bytes_read);
}
match next_decision.decision {
PolicyDecision::DenyRecord => {
self.is_deny_record = true;
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"record access denied by policy",
));
}
PolicyDecision::DenyCapsule => {
self.is_deny_capsule = true;
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"capsule access denied by policy",
));
}
PolicyDecision::Redact => {
self.current_chunk_len -= next_decision.end - next_decision.start;
self.current_chunk_offset += next_decision.end - next_decision.start;
if !self.previously_redacted {
self.previously_redacted = true;
to_copy = std::cmp::min(self.redact_token.len(), buf.len() - bytes_read);
buf[bytes_read..bytes_read + to_copy]
.copy_from_slice(&self.redact_token[..to_copy]);
bytes_read += to_copy;
if to_copy < self.redact_token.len() {
self.redact_token_offset = to_copy;
}
}
}
_ => {
self.previously_redacted = false;
to_copy = std::cmp::min(
buf.len() - bytes_read,
next_decision.end - next_decision.start,
);
buf[bytes_read..bytes_read + to_copy].copy_from_slice(
&self.current_chunk
[self.current_chunk_offset..self.current_chunk_offset + to_copy],
);
bytes_read += to_copy;
self.current_chunk_len -= to_copy;
self.current_chunk_offset += to_copy;
}
}
self.current_decision_index += 1;
}
let to_copy = std::cmp::min(self.current_chunk_len, buf.len() - bytes_read);
if to_copy > 0 {
buf[bytes_read..bytes_read + to_copy].copy_from_slice(
&self.current_chunk[self.current_chunk_offset..self.current_chunk_offset + to_copy],
);
bytes_read += to_copy;
self.current_chunk_offset += to_copy;
self.current_chunk_len -= to_copy;
}
Ok(bytes_read)
}
}
pub trait ProcessHooks: Send + Sync {
fn get_span_tags(&self, content: &[u8], path: &str) -> Result<Vec<SpanTag>, CapsuleError>;
fn has_classification_hooks(&self) -> bool;
}
#[derive(Clone)]
struct CellMeta {
row: usize,
col: usize,
data: Vec<u8>,
length: usize,
name: String,
skip_classification: bool,
tags: Vec<SpanTag>,
flags: u8,
}
impl Default for CellMeta {
fn default() -> Self {
Self {
row: 0,
col: 0,
data: vec![0; CELL_CLASSIFIER_CHUNK_SIZE + CELL_CLASSIFIER_OVERLAP],
length: 0,
name: "".to_string(),
skip_classification: false,
tags: vec![],
flags: 0,
}
}
}
pub struct ClassifyingReader<I: ProcessHooks> {
rows: Vec<RowReader>,
ra_cell_col: usize,
ra_cell_row: usize,
current_cell_idx: usize,
columns: Vec<Column>,
input_bytes_read: usize,
processors: Vec<Arc<RwLock<I>>>,
next_chunk_tags: Vec<SpanTag>,
maximum_in_flight: usize,
in_flight: VecDeque<Box<dyn Resolver + Send>>,
read_ahead_buffer: Vec<u8>,
eof: bool,
overlap: Vec<u8>,
}
impl<I: ProcessHooks + 'static> ClassifyingReader<I> {
pub fn new(
rows: Vec<RowReader>,
columns: Vec<Column>,
processors: Vec<Arc<RwLock<I>>>,
) -> Self {
Self {
rows,
columns,
ra_cell_col: 0,
ra_cell_row: 0,
input_bytes_read: 0,
processors,
next_chunk_tags: Vec::new(),
maximum_in_flight: 32,
in_flight: VecDeque::new(),
read_ahead_buffer: Vec::new(),
eof: false,
overlap: Vec::new(),
current_cell_idx: 0,
}
}
fn increment_cell(&mut self) -> u8 {
let mut flags: u8 = 0;
self.ra_cell_col += 1;
if self.ra_cell_col >= self.columns.len() {
self.ra_cell_row += 1;
self.ra_cell_col = 0;
flags |= END_OF_ROW;
}
if self.ra_cell_row >= self.rows.len() {
flags |= END_OF_FILE;
self.eof = true;
}
self.current_cell_idx = 0;
flags
}
fn add_user_tags(&mut self, user_tags: &Vec<SpanTag>, bytes_read: usize) -> Vec<SpanTag> {
let mut result = Vec::new();
let current_cell_start = self.current_cell_idx;
let current_cell_end = self.current_cell_idx + bytes_read;
let current_cell_size = current_cell_end - current_cell_start;
for tag in user_tags {
if tag.start < current_cell_end && tag.end > current_cell_start {
let mut chunk_tag = tag.clone();
chunk_tag.start = match chunk_tag.start < current_cell_start {
true => 0, false => chunk_tag.start - current_cell_start,
};
chunk_tag.end = match chunk_tag.end > current_cell_end {
true => current_cell_size,
false => current_cell_size - (current_cell_end - chunk_tag.end),
};
result.push(chunk_tag)
}
}
result
}
fn build_request_batch(&mut self) -> Result<Vec<CellMeta>, Error> {
let mut bytes_read: usize = 0;
let mut cells_in_flight: Vec<CellMeta> = Vec::new();
let target_chunk_size = CELL_CLASSIFIER_CHUNK_SIZE + CELL_CLASSIFIER_OVERLAP;
let mut current_cell = CellMeta {
row: self.ra_cell_row,
col: self.ra_cell_col,
data: vec![0; CELL_CLASSIFIER_CHUNK_SIZE + CELL_CLASSIFIER_OVERLAP],
length: self.overlap.len(),
skip_classification: self.columns[self.ra_cell_col].skip_classification,
name: self.columns[self.ra_cell_col].name.clone(),
tags: vec![],
flags: 0,
};
if current_cell.length > 0 {
bytes_read += current_cell.length;
current_cell.data[..bytes_read].copy_from_slice(&self.overlap[..]);
self.overlap = Vec::new();
}
while bytes_read < target_chunk_size && !self.eof {
match self.rows[self.ra_cell_row].cells[self.ra_cell_col]
.read(&mut current_cell.data[current_cell.length..])
{
Ok(0) => {
current_cell.data.truncate(current_cell.length);
current_cell.tags.append(
&mut self.add_user_tags(
&self.rows[self.ra_cell_row].cells[self.ra_cell_col]
.tags
.clone(),
current_cell.length,
),
);
current_cell.tags.extend(capsule_tag_to_span_tags(
self.columns[self.ra_cell_col].tags.clone(),
0,
current_cell.length,
));
current_cell.tags.extend(capsule_tag_to_span_tags(
self.rows[self.ra_cell_row].tags.clone(),
0,
current_cell.length,
));
self.current_cell_idx += current_cell.length;
if current_cell.length > CELL_CLASSIFIER_CHUNK_SIZE {
self.overlap
.extend_from_slice(¤t_cell.data[CELL_CLASSIFIER_CHUNK_SIZE..]);
self.current_cell_idx -= current_cell.length - CELL_CLASSIFIER_CHUNK_SIZE;
} else {
current_cell.flags |= END_OF_CELL;
current_cell.flags |= self.increment_cell();
}
cells_in_flight.push(mem::take(&mut current_cell));
if bytes_read > CELL_CLASSIFIER_CHUNK_SIZE
|| current_cell.length > CELL_CLASSIFIER_CHUNK_SIZE
|| self.eof
{
break;
}
current_cell.col = self.ra_cell_col;
current_cell.row = self.ra_cell_row;
current_cell.skip_classification =
self.columns[current_cell.col].skip_classification;
current_cell.name = self.columns[current_cell.col].name.clone();
current_cell.tags = self.rows[self.ra_cell_row].cells[self.ra_cell_col]
.tags
.clone();
}
Ok(n) => {
current_cell.length += n;
bytes_read += n;
if bytes_read >= target_chunk_size {
current_cell.data.truncate(current_cell.length);
current_cell.tags.append(
&mut self.add_user_tags(
&self.rows[self.ra_cell_row].cells[self.ra_cell_col]
.tags
.clone(),
current_cell.length,
),
);
current_cell.tags.extend(capsule_tag_to_span_tags(
self.columns[self.ra_cell_col].tags.clone(),
0,
current_cell.length,
));
current_cell.tags.extend(capsule_tag_to_span_tags(
self.rows[self.ra_cell_row].tags.clone(),
0,
current_cell.length,
));
self.current_cell_idx += current_cell.length;
if current_cell.length > CELL_CLASSIFIER_CHUNK_SIZE {
self.overlap.extend_from_slice(
¤t_cell.data[CELL_CLASSIFIER_CHUNK_SIZE..],
);
self.current_cell_idx -=
current_cell.length - CELL_CLASSIFIER_CHUNK_SIZE;
}
cells_in_flight.push(current_cell);
break;
}
}
Err(e) => return Err(e),
}
}
Ok(cells_in_flight)
}
fn sow_with_classification(&mut self, mut cell: CellMeta) -> Box<dyn Resolver + Send> {
let processor = Arc::clone(&self.processors[cell.col]);
let has_hooks = processor.read().unwrap().has_classification_hooks();
match has_hooks {
false => self.sow_without_classification(cell),
true => Box::new(ClassificationResolver {
handle: RUNTIME.spawn_blocking(move || {
let data_clone = cell.data.clone();
let name_clone = cell.name.clone();
let tags = processor
.read()
.unwrap()
.get_span_tags(data_clone.as_slice(), name_clone.as_str())
.map_err(|e| {
Error::new(
std::io::ErrorKind::Other,
format!("getting span tags: {}", e),
)
})?;
cell.tags.extend(tags);
if cell.length > CELL_CLASSIFIER_CHUNK_SIZE {
cell.length = CELL_CLASSIFIER_CHUNK_SIZE;
cell.data.truncate(CELL_CLASSIFIER_CHUNK_SIZE);
}
Ok(vec![mem::take(&mut cell)])
}),
}),
}
}
fn sow_without_classification(&mut self, mut cell: CellMeta) -> Box<dyn Resolver + Send> {
if cell.length > CELL_CLASSIFIER_CHUNK_SIZE {
cell.length = CELL_CLASSIFIER_CHUNK_SIZE;
cell.data.truncate(CELL_CLASSIFIER_CHUNK_SIZE);
}
Box::new(SkipClassificationResolver {
result: vec![mem::take(&mut cell)],
})
}
fn send_request(&mut self, cell: CellMeta) {
match cell.skip_classification {
true => {
let hook = self.sow_without_classification(cell);
self.in_flight.push_back(hook);
}
false => {
let hook = self.sow_with_classification(cell);
self.in_flight.push_back(hook);
}
}
}
fn send_requests(&mut self) -> Result<(), std::io::Error> {
while self.in_flight.len() < self.maximum_in_flight
&& !self.eof
&& self.read_ahead_buffer.len() < READ_AHEAD_SIZE
{
let batch = self.build_request_batch()?;
for cell in batch {
self.send_request(cell);
}
}
if self.eof && !self.overlap.is_empty() {
let overlap_len = self.overlap.len();
let last_row_idx = self.rows.len() - 1;
let last_col_idx = self.columns.len() - 1;
let mut tail = CellMeta {
row: last_row_idx,
col: last_col_idx,
data: mem::take(&mut self.overlap),
length: overlap_len,
name: self.columns[last_col_idx].name.clone(),
skip_classification: self.columns[last_col_idx].skip_classification,
tags: vec![],
flags: END_OF_CELL | END_OF_ROW | END_OF_FILE,
};
tail.tags.append(&mut self.add_user_tags(
&self.rows[last_row_idx].cells[last_col_idx].tags.clone(),
overlap_len,
));
tail.tags.extend(capsule_tag_to_span_tags(
self.columns[last_col_idx].tags.clone(),
0,
overlap_len,
));
tail.tags.extend(capsule_tag_to_span_tags(
self.rows[last_row_idx].tags.clone(),
0,
overlap_len,
));
self.send_request(tail);
}
Ok(())
}
fn collect_tags(&mut self, mut span_tags: Vec<SpanTag>, data_len: usize) -> Vec<SpanTag> {
span_tags.extend(std::mem::take(&mut self.next_chunk_tags));
let mut tags: Vec<SpanTag> = Vec::new();
for tag in span_tags.into_iter() {
if tag.start >= CELL_CLASSIFIER_CHUNK_SIZE {
} else if tag.start < CELL_CLASSIFIER_CHUNK_SIZE && tag.end > CELL_CLASSIFIER_CHUNK_SIZE
{
tags.push(SpanTag {
start: tag.start,
end: CELL_CLASSIFIER_CHUNK_SIZE,
tag: tag.tag.clone(),
});
self.next_chunk_tags.push(SpanTag {
start: 0,
end: tag.end - CELL_CLASSIFIER_CHUNK_SIZE,
tag: tag.tag,
});
} else {
tags.push(SpanTag {
start: tag.start,
end: std::cmp::min(tag.end, data_len),
tag: tag.tag,
});
}
}
self.input_bytes_read += data_len;
tags
}
fn pack_reaped_response(&mut self, cells: Vec<CellMeta>) -> Result<(), Error> {
for mut cell in cells {
let mut length = cell.data.len();
let mut header = Vec::new();
ciborium::into_writer(
&ClassifyingReaderHeader {
length,
tags: self.collect_tags(mem::take(&mut cell.tags), length),
},
&mut header,
)
.map_err(|e| Error::other(format!("failed to serialize chunk header: {}", e)))?;
self.read_ahead_buffer
.extend_from_slice(&((header.len() + length) as u32).to_le_bytes());
self.read_ahead_buffer.push(cell.flags);
self.read_ahead_buffer.append(&mut header);
self.read_ahead_buffer.extend(cell.data.clone());
}
Ok(())
}
fn reap_in_flight(&mut self) -> Result<(), Error> {
let mut first = true;
while !self.in_flight.is_empty() && self.read_ahead_buffer.len() < READ_AHEAD_SIZE >> 2 {
if !first && !self.in_flight.front().unwrap().is_finished() {
break;
}
first = false;
let data = self.in_flight.pop_front().unwrap().resolve()?;
self.pack_reaped_response(data)?;
}
Ok(())
}
}
impl<I: ProcessHooks + 'static> Read for ClassifyingReader<I> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
if self.in_flight.is_empty() {
self.send_requests()?;
}
if self.eof && self.in_flight.is_empty() && self.read_ahead_buffer.is_empty() {
return Ok(0);
}
self.reap_in_flight()?;
self.send_requests()?;
let bytes_read = std::cmp::min(buf.len(), self.read_ahead_buffer.len());
let tail = self.read_ahead_buffer.split_off(bytes_read);
buf[..bytes_read].copy_from_slice(&self.read_ahead_buffer[..]);
self.read_ahead_buffer = tail;
Ok(bytes_read)
}
}
fn capsule_tag_to_span_tags(
capsule_tags: Vec<CapsuleTag>,
start: usize,
end: usize,
) -> Vec<SpanTag> {
let mut span_tags = Vec::new();
for tag in capsule_tags {
span_tags.push(SpanTag { tag, start, end });
}
span_tags
}
pub struct ClassifyAndRedact<P: PolicyEnforcer + 'static> {
columns: Vec<Column>,
capsule_tags: Vec<CapsuleTag>,
enforcer: Arc<Mutex<P>>,
buffer_file: Arc<Mutex<std::fs::File>>,
}
impl<P: PolicyEnforcer + 'static> ClassifyAndRedact<P> {
pub fn new<I: DataHookInvoker + 'static>(
columns: Vec<Column>,
mut capsule_tags: Vec<CapsuleTag>,
hook_processors: Vec<Arc<RwLock<HookProcessor<I>>>>,
read_parameters: HashMap<String, String>,
policy_engine: Arc<Mutex<PolicyEngine>>,
data: Vec<RowReader>,
) -> Result<Self, CapsuleError> {
let mut reader = ClassifyingReader::new(data, columns.clone(), hook_processors.clone());
let mut buffer_file = tempfile::tempfile().map_err(|e| {
CapsuleError::Generic(format!("creating temporary file for buffering: {}", e))
})?;
let _ = std::io::copy(&mut reader, &mut buffer_file)
.map_err(|e| CapsuleError::Generic(format!("classifying input data: {}", e)))?;
buffer_file
.flush()
.map_err(|e| CapsuleError::Generic(format!("flushing buffer file: {}", e)))?;
let _ = buffer_file.seek(std::io::SeekFrom::Start(0)).map_err(|e| {
CapsuleError::Generic(format!("seeking to beginning of buffer file: {}", e))
})?;
for processor in &hook_processors {
capsule_tags.append(
&mut processor
.read()
.unwrap()
.collated_capsule_tags
.lock()
.unwrap()
.iter()
.map(|tag| CapsuleTag::from_tag(tag))
.collect::<Result<Vec<CapsuleTag>, CapsuleError>>()?,
);
}
let mut unique_capsule_tags: Vec<CapsuleTag> = Vec::new();
for capsule_tag in capsule_tags.drain(..) {
if !unique_capsule_tags.iter().any(|t: &CapsuleTag| {
t.name == capsule_tag.name
&& t.source == capsule_tag.source
&& t.hook_version == capsule_tag.hook_version
}) {
unique_capsule_tags.push(capsule_tag);
}
}
Ok(Self {
columns: columns.clone(),
capsule_tags: unique_capsule_tags.clone(),
enforcer: Arc::new(Mutex::new(P::init_enforcer(
Some(policy_engine),
unique_capsule_tags,
columns.iter().map(|column| column.tags.clone()).collect(),
read_parameters,
HashMap::new(), )?)),
buffer_file: Arc::new(Mutex::new(buffer_file)),
})
}
}
impl<P: PolicyEnforcer + 'static> RowIterator for ClassifyAndRedact<P> {
fn for_each_row(
&mut self,
redact_tags: &[CapsuleTag],
f: &mut dyn FnMut(&mut dyn CellIterator) -> Result<(), CapsuleError>,
) -> Result<(), CapsuleError> {
self.buffer_file
.lock()
.unwrap()
.seek(std::io::SeekFrom::Start(0))
.map_err(|e| {
CapsuleError::Generic(format!("seeking to beginning of buffer file: {}", e))
})?;
self.for_each_row_default(redact_tags, f)
}
fn next_row(
&mut self,
redact_tags: Vec<CapsuleTag>,
) -> Result<Box<dyn CellIterator + 'static>, CapsuleError> {
{
let mut buffer_file = self.buffer_file.lock().unwrap();
let current_position = buffer_file
.seek(std::io::SeekFrom::Current(0))
.map_err(|e| CapsuleError::Generic(format!("seeking file: {}", e)))?;
let mut buffer = [0u8; 1];
let bytes_read = buffer_file
.read(&mut buffer)
.map_err(|e| CapsuleError::Generic(format!("reading next byte: {}", e)))?;
if bytes_read == 0 {
return Err(CapsuleError::EndOfCapsule);
}
buffer_file
.seek(std::io::SeekFrom::Start(current_position))
.map_err(|e| CapsuleError::Generic(format!("seeking file: {}", e)))?;
}
Ok(Box::new(CellDecoder::new(
self.buffer_file.clone(),
Some(self.enforcer.clone()),
redact_tags,
)?))
}
fn domain_id(&self) -> String {
"unknown".to_string()
}
fn extra_data(&self) -> String {
"".to_string()
}
fn capsule_ids(&self) -> Vec<String> {
Vec::new()
}
fn capsule_tags(&self) -> Vec<CapsuleTag> {
self.capsule_tags.clone()
}
fn columns(&self) -> Vec<Column> {
self.columns.clone()
}
fn open_failures(&self) -> Vec<String> {
Vec::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::capsule::classifier::{ClassifyingReader, ClassifyingReaderHeader, ProcessHooks};
use crate::capsule::common::{CellReader, Column, TagType, KEY_SIZE, NONCE_SIZE};
use crate::capsule::policy_enforcer::DefaultPolicyEnforcer;
use crate::session::session::SessionError;
use crate::session::DataHookInvoker;
use antimatter_api::apis::configuration;
use antimatter_api::models::{DataTaggingHookInput, DataTaggingHookResponse};
use rand::distributions::{Alphanumeric, DistString};
use rand::RngCore;
use std::io::Cursor;
use std::sync::RwLock;
struct MockHookInvoker {
pub results: Vec<Result<DataTaggingHookResponse, SessionError>>,
pub current_index: RwLock<usize>,
}
impl DataHookInvoker for MockHookInvoker {
fn invoke_hook(
&self,
_configuration: &configuration::Configuration,
_domain_id: &str,
_write_context_name: Option<&str>,
_hook_name: &str,
_data_tagging_hook_input: DataTaggingHookInput,
) -> Result<DataTaggingHookResponse, SessionError> {
if *self.current_index.read().unwrap() >= self.results.len() {
return Ok(DataTaggingHookResponse {
records: vec![],
version: "0.0.0".to_string(),
});
}
*self.current_index.write().unwrap() += 1;
match &self.results[*self.current_index.read().unwrap() - 1] {
Ok(resp) => Ok(resp.clone()),
Err(e) => Err(e.clone()),
}
}
}
struct MockHookProcessor {
pub results: Vec<Result<Vec<SpanTag>, CapsuleError>>,
pub current_index: RwLock<usize>,
}
impl ProcessHooks for MockHookProcessor {
fn get_span_tags(
&self,
_content: &[u8],
_path: &str,
) -> Result<Vec<SpanTag>, CapsuleError> {
if *self.current_index.read().unwrap() >= self.results.len() {
return Ok(Vec::new());
}
*self.current_index.write().unwrap() += 1;
match &self.results[*self.current_index.read().unwrap() - 1] {
Ok(tags) => Ok(tags.clone()),
Err(e) => Err(e.clone()),
}
}
fn has_classification_hooks(&self) -> bool {
true
}
}
struct MockPolicyEnforcer {
pub results: Vec<Result<(PolicyDecision, Vec<SpanPolicyDecision>), CapsuleError>>,
pub current_index: usize,
}
impl EnforcePolicy for MockPolicyEnforcer {
fn enforce(
&mut self,
_span_tags: &[SpanTag],
_redact_tags: &[CapsuleTag],
_data: &[u8],
) -> Result<(PolicyDecision, Vec<SpanPolicyDecision>), CapsuleError> {
if self.current_index >= self.results.len() {
return Ok((PolicyDecision::Allow, Vec::new()));
}
self.current_index += 1;
match &self.results[self.current_index - 1] {
Ok((decision, span_decisions)) => Ok((decision.clone(), span_decisions.clone())),
Err(e) => Err(e.clone()),
}
}
}
fn decode_to_readers<R: Read + Send + 'static>(
reader: Arc<Mutex<R>>,
strip_classifier: bool,
) -> Vec<Vec<CellReader>> {
let mut decoded = Vec::<Vec<CellReader>>::new();
let mut done = false;
while !done {
let mut row_data = Vec::<CellReader>::new();
let mut decoder =
CellDecoder::<_, DefaultPolicyEnforcer>::new(reader.clone(), None, vec![])
.expect("failed to create decoder");
match decoder.for_each_cell(&mut |cell: &mut dyn Read| {
let mut data: Vec<u8> = Vec::new();
match strip_classifier {
true => {
let mut taker = cell.take(50);
let header: ClassifyingReaderHeader =
ciborium::from_reader(&mut taker).unwrap();
drop(taker);
data.resize(header.length, 0);
cell.read_exact(&mut data).unwrap();
}
false => {
let _ = cell.read_to_end(&mut data);
}
}
row_data.push(
CellReader::new(Vec::new(), Box::new(Cursor::new(data)))
.expect("failed to push row data"),
);
Ok(())
}) {
Err(CapsuleError::EndOfCapsule) => done = true,
Err(CapsuleError::EndOfRow) => {}
Err(e) => assert!(false, "unexpected error from for_each_cell: {}", e),
Ok(()) => {}
}
decoded.push(row_data);
drop(decoder);
}
decoded
}
fn simple_redaction_test_helper(
data: Vec<u8>,
enforcer_results: Vec<Result<(PolicyDecision, Vec<SpanPolicyDecision>), CapsuleError>>,
expected_result: &str,
) {
let mut redact_output = Vec::<u8>::new();
let processors = vec![Arc::new(RwLock::new(MockHookProcessor {
current_index: RwLock::new(0),
results: Vec::new(),
}))];
let rows = vec![RowReader {
tags: vec![],
cells: vec![CellReader {
data: Box::new(Cursor::new(data.clone())),
tags: vec![],
}],
}];
let columns = vec![Column {
name: "col 0".to_string(),
tags: vec![],
skip_classification: false,
}];
let classifying_reader = Arc::new(Mutex::new(ClassifyingReader::new(
rows, columns, processors,
)));
let mut decoded = decode_to_readers(classifying_reader, false);
let reader = mem::replace(
&mut decoded[0][0],
CellReader {
data: Box::new(std::io::empty()),
tags: vec![],
},
);
let mut redacting_reader = RedactingReader::new(
Arc::new(Mutex::new(reader)),
MockPolicyEnforcer {
current_index: 0,
results: enforcer_results,
},
Vec::new(),
);
redacting_reader
.read_to_end(&mut redact_output)
.expect("failed to read from redacting reader");
assert!(
String::from_utf8(redact_output).expect("failed to convert redacted output to string")
== expected_result,
"output does not match"
);
}
pub fn random_input() -> (Vec<RowReader>, Vec<Column>, Vec<Vec<Vec<u8>>>) {
use rand::Rng;
let mut rng = rand::thread_rng();
let rows = 1 + rng.gen::<usize>() % 50;
let cells = 1 + rng.gen::<usize>() % 50;
let mut input = Vec::<RowReader>::new();
let mut raw_input = Vec::<Vec<Vec<u8>>>::new();
let mut cols = Vec::new();
for _ in 0..rows {
let mut raw_input_row = Vec::<Vec<u8>>::new();
let mut row = Vec::<CellReader>::new();
for _ in 0..cells {
let data_size = 1 + rng.gen::<usize>() % 1024;
let data: Vec<u8> = (0..data_size).map(|_| rng.gen()).collect();
raw_input_row.push(data.clone());
row.push(
CellReader::new(Vec::new(), Box::new(Cursor::new(data)))
.expect("unexpected error pushing to row"),
);
}
raw_input.push(raw_input_row);
input.push(RowReader {
tags: vec![],
cells: row,
});
}
for _ in 0..cells {
cols.push(Column {
name: Alphanumeric.sample_string(&mut rand::thread_rng(), 16),
tags: vec![],
skip_classification: false,
})
}
(input, cols, raw_input)
}
fn check_equal(raw_input: &Vec<Vec<Vec<u8>>>, output: &mut Vec<Vec<CellReader>>) {
assert!(
output.len() == raw_input.len(),
"unexpected output length {} (!={})",
output.len(),
raw_input.len(),
);
for i in 0..raw_input.len() {
assert!(
output[i].len() == raw_input[i].len(),
"unexpected number of cells in row {} ({} != {})",
i,
output[i].len(),
raw_input[i].len(),
);
for j in 0..output[i].len() {
let mut data = Vec::<u8>::new();
let _ = output[i][j].data.read_to_end(&mut data);
assert!(
data.len() == raw_input[i][j].len(),
"length of data {} does not match raw input {} for index [{}][{}]",
data.len(),
raw_input[i][j].len(),
i,
j,
);
assert!(data == raw_input[i][j], "unexpected data in cell");
}
}
}
#[test]
fn table_encrypt_and_decrypt_decode_small_plaintext() {
use crate::capsule::streaming_aead::{DecryptingAEAD, EncryptingAEADReader};
let rows = vec![RowReader {
tags: vec![],
cells: vec![
CellReader::new(Vec::new(), Box::new(Cursor::new(vec![1, 2, 3, 4])))
.expect("failed to create CellReader"),
],
}];
let processors = vec![Arc::new(RwLock::new(MockHookProcessor {
current_index: RwLock::new(0),
results: Vec::new(),
}))];
let columns = vec![Column {
name: "col 0".to_string(),
tags: vec![],
skip_classification: false,
}];
let encrypted = Arc::new(Mutex::new(
EncryptingAEADReader::new(
[0u8; NONCE_SIZE],
&[0u8; KEY_SIZE],
ClassifyingReader::new(rows, columns, processors),
)
.expect("failed creating EncryptingAEAD"),
));
let mut decrypted = DecryptingAEAD::new(&[0u8; KEY_SIZE], encrypted)
.expect("failed creating DecryptingAEAD");
let mut decrypted_data: Vec<u8> = Vec::new();
decrypted
.read_to_end(&mut decrypted_data)
.expect("failed to read decrypted data");
let input_reader = Arc::new(Mutex::new(Cursor::new(decrypted_data)));
let mut decoded = Vec::<Vec<CellReader>>::new();
let mut done = false;
while !done {
let mut row_data = Vec::<CellReader>::new();
let mut decoder = CellDecoder::<_, DefaultPolicyEnforcer>::new(
input_reader.clone(),
None,
Vec::new(),
)
.expect("failed to create CellDecoder");
match decoder.for_each_cell(&mut |cell: &mut dyn Read| {
let mut taker = cell.take(50);
let header: ClassifyingReaderHeader = ciborium::from_reader(&mut taker).unwrap();
drop(taker);
let mut data: Vec<u8> = vec![0; header.length];
cell.read_exact(&mut data).unwrap();
row_data.push(
CellReader::new(Vec::new(), Box::new(Cursor::new(data)))
.expect("failed to push row data"),
);
Ok(())
}) {
Err(CapsuleError::EndOfCapsule) => done = true,
Err(CapsuleError::EndOfRow) => {}
Err(e) => assert!(false, "unexpected error from CellDecoder: {}", e),
Ok(()) => {}
}
decoded.push(row_data);
}
}
#[test]
fn table_encrypt_and_decrypt_decode() {
use crate::capsule::streaming_aead::{DecryptingAEAD, EncryptingAEADReader};
let (cells, columns, raw_input) = random_input();
let mut processors = Vec::new();
for _ in 0..columns.len() {
processors.push(Arc::new(RwLock::new(MockHookProcessor {
current_index: RwLock::new(0),
results: Vec::new(),
})))
}
let encrypted = Arc::new(Mutex::new(
EncryptingAEADReader::new(
[0u8; NONCE_SIZE],
&[0u8; KEY_SIZE],
ClassifyingReader::new(cells, columns, processors),
)
.expect("failed creating EncryptingAEAD"),
));
let decrypted = Arc::new(Mutex::new(
DecryptingAEAD::new(&[0u8; KEY_SIZE], encrypted)
.expect("failed creating DecryptingAEAD"),
));
let mut decoded = decode_to_readers(decrypted, true);
check_equal(&raw_input, &mut decoded);
}
#[test]
fn classify_and_decode() {
let (cells, columns, raw_input) = random_input();
let mut processors = Vec::new();
for _ in 0..columns.len() {
processors.push(Arc::new(RwLock::new(MockHookProcessor {
current_index: RwLock::new(0),
results: Vec::new(),
})))
}
let extra = Cursor::new(vec![1, 2, 3, 4]);
let reader = Arc::new(Mutex::new(
ClassifyingReader::new(cells, columns, processors).chain(extra),
));
let mut decoded = decode_to_readers(reader, true);
check_equal(&raw_input, &mut decoded);
}
#[test]
fn classifier_with_empty_cells() {
let rows = vec![
RowReader {
tags: vec![],
cells: vec![
CellReader::new(Vec::new(), Box::new(std::io::empty()))
.expect("failed to construct CellReader"),
CellReader::new(Vec::new(), Box::new(std::io::empty()))
.expect("failed to construct CellReader"),
],
},
RowReader {
tags: vec![],
cells: vec![
CellReader::new(Vec::new(), Box::new(Cursor::new(vec![1, 2, 3, 4])))
.expect("failed to construct CellReader"),
CellReader::new(Vec::new(), Box::new(std::io::empty()))
.expect("failed to construct CellReader"),
],
},
];
let columns = vec![
Column {
name: "a".to_string(),
tags: vec![],
skip_classification: false,
},
Column {
name: "b".to_string(),
tags: vec![],
skip_classification: false,
},
];
let mut processors = Vec::new();
for _ in 0..columns.len() {
processors.push(Arc::new(RwLock::new(MockHookProcessor {
current_index: RwLock::new(0),
results: Vec::new(),
})))
}
let reader = Arc::new(Mutex::new(ClassifyingReader::new(
rows, columns, processors,
)));
let mut decoded = Vec::<Vec<CellReader>>::new();
let mut done = false;
while !done {
let mut row_data = Vec::<CellReader>::new();
let mut decoder =
CellDecoder::<_, DefaultPolicyEnforcer>::new(reader.clone(), None, vec![])
.expect("failed to create decoder");
match decoder.for_each_cell(&mut |cell: &mut dyn Read| {
let mut taker = cell.take(50);
let header: ClassifyingReaderHeader = ciborium::from_reader(&mut taker).unwrap();
drop(taker);
let mut data: Vec<u8> = vec![0; header.length];
cell.read_exact(&mut data).unwrap();
row_data.push(
CellReader::new(Vec::new(), Box::new(Cursor::new(data)))
.expect("failed to push row data"),
);
Ok(())
}) {
Err(CapsuleError::EndOfCapsule) => done = true,
Err(CapsuleError::EndOfRow) => {}
Err(e) => assert!(false, "unexpected error from for_each_cell: {}", e),
Ok(()) => {}
}
decoded.push(row_data);
}
check_equal(
&vec![vec![vec![], vec![]], vec![vec![1, 2, 3, 4], vec![]]],
&mut decoded,
);
}
#[test]
fn simple_allow() {
simple_redaction_test_helper(
"this is some test input".as_bytes().to_vec(),
vec![Ok((
PolicyDecision::Allow,
vec![SpanPolicyDecision {
start: 13,
end: 16,
decision: PolicyDecision::Allow,
}],
))],
"this is some test input",
);
}
#[test]
fn redact_multiple_sections() {
simple_redaction_test_helper(
"this is some test input".as_bytes().to_vec(),
vec![Ok((
PolicyDecision::Allow,
vec![
SpanPolicyDecision {
start: 0,
end: 4,
decision: PolicyDecision::Redact,
},
SpanPolicyDecision {
start: 4,
end: 8,
decision: PolicyDecision::Allow,
},
SpanPolicyDecision {
start: 8,
end: 12,
decision: PolicyDecision::Redact,
},
SpanPolicyDecision {
start: 12,
end: 16,
decision: PolicyDecision::Tokenize,
},
SpanPolicyDecision {
start: 18,
end: 23,
decision: PolicyDecision::Redact,
},
],
))],
"{redacted} is {redacted} test {redacted}",
);
}
#[test]
fn simple_redact() {
simple_redaction_test_helper(
"this is some test input".as_bytes().to_vec(),
vec![Ok((
PolicyDecision::Allow,
vec![SpanPolicyDecision {
start: 13,
end: 17,
decision: PolicyDecision::Redact,
}],
))],
"this is some {redacted} input",
);
}
#[test]
fn redact_adjacent_sections() {
simple_redaction_test_helper(
"this is some test input".as_bytes().to_vec(),
vec![Ok((
PolicyDecision::Allow,
vec![
SpanPolicyDecision {
start: 0,
end: 4,
decision: PolicyDecision::Redact,
},
SpanPolicyDecision {
start: 4,
end: 7,
decision: PolicyDecision::Redact,
},
],
))],
"{redacted} some test input",
);
}
#[test]
fn classify_and_redact() {
use rand::Rng;
let mut rng = rand::thread_rng();
let a = rng.gen::<u8>() % 11;
let b = rng.gen::<usize>() % CELL_CLASSIFIER_CHUNK_SIZE;
let data_size: usize = (a as usize) * CELL_CLASSIFIER_CHUNK_SIZE + b;
let data: Vec<u8> = (0..data_size).map(|_| rng.gen_range(0x00..=0x7F)).collect();
let mut redact_output = Vec::<u8>::new();
let processors = vec![Arc::new(RwLock::new(MockHookProcessor {
results: Vec::new(),
current_index: RwLock::new(0),
}))];
let rows: Vec<RowReader> = vec![RowReader {
cells: vec![CellReader {
data: Box::new(Cursor::new(data.clone())),
tags: vec![],
}],
tags: vec![],
}];
let columns = vec![Column {
name: "col 0".to_string(),
tags: vec![],
skip_classification: false,
}];
let classifying_reader = Arc::new(Mutex::new(ClassifyingReader::new(
rows, columns, processors,
)));
let mut readers = decode_to_readers(classifying_reader, false);
let reader = mem::replace(
&mut readers[0][0],
CellReader {
data: Box::new(std::io::empty()),
tags: vec![],
},
);
let mut redacting_reader = RedactingReader::new(
Arc::new(Mutex::new(reader)),
MockPolicyEnforcer {
results: Vec::new(),
current_index: 0,
},
Vec::new(),
);
redacting_reader
.read_to_end(&mut redact_output)
.expect("failed to read from redacting reader");
assert!(
redact_output.len() == data.len(),
"length {} is not equal to expected length {}",
redact_output.len(),
data.len()
);
assert!(
redact_output == data,
"redacted output does not match input"
);
}
#[test]
fn classify_and_redact_row_iterator() {
use rand::Rng;
let mut rng = rand::thread_rng();
let a = rng.gen::<u8>() % 11;
let b = rng.gen::<usize>() % CELL_CLASSIFIER_CHUNK_SIZE;
let data_size: usize = (a as usize) * CELL_CLASSIFIER_CHUNK_SIZE + b;
let data: Vec<u8> = (0..data_size).map(|_| rng.gen_range(0x00..=0x7F)).collect();
let processor = Arc::new(RwLock::new(HookProcessor::new(
"test".to_string(),
None,
configuration::Configuration::new(),
&mut Column {
name: "test".to_string(),
tags: vec![],
skip_classification: false,
},
&vec![],
Arc::new(MockHookInvoker {
results: vec![],
current_index: 0.into(),
}),
&vec![],
)));
let mut path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("static/fixtures/allow_all.wasm");
let wasm_bytes = std::fs::read(path).expect("unable to read WASM file");
let mut it = ClassifyAndRedact::<DefaultPolicyEnforcer>::new(
vec![Column {
name: "test".to_string(),
tags: vec![],
skip_classification: false,
}],
vec![],
vec![processor],
HashMap::new(),
Arc::new(Mutex::new(
RUNTIME
.block_on(PolicyEngine::new(&wasm_bytes))
.expect("failed to create PolicyEngine"),
)),
vec![RowReader {
tags: vec![],
cells: vec![CellReader::new(vec![], Cursor::new(data.clone()))
.expect("failed to create CellReader")],
}],
)
.expect("failed to create ClassifyAndRedact");
let (span_tags, result_data) = it.read_all(&vec![]).expect("read_all failed");
assert_eq!(span_tags.len(), 1);
assert_eq!(span_tags[0].len(), 1);
assert_eq!(span_tags[0][0].len(), 0);
assert_eq!(data, result_data[0][0]);
}
#[test]
fn split_user_tags() {
let mut reader = ClassifyingReader::<MockHookProcessor>::new(vec![], vec![], vec![]);
let tags = vec![
SpanTag {
tag: CapsuleTag {
name: "tag0".to_string(),
tag_type: TagType::Unary,
value: "".to_string(),
source: "".to_string(),
hook_version: (0, 0, 0),
},
start: 0,
end: 5000,
},
SpanTag {
tag: CapsuleTag {
name: "tag1".to_string(),
tag_type: TagType::Unary,
value: "".to_string(),
source: "".to_string(),
hook_version: (0, 0, 0),
},
start: 0,
end: 5,
},
SpanTag {
tag: CapsuleTag {
name: "tag2".to_string(),
tag_type: TagType::Unary,
value: "".to_string(),
source: "".to_string(),
hook_version: (0, 0, 0),
},
start: 5,
end: 500,
},
SpanTag {
tag: CapsuleTag {
name: "tag3".to_string(),
tag_type: TagType::Unary,
value: "".to_string(),
source: "".to_string(),
hook_version: (0, 0, 0),
},
start: 10,
end: 50,
},
SpanTag {
tag: CapsuleTag {
name: "tag4".to_string(),
tag_type: TagType::Unary,
value: "".to_string(),
source: "".to_string(),
hook_version: (0, 0, 0),
},
start: 50,
end: 5000,
},
SpanTag {
tag: CapsuleTag {
name: "tag5".to_string(),
tag_type: TagType::Unary,
value: "".to_string(),
source: "".to_string(),
hook_version: (0, 0, 0),
},
start: 100,
end: 101,
},
SpanTag {
tag: CapsuleTag {
name: "tag6".to_string(),
tag_type: TagType::Unary,
value: "".to_string(),
source: "".to_string(),
hook_version: (0, 0, 0),
},
start: 11,
end: 50,
},
];
reader.current_cell_idx = 0;
let new_tags = reader.add_user_tags(&tags, 50);
assert_eq!(new_tags.len(), 5);
for tag in new_tags {
if tag.tag.name == "tag0".to_string() {
assert_eq!(tag.end, 50);
assert_eq!(tag.start, 0);
}
assert!(tag.end <= 50);
}
reader.current_cell_idx = 5;
let new_tags = reader.add_user_tags(&tags, 50);
assert_eq!(new_tags.len(), 5);
for tag in new_tags {
match tag.tag.name.as_str() {
"tag0" | "tag2" => {
assert_eq!(tag.end, 50);
assert_eq!(tag.start, 0);
}
"tag3" => {
assert_eq!(tag.end, 45);
assert_eq!(tag.start, 5);
}
"tag4" => {
assert_eq!(tag.end, 50);
assert_eq!(tag.start, 45);
}
"tag6" => {
assert_eq!(tag.end, 45);
assert_eq!(tag.start, 6);
}
x => panic!("unknown tag in response: {}", x),
}
}
reader.current_cell_idx = 100;
let new_tags = reader.add_user_tags(&tags, 10000);
assert_eq!(new_tags.len(), 4);
for tag in new_tags {
match tag.tag.name.as_str() {
"tag0" | "tag4" => {
assert_eq!(tag.end, 4900);
assert_eq!(tag.start, 0);
}
"tag2" => {
assert_eq!(tag.end, 400);
assert_eq!(tag.start, 0);
}
"tag5" => {
assert_eq!(tag.end, 1);
assert_eq!(tag.start, 0);
}
x => panic!("unknown tag in response: {}", x),
}
}
}
#[test]
fn build_request_batch_large() {
let columns = vec![
Column {
name: "a".to_string(),
tags: vec![],
skip_classification: false,
},
Column {
name: "b".to_string(),
tags: vec![],
skip_classification: false,
},
];
let mut data_a = vec![0; CELL_CLASSIFIER_CHUNK_SIZE * 5 - 500];
let mut data_b = vec![0; 100];
let mut rng = rand::thread_rng();
rng.fill_bytes(&mut data_a);
rng.fill_bytes(&mut data_b);
let rows = vec![RowReader {
tags: vec![],
cells: vec![
CellReader {
data: Box::new(Cursor::new(data_a.clone())),
tags: vec![],
},
CellReader {
data: Box::new(Cursor::new(data_b.clone())),
tags: vec![],
},
],
}];
let mut reader = ClassifyingReader::<MockHookProcessor>::new(rows, columns, vec![]);
let mut batch_a = reader.build_request_batch().expect("failed to build batch");
assert_eq!(batch_a.len(), 1);
let meta_a = &batch_a[0];
assert_eq!(
meta_a.data.len(),
CELL_CLASSIFIER_CHUNK_SIZE + CELL_CLASSIFIER_OVERLAP
);
assert_eq!(meta_a.row, 0);
assert_eq!(meta_a.row, 0);
assert_eq!(
meta_a.length,
CELL_CLASSIFIER_CHUNK_SIZE + CELL_CLASSIFIER_OVERLAP
);
assert_eq!(reader.overlap, meta_a.data[CELL_CLASSIFIER_CHUNK_SIZE..]);
let mut batch_b = reader.build_request_batch().expect("failed to build batch");
assert_eq!(batch_a.len(), 1);
let meta_b = &batch_b[0];
assert_eq!(
meta_b.data.len(),
CELL_CLASSIFIER_CHUNK_SIZE + CELL_CLASSIFIER_OVERLAP
);
assert_eq!(meta_b.row, 0);
assert_eq!(meta_b.row, 0);
assert_eq!(
meta_b.length,
CELL_CLASSIFIER_CHUNK_SIZE + CELL_CLASSIFIER_OVERLAP
);
assert_eq!(reader.overlap, meta_b.data[CELL_CLASSIFIER_CHUNK_SIZE..]);
assert_ne!(reader.overlap, meta_a.data[CELL_CLASSIFIER_CHUNK_SIZE..]);
assert_eq!(
meta_a.data[CELL_CLASSIFIER_CHUNK_SIZE..],
meta_b.data[..256]
);
batch_a[0].data.truncate(CELL_CLASSIFIER_CHUNK_SIZE);
batch_b[0].data.truncate(CELL_CLASSIFIER_CHUNK_SIZE);
let mut result_buffers = Vec::new();
let mut current_cell = 0;
result_buffers.push(Vec::new());
result_buffers[current_cell].extend_from_slice(&mut batch_a[0].data);
result_buffers[current_cell].extend_from_slice(&mut batch_b[0].data);
while !reader.eof {
let batch = reader.build_request_batch().expect("failed to build batch");
for mut cell in batch {
if cell.col != current_cell {
current_cell = cell.col;
result_buffers.push(Vec::new())
}
cell.data.truncate(CELL_CLASSIFIER_CHUNK_SIZE);
result_buffers[current_cell].extend_from_slice(&mut cell.data);
}
}
assert_eq!(result_buffers[0], data_a);
assert_eq!(result_buffers[1], data_b);
}
#[test]
fn build_request_batch_row_tags() {
let columns = vec![
Column {
name: "a".to_string(),
tags: vec![],
skip_classification: false,
},
Column {
name: "b".to_string(),
tags: vec![],
skip_classification: false,
},
];
let mut rng = rand::thread_rng();
let mut data: Vec<Vec<Vec<u8>>> = Vec::new();
let mut rows: Vec<RowReader> = Vec::new();
for i in 0..5 {
let mut data_col0 = vec![0; 5000];
let mut data_col1 = vec![0; 5090];
rng.fill_bytes(&mut data_col0);
rng.fill_bytes(&mut data_col1);
data.push(vec![data_col0.clone(), data_col1.clone()]);
rows.push(RowReader {
cells: vec![
CellReader {
data: Box::new(Cursor::new(data_col0)),
tags: vec![],
},
CellReader {
data: Box::new(Cursor::new(data_col1)),
tags: vec![],
},
],
tags: vec![CapsuleTag {
name: format!("row_tag_{}", i),
tag_type: TagType::Unary,
value: format!("{}", i),
source: "".to_string(),
hook_version: (0, 0, 0),
}],
})
}
let mut reader = ClassifyingReader::<MockHookProcessor>::new(rows, columns, vec![]);
let mut result_buffers = Vec::new();
let mut result_tags = Vec::new();
let mut current_row = 0;
let mut current_col = 0;
result_buffers.push(Vec::new());
result_buffers[0].push(Vec::new());
result_tags.push(Vec::new());
result_tags[0].push(Vec::new());
while !reader.eof {
let batch = reader.build_request_batch().expect("failed to build batch");
for mut cell in batch {
if cell.row != current_row {
current_row = cell.row;
result_buffers.push(Vec::new());
result_tags.push(Vec::new());
}
if cell.col != current_col {
current_col = cell.col;
result_buffers[current_row].push(Vec::new());
result_tags[current_row].push(Vec::new());
}
cell.data.truncate(CELL_CLASSIFIER_CHUNK_SIZE);
result_buffers[current_row][current_col].extend_from_slice(&mut cell.data);
result_tags[current_row][current_col].extend_from_slice(&mut cell.tags);
}
}
for row in 0..5 {
for col in 0..2 {
assert_eq!(result_buffers[row][col], data[row][col]);
let tag = &result_tags[row][col][0];
assert_eq!(tag.tag.name, format!("row_tag_{}", row));
}
}
}
}