mcprotocol_rs/transport/http/
server.rs1use crate::{protocol::Message, Result};
2use async_trait::async_trait;
3use axum::{
4 extract::State,
5 response::{
6 sse::{Event, Sse},
7 IntoResponse,
8 },
9 routing::{get, post},
10 Json, Router,
11};
12use futures::stream::Stream;
13use std::{convert::Infallible, net::SocketAddr, sync::Arc};
14use tokio::sync::broadcast;
15use tokio_stream::StreamExt;
16
17#[derive(Clone)]
19pub struct HttpServerConfig {
20 pub addr: SocketAddr,
21 pub auth_token: Option<String>,
22}
23
24pub struct AxumHttpServer {
26 config: HttpServerConfig,
27 tx: broadcast::Sender<Message>,
28}
29
30impl Clone for AxumHttpServer {
31 fn clone(&self) -> Self {
32 Self {
33 config: self.config.clone(),
34 tx: self.tx.clone(),
35 }
36 }
37}
38
39impl AxumHttpServer {
40 pub fn new(config: HttpServerConfig) -> Self {
42 let (tx, _) = broadcast::channel(32);
43 Self { config, tx }
44 }
45
46 fn create_router(state: Arc<Self>) -> Router {
48 Router::new()
49 .route("/events", get(Self::sse_handler))
50 .route("/messages", post(Self::message_handler))
51 .with_state(state)
52 }
53
54 async fn sse_handler(
56 State(state): State<Arc<Self>>,
57 ) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
58 let mut rx = state.tx.subscribe();
59 let stream = async_stream::stream! {
60 while let Ok(msg) = rx.recv().await {
61 if let Ok(json) = serde_json::to_string(&msg) {
62 yield Ok(Event::default().data(json));
63 }
64 }
65 };
66
67 Sse::new(stream)
68 }
69
70 async fn message_handler(
72 State(state): State<Arc<Self>>,
73 Json(message): Json<Message>,
74 ) -> impl IntoResponse {
75 match state.tx.send(message) {
76 Ok(_) => (axum::http::StatusCode::OK, "Message sent").into_response(),
77 Err(e) => (
78 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
79 format!("Failed to broadcast message: {}", e),
80 )
81 .into_response(),
82 }
83 }
84}
85
86#[async_trait]
87impl super::HttpTransport for AxumHttpServer {
88 async fn initialize(&mut self) -> Result<()> {
89 let app = Self::create_router(Arc::new(self.clone()));
90 let addr = self.config.addr;
91
92 tokio::spawn(async move {
93 axum::serve(tokio::net::TcpListener::bind(addr).await.unwrap(), app)
94 .await
95 .unwrap();
96 });
97
98 Ok(())
99 }
100
101 async fn send(&self, message: Message) -> Result<()> {
102 self.tx
103 .send(message)
104 .map_err(|e| crate::Error::Transport(e.to_string()))?;
105 Ok(())
106 }
107
108 async fn receive(&self) -> Result<Message> {
109 let mut rx = self.tx.subscribe();
110 rx.recv()
111 .await
112 .map_err(|e| crate::Error::Transport(e.to_string()))
113 }
114
115 async fn close(&mut self) -> Result<()> {
116 Ok(())
118 }
119}
120
121pub type DefaultHttpServer = AxumHttpServer;