use crate::storage::{LogEvent, LogFilter, LogStorage, SortOrder};
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::extract::State;
use axum::response::{IntoResponse, Response};
use axum::Json;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone)]
pub struct LogsState {
pub storage: LogStorage,
}
impl LogsState {
pub fn new(storage: LogStorage) -> Self {
Self { storage }
}
}
#[derive(Debug, Deserialize)]
pub struct LogsRequest {
pub limit: Option<usize>,
#[serde(default)]
pub offset: usize,
pub global_level: Option<String>,
#[serde(default)]
pub target_levels: HashMap<String, String>,
pub search: Option<String>,
pub target: Option<String>,
#[serde(default)]
pub sort_order: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct LogsResponse {
pub logs: Vec<LogEvent>,
pub total: usize,
}
#[derive(Debug, Serialize)]
pub struct TargetsResponse {
pub targets: Vec<String>,
}
pub async fn get_logs(
State(state): State<Arc<LogsState>>,
Json(request): Json<LogsRequest>,
) -> Response {
let sort_order = match request.sort_order.as_deref() {
Some("oldest_first") => SortOrder::OldestFirst,
_ => SortOrder::NewestFirst, };
let filter = LogFilter {
global_level: request.global_level.map(|l| l.to_uppercase()),
target_levels: request
.target_levels
.iter()
.map(|(k, v)| (k.clone(), v.to_uppercase()))
.collect(),
search: request.search.filter(|s| !s.is_empty()),
target: request.target.filter(|t| !t.is_empty()),
sort_order,
};
let (logs, total_filtered) =
state
.storage
.get_filtered(&filter, request.limit, Some(request.offset));
let response = LogsResponse {
logs,
total: total_filtered,
};
Json(response).into_response()
}
pub async fn ws_logs(ws: WebSocketUpgrade, State(state): State<Arc<LogsState>>) -> Response {
ws.on_upgrade(|socket| handle_ws_connection(socket, state))
}
async fn handle_ws_connection(mut socket: WebSocket, state: Arc<LogsState>) {
tracing::debug!("WebSocket connection established");
let mut rx = state.storage.subscribe();
let mut ping_interval = tokio::time::interval(std::time::Duration::from_secs(30));
ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
result = rx.recv() => {
match result {
Ok(log_event) => {
let json = match serde_json::to_string(&log_event) {
Ok(json) => json,
Err(e) => {
tracing::error!("Failed to serialize log event: {}", e);
continue;
}
};
if socket.send(Message::Text(json.into())).await.is_err() {
tracing::debug!("WebSocket client disconnected");
break;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
tracing::debug!("WebSocket receiver lagged, missed {} messages", count);
continue;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
tracing::warn!("Broadcast channel closed");
break;
}
}
}
msg = socket.recv() => {
match msg {
Some(Ok(Message::Ping(data))) => {
if socket.send(Message::Pong(data)).await.is_err() {
break;
}
}
Some(Ok(Message::Pong(_))) => {
}
Some(Ok(Message::Close(_))) => {
tracing::debug!("WebSocket client sent close frame");
break;
}
Some(Ok(_)) => {
}
Some(Err(e)) => {
tracing::debug!("WebSocket error: {}", e);
break;
}
None => {
tracing::debug!("WebSocket connection closed by client");
break;
}
}
}
_ = ping_interval.tick() => {
if socket.send(Message::Ping(vec![].into())).await.is_err() {
tracing::debug!("Failed to send ping, client disconnected");
break;
}
}
}
}
tracing::debug!("WebSocket connection closed");
}
pub async fn get_targets(State(state): State<Arc<LogsState>>) -> Response {
let targets = state.storage.get_targets();
let response = TargetsResponse { targets };
Json(response).into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_logs_request_default() {
let request = LogsRequest {
limit: Some(100),
offset: 0,
global_level: None,
target_levels: HashMap::new(),
search: None,
target: None,
sort_order: None,
};
assert_eq!(request.limit, Some(100));
assert_eq!(request.offset, 0);
}
}