mcprotocol_rs/transport/http/
server.rs

1use crate::{protocol::Message, Result};
2use async_trait::async_trait;
3use axum::{
4    extract::State,
5    response::{
6        sse::{Event, Sse},
7        IntoResponse,
8    },
9    routing::{get, post},
10    Json, Router,
11};
12use futures::stream::Stream;
13use std::{convert::Infallible, net::SocketAddr, sync::Arc};
14use tokio::sync::broadcast;
15use tokio_stream::StreamExt;
16
17/// HTTP server configuration
18#[derive(Clone)]
19pub struct HttpServerConfig {
20    pub addr: SocketAddr,
21    pub auth_token: Option<String>,
22}
23
24/// Axum HTTP server implementation
25pub struct AxumHttpServer {
26    config: HttpServerConfig,
27    tx: broadcast::Sender<Message>,
28}
29
30impl Clone for AxumHttpServer {
31    fn clone(&self) -> Self {
32        Self {
33            config: self.config.clone(),
34            tx: self.tx.clone(),
35        }
36    }
37}
38
39impl AxumHttpServer {
40    /// Create a new Axum HTTP server
41    pub fn new(config: HttpServerConfig) -> Self {
42        let (tx, _) = broadcast::channel(32);
43        Self { config, tx }
44    }
45
46    /// Create Axum router
47    fn create_router(state: Arc<Self>) -> Router {
48        Router::new()
49            .route("/events", get(Self::sse_handler))
50            .route("/messages", post(Self::message_handler))
51            .with_state(state)
52    }
53
54    /// SSE event handler
55    async fn sse_handler(
56        State(state): State<Arc<Self>>,
57    ) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
58        let mut rx = state.tx.subscribe();
59        let stream = async_stream::stream! {
60            while let Ok(msg) = rx.recv().await {
61                if let Ok(json) = serde_json::to_string(&msg) {
62                    yield Ok(Event::default().data(json));
63                }
64            }
65        };
66
67        Sse::new(stream)
68    }
69
70    /// Message handler
71    async fn message_handler(
72        State(state): State<Arc<Self>>,
73        Json(message): Json<Message>,
74    ) -> impl IntoResponse {
75        match state.tx.send(message) {
76            Ok(_) => (axum::http::StatusCode::OK, "Message sent").into_response(),
77            Err(e) => (
78                axum::http::StatusCode::INTERNAL_SERVER_ERROR,
79                format!("Failed to broadcast message: {}", e),
80            )
81                .into_response(),
82        }
83    }
84}
85
86#[async_trait]
87impl super::HttpTransport for AxumHttpServer {
88    async fn initialize(&mut self) -> Result<()> {
89        let app = Self::create_router(Arc::new(self.clone()));
90        let addr = self.config.addr;
91
92        tokio::spawn(async move {
93            axum::serve(tokio::net::TcpListener::bind(addr).await.unwrap(), app)
94                .await
95                .unwrap();
96        });
97
98        Ok(())
99    }
100
101    async fn send(&self, message: Message) -> Result<()> {
102        self.tx
103            .send(message)
104            .map_err(|e| crate::Error::Transport(e.to_string()))?;
105        Ok(())
106    }
107
108    async fn receive(&self) -> Result<Message> {
109        let mut rx = self.tx.subscribe();
110        rx.recv()
111            .await
112            .map_err(|e| crate::Error::Transport(e.to_string()))
113    }
114
115    async fn close(&mut self) -> Result<()> {
116        // Axum server will close automatically when dropped
117        Ok(())
118    }
119}
120
121/// Default HTTP server type
122pub type DefaultHttpServer = AxumHttpServer;