use crate::protocol::{RequestId, Response};
use crate::{protocol::Message, Error, Result};
use async_trait::async_trait;
use axum::{
extract::State,
http::StatusCode,
middleware::{self, Next},
response::{
sse::{Event, Sse},
IntoResponse,
},
routing::{get, post},
Json, Router,
};
use futures::{
channel::mpsc,
stream::{Stream, StreamExt},
};
use serde_json::json;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use std::{convert::Infallible, net::SocketAddr, sync::Arc};
use tokio::sync::Mutex;
type ClientId = u64;
#[derive(Clone)]
struct ClientInfo {
sender: MessageSender,
last_request_id: Option<RequestId>,
connected_at: std::time::Instant,
}
type MessageSender = mpsc::UnboundedSender<Message>;
#[derive(Clone)]
pub struct HttpServerConfig {
pub addr: SocketAddr,
pub auth_token: Option<String>,
}
pub struct AxumHttpServer {
config: HttpServerConfig,
clients: Arc<Mutex<HashMap<ClientId, ClientInfo>>>,
next_client_id: Arc<AtomicU64>,
}
impl Clone for AxumHttpServer {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
clients: self.clients.clone(),
next_client_id: self.next_client_id.clone(),
}
}
}
impl AxumHttpServer {
pub fn new(config: HttpServerConfig) -> Self {
Self {
config,
clients: Arc::new(Mutex::new(HashMap::new())),
next_client_id: Arc::new(AtomicU64::new(1)),
}
}
fn validate_auth_token(
headers: &axum::http::HeaderMap,
auth_token: &Option<String>,
) -> Result<()> {
if let Some(expected_token) = auth_token {
match headers.get("Authorization") {
Some(auth_header) => {
let auth_str = auth_header
.to_str()
.map_err(|_| Error::Transport("Invalid authorization header".into()))?;
if !auth_str.starts_with("Bearer ") {
return Err(Error::Transport("Invalid authorization format".into()));
}
let token = &auth_str["Bearer ".len()..];
if token != expected_token {
return Err(Error::Transport("Invalid token".into()));
}
}
None => return Err(Error::Transport("Missing authorization header".into())),
}
}
Ok(())
}
async fn auth_middleware(
State(auth_token): State<Option<String>>,
headers: axum::http::HeaderMap,
request: axum::http::Request<axum::body::Body>,
next: Next,
) -> impl IntoResponse {
match Self::validate_auth_token(&headers, &auth_token) {
Ok(_) => Ok(next.run(request).await),
Err(_) => Err(StatusCode::UNAUTHORIZED),
}
}
fn create_router(state: Arc<Self>) -> Router {
let auth_token = state.config.auth_token.clone();
Router::new()
.route("/events", get(Self::sse_handler))
.route("/messages", post(Self::message_handler))
.layer(middleware::from_fn_with_state(
auth_token.clone(),
Self::auth_middleware,
))
.with_state(state)
}
async fn cleanup_inactive_clients(&self) {
let now = std::time::Instant::now();
let timeout = std::time::Duration::from_secs(300);
let mut clients = self.clients.lock().await;
clients.retain(|_, info| {
let is_active = now.duration_since(info.connected_at) < timeout;
is_active
});
}
async fn sse_handler(
State(state): State<Arc<Self>>,
) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
let (tx, rx) = mpsc::unbounded();
let client_id = state.next_client_id.fetch_add(1, Ordering::SeqCst);
let client_info = ClientInfo {
sender: tx,
last_request_id: None,
connected_at: std::time::Instant::now(),
};
state.clients.lock().await.insert(client_id, client_info);
let state_clone = state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
loop {
interval.tick().await;
state_clone.cleanup_inactive_clients().await;
}
});
let clients = state.clients.clone();
let stream = async_stream::stream! {
let endpoint = format!("http://{}/messages", state.config.addr);
yield Ok(Event::default()
.event("endpoint")
.data(format!("{{\"endpoint\":\"{}\",\"clientId\":\"{}\"}}", endpoint, client_id)));
let mut rx = rx;
while let Some(msg) = rx.next().await {
if let Ok(json) = serde_json::to_string(&msg) {
yield Ok(Event::default()
.event("message")
.data(json));
}
}
clients.lock().await.remove(&client_id);
};
Sse::new(stream).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(1))
.text("ping"),
)
}
async fn find_client_by_request_id(&self, request_id: &RequestId) -> Option<ClientId> {
let clients = self.clients.lock().await;
for (client_id, info) in clients.iter() {
if let Some(last_request_id) = &info.last_request_id {
if last_request_id == request_id {
return Some(*client_id);
}
}
}
None
}
async fn message_handler(
State(state): State<Arc<Self>>,
headers: axum::http::HeaderMap,
Json(message): Json<Message>,
) -> impl IntoResponse {
let client_id = headers
.get("X-Client-ID")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok());
if let Some(client_id) = client_id {
if let Some(client_info) = state.clients.lock().await.get_mut(&client_id) {
client_info.connected_at = std::time::Instant::now();
}
}
match &message {
Message::Request(request) => {
if let Some(client_id) = client_id {
if let Some(client_info) = state.clients.lock().await.get_mut(&client_id) {
client_info.last_request_id = Some(request.id.clone());
}
let response = match request.method.as_str() {
"ping" => {
Response::success(json!({}), request.id.clone())
}
"shutdown" => {
Response::success(json!(null), request.id.clone())
}
_ => {
Response::error(
crate::protocol::ResponseError {
code: crate::error_codes::METHOD_NOT_FOUND,
message: "Method not found".to_string(),
data: None,
},
request.id.clone(),
)
}
};
if let Some(client_info) = state.clients.lock().await.get(&client_id) {
let _ = client_info
.sender
.unbounded_send(Message::Response(response));
}
}
}
Message::Notification(notification) => {
if notification.method.as_str() == "exit" {
state.clients.lock().await.clear();
}
}
_ => {
}
}
(axum::http::StatusCode::OK, "Message sent").into_response()
}
async fn send_to_client(&self, client_id: ClientId, message: Message) -> Result<()> {
if let Some(client_info) = self.clients.lock().await.get(&client_id) {
client_info
.sender
.unbounded_send(message)
.map_err(|e| crate::Error::Transport(e.to_string()))?;
}
Ok(())
}
}
#[async_trait]
impl super::HttpTransport for AxumHttpServer {
async fn initialize(&mut self) -> Result<()> {
let app = Self::create_router(Arc::new(self.clone()));
let addr = self.config.addr;
tokio::spawn(async move {
axum::serve(tokio::net::TcpListener::bind(addr).await.unwrap(), app)
.await
.unwrap();
});
Ok(())
}
async fn send(&self, message: Message) -> Result<()> {
match &message {
Message::Response(response) => {
if let Some(client_id) = self.find_client_by_request_id(&response.id).await {
self.send_to_client(client_id, message).await?;
}
}
Message::Notification(_) => {
let clients = self.clients.lock().await;
for (client_id, _) in clients.iter() {
self.send_to_client(*client_id, message.clone()).await?;
}
}
_ => {
}
}
Ok(())
}
async fn receive(&self) -> Result<Message> {
Err(crate::Error::Transport(
"Server does not support direct message receiving. Use HTTP POST endpoint instead."
.into(),
))
}
async fn close(&mut self) -> Result<()> {
self.clients.lock().await.clear();
Ok(())
}
}
pub type DefaultHttpServer = AxumHttpServer;