hen 0.16.0

Run protocol-aware API request collections from the command line or through MCP.
Documentation
use std::{
    collections::HashMap,
    sync::{Arc, Mutex},
};

use tokio::{
    sync::{mpsc, oneshot, Mutex as AsyncMutex},
    task::AbortHandle,
};

use super::RequestProtocol;

#[derive(Debug, Clone, Default)]
pub struct SessionRegistry {
    inner: Arc<Mutex<HashMap<String, SessionHandle>>>,
}

#[derive(Debug, Clone)]
pub struct SessionHandle {
    pub name: String,
    pub protocol: RequestProtocol,
    pub metadata: HashMap<String, String>,
    runtime: Option<SessionRuntime>,
}

#[derive(Debug, Clone)]
enum SessionRuntime {
    Sse(Arc<SseSessionRuntime>),
    Ws(Arc<WsSessionRuntime>),
}

#[derive(Debug)]
pub(crate) struct SseSessionRuntime {
    receiver: AsyncMutex<mpsc::UnboundedReceiver<SseMessage>>,
    abort_handle: AbortHandle,
}

#[derive(Debug)]
pub(crate) struct WsSessionRuntime {
    receiver: AsyncMutex<mpsc::UnboundedReceiver<WsMessage>>,
    sender: mpsc::UnboundedSender<WsCommand>,
    reader_abort_handle: AbortHandle,
    writer_abort_handle: AbortHandle,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct SseMessage {
    pub event: Option<String>,
    pub id: Option<String>,
    pub data: String,
    pub raw_event: String,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum WsMessageKind {
    Text,
    Binary,
    Close,
}

impl WsMessageKind {
    pub(crate) fn as_str(self) -> &'static str {
        match self {
            Self::Text => "text",
            Self::Binary => "binary",
            Self::Close => "close",
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct WsMessage {
    pub kind: WsMessageKind,
    pub data: String,
    pub raw_message: String,
}

#[derive(Debug)]
pub(crate) struct WsCommand {
    pub payload: String,
    pub completion: oneshot::Sender<Result<(), String>>,
}

impl SseSessionRuntime {
    pub(crate) fn new(
        receiver: mpsc::UnboundedReceiver<SseMessage>,
        abort_handle: AbortHandle,
    ) -> Self {
        Self {
            receiver: AsyncMutex::new(receiver),
            abort_handle,
        }
    }

    pub(crate) async fn recv(&self) -> Option<SseMessage> {
        self.receiver.lock().await.recv().await
    }

    fn abort(&self) {
        self.abort_handle.abort();
    }
}

impl WsSessionRuntime {
    pub(crate) fn new(
        receiver: mpsc::UnboundedReceiver<WsMessage>,
        sender: mpsc::UnboundedSender<WsCommand>,
        reader_abort_handle: AbortHandle,
        writer_abort_handle: AbortHandle,
    ) -> Self {
        Self {
            receiver: AsyncMutex::new(receiver),
            sender,
            reader_abort_handle,
            writer_abort_handle,
        }
    }

    pub(crate) async fn recv(&self) -> Option<WsMessage> {
        self.receiver.lock().await.recv().await
    }

    pub(crate) async fn send(&self, payload: String) -> Result<(), String> {
        let (completion_tx, completion_rx) = oneshot::channel();
        self.sender
            .send(WsCommand {
                payload,
                completion: completion_tx,
            })
            .map_err(|_| "WebSocket session writer is closed".to_string())?;

        completion_rx
            .await
            .map_err(|_| "WebSocket session writer dropped before acknowledging the send".to_string())?
    }

    fn abort(&self) {
        self.reader_abort_handle.abort();
        self.writer_abort_handle.abort();
    }
}

impl SessionHandle {
    pub(crate) fn sse_runtime(&self) -> Option<Arc<SseSessionRuntime>> {
        match &self.runtime {
            Some(SessionRuntime::Sse(runtime)) => Some(runtime.clone()),
            Some(SessionRuntime::Ws(_)) | None => None,
        }
    }

    pub(crate) fn ws_runtime(&self) -> Option<Arc<WsSessionRuntime>> {
        match &self.runtime {
            Some(SessionRuntime::Ws(runtime)) => Some(runtime.clone()),
            None => None,
            Some(SessionRuntime::Sse(_)) => None,
        }
    }
}

impl SessionRegistry {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn upsert(
        &self,
        name: impl Into<String>,
        protocol: RequestProtocol,
        metadata: HashMap<String, String>,
    ) -> SessionHandle {
        self.insert_handle(SessionHandle {
            name: name.into(),
            protocol,
            metadata,
            runtime: None,
        })
    }

    pub(crate) fn upsert_sse(
        &self,
        name: impl Into<String>,
        metadata: HashMap<String, String>,
        receiver: mpsc::UnboundedReceiver<SseMessage>,
        abort_handle: AbortHandle,
    ) -> SessionHandle {
        self.insert_handle(SessionHandle {
            name: name.into(),
            protocol: RequestProtocol::Sse,
            metadata,
            runtime: Some(SessionRuntime::Sse(Arc::new(SseSessionRuntime::new(
                receiver,
                abort_handle,
            )))),
        })
    }

    pub(crate) fn upsert_ws(
        &self,
        name: impl Into<String>,
        metadata: HashMap<String, String>,
        receiver: mpsc::UnboundedReceiver<WsMessage>,
        sender: mpsc::UnboundedSender<WsCommand>,
        reader_abort_handle: AbortHandle,
        writer_abort_handle: AbortHandle,
    ) -> SessionHandle {
        self.insert_handle(SessionHandle {
            name: name.into(),
            protocol: RequestProtocol::Ws,
            metadata,
            runtime: Some(SessionRuntime::Ws(Arc::new(WsSessionRuntime::new(
                receiver,
                sender,
                reader_abort_handle,
                writer_abort_handle,
            )))),
        })
    }

    fn insert_handle(&self, handle: SessionHandle) -> SessionHandle {
        let mut sessions = self
            .inner
            .lock()
            .expect("session registry mutex poisoned");

        if let Some(previous) = sessions.insert(handle.name.clone(), handle.clone()) {
            abort_runtime(previous.runtime.as_ref());
        }

        handle
    }

    pub fn get(&self, name: &str) -> Option<SessionHandle> {
        self.inner
            .lock()
            .expect("session registry mutex poisoned")
            .get(name)
            .cloned()
    }

    pub fn remove(&self, name: &str) -> Option<SessionHandle> {
        let removed = self
            .inner
            .lock()
            .expect("session registry mutex poisoned")
            .remove(name);

        if let Some(handle) = removed.as_ref() {
            abort_runtime(handle.runtime.as_ref());
        }

        removed
    }

    pub fn clear(&self) {
        let mut sessions = self
            .inner
            .lock()
            .expect("session registry mutex poisoned");

        for handle in sessions.values() {
            abort_runtime(handle.runtime.as_ref());
        }

        sessions.clear();
    }
}

fn abort_runtime(runtime: Option<&SessionRuntime>) {
    match runtime {
        Some(SessionRuntime::Sse(runtime)) => runtime.abort(),
        Some(SessionRuntime::Ws(runtime)) => runtime.abort(),
        None => {}
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn registry_round_trips_handles() {
        let registry = SessionRegistry::new();
        let handle = registry.upsert("alpha", RequestProtocol::Http, HashMap::new());

        let loaded = registry.get("alpha").expect("session should exist");
        assert_eq!(loaded.name, handle.name);
        assert_eq!(loaded.protocol, RequestProtocol::Http);

        let removed = registry.remove("alpha").expect("session should be removed");
        assert_eq!(removed.name, "alpha");
        assert!(registry.get("alpha").is_none());
    }
}