use crate::transport::{Error, PendingRequests, TransportMessage};
use async_trait::async_trait;
use eventsource_client::{Client, SSE};
use futures::TryStreamExt;
use mcp_core_fishcode2025::protocol::{JsonRpcMessage, JsonRpcRequest};
use reqwest::Client as HttpClient;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tokio::time::{timeout, Duration};
use tracing::warn;
use url::Url;
use super::{send_message, Transport, TransportHandle};
const ENDPOINT_TIMEOUT_SECS: u64 = 5;
pub struct SseActor {
receiver: mpsc::Receiver<TransportMessage>,
pending_requests: Arc<PendingRequests>,
sse_url: String,
http_client: HttpClient,
post_endpoint: Arc<RwLock<Option<String>>>,
}
impl SseActor {
pub fn new(
receiver: mpsc::Receiver<TransportMessage>,
pending_requests: Arc<PendingRequests>,
sse_url: String,
post_endpoint: Arc<RwLock<Option<String>>>,
) -> Self {
Self {
receiver,
pending_requests,
sse_url,
post_endpoint,
http_client: HttpClient::new(),
}
}
pub async fn run(self) {
tokio::join!(
Self::handle_incoming_messages(
self.sse_url.clone(),
Arc::clone(&self.pending_requests),
Arc::clone(&self.post_endpoint)
),
Self::handle_outgoing_messages(
self.receiver,
self.http_client.clone(),
Arc::clone(&self.post_endpoint),
Arc::clone(&self.pending_requests),
)
);
}
async fn handle_incoming_messages(
sse_url: String,
pending_requests: Arc<PendingRequests>,
post_endpoint: Arc<RwLock<Option<String>>>,
) {
let client = match eventsource_client::ClientBuilder::for_url(&sse_url) {
Ok(builder) => builder.build(),
Err(e) => {
pending_requests.clear().await;
warn!("Failed to connect SSE client: {}", e);
return;
}
};
let mut stream = client.stream();
while let Ok(Some(event)) = stream.try_next().await {
match event {
SSE::Event(e) if e.event_type == "endpoint" => {
let base_url = Url::parse(&sse_url).expect("Invalid base URL");
let post_url = base_url
.join(&e.data)
.expect("Failed to resolve endpoint URL");
tracing::debug!("Discovered SSE POST endpoint: {}", post_url);
*post_endpoint.write().await = Some(post_url.to_string());
break;
}
_ => continue,
}
}
while let Ok(Some(event)) = stream.try_next().await {
match event {
SSE::Event(e) if e.event_type == "message" => {
match serde_json::from_str::<JsonRpcMessage>(&e.data) {
Ok(message) => {
match &message {
JsonRpcMessage::Response(response) => {
if let Some(id) = &response.id {
pending_requests
.respond(&id.to_string(), Ok(message))
.await;
}
}
JsonRpcMessage::Error(error) => {
if let Some(id) = &error.id {
pending_requests
.respond(&id.to_string(), Ok(message))
.await;
}
}
_ => {} }
}
Err(err) => {
warn!("Failed to parse SSE message: {err}");
}
}
}
_ => { }
}
}
tracing::error!("SSE stream ended or encountered an error; clearing pending requests.");
pending_requests.clear().await;
}
async fn handle_outgoing_messages(
mut receiver: mpsc::Receiver<TransportMessage>,
http_client: HttpClient,
post_endpoint: Arc<RwLock<Option<String>>>,
pending_requests: Arc<PendingRequests>,
) {
while let Some(transport_msg) = receiver.recv().await {
let post_url = match post_endpoint.read().await.as_ref() {
Some(url) => url.clone(),
None => {
if let Some(response_tx) = transport_msg.response_tx {
let _ = response_tx.send(Err(Error::NotConnected));
}
continue;
}
};
let message_str = match serde_json::to_string(&transport_msg.message) {
Ok(s) => s,
Err(e) => {
if let Some(tx) = transport_msg.response_tx {
let _ = tx.send(Err(Error::Serialization(e)));
}
continue;
}
};
if let Some(response_tx) = transport_msg.response_tx {
if let JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) =
&transport_msg.message
{
pending_requests.insert(id.to_string(), response_tx).await;
}
}
match http_client
.post(&post_url)
.header("Content-Type", "application/json")
.body(message_str)
.send()
.await
{
Ok(resp) => {
if !resp.status().is_success() {
let err = Error::HttpError {
status: resp.status().as_u16(),
message: resp.status().to_string(),
};
warn!("HTTP request returned error: {err}");
}
}
Err(e) => {
warn!("HTTP POST failed: {e}");
}
}
}
tracing::error!("SseActor: outgoing message loop ended. Clearing pending requests.");
pending_requests.clear().await;
}
}
#[derive(Clone)]
pub struct SseTransportHandle {
sender: mpsc::Sender<TransportMessage>,
}
#[async_trait::async_trait]
impl TransportHandle for SseTransportHandle {
async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error> {
send_message(&self.sender, message).await
}
}
#[derive(Clone)]
pub struct SseTransport {
sse_url: String,
env: HashMap<String, String>,
}
impl SseTransport {
pub fn new<S: Into<String>>(sse_url: S, env: HashMap<String, String>) -> Self {
Self {
sse_url: sse_url.into(),
env,
}
}
async fn wait_for_endpoint(
post_endpoint: Arc<RwLock<Option<String>>>,
) -> Result<String, Error> {
let check_interval = Duration::from_millis(100);
let mut attempts = 0;
let max_attempts = 10;
while attempts < max_attempts {
if let Some(url) = post_endpoint.read().await.clone() {
return Ok(url);
}
tokio::time::sleep(check_interval).await;
attempts += 1;
}
Err(Error::SseConnection("No endpoint discovered".to_string()))
}
}
#[async_trait]
impl Transport for SseTransport {
type Handle = SseTransportHandle;
async fn start(&self) -> Result<Self::Handle, Error> {
for (key, value) in &self.env {
std::env::set_var(key, value);
}
let (tx, rx) = mpsc::channel(32);
let post_endpoint: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
let post_endpoint_clone = Arc::clone(&post_endpoint);
let actor = SseActor::new(
rx,
Arc::new(PendingRequests::new()),
self.sse_url.clone(),
post_endpoint,
);
tokio::spawn(actor.run());
match timeout(
Duration::from_secs(ENDPOINT_TIMEOUT_SECS),
Self::wait_for_endpoint(post_endpoint_clone),
)
.await
{
Ok(_) => Ok(SseTransportHandle { sender: tx }),
Err(e) => Err(Error::SseConnection(e.to_string())),
}
}
async fn close(&self) -> Result<(), Error> {
Ok(())
}
}