hen 0.17.0

Run protocol-aware API request collections from the command line or through MCP.
Documentation
use std::{
    collections::HashMap,
    sync::{Arc, Mutex},
    time::Instant,
};

use reqwest::cookie::Jar;
use tokio::{
    sync::{mpsc, oneshot, Mutex as AsyncMutex},
    task::AbortHandle,
};

use super::RequestProtocol;

#[derive(Debug, Clone, Default)]
pub struct SessionRegistry {
    inner: Arc<Mutex<HashMap<String, SessionHandle>>>,
    oauth: Arc<Mutex<HashMap<String, Arc<AsyncMutex<OAuthSessionRuntime>>>>>,
}

#[derive(Debug, Clone)]
pub struct SessionHandle {
    pub name: String,
    pub protocol: RequestProtocol,
    pub metadata: HashMap<String, String>,
    runtime: Option<SessionRuntime>,
}

#[derive(Debug, Clone)]
enum SessionRuntime {
    Http(Arc<HttpSessionRuntime>),
    Sse(Arc<SseSessionRuntime>),
    Ws(Arc<WsSessionRuntime>),
}

#[derive(Debug, Default)]
pub(crate) struct HttpSessionRuntime {
    cookie_jar: Arc<Jar>,
}

#[derive(Debug, Default)]
pub(crate) struct OAuthSessionRuntime {
    pub cache_key: Option<String>,
    pub token_endpoint: Option<String>,
    pub token_fields: HashMap<String, String>,
    pub refresh_token: Option<String>,
    pub expires_at: Option<Instant>,
}

#[derive(Debug)]
pub(crate) struct SseSessionRuntime {
    receiver: AsyncMutex<mpsc::UnboundedReceiver<SseMessage>>,
    abort_handle: AbortHandle,
}

