use crate::extensions::AuditLogEntry;
use crate::logging::api::objects::AuditLogV3;
use crate::service::{Layer, Service};
use conjure_error::Error;
use futures_sink::Sink;
use futures_util::SinkExt;
use http::{Response, StatusCode};
use http_body::{Body, Frame, SizeHint};
use pin_project::pin_project;
use std::error;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::Mutex;
use witchcraft_log::error;
pub struct AuditLogLayer<T> {
logger: Arc<Mutex<T>>,
}
impl<T> AuditLogLayer<T> {
pub fn new(logger: Arc<Mutex<T>>) -> Self {
AuditLogLayer { logger }
}
}
impl<S, T> Layer<S> for AuditLogLayer<T> {
type Service = AuditLogService<S, T>;
fn layer(self, inner: S) -> Self::Service {
AuditLogService {
logger: self.logger,
inner,
}
}
}
pub struct AuditLogService<S, T> {
logger: Arc<Mutex<T>>,
inner: S,
}
impl<S, T, R, B> Service<R> for AuditLogService<S, T>
where
S: Service<R, Response = Response<B>> + Sync,
T: Sink<AuditLogV3> + Unpin + 'static + Send,
T::Error: Into<Box<dyn error::Error + Sync + Send>>,
R: Send,
B: Send,
{
type Response = Response<AuditLogResponseBody<B>>;
async fn call(&self, req: R) -> Self::Response {
let mut response = self.inner.call(req).await;
if let Some(audit_log_entry) = response.extensions_mut().remove::<AuditLogEntry>() {
let send = async {
self.logger
.lock()
.await
.feed(audit_log_entry.0)
.await
.map_err(Error::internal_safe)?;
Ok(())
};
if let Err(e) = send.await {
error!("error persisting audit log entry", error: e);
let mut response = Response::new(AuditLogResponseBody { inner: None });
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return response;
}
}
response.map(|inner| AuditLogResponseBody { inner: Some(inner) })
}
}
#[pin_project]
pub struct AuditLogResponseBody<B> {
#[pin]
inner: Option<B>,
}
impl<B> Body for AuditLogResponseBody<B>
where
B: Body,
{
type Data = B::Data;
type Error = B::Error;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.project();
match this.inner.as_pin_mut() {
Some(inner) => inner.poll_frame(cx),
None => Poll::Ready(None),
}
}
fn size_hint(&self) -> SizeHint {
match &self.inner {
Some(inner) => inner.size_hint(),
None => SizeHint::with_exact(0),
}
}
fn is_end_stream(&self) -> bool {
match &self.inner {
Some(inner) => inner.is_end_stream(),
None => true,
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::logging::api::objects::{AuditProducer, AuditResult};
use crate::service::test_util::service_fn;
use conjure_object::{Utc, Uuid};
#[allow(clippy::large_enum_variant)]
#[derive(PartialEq, Debug)]
enum TestSinkEvent {
Item(AuditLogV3),
Flush,
}
struct TestSink {
events: Vec<TestSinkEvent>,
}
impl Sink<AuditLogV3> for TestSink {
type Error = &'static str;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(mut self: Pin<&mut Self>, item: AuditLogV3) -> Result<(), Self::Error> {
self.events.push(TestSinkEvent::Item(item));
Ok(())
}
fn poll_flush(
mut self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.events.push(TestSinkEvent::Flush);
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
unimplemented!()
}
}
#[tokio::test]
async fn no_op_with_no_audit_event() {
let service = AuditLogLayer::new(Arc::new(Mutex::new(TestSink { events: vec![] })))
.layer(service_fn(|_| async { Response::new(()) }));
let response = service.call(()).await;
assert_eq!(response.status(), StatusCode::OK);
assert!(response.body().inner.is_some());
assert_eq!(service.logger.lock().await.events, vec![]);
}
#[tokio::test]
async fn log_audit_event() {
let log = AuditLogV3::builder()
.type_("audit.3")
.product("baz")
.product_version("1")
.producer_type(AuditProducer::Server)
.event_id(Uuid::new_v4())
.time(Utc::now())
.name("PUT_FILE")
.result(AuditResult::Success)
.build();
let service = AuditLogLayer::new(Arc::new(Mutex::new(TestSink { events: vec![] }))).layer(
service_fn(|_| {
let log = log.clone();
async {
let mut response = Response::new(());
response.extensions_mut().insert(AuditLogEntry::v3(log));
response
}
}),
);
let response = service.call(()).await;
assert_eq!(response.status(), StatusCode::OK);
assert!(response.body().inner.is_some());
assert_eq!(
service.logger.lock().await.events,
vec![TestSinkEvent::Item(log.clone())]
);
}
}