#![deny(missing_docs)]
use actix_web::{
body::BodySize,
dev::{MessageBody, ResponseBody, Service, ServiceRequest, ServiceResponse, Transform},
http::{header::REFERER, Method, StatusCode},
web::Bytes,
};
use chrono::{DateTime, Utc};
use futures::{future, Future};
use serde::{Serialize, Serializer};
use std::{
collections::HashMap,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
struct Serializers;
impl Serializers {
fn to_rfc3339<S>(time: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&time.to_rfc3339())
}
fn to_u16<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u16(status.as_u16())
}
fn to_string<S, T>(field: T, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: ToString,
{
serializer.serialize_str(&field.to_string())
}
}
#[derive(Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
enum Level {
Info,
Warn,
Error,
}
impl From<&StatusCode> for Level {
fn from(status: &StatusCode) -> Self {
let code = status.as_u16();
if code < 400 {
Self::Info
} else if code >= 400 && code < 500 {
Self::Warn
} else {
Self::Error
}
}
}
#[derive(Serialize)]
struct HttpDescriptors {
latency: String,
#[serde(flatten)]
request: RequestDescriptors,
#[serde(flatten)]
response: ResponseDescriptors,
}
#[derive(Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct RequestDescriptors {
#[serde(skip_serializing_if = "Option::is_none")]
referer: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
remote_ip: Option<String>,
#[serde(serialize_with = "Serializers::to_string")]
request_method: Method,
request_url: String,
#[serde(serialize_with = "Serializers::to_rfc3339")]
time: DateTime<Utc>,
}
impl From<&ServiceRequest> for RequestDescriptors {
fn from(request: &ServiceRequest) -> Self {
let request_method = request.method().to_owned();
let request_url = request.path().to_owned();
let headers = request.headers();
let time = Utc::now();
let remote_ip = request.connection_info().remote().map(String::from);
let referer = headers.get(REFERER).and_then(|header| {
if let Ok(valid_header) = header.to_str() {
Some(valid_header.to_string())
} else {
None
}
});
Self {
referer,
remote_ip,
request_method,
request_url,
time,
}
}
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct ResponseDescriptors {
#[serde(skip_serializing)]
time: DateTime<Utc>,
#[serde(serialize_with = "Serializers::to_u16")]
status: StatusCode,
}
impl<B> From<&ServiceResponse<B>> for ResponseDescriptors {
fn from(response: &ServiceResponse<B>) -> Self {
let status = response.status();
let time = Utc::now();
Self { status, time }
}
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Log<'a> {
#[serde(skip_serializing_if = "HashMap::is_empty")]
#[serde(flatten)]
fields: HashMap<&'static str, &'static str>,
http_request: &'a HttpDescriptors,
severity: Level,
#[serde(serialize_with = "Serializers::to_rfc3339")]
time: &'a DateTime<Utc>,
}
#[derive(Clone, Default)]
pub struct RequestLogger {
fields: HashMap<&'static str, &'static str>,
}
impl RequestLogger {
pub fn new() -> Self {
Self::default()
}
pub fn field(mut self, key: &'static str, value: &'static str) -> Self {
self.fields.insert(key, value);
self
}
}
impl<S, B> Transform<S> for RequestLogger
where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>
+ 'static,
B: MessageBody + 'static,
{
type Request = ServiceRequest;
type Response = ServiceResponse<LogMessage<B>>;
type Error = actix_web::Error;
type InitError = ();
type Transform = LoggerMiddleware<S>;
type Future = future::Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
let fields = self.fields.clone();
future::ok(LoggerMiddleware { service, fields })
}
}
pub struct LoggerMiddleware<S> {
fields: HashMap<&'static str, &'static str>,
service: S,
}
impl<S, B> Service for LoggerMiddleware<S>
where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
B: MessageBody,
{
type Request = ServiceRequest;
type Response = ServiceResponse<LogMessage<B>>;
type Error = actix_web::Error;
type Future = LoggerResponse<S, B>;
fn poll_ready(&mut self, context: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(context)
}
fn call(&mut self, request: Self::Request) -> Self::Future {
let request_descriptors = RequestDescriptors::from(&request);
let fields = std::mem::take(&mut self.fields);
LoggerResponse {
fields,
future: self.service.call(request),
request_descriptors,
_t: PhantomData,
}
}
}
#[pin_project::pin_project]
pub struct LoggerResponse<S, B>
where
B: MessageBody,
S: Service,
{
fields: HashMap<&'static str, &'static str>,
#[pin]
future: S::Future,
request_descriptors: RequestDescriptors,
_t: PhantomData<B>,
}
impl<S, B> Future for LoggerResponse<S, B>
where
B: MessageBody,
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
{
type Output = Result<ServiceResponse<LogMessage<B>>, actix_web::Error>;
fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
let projected = self.project();
let response = futures::ready!(projected.future.poll(context));
let request_descriptors = projected.request_descriptors.clone();
let fields = projected.fields.clone();
Poll::Ready(response.map(|response| {
let response_descriptors = ResponseDescriptors::from(&response);
let milliseconds_latency = response_descriptors
.time
.signed_duration_since(request_descriptors.time)
.num_milliseconds();
let seconds_latency = milliseconds_latency / 1000;
let latency = format!("{}s", &seconds_latency);
let http_descriptors = HttpDescriptors {
latency,
response: response_descriptors,
request: request_descriptors,
};
response.map_body(move |_, body| {
ResponseBody::Body(LogMessage {
body,
fields,
http_descriptors,
})
})
}))
}
}
pub struct LogMessage<B> {
body: ResponseBody<B>,
fields: HashMap<&'static str, &'static str>,
http_descriptors: HttpDescriptors,
}
impl<B> Drop for LogMessage<B> {
fn drop(&mut self) {
let severity = Level::from(&self.http_descriptors.response.status);
let time = &self.http_descriptors.response.time;
let fields = std::mem::take(&mut self.fields);
let log = Log {
http_request: &self.http_descriptors,
fields,
severity,
time,
};
if let Ok(message) = serde_json::to_string(&log) {
println!("{}", &message);
}
}
}
impl<B: MessageBody> MessageBody for LogMessage<B> {
fn size(&self) -> BodySize {
self.body.size()
}
fn poll_next(
&mut self,
context: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, actix_web::Error>>> {
match self.body.poll_next(context) {
Poll::Ready(Some(Ok(chunk))) => Poll::Ready(Some(Ok(chunk))),
poll_state => poll_state,
}
}
}