use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::Instant,
};
use reqwest::cookie::Jar;
use tokio::{
sync::{mpsc, oneshot, Mutex as AsyncMutex},
task::AbortHandle,
};
use super::RequestProtocol;
#[derive(Debug, Clone, Default)]
pub struct SessionRegistry {
inner: Arc<Mutex<HashMap<String, SessionHandle>>>,
oauth: Arc<Mutex<HashMap<String, Arc<AsyncMutex<OAuthSessionRuntime>>>>>,
}
#[derive(Debug, Clone)]
pub struct SessionHandle {
pub name: String,
pub protocol: RequestProtocol,
pub metadata: HashMap<String, String>,
runtime: Option<SessionRuntime>,
}
#[derive(Debug, Clone)]
enum SessionRuntime {
Http(Arc<HttpSessionRuntime>),
Sse(Arc<SseSessionRuntime>),
Ws(Arc<WsSessionRuntime>),
}
#[derive(Debug, Default)]
pub(crate) struct HttpSessionRuntime {
cookie_jar: Arc<Jar>,
}
#[derive(Debug, Default)]
pub(crate) struct OAuthSessionRuntime {
pub cache_key: Option<String>,
pub token_endpoint: Option<String>,
pub token_fields: HashMap<String, String>,
pub refresh_token: Option<String>,
pub expires_at: Option<Instant>,
}
#[derive(Debug)]
pub(crate) struct SseSessionRuntime {
receiver: AsyncMutex<mpsc::UnboundedReceiver<SseMessage>>,
abort_handle: AbortHandle,
}
#[derive(Debug)]
pub(crate) struct WsSessionRuntime {
receiver: AsyncMutex<mpsc::UnboundedReceiver<WsMessage>>,
sender: mpsc::UnboundedSender<WsCommand>,
reader_abort_handle: AbortHandle,
writer_abort_handle: AbortHandle,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct SseMessage {
pub event: Option<String>,
pub id: Option<String>,
pub data: String,
pub raw_event: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum WsMessageKind {
Text,
Binary,
Close,
}
impl WsMessageKind {
pub(crate) fn as_str(self) -> &'static str {
match self {
Self::Text => "text",
Self::Binary => "binary",
Self::Close => "close",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct WsMessage {
pub kind: WsMessageKind,
pub data: String,
pub raw_message: String,
}
#[derive(Debug)]
pub(crate) struct WsCommand {
pub payload: String,
pub completion: oneshot::Sender<Result<(), String>>,
}
impl SseSessionRuntime {
pub(crate) fn new(
receiver: mpsc::UnboundedReceiver<SseMessage>,
abort_handle: AbortHandle,
) -> Self {
Self {
receiver: AsyncMutex::new(receiver),
abort_handle,
}
}
pub(crate) async fn recv(&self) -> Option<SseMessage> {
self.receiver.lock().await.recv().await
}
fn abort(&self) {
self.abort_handle.abort();
}
}
impl HttpSessionRuntime {
pub(crate) fn cookie_jar(&self) -> Arc<Jar> {
self.cookie_jar.clone()
}
}
impl WsSessionRuntime {
pub(crate) fn new(
receiver: mpsc::UnboundedReceiver<WsMessage>,
sender: mpsc::UnboundedSender<WsCommand>,
reader_abort_handle: AbortHandle,
writer_abort_handle: AbortHandle,
) -> Self {
Self {
receiver: AsyncMutex::new(receiver),
sender,
reader_abort_handle,
writer_abort_handle,
}
}
pub(crate) async fn recv(&self) -> Option<WsMessage> {
self.receiver.lock().await.recv().await
}
pub(crate) async fn send(&self, payload: String) -> Result<(), String> {
let (completion_tx, completion_rx) = oneshot::channel();
self.sender
.send(WsCommand {
payload,
completion: completion_tx,
})
.map_err(|_| "WebSocket session writer is closed".to_string())?;
completion_rx
.await
.map_err(|_| "WebSocket session writer dropped before acknowledging the send".to_string())?
}
fn abort(&self) {
self.reader_abort_handle.abort();
self.writer_abort_handle.abort();
}
}
impl SessionHandle {
pub(crate) fn http_runtime(&self) -> Option<Arc<HttpSessionRuntime>> {
match &self.runtime {
Some(SessionRuntime::Http(runtime)) => Some(runtime.clone()),
Some(SessionRuntime::Sse(_)) | Some(SessionRuntime::Ws(_)) | None => None,
}
}
pub(crate) fn sse_runtime(&self) -> Option<Arc<SseSessionRuntime>> {
match &self.runtime {
Some(SessionRuntime::Http(_)) => None,
Some(SessionRuntime::Sse(runtime)) => Some(runtime.clone()),
Some(SessionRuntime::Ws(_)) | None => None,
}
}
pub(crate) fn ws_runtime(&self) -> Option<Arc<WsSessionRuntime>> {
match &self.runtime {
Some(SessionRuntime::Ws(runtime)) => Some(runtime.clone()),
None => None,
Some(SessionRuntime::Http(_)) | Some(SessionRuntime::Sse(_)) => None,
}
}
}
impl SessionRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn upsert(
&self,
name: impl Into<String>,
protocol: RequestProtocol,
metadata: HashMap<String, String>,
) -> SessionHandle {
self.insert_handle(SessionHandle {
name: name.into(),
protocol,
metadata,
runtime: None,
})
}
pub(crate) fn get_or_create_http(
&self,
name: impl Into<String>,
metadata: HashMap<String, String>,
) -> SessionHandle {
let name = name.into();
let mut sessions = self
.inner
.lock()
.expect("session registry mutex poisoned");
if let Some(existing) = sessions.get(&name) {
return existing.clone();
}
let handle = SessionHandle {
name: name.clone(),
protocol: RequestProtocol::Http,
metadata,
runtime: Some(SessionRuntime::Http(Arc::new(HttpSessionRuntime::default()))),
};
sessions.insert(name, handle.clone());
handle
}
pub(crate) fn upsert_sse(
&self,
name: impl Into<String>,
metadata: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<SseMessage>,
abort_handle: AbortHandle,
) -> SessionHandle {
self.insert_handle(SessionHandle {
name: name.into(),
protocol: RequestProtocol::Sse,
metadata,
runtime: Some(SessionRuntime::Sse(Arc::new(SseSessionRuntime::new(
receiver,
abort_handle,
)))),
})
}
pub(crate) fn upsert_ws(
&self,
name: impl Into<String>,
metadata: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<WsMessage>,
sender: mpsc::UnboundedSender<WsCommand>,
reader_abort_handle: AbortHandle,
writer_abort_handle: AbortHandle,
) -> SessionHandle {
self.insert_handle(SessionHandle {
name: name.into(),
protocol: RequestProtocol::Ws,
metadata,
runtime: Some(SessionRuntime::Ws(Arc::new(WsSessionRuntime::new(
receiver,
sender,
reader_abort_handle,
writer_abort_handle,
)))),
})
}
pub(crate) fn get_or_create_oauth(
&self,
name: impl Into<String>,
) -> Arc<AsyncMutex<OAuthSessionRuntime>> {
let name = name.into();
let mut runtimes = self
.oauth
.lock()
.expect("oauth registry mutex poisoned");
runtimes
.entry(name)
.or_insert_with(|| Arc::new(AsyncMutex::new(OAuthSessionRuntime::default())))
.clone()
}
fn insert_handle(&self, handle: SessionHandle) -> SessionHandle {
let mut sessions = self
.inner
.lock()
.expect("session registry mutex poisoned");
if let Some(previous) = sessions.insert(handle.name.clone(), handle.clone()) {
abort_runtime(previous.runtime.as_ref());
}
handle
}
pub fn get(&self, name: &str) -> Option<SessionHandle> {
self.inner
.lock()
.expect("session registry mutex poisoned")
.get(name)
.cloned()
}
pub fn remove(&self, name: &str) -> Option<SessionHandle> {
let removed = self
.inner
.lock()
.expect("session registry mutex poisoned")
.remove(name);
if let Some(handle) = removed.as_ref() {
abort_runtime(handle.runtime.as_ref());
}
removed
}
pub fn clear(&self) {
let mut sessions = self
.inner
.lock()
.expect("session registry mutex poisoned");
let mut oauth = self
.oauth
.lock()
.expect("oauth registry mutex poisoned");
for handle in sessions.values() {
abort_runtime(handle.runtime.as_ref());
}
sessions.clear();
oauth.clear();
}
}
fn abort_runtime(runtime: Option<&SessionRuntime>) {
match runtime {
Some(SessionRuntime::Http(_)) => {}
Some(SessionRuntime::Sse(runtime)) => runtime.abort(),
Some(SessionRuntime::Ws(runtime)) => runtime.abort(),
None => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn registry_round_trips_handles() {
let registry = SessionRegistry::new();
let handle = registry.upsert("alpha", RequestProtocol::Http, HashMap::new());
let loaded = registry.get("alpha").expect("session should exist");
assert_eq!(loaded.name, handle.name);
assert_eq!(loaded.protocol, RequestProtocol::Http);
let removed = registry.remove("alpha").expect("session should be removed");
assert_eq!(removed.name, "alpha");
assert!(registry.get("alpha").is_none());
}
}