use std::collections::HashMap;
use std::sync::Arc;
use std::task::{Context, Poll};
use actix_service::{Service, Transform};
use actix_web::dev::{ServiceRequest, ServiceResponse};
use actix_web::{Error, HttpMessage};
use futures_util::future::{ok, LocalBoxFuture, Ready};
use crate::env::{get_env_options, EnvError};
use crate::types::{Breadcrumb, Level, RequestContext};
use crate::{add_breadcrumb, get_client, init};
use super::common::{build_url, extract_client_ip, filter_headers};
#[derive(Clone, Default)]
pub struct BugwatchActix {
capture_server_errors: bool,
add_breadcrumbs: bool,
}
impl BugwatchActix {
pub fn new() -> Self {
Self {
capture_server_errors: true,
add_breadcrumbs: true,
}
}
pub fn from_env() -> Result<Self, EnvError> {
if get_client().is_none() {
let options = get_env_options()?;
init(options);
}
Ok(Self::new())
}
pub fn capture_server_errors(mut self, capture: bool) -> Self {
self.capture_server_errors = capture;
self
}
pub fn add_breadcrumbs(mut self, add: bool) -> Self {
self.add_breadcrumbs = add;
self
}
}
impl<S, B> Transform<S, ServiceRequest> for BugwatchActix
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Transform = BugwatchActixMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(BugwatchActixMiddleware {
service: Arc::new(service),
capture_server_errors: self.capture_server_errors,
add_breadcrumbs: self.add_breadcrumbs,
})
}
}
pub struct BugwatchActixMiddleware<S> {
service: Arc<S>,
capture_server_errors: bool,
add_breadcrumbs: bool,
}
impl<S> Clone for BugwatchActixMiddleware<S> {
fn clone(&self) -> Self {
Self {
service: self.service.clone(),
capture_server_errors: self.capture_server_errors,
add_breadcrumbs: self.add_breadcrumbs,
}
}
}
impl<S, B> Service<ServiceRequest> for BugwatchActixMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&self, req: ServiceRequest) -> Self::Future {
let service = self.service.clone();
let capture_server_errors = self.capture_server_errors;
let add_breadcrumbs = self.add_breadcrumbs;
let request_context = extract_request_context(&req);
let method = req.method().to_string();
let path = req.path().to_string();
req.extensions_mut().insert(BugwatchRequestContext(request_context.clone()));
Box::pin(async move {
let response = service.call(req).await?;
let status = response.status();
if add_breadcrumbs {
let mut data = HashMap::new();
data.insert("status_code".to_string(), serde_json::json!(status.as_u16()));
if let Some(ref url) = request_context.url {
data.insert("url".to_string(), serde_json::json!(url));
}
add_breadcrumb(
Breadcrumb::new("http", format!("{} {} -> {}", method, path, status.as_u16()))
.with_level(if status.is_server_error() {
Level::Error
} else if status.is_client_error() {
Level::Warning
} else {
Level::Info
})
.with_data(data),
);
}
if capture_server_errors && status.is_server_error() {
let message = format!(
"HTTP {} {} returned {}",
method,
path,
status.as_u16()
);
if let Some(client) = get_client() {
let mut tags = HashMap::new();
tags.insert("http.method".to_string(), method);
tags.insert("http.status_code".to_string(), status.as_u16().to_string());
if let Some(ref url) = request_context.url {
tags.insert("http.url".to_string(), url.clone());
}
let mut extra = HashMap::new();
extra.insert("request".to_string(), serde_json::to_value(&request_context).unwrap_or_default());
client.capture_message_with_options(&message, Level::Error, Some(tags), Some(extra));
}
}
Ok(response)
})
}
}
#[derive(Clone)]
pub struct BugwatchRequestContext(pub RequestContext);
fn extract_request_context(req: &ServiceRequest) -> RequestContext {
let conn_info = req.connection_info();
let headers: HashMap<String, String> = req
.headers()
.iter()
.filter_map(|(name, value)| {
value.to_str().ok().map(|v| (name.to_string(), v.to_string()))
})
.collect();
let filtered_headers = filter_headers(&headers);
let client_ip = extract_client_ip(&headers);
RequestContext {
url: Some(build_url(
conn_info.scheme(),
conn_info.host(),
req.path(),
if req.query_string().is_empty() { None } else { Some(req.query_string()) },
)),
method: Some(req.method().to_string()),
headers: Some(filtered_headers),
query_string: if req.query_string().is_empty() {
None
} else {
Some(req.query_string().to_string())
},
client_ip: client_ip.map(|ip| ip.to_string()),
..Default::default()
}
}
pub fn capture_actix_error<E: std::error::Error>(
req: &actix_web::HttpRequest,
error: &E,
) -> String {
if let Some(client) = get_client() {
let headers: HashMap<String, String> = req
.headers()
.iter()
.filter_map(|(name, value)| {
value.to_str().ok().map(|v| (name.to_string(), v.to_string()))
})
.collect();
let filtered_headers = filter_headers(&headers);
let client_ip = extract_client_ip(&headers);
let conn_info = req.connection_info();
let request_context = RequestContext {
url: Some(build_url(
conn_info.scheme(),
conn_info.host(),
req.path(),
if req.query_string().is_empty() { None } else { Some(req.query_string()) },
)),
method: Some(req.method().to_string()),
headers: Some(filtered_headers),
query_string: if req.query_string().is_empty() {
None
} else {
Some(req.query_string().to_string())
},
client_ip: client_ip.map(|ip| ip.to_string()),
..Default::default()
};
let mut tags = HashMap::new();
tags.insert("http.method".to_string(), req.method().to_string());
if let Some(ref url) = request_context.url {
tags.insert("http.url".to_string(), url.clone());
}
let mut extra = HashMap::new();
extra.insert("request".to_string(), serde_json::to_value(&request_context).unwrap_or_default());
client.capture_error_with_options(error, Level::Error, Some(tags), Some(extra))
} else {
String::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bugwatch_actix_builder() {
let middleware = BugwatchActix::new()
.capture_server_errors(false)
.add_breadcrumbs(false);
assert!(!middleware.capture_server_errors);
assert!(!middleware.add_breadcrumbs);
}
#[test]
fn test_from_env_missing_key() {
let original = std::env::var("BUGWATCH_API_KEY").ok();
std::env::remove_var("BUGWATCH_API_KEY");
let result = BugwatchActix::from_env();
assert!(result.is_err());
if let Some(val) = original {
std::env::set_var("BUGWATCH_API_KEY", val);
}
}
}