use crate::capsule::capsule_v2::SealedV2Capsule;
use crate::capsule::common::CapsuleError;
use crate::capsule::policy_enforcer::PolicyEnforcer;
use crate::capsule::util_readers::MutexReader;
use crate::capsule::{CapsuleTag, CellIterator, Column, RowIterator};
use crate::session::policy_engine::PolicyEngine;
use antimatter_api::models::{CapsuleOpenRequest, CapsuleOpenResponse, NewAccessLogEntry};
use polonius_the_crab::prelude::*;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::{io::Read, marker::Send};
pub type AccessLogSender =
Arc<Mutex<Box<dyn Fn(NewAccessLogEntry, &str, &str, &str) -> Result<(), CapsuleError> + Send>>>;
pub type CapsuleOpener = Box<
dyn FnMut(
&str,
&str,
&str,
&Option<Vec<u8>>,
CapsuleOpenRequest,
) -> Result<
Option<(CapsuleOpenResponse, Option<Arc<Mutex<PolicyEngine>>>)>,
CapsuleError,
> + Send,
>;
pub struct V2Bundle<R: Read + Send + 'static, P: PolicyEnforcer + 'static> {
input: Arc<Mutex<R>>,
read_context: String,
extra: String,
domain_id: String,
access_log_sender: AccessLogSender,
open_capsule: Arc<Mutex<CapsuleOpener>>,
read_params: HashMap<String, String>,
domain_identity_params: HashMap<String, String>,
current_capsule: SealedV2Capsule<std::io::Chain<std::io::Cursor<[u8; 1]>, MutexReader<R>>, P>,
capsule_ids: Vec<String>,
capsule_tags: Vec<CapsuleTag>,
columns: Vec<Column>,
open_failures: Vec<String>,
}
pub fn next_capsule<R: Read + Send + 'static, P: PolicyEnforcer + 'static>(
input: Arc<Mutex<R>>,
domain_id: &str,
read_context: &str,
access_log_sender: AccessLogSender,
open_capsule: Arc<Mutex<CapsuleOpener>>,
read_params: HashMap<String, String>,
domain_identity_params: HashMap<String, String>,
) -> Result<
SealedV2Capsule<std::io::Chain<std::io::Cursor<[u8; 1]>, MutexReader<R>>, P>,
CapsuleError,
> {
let mut next_byte = [0u8; 1];
let read_result = input.lock().unwrap().read(&mut next_byte);
match read_result {
Ok(0) => Err(CapsuleError::EndOfCapsule),
Ok(_) => Ok(SealedV2Capsule::from_reader(
Arc::new(Mutex::new(std::io::Cursor::new(next_byte).chain(
MutexReader {
reader: input.clone(),
},
))),
domain_id,
read_context,
access_log_sender,
open_capsule,
read_params,
domain_identity_params,
)?),
Err(e) => Err(CapsuleError::Generic(format!(
"reading input stream: {}",
e
))),
}
}
impl<R: Read + Send + 'static, P: PolicyEnforcer + 'static> V2Bundle<R, P> {
pub fn from_reader(
input: Arc<Mutex<R>>,
read_context: String,
extra: String,
domain_id: String,
access_log_sender: AccessLogSender,
open_capsule: Arc<Mutex<CapsuleOpener>>,
read_params: HashMap<String, String>,
domain_identity_params: HashMap<String, String>,
) -> Result<Self, CapsuleError> {
Ok(Self {
input: input.clone(),
extra,
access_log_sender: access_log_sender.clone(),
open_capsule: open_capsule.clone(),
read_params: read_params.clone(),
domain_identity_params: domain_identity_params.clone(),
current_capsule: next_capsule(
input,
&domain_id,
&read_context,
access_log_sender,
open_capsule,
read_params,
domain_identity_params,
)?,
domain_id,
read_context,
capsule_ids: Vec::new(),
capsule_tags: Vec::new(),
columns: Vec::new(),
open_failures: Vec::new(),
})
}
fn next_capsule_and_next_row(
&mut self,
redact_tags: Vec<CapsuleTag>,
) -> Result<Box<dyn CellIterator + 'static>, CapsuleError> {
{
let current_capsule = &self.current_capsule;
self.capsule_ids.append(&mut current_capsule.capsule_ids());
self.capsule_tags
.append(&mut current_capsule.capsule_tags());
if self.columns.is_empty() {
self.columns.append(&mut current_capsule.columns());
}
}
self.current_capsule = next_capsule(
self.input.clone(),
&self.domain_id.clone(),
&self.read_context.clone(),
self.access_log_sender.clone(),
self.open_capsule.clone(),
self.read_params.clone(),
self.domain_identity_params.clone(),
)?;
self.next_row(redact_tags)
}
}
impl<R: Read + Send + 'static, P: PolicyEnforcer + 'static> RowIterator for V2Bundle<R, P> {
fn next_row(
&mut self,
redact_tags: Vec<CapsuleTag>,
) -> Result<Box<dyn CellIterator + 'static>, CapsuleError> {
let mut this = self;
let err = polonius!(
|this| -> Result<Box<dyn CellIterator + 'static>, CapsuleError> {
match this.current_capsule.next_row(redact_tags.clone()) {
Err(e) => exit_polonius!(e),
Ok(row) => {
polonius_return!(Ok(row));
}
}
}
);
match err {
CapsuleError::EndOfCapsule => this.next_capsule_and_next_row(redact_tags),
e => Err(e),
}
}
fn domain_id(&self) -> String {
self.domain_id.clone()
}
fn extra_data(&self) -> String {
self.extra.clone()
}
fn capsule_ids(&self) -> Vec<String> {
self.capsule_ids.clone()
}
fn capsule_tags(&self) -> Vec<CapsuleTag> {
self.capsule_tags.clone()
}
fn columns(&self) -> Vec<Column> {
self.columns.clone()
}
fn open_failures(&self) -> Vec<String> {
self.open_failures.clone()
}
}