#[derive(Debug)]
pub(crate) struct WsSessionRuntime {
    receiver: AsyncMutex<mpsc::UnboundedReceiver<WsMessage>>,
    sender: mpsc::UnboundedSender<WsCommand>,
    reader_abort_handle: AbortHandle,
    writer_abort_handle: AbortHandle,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct SseMessage {
    pub event: Option<String>,
    pub id: Option<String>,
    pub data: String,
    pub raw_event: String,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum WsMessageKind {
    Text,
    Binary,
    Close,
}

impl WsMessageKind {
    pub(crate) fn as_str(self) -> &'static str {
        match self {
            Self::Text => "text",
            Self::Binary => "binary",
            Self::Close => "close",
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct WsMessage {
    pub kind: WsMessageKind,
    pub data: String,
    pub raw_message: String,
}

#[derive(Debug)]
pub(crate) struct WsCommand {
    pub payload: String,
    pub completion: oneshot::Sender<Result<(), String>>,
}

impl SseSessionRuntime {
    pub(crate) fn new(
        receiver: mpsc::UnboundedReceiver<SseMessage>,
        abort_handle: AbortHandle,
    ) -> Self {
        Self {
            receiver: AsyncMutex::new(receiver),
            abort_handle,
        }
    }

    pub(crate) async fn recv(&self) -> Option<SseMessage> {
        self.receiver.lock().await.recv().await
    }

    fn abort(&self) {
        self.abort_handle.abort();
    }
}

impl HttpSessionRuntime {
    pub(crate) fn cookie_jar(&self) -> Arc<Jar> {
        self.cookie_jar.clone()
    }
}

impl WsSessionRuntime {
    pub(crate) fn new(
        receiver: mpsc::UnboundedReceiver<WsMessage>,
        sender: mpsc::UnboundedSender<WsCommand>,
        reader_abort_handle: AbortHandle,
        writer_abort_handle: AbortHandle,
    ) -> Self {
        Self {
            receiver: AsyncMutex::new(receiver),
            sender,
            reader_abort_handle,
            writer_abort_handle,
        }
    }

    pub(crate) async fn recv(&self) -> Option<WsMessage> {
        self.receiver.lock().await.recv().await
    }

    pub(crate) async fn send(&self, payload: String) -> Result<(), String> {
        let (completion_tx, completion_rx) = oneshot::channel();
        self.sender
            .send(WsCommand {
                payload,
                completion: completion_tx,
            })
            .map_err(|_| "WebSocket session writer is closed".to_string())?;

        completion_rx
            .await
            .map_err(|_| "WebSocket session writer dropped before acknowledging the send".to_string())?
    }

    fn abort(&self) {
        self.reader_abort_handle.abort();
        self.writer_abort_handle.abort();
    }
}

impl SessionHandle {
    pub(crate) fn http_runtime(&self) -> Option<Arc<HttpSessionRuntime>> {
        match &self.runtime {
            Some(SessionRuntime::Http(runtime)) => Some(runtime.clone()),
            Some(SessionRuntime::Sse(_)) | Some(SessionRuntime::Ws(_)) | None => None,
        }
    }

    pub(crate) fn sse_runtime(&self) -> Option<Arc<SseSessionRuntime>> {
        match &self.runtime {
            Some(SessionRuntime::Http(_)) => None,
            Some(SessionRuntime::Sse(runtime)) => Some(runtime.clone()),
            Some(SessionRuntime::Ws(_)) | None => None,
        }
    }

    pub(crate) fn ws_runtime(&self) -> Option<Arc<WsSessionRuntime>> {
        match &self.runtime {
            Some(SessionRuntime::Ws(runtime)) => Some(runtime.clone()),
            None => None,
            Some(SessionRuntime::Http(_)) | Some(SessionRuntime::Sse(_)) => None,
        }
    }
}

impl SessionRegistry {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn upsert(
        &self,
        name: impl Into<String>,
        protocol: RequestProtocol,
        metadata: HashMap<String, String>,
    ) -> SessionHandle {
        self.insert_handle(SessionHandle {
            name: name.into(),
            protocol,
            metadata,
            runtime: None,
        })
    }

    pub(crate) fn get_or_create_http(
        &self,
        name: impl Into<String>,
        metadata: HashMap<String, String>,
    ) -> SessionHandle {
        let name = name.into();
        let mut sessions = self
            .inner
            .lock()
            .expect("session registry mutex poisoned");

        if let Some(existing) = sessions.get(&name) {
            return existing.clone();
        }

        let handle = SessionHandle {
            name: name.clone(),
            protocol: RequestProtocol::Http,
            metadata,
            runtime: Some(SessionRuntime::Http(Arc::new(HttpSessionRuntime::default()))),
        };
        sessions.insert(name, handle.clone());
        handle
    }

    pub(crate) fn upsert_sse(
        &self,
        name: impl Into<String>,
        metadata: HashMap<String, String>,
        receiver: mpsc::UnboundedReceiver<SseMessage>,
        abort_handle: AbortHandle,
    ) -> SessionHandle {
        self.insert_handle(SessionHandle {
            name: name.into(),
            protocol: RequestProtocol::Sse,
            metadata,
            runtime: Some(SessionRuntime::Sse(Arc::new(SseSessionRuntime::new(
                receiver,
                abort_handle,
            )))),
        })
    }

    pub(crate) fn upsert_ws(
        &self,
        name: impl Into<String>,
        metadata: HashMap<String, String>,
        receiver: mpsc::UnboundedReceiver<WsMessage>,
        sender: mpsc::UnboundedSender<WsCommand>,
        reader_abort_handle: AbortHandle,
        writer_abort_handle: AbortHandle,
    ) -> SessionHandle {
        self.insert_handle(SessionHandle {
            name: name.into(),
            protocol: RequestProtocol::Ws,
            metadata,
            runtime: Some(SessionRuntime::Ws(Arc::new(WsSessionRuntime::new(
                receiver,
                sender,
                reader_abort_handle,
                writer_abort_handle,
            )))),
        })
    }

    pub(crate) fn get_or_create_oauth(
        &self,
        name: impl Into<String>,
    ) -> Arc<AsyncMutex<OAuthSessionRuntime>> {
        let name = name.into();
        let mut runtimes = self
            .oauth
            .lock()
            .expect("oauth registry mutex poisoned");

        runtimes
            .entry(name)
            .or_insert_with(|| Arc::new(AsyncMutex::new(OAuthSessionRuntime::default())))
            .clone()
    }

    fn insert_handle(&self, handle: SessionHandle) -> SessionHandle {
        let mut sessions = self
            .inner
            .lock()
            .expect("session registry mutex poisoned");

        if let Some(previous) = sessions.insert(handle.name.clone(), handle.clone()) {
            abort_runtime(previous.runtime.as_ref());
        }

        handle
    }

    pub fn get(&self, name: &str) -> Option<SessionHandle> {
        self.inner
            .lock()
            .expect("session registry mutex poisoned")
            .get(name)
            .cloned()
    }

    pub fn remove(&self, name: &str) -> Option<SessionHandle> {
        let removed = self
            .inner
            .lock()
            .expect("session registry mutex poisoned")
            .remove(name);

        if let Some(handle) = removed.as_ref() {
            abort_runtime(handle.runtime.as_ref());
        }

        removed
    }

    pub fn clear(&self) {
        let mut sessions = self
            .inner
            .lock()
            .expect("session registry mutex poisoned");
        let mut oauth = self
            .oauth
            .lock()
            .expect("oauth registry mutex poisoned");

        for handle in sessions.values() {
            abort_runtime(handle.runtime.as_ref());
        }

        sessions.clear();
        oauth.clear();
    }
}

fn abort_runtime(runtime: Option<&SessionRuntime>) {
    match runtime {
        Some(SessionRuntime::Http(_)) => {}
        Some(SessionRuntime::Sse(runtime)) => runtime.abort(),
        Some(SessionRuntime::Ws(runtime)) => runtime.abort(),
        None => {}
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn registry_round_trips_handles() {
        let registry = SessionRegistry::new();
        let handle = registry.upsert("alpha", RequestProtocol::Http, HashMap::new());

        let loaded = registry.get("alpha").expect("session should exist");
        assert_eq!(loaded.name, handle.name);
        assert_eq!(loaded.protocol, RequestProtocol::Http);

        let removed = registry.remove("alpha").expect("session should be removed");
        assert_eq!(removed.name, "alpha");
        assert!(registry.get("alpha").is_none());
    }
}