#![cfg_attr(coverage_nightly, coverage(off))]
use crate::transport::{TransportAdapter, TransportError};
use async_trait::async_trait;
use pmcp::transport::TransportMessage;
use std::fmt::Debug;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tracing::{debug, info};
#[derive(Debug)]
pub struct HttpSseTransportAdapter {
receiver: mpsc::Receiver<TransportMessage>,
sender: mpsc::Sender<TransportMessage>,
state: Arc<RwLock<ConnectionState>>,
}
#[derive(Debug)]
struct ConnectionState {
connected: bool,
client_id: Option<String>,
}
impl HttpSseTransportAdapter {
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn serve(addr: &str) -> Result<Self, TransportError> {
info!("Starting HTTP/SSE server on {}", addr);
let (tx, rx) = mpsc::channel(100);
let state = Arc::new(RwLock::new(ConnectionState {
connected: true,
client_id: None,
}));
let server_state = state.clone();
let server_tx = tx.clone();
let addr = addr.to_string();
tokio::spawn(async move {
if let Err(e) = Self::run_http_server(&addr, server_tx, server_state).await {
tracing::error!("HTTP server error: {}", e);
}
});
Ok(Self {
receiver: rx,
sender: tx,
state,
})
}
async fn run_http_server(
addr: &str,
tx: mpsc::Sender<TransportMessage>,
state: Arc<RwLock<ConnectionState>>,
) -> Result<(), TransportError> {
use axum::{
extract::State,
response::sse::{Event, Sse},
routing::{get, post},
Json, Router,
};
use futures::stream::Stream;
use std::convert::Infallible;
let sse_state = state.clone();
let sse_handler = move || {
let state = sse_state.clone();
async move {
let stream = async_stream::stream! {
loop {
if !state.read().await.connected {
break;
}
tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
yield Ok::<_, Infallible>(Event::default().comment("keepalive"));
}
};
Sse::new(stream).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(std::time::Duration::from_secs(30))
.text("keepalive"),
)
}
};
let post_handler = move |State(tx): State<mpsc::Sender<TransportMessage>>, body: String| async move {
debug!("Received HTTP POST message");
let msg = TransportMessage::text(body);
if tx.send(msg).await.is_err() {
return Err("Failed to process message");
}
Ok::<_, &'static str>("OK")
};
let app = Router::new()
.route("/sse", get(sse_handler))
.route("/message", post(post_handler))
.with_state(tx);
let listener = tokio::net::TcpListener::bind(addr)
.await
.map_err(|e| TransportError::Connection(format!("Failed to bind: {}", e)))?;
debug!("HTTP/SSE server listening on {}", addr);
axum::serve(listener, app)
.await
.map_err(|e| TransportError::Connection(format!("HTTP server error: {}", e)))?;
Ok(())
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub async fn boxed(addr: &str) -> Result<Box<dyn TransportAdapter>, TransportError> {
Ok(Box::new(Self::serve(addr).await?))
}
}
#[async_trait]
impl TransportAdapter for HttpSseTransportAdapter {
async fn send(&mut self, message: TransportMessage) -> Result<(), TransportError> {
self.sender
.send(message)
.await
.map_err(|_| TransportError::Send("SSE send failed".to_string()))
}
async fn receive(&mut self) -> Result<TransportMessage, TransportError> {
self.receiver
.recv()
.await
.ok_or(TransportError::Receive("Connection closed".to_string()))
}
async fn close(&mut self) -> Result<(), TransportError> {
let mut state = self.state.write().await;
state.connected = false;
Ok(())
}
fn is_connected(&self) -> bool {
self.state.try_read().map(|s| s.connected).unwrap_or(false)
}
fn transport_type(&self) -> &'static str {
"http-sse"
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn test_sse_event_format(data in "\\PC+", event_type in "[a-z]+") {
let formatted = format!("event: {}\ndata: {}\n\n", event_type, data);
prop_assert!(formatted.starts_with("event: "));
prop_assert!(formatted.contains("\ndata: "));
prop_assert!(formatted.ends_with("\n\n"));
}
#[test]
fn test_keepalive_intervals(interval_secs in 1u64..120) {
let duration = std::time::Duration::from_secs(interval_secs);
prop_assert!(duration.as_secs() > 0);
prop_assert!(duration.as_secs() <= 120); }
}
#[tokio::test]
async fn test_http_sse_server_creation() {
let result = HttpSseTransportAdapter::serve("127.0.0.1:0").await;
assert!(result.is_ok());
if let Ok(transport) = result {
assert!(transport.is_connected());
assert_eq!(transport.transport_type(), "http-sse");
}
}
#[tokio::test]
async fn test_connection_state_management() {
let transport = HttpSseTransportAdapter::serve("127.0.0.1:0").await.unwrap();
assert!(transport.is_connected());
let mut transport = transport;
transport.close().await.unwrap();
assert!(!transport.is_connected());
}
#[test]
fn test_http_sse_transport_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<HttpSseTransportAdapter>();
}
#[test]
fn test_transport_type() {
let (tx, rx) = mpsc::channel(10);
let state = Arc::new(RwLock::new(ConnectionState {
connected: true,
client_id: None,
}));
let transport = HttpSseTransportAdapter {
receiver: rx,
sender: tx,
state,
};
assert_eq!(transport.transport_type(), "http-sse");
}
#[tokio::test]
async fn test_send_message() {
let (tx, mut rx) = mpsc::channel(10);
let (tx2, _rx2) = mpsc::channel(10);
let state = Arc::new(RwLock::new(ConnectionState {
connected: true,
client_id: None,
}));
let mut transport = HttpSseTransportAdapter {
receiver: rx,
sender: tx,
state,
};
drop(rx);
let result = transport.send(TransportMessage::text("test")).await;
assert!(result.is_err() || result.is_ok());
}
#[tokio::test]
async fn test_receive_closed_channel() {
let (tx, rx) = mpsc::channel::<TransportMessage>(10);
let state = Arc::new(RwLock::new(ConnectionState {
connected: true,
client_id: None,
}));
let mut transport = HttpSseTransportAdapter {
receiver: rx,
sender: tx.clone(),
state,
};
drop(tx);
let result = transport.receive().await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, TransportError::Receive(_)));
}
#[test]
fn test_connection_state_default() {
let state = ConnectionState {
connected: false,
client_id: Some("test-client".to_string()),
};
assert!(!state.connected);
assert_eq!(state.client_id, Some("test-client".to_string()));
}
}