use crate::error::Result;
use bytes::Bytes;
use slinger::{Body, Request, Response};
use std::fmt;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::time::timeout;
use uuid::Uuid;
fn generate_session_id() -> u128 {
Uuid::new_v4().as_u128()
}
#[derive(Clone)]
pub struct MitmRequest {
session_id: u128,
pub source: Option<SocketAddr>,
pub destination: String,
pub timestamp: u64,
is_http: bool,
pub request: Request,
}
impl MitmRequest {
pub fn new(destination: impl Into<String>, request: Request) -> Self {
Self {
session_id: generate_session_id(),
source: None,
destination: destination.into(),
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0),
is_http: true,
request,
}
}
pub fn with_source(source: SocketAddr, destination: impl Into<String>, request: Request) -> Self {
Self {
session_id: generate_session_id(),
source: Some(source),
destination: destination.into(),
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0),
is_http: true,
request,
}
}
pub fn raw_tcp(destination: impl Into<String>, body: impl Into<Bytes>) -> Self {
let request = Request {
body: Some(Body::from(body.into())),
..Default::default()
};
Self {
session_id: generate_session_id(),
source: None,
destination: destination.into(),
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0),
is_http: false,
request,
}
}
pub fn raw_tcp_with_source(
source: SocketAddr,
destination: impl Into<String>,
body: impl Into<Bytes>,
) -> Self {
let request = Request {
body: Some(Body::from(body.into())),
..Default::default()
};
Self {
session_id: generate_session_id(),
source: Some(source),
destination: destination.into(),
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0),
is_http: false,
request,
}
}
pub fn session_id(&self) -> u128 {
self.session_id
}
pub fn set_session_id(&mut self, session_id: u128) {
self.session_id = session_id;
}
pub fn source(&self) -> Option<SocketAddr> {
self.source
}
pub fn destination(&self) -> &str {
&self.destination
}
pub fn timestamp(&self) -> u64 {
self.timestamp
}
pub fn request(&self) -> &Request {
&self.request
}
pub fn request_mut(&mut self) -> &mut Request {
&mut self.request
}
pub fn body(&self) -> Option<&Body> {
self.request.body.as_ref()
}
pub fn set_body(&mut self, body: impl Into<Bytes>) {
self.request.body = Some(Body::from(body.into()));
}
pub fn is_http(&self) -> bool {
self.is_http
}
}
impl fmt::Debug for MitmRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MitmRequest")
.field("session_id", &self.session_id)
.field("source", &self.source)
.field("destination", &self.destination)
.field("timestamp", &self.timestamp)
.field("is_http", &self.is_http())
.field("request", &self.request)
.finish()
}
}
#[derive(Clone)]
pub struct MitmResponse {
session_id: u128,
pub source: String,
pub destination: Option<SocketAddr>,
pub timestamp: u64,
is_http: bool,
pub response: Response,
}
impl MitmResponse {
pub fn new(session_id: u128, source: impl Into<String>, response: Response) -> Self {
Self {
session_id,
source: source.into(),
destination: None,
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0),
is_http: true,
response,
}
}
pub fn with_destination(
session_id: u128,
source: impl Into<String>,
destination: SocketAddr,
response: Response,
) -> Self {
Self {
session_id,
source: source.into(),
destination: Some(destination),
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0),
is_http: true,
response,
}
}
pub fn raw_tcp(session_id: u128, source: impl Into<String>, body: impl Into<Bytes>) -> Self {
let response = Response {
body: Some(Body::from(body.into())),
..Default::default()
};
Self {
session_id,
source: source.into(),
destination: None,
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0),
is_http: false,
response,
}
}
pub fn raw_tcp_with_destination(
session_id: u128,
source: impl Into<String>,
destination: SocketAddr,
body: impl Into<Bytes>,
) -> Self {
let response = Response {
body: Some(Body::from(body.into())),
..Default::default()
};
Self {
session_id,
source: source.into(),
destination: Some(destination),
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0),
is_http: false,
response,
}
}
pub fn session_id(&self) -> u128 {
self.session_id
}
pub fn source(&self) -> &str {
&self.source
}
pub fn destination(&self) -> Option<SocketAddr> {
self.destination
}
pub fn timestamp(&self) -> u64 {
self.timestamp
}
pub fn response(&self) -> &Response {
&self.response
}
pub fn response_mut(&mut self) -> &mut Response {
&mut self.response
}
pub fn body(&self) -> Option<&Body> {
self.response.body.as_ref()
}
pub fn set_body(&mut self, body: impl Into<Bytes>) {
self.response.body = Some(Body::from(body.into()));
}
pub fn is_http(&self) -> bool {
self.is_http
}
}
impl fmt::Debug for MitmResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MitmResponse")
.field("session_id", &self.session_id)
.field("source", &self.source)
.field("destination", &self.destination)
.field("timestamp", &self.timestamp)
.field("is_http", &self.is_http())
.field("response", &self.response)
.finish()
}
}
#[async_trait::async_trait]
pub trait Interceptor: Send + Sync {
async fn intercept_request(&self, request: MitmRequest) -> Result<Option<MitmRequest>> {
Ok(Some(request))
}
async fn intercept_response(&self, response: MitmResponse) -> Result<Option<MitmResponse>> {
Ok(Some(response))
}
}
pub struct InterceptorHandler {
interceptors: Vec<Arc<dyn Interceptor>>,
timeout_secs: u64,
}
impl InterceptorHandler {
pub fn new() -> Self {
Self {
interceptors: Vec::new(),
timeout_secs: 60,
}
}
pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
self.timeout_secs = timeout_secs;
self
}
pub fn add_interceptor(&mut self, interceptor: Arc<dyn Interceptor>) {
self.interceptors.push(interceptor);
}
pub async fn process_request(&self, mut request: MitmRequest) -> Result<Option<MitmRequest>> {
for interceptor in &self.interceptors {
let request_clone = request.clone();
match timeout(
std::time::Duration::from_secs(self.timeout_secs),
interceptor.intercept_request(request_clone),
)
.await
{
Ok(Ok(Some(modified))) => request = modified,
Ok(Ok(None)) => return Ok(None), Ok(Err(e)) => return Err(e), Err(_) => {
tracing::warn!(
"Interceptor timed out after {}s; skipping",
self.timeout_secs
);
continue;
}
}
}
Ok(Some(request))
}
pub async fn process_response(&self, mut response: MitmResponse) -> Result<Option<MitmResponse>> {
for interceptor in &self.interceptors {
let response_clone = response.clone();
match timeout(
std::time::Duration::from_secs(self.timeout_secs),
interceptor.intercept_response(response_clone),
)
.await
{
Ok(Ok(Some(modified))) => response = modified,
Ok(Ok(None)) => return Ok(None), Ok(Err(e)) => return Err(e), Err(_) => {
tracing::warn!(
"Interceptor timed out after {}s; skipping",
self.timeout_secs
);
continue;
}
}
}
Ok(Some(response))
}
pub fn has_interceptors(&self) -> bool {
!self.interceptors.is_empty()
}
}
impl Default for InterceptorHandler {
fn default() -> Self {
Self::new()
}
}
pub struct InterceptorFactory;
impl InterceptorFactory {
pub fn logging() -> LoggingInterceptor {
LoggingInterceptor
}
}
pub struct LoggingInterceptor;
#[async_trait::async_trait]
impl Interceptor for LoggingInterceptor {
async fn intercept_request(&self, request: MitmRequest) -> Result<Option<MitmRequest>> {
if request.is_http() {
tracing::info!(
"[MITM] HTTP Request (session_id={}): {} {}",
request.session_id(),
request.request().method(),
request.request().uri()
);
for (name, value) in request.request().headers() {
tracing::info!(" {}: {:?}", name, value);
}
} else {
tracing::info!(
"[MITM] TCP Request (session_id={}) to {}: {} bytes",
request.session_id(),
request.destination(),
request.body().map(|b| b.len()).unwrap_or(0)
);
}
if let Some(source) = request.source() {
tracing::info!(" From: {}", source);
}
tracing::info!(" Timestamp: {}", request.timestamp());
Ok(Some(request))
}
async fn intercept_response(&self, response: MitmResponse) -> Result<Option<MitmResponse>> {
if response.is_http() {
tracing::info!(
"[MITM] HTTP Response (session_id={}): {}",
response.session_id(),
response.response().status_code()
);
for (name, value) in response.response().headers() {
tracing::info!(" {}: {:?}", name, value);
}
} else {
tracing::info!(
"[MITM] TCP Response (session_id={}) from {}: {} bytes",
response.session_id(),
response.source(),
response.body().map(|b| b.len()).unwrap_or(0)
);
}
if let Some(destination) = response.destination() {
tracing::info!(" To: {}", destination);
}
tracing::info!(" Timestamp: {}", response.timestamp());
Ok(Some(response))
}
}