use chrono::{DateTime, Utc};
use hyper::{HeaderMap, Method, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::net::SocketAddr;
use std::sync::Arc;
#[cfg(feature = "database-logging")]
use sqlx::{sqlite::SqlitePool, Pool, Sqlite};
#[derive(Clone)]
pub struct DbLogger {
#[cfg(feature = "database-logging")]
pool: Option<Arc<Pool<Sqlite>>>,
#[cfg(not(feature = "database-logging"))]
_dummy: (), }
#[derive(Debug, Serialize, Deserialize)]
pub struct RequestLog {
timestamp: DateTime<Utc>,
client_ip: String,
method: String,
path: String,
status_code: Option<u16>,
direction: String, headers: Value, body: Value, }
impl DbLogger {
pub async fn new(db_url: Option<String>) -> Self {
#[cfg(feature = "database-logging")]
{
if let Some(url) = db_url {
match SqlitePool::connect(&url).await {
Ok(pool) => {
if let Err(e) = Self::initialize_database(&pool).await {
eprintln!("Failed to initialize database: {e}");
return Self { pool: None };
}
return Self {
pool: Some(Arc::new(pool)),
};
}
Err(e) => {
eprintln!("Failed to connect to database: {e}");
return Self { pool: None };
}
}
}
Self { pool: None }
}
#[cfg(not(feature = "database-logging"))]
{
Self { _dummy: () }
}
}
#[cfg(feature = "database-logging")]
async fn initialize_database(pool: &Pool<Sqlite>) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS request_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
client_ip TEXT NOT NULL,
method TEXT NOT NULL,
path TEXT NOT NULL,
status_code INTEGER,
direction TEXT NOT NULL,
headers TEXT NOT NULL,
body TEXT NOT NULL
)
"#,
)
.execute(pool)
.await?;
Ok(())
}
pub async fn log_request(
&self,
client_ip: &SocketAddr,
method: &Method,
path: &str,
headers: &HeaderMap,
body: &[u8],
) {
#[cfg(feature = "database-logging")]
if let Some(pool) = &self.pool {
let body_json = if body.is_empty() {
serde_json::json!({ "content": "<empty>" })
} else if let Ok(value) = serde_json::from_slice::<serde_json::Value>(body) {
value
} else {
serde_json::json!({ "content": String::from_utf8_lossy(body) })
};
let headers_json = headers
.iter()
.map(|(k, v)| {
let key = k.as_str().to_string();
let value =
serde_json::Value::String(v.to_str().unwrap_or("<binary>").to_string());
(key, value)
})
.collect::<serde_json::Map<String, serde_json::Value>>();
let log = RequestLog {
timestamp: Utc::now(),
client_ip: client_ip.to_string(),
method: method.to_string(),
path: path.to_string(),
status_code: None, direction: "request".to_string(),
headers: serde_json::Value::Object(headers_json),
body: body_json,
};
if let Err(e) = self.save_log(&log, pool).await {
eprintln!("Failed to log request to database: {e}");
}
}
}
pub async fn log_response(
&self,
client_ip: &SocketAddr,
method: &Method,
path: &str,
status: &StatusCode,
headers: &HeaderMap,
body: &[u8],
) {
#[cfg(feature = "database-logging")]
if let Some(pool) = &self.pool {
let body_json = if body.is_empty() {
serde_json::json!({ "content": "<empty>" })
} else if let Ok(value) = serde_json::from_slice::<serde_json::Value>(body) {
value
} else {
serde_json::json!({ "content": String::from_utf8_lossy(body) })
};
let headers_json = headers
.iter()
.map(|(k, v)| {
let key = k.as_str().to_string();
let value =
serde_json::Value::String(v.to_str().unwrap_or("<binary>").to_string());
(key, value)
})
.collect::<serde_json::Map<String, serde_json::Value>>();
let log = RequestLog {
timestamp: Utc::now(),
client_ip: client_ip.to_string(),
method: method.to_string(),
path: path.to_string(),
status_code: Some(status.as_u16()),
direction: "response".to_string(),
headers: serde_json::Value::Object(headers_json),
body: body_json,
};
if let Err(e) = self.save_log(&log, pool).await {
eprintln!("Failed to log response to database: {e}");
}
}
}
#[cfg(feature = "database-logging")]
async fn save_log(&self, log: &RequestLog, pool: &Pool<Sqlite>) -> Result<(), sqlx::Error> {
let headers_str = serde_json::to_string(&log.headers).unwrap_or_default();
let body_str = serde_json::to_string(&log.body).unwrap_or_default();
sqlx::query(
r#"
INSERT INTO request_logs
(timestamp, client_ip, method, path, status_code, direction, headers, body)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(log.timestamp.to_rfc3339())
.bind(&log.client_ip)
.bind(&log.method)
.bind(&log.path)
.bind(log.status_code)
.bind(&log.direction)
.bind(headers_str)
.bind(body_str)
.execute(pool)
.await?;
Ok(())
}
}