use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::mpsc;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::thread;
use futures_util::{SinkExt, StreamExt};
use serde_json::Value;
use tokio::sync::oneshot;
use tokio::sync::{mpsc as async_mpsc, Mutex as TokioMutex};
use tokio_tungstenite::{connect_async, tungstenite::Message};
use log::{debug, info, warn};
use crate::wamp;
const DEFAULT_WAAPI_URL: &str = "ws://localhost:8080/waapi";
const DEFAULT_REALM: &str = "realm1";
#[derive(Debug, thiserror::Error)]
pub enum WaapiError {
#[error("client already disconnected")]
Disconnected,
#[error("WAMP error: {0}")]
Wamp(String),
#[error("WebSocket error: {0}")]
WebSocket(#[from] Box<tokio_tungstenite::tungstenite::Error>),
#[error("{0}")]
Serde(#[from] serde_json::Error),
#[error("{0}")]
Io(#[from] std::io::Error),
}
type CallResult = Result<Option<Value>, WaapiError>;
type SubResult = Result<u64, WaapiError>;
type UnsubResult = Result<(), WaapiError>;
pub type EventPayload = (u64, Option<Value>);
type WsSink = futures_util::stream::SplitSink<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
Message,
>;
#[derive(Debug)]
struct WampConn {
ws_tx: TokioMutex<WsSink>,
pending_calls: StdMutex<HashMap<u64, oneshot::Sender<CallResult>>>,
pending_subs: StdMutex<HashMap<u64, oneshot::Sender<SubResult>>>,
pending_unsubs: StdMutex<HashMap<u64, oneshot::Sender<UnsubResult>>>,
event_senders: StdMutex<HashMap<u64, async_mpsc::UnboundedSender<EventPayload>>>,
next_id: AtomicU64,
}
impl WampConn {
fn new(sink: WsSink) -> Self {
Self {
ws_tx: TokioMutex::new(sink),
pending_calls: StdMutex::new(HashMap::new()),
pending_subs: StdMutex::new(HashMap::new()),
pending_unsubs: StdMutex::new(HashMap::new()),
event_senders: StdMutex::new(HashMap::new()),
next_id: AtomicU64::new(1),
}
}
fn next_id(&self) -> u64 {
self.next_id.fetch_add(1, Ordering::Relaxed)
}
async fn send(&self, text: String) -> Result<(), WaapiError> {
self.ws_tx
.lock()
.await
.send(Message::Text(text.into()))
.await
.map_err(|e| WaapiError::WebSocket(Box::new(e)))
}
}
type WsStream = tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>;
async fn run_event_loop(
conn: Arc<WampConn>,
mut ws_rx: futures_util::stream::SplitStream<WsStream>,
connected: Arc<AtomicBool>,
) {
while let Some(msg) = ws_rx.next().await {
match msg {
Ok(Message::Text(text)) => {
if let Some(wamp_msg) = wamp::parse(&text) {
dispatch(&conn, wamp_msg);
}
}
Ok(Message::Close(_)) | Err(_) => break,
_ => {}
}
}
connected.store(false, Ordering::Release);
drain_pending(&conn);
}
fn dispatch(conn: &WampConn, msg: wamp::WampMessage) {
match msg {
wamp::WampMessage::Result { request_id, kwargs } => {
if let Some(tx) = conn
.pending_calls
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&request_id)
{
let _ = tx.send(Ok(kwargs));
}
}
wamp::WampMessage::Error {
request_type,
request_id,
error,
} => {
let err_str = error.clone();
if request_type == 48 {
if let Some(tx) = conn
.pending_calls
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&request_id)
{
let _ = tx.send(Err(WaapiError::Wamp(err_str)));
return;
}
}
if request_type == 32 {
if let Some(tx) = conn
.pending_subs
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&request_id)
{
let _ = tx.send(Err(WaapiError::Wamp(error)));
return;
}
}
if request_type == 34 {
if let Some(tx) = conn
.pending_unsubs
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&request_id)
{
let _ = tx.send(Err(WaapiError::Wamp(error)));
}
}
}
wamp::WampMessage::Subscribed {
request_id,
sub_id,
} => {
if let Some(tx) = conn
.pending_subs
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&request_id)
{
let _ = tx.send(Ok(sub_id));
}
}
wamp::WampMessage::Unsubscribed { request_id } => {
if let Some(tx) = conn
.pending_unsubs
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&request_id)
{
let _ = tx.send(Ok(()));
}
}
wamp::WampMessage::Event {
sub_id,
pub_id,
kwargs,
} => {
let senders = conn
.event_senders
.lock()
.unwrap_or_else(|e| e.into_inner());
if let Some(tx) = senders.get(&sub_id) {
let _ = tx.send((pub_id, kwargs));
}
}
wamp::WampMessage::Goodbye | wamp::WampMessage::Welcome { .. } => {}
}
}
fn drain_pending(conn: &WampConn) {
let calls: Vec<_> = conn
.pending_calls
.lock()
.unwrap_or_else(|e| e.into_inner())
.drain()
.collect();
for (_, tx) in calls {
let _ = tx.send(Err(WaapiError::Disconnected));
}
let subs: Vec<_> = conn
.pending_subs
.lock()
.unwrap_or_else(|e| e.into_inner())
.drain()
.collect();
for (_, tx) in subs {
let _ = tx.send(Err(WaapiError::Disconnected));
}
let unsubs: Vec<_> = conn
.pending_unsubs
.lock()
.unwrap_or_else(|e| e.into_inner())
.drain()
.collect();
for (_, tx) in unsubs {
let _ = tx.send(Err(WaapiError::Disconnected));
}
}
async fn read_welcome(
ws_rx: &mut futures_util::stream::SplitStream<WsStream>,
) -> Result<u64, WaapiError> {
loop {
match ws_rx.next().await {
Some(Ok(Message::Text(text))) => {
if let Some(wamp::WampMessage::Welcome { session_id }) = wamp::parse(&text) {
return Ok(session_id);
}
return Err(WaapiError::Wamp(format!("expected WELCOME, got: {text}")));
}
Some(Ok(_)) => continue, Some(Err(e)) => return Err(WaapiError::WebSocket(Box::new(e))),
None => return Err(WaapiError::Disconnected),
}
}
}
#[derive(Debug)]
pub struct SubscriptionHandle {
sub_id: u64,
conn: Arc<WampConn>,
subscription_ids: Arc<StdMutex<Vec<u64>>>,
recv_task: Option<tokio::task::JoinHandle<()>>,
is_unsubscribed: bool,
}
fn mark_unsubscribed(flag: &mut bool) -> bool {
if *flag {
false
} else {
*flag = true;
true
}
}
impl SubscriptionHandle {
pub async fn unsubscribe(mut self) -> Result<(), WaapiError> {
debug!("Unsubscribing sub_id={}", self.sub_id);
if let Some(task) = self.recv_task.take() {
task.abort();
}
self.subscription_ids
.lock()
.unwrap_or_else(|e| e.into_inner())
.retain(|&id| id != self.sub_id);
self.conn
.event_senders
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&self.sub_id);
if !mark_unsubscribed(&mut self.is_unsubscribed) {
return Ok(());
}
do_network_unsubscribe(&self.conn, self.sub_id).await
}
}
async fn do_network_unsubscribe(conn: &WampConn, sub_id: u64) -> Result<(), WaapiError> {
let id = conn.next_id();
let (tx, rx) = oneshot::channel();
conn.pending_unsubs
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(id, tx);
conn.send(wamp::unsubscribe_msg(id, sub_id)).await?;
rx.await.unwrap_or(Err(WaapiError::Disconnected))
}
impl Drop for SubscriptionHandle {
fn drop(&mut self) {
let sub_id = self.sub_id;
let conn = Arc::clone(&self.conn);
let subscription_ids = Arc::clone(&self.subscription_ids);
if let Some(task) = self.recv_task.take() {
task.abort();
}
subscription_ids
.lock()
.unwrap_or_else(|e| e.into_inner())
.retain(|&id| id != sub_id);
conn.event_senders
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&sub_id);
if !mark_unsubscribed(&mut self.is_unsubscribed) {
return;
}
if let Ok(rt) = tokio::runtime::Handle::try_current() {
debug!("SubscriptionHandle dropped, spawning unsubscribe for sub_id={sub_id}");
rt.spawn(async move {
let _ = do_network_unsubscribe(&conn, sub_id).await;
});
} else {
warn!("SubscriptionHandle dropped without runtime, skipping network unsubscribe for sub_id={sub_id}");
}
}
}
#[derive(Debug)]
pub struct WaapiClient {
conn: Option<Arc<WampConn>>,
event_loop_handle: Option<tokio::task::JoinHandle<()>>,
subscription_ids: Arc<StdMutex<Vec<u64>>>,
connected: Arc<AtomicBool>,
}
impl WaapiClient {
pub async fn connect() -> Result<Self, WaapiError> {
Self::connect_with_url(DEFAULT_WAAPI_URL).await
}
pub async fn connect_with_url(url: &str) -> Result<Self, WaapiError> {
info!("Connecting to WAAPI at {url}");
let (ws_stream, _) = connect_async(url).await.map_err(|e| WaapiError::WebSocket(Box::new(e)))?;
let (ws_tx, mut ws_rx) = ws_stream.split();
let conn = Arc::new(WampConn::new(ws_tx));
conn.send(wamp::hello_msg(DEFAULT_REALM)).await?;
let _session_id = read_welcome(&mut ws_rx).await?;
let connected = Arc::new(AtomicBool::new(true));
let connected_flag = Arc::clone(&connected);
let conn_for_loop = Arc::clone(&conn);
let handle = tokio::spawn(async move {
run_event_loop(conn_for_loop, ws_rx, connected_flag).await;
});
info!("Connected to WAAPI at {url}");
Ok(Self {
conn: Some(conn),
event_loop_handle: Some(handle),
subscription_ids: Arc::new(StdMutex::new(Vec::new())),
connected,
})
}
pub async fn call(
&self,
uri: &str,
args: Option<Value>,
options: Option<Value>,
) -> Result<Option<Value>, WaapiError> {
let conn = self.conn.as_ref().ok_or(WaapiError::Disconnected)?;
let id = conn.next_id();
let (tx, rx) = oneshot::channel();
conn.pending_calls
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(id, tx);
debug!("Calling WAAPI: {uri} (id={id})");
conn.send(wamp::call_msg(id, uri, args.as_ref(), options.as_ref()))
.await?;
rx.await.unwrap_or(Err(WaapiError::Disconnected))
}
pub(crate) async fn subscribe_inner(
&self,
topic: &str,
options: Option<Value>,
) -> Result<
(
SubscriptionHandle,
async_mpsc::UnboundedReceiver<EventPayload>,
),
WaapiError,
> {
let conn = self.conn.as_ref().ok_or(WaapiError::Disconnected)?;
let id = conn.next_id();
let (tx, rx) = oneshot::channel();
conn.pending_subs
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(id, tx);
conn.send(wamp::subscribe_msg(id, topic, options.as_ref()))
.await?;
let sub_id = rx.await.unwrap_or(Err(WaapiError::Disconnected))?;
debug!("Subscribed to {topic} (sub_id={sub_id})");
let (event_tx, event_rx) = async_mpsc::unbounded_channel();
conn.event_senders
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(sub_id, event_tx);
self.subscription_ids
.lock()
.unwrap_or_else(|e| e.into_inner())
.push(sub_id);
let handle = SubscriptionHandle {
sub_id,
conn: Arc::clone(conn),
subscription_ids: Arc::clone(&self.subscription_ids),
recv_task: None,
is_unsubscribed: false,
};
Ok((handle, event_rx))
}
pub async fn subscribe<F>(
&self,
topic: &str,
options: Option<Value>,
callback: F,
) -> Result<SubscriptionHandle, WaapiError>
where
F: Fn(Option<Value>) + Send + Sync + 'static,
{
let (mut handle, mut event_rx) = self.subscribe_inner(topic, options).await?;
let recv_task = tokio::spawn(async move {
while let Some((_pub_id, kwargs)) = event_rx.recv().await {
callback(kwargs);
}
});
handle.recv_task = Some(recv_task);
Ok(handle)
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.conn.is_some() && self.connected.load(Ordering::Acquire)
}
pub async fn disconnect(mut self) {
info!("Disconnecting from WAAPI");
self.cleanup().await;
info!("Disconnected from WAAPI");
}
async fn cleanup(&mut self) {
self.connected.store(false, Ordering::Release);
if let Some(conn) = self.conn.take() {
let ids: Vec<u64> = {
let mut guard = self.subscription_ids.lock().unwrap_or_else(|e| e.into_inner());
std::mem::take(&mut *guard)
};
for sub_id in ids {
let id = conn.next_id();
let (tx, rx) = oneshot::channel();
conn.pending_unsubs
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(id, tx);
if conn.send(wamp::unsubscribe_msg(id, sub_id)).await.is_ok() {
let _ = rx.await;
}
}
let _ = conn.send(wamp::goodbye_msg()).await;
let _ = conn.ws_tx.lock().await.close().await;
}
if let Some(handle) = self.event_loop_handle.take() {
handle.abort();
}
}
}
impl Drop for WaapiClient {
fn drop(&mut self) {
if self.conn.is_some() || self.event_loop_handle.is_some() {
let conn = self.conn.take();
let event_loop = self.event_loop_handle.take();
let subscription_ids = Arc::clone(&self.subscription_ids);
let connected = Arc::clone(&self.connected);
connected.store(false, Ordering::Release);
if let Ok(rt) = tokio::runtime::Handle::try_current() {
debug!("WaapiClient dropped, spawning async cleanup");
rt.spawn(async move {
if let Some(conn) = conn {
let ids: Vec<u64> = {
let mut guard =
subscription_ids.lock().unwrap_or_else(|e| e.into_inner());
std::mem::take(&mut *guard)
};
for sub_id in ids {
let id = conn.next_id();
let (tx, rx) = oneshot::channel::<UnsubResult>();
conn.pending_unsubs
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(id, tx);
if conn.send(wamp::unsubscribe_msg(id, sub_id)).await.is_ok() {
let _ = rx.await;
}
}
let _ = conn.send(wamp::goodbye_msg()).await;
let _ = conn.ws_tx.lock().await.close().await;
}
if let Some(h) = event_loop {
h.abort();
}
});
} else {
warn!("WaapiClient dropped without runtime, skipping graceful cleanup");
if let Some(h) = event_loop {
h.abort();
}
}
}
}
}
#[derive(Debug)]
pub struct SubscriptionHandleSync {
runtime: Arc<tokio::runtime::Runtime>,
inner: Option<SubscriptionHandle>,
bridge_join: Option<thread::JoinHandle<()>>,
bridge_thread_id: Option<thread::ThreadId>,
}
impl SubscriptionHandleSync {
pub fn unsubscribe(mut self) -> Result<(), WaapiError> {
let inner = self.inner.take();
let bridge_join = self.bridge_join.take();
if let Some(h) = inner {
self.runtime.block_on(h.unsubscribe())?;
}
if let Some(jh) = bridge_join {
let _ = jh.join();
}
Ok(())
}
}
impl Drop for SubscriptionHandleSync {
fn drop(&mut self) {
let is_bridge_thread = self.bridge_thread_id.as_ref() == Some(&thread::current().id());
let inner = self.inner.take();
let bridge_join = self.bridge_join.take();
let runtime = Arc::clone(&self.runtime);
if let Some(h) = inner {
if tokio::runtime::Handle::try_current().is_ok() {
warn!("SubscriptionHandleSync dropped inside async context, falling back to spawn");
runtime.handle().spawn(async move {
let _ = h.unsubscribe().await;
});
} else {
let _ = runtime.block_on(h.unsubscribe());
}
}
if !is_bridge_thread {
if let Some(jh) = bridge_join {
let _ = jh.join();
}
}
}
}
#[derive(Debug)]
pub struct WaapiClientSync {
runtime: Arc<tokio::runtime::Runtime>,
client: Option<WaapiClient>,
}
impl WaapiClientSync {
pub fn connect() -> Result<Self, WaapiError> {
Self::connect_with_url(DEFAULT_WAAPI_URL)
}
pub fn connect_with_url(url: &str) -> Result<Self, WaapiError> {
info!("Connecting to WAAPI (sync) at {url}");
let runtime = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?,
);
let client = runtime.block_on(WaapiClient::connect_with_url(url))?;
info!("Connected to WAAPI (sync) at {url}");
Ok(Self {
runtime,
client: Some(client),
})
}
pub fn call(
&self,
uri: &str,
args: Option<Value>,
options: Option<Value>,
) -> Result<Option<Value>, WaapiError> {
let client = self.client.as_ref().ok_or(WaapiError::Disconnected)?;
self.runtime.block_on(client.call(uri, args, options))
}
pub fn subscribe<F>(
&self,
topic: &str,
options: Option<Value>,
callback: F,
) -> Result<SubscriptionHandleSync, WaapiError>
where
F: Fn(Option<Value>) + Send + Sync + 'static,
{
let client = self.client.as_ref().ok_or(WaapiError::Disconnected)?;
let (inner, mut async_rx) = self
.runtime
.block_on(client.subscribe_inner(topic, options))?;
let (id_tx, id_rx) = mpsc::channel();
let runtime = Arc::clone(&self.runtime);
let bridge_join = thread::spawn(move || {
let _ = id_tx.send(thread::current().id());
while let Some((_pub_id, kwargs)) = runtime.block_on(async_rx.recv()) {
callback(kwargs);
}
});
let bridge_thread_id = id_rx.recv().ok();
Ok(SubscriptionHandleSync {
runtime: Arc::clone(&self.runtime),
inner: Some(inner),
bridge_join: Some(bridge_join),
bridge_thread_id,
})
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.client.as_ref().is_some_and(|c| c.is_connected())
}
pub fn disconnect(mut self) {
info!("Disconnecting from WAAPI (sync)");
if let Some(client) = self.client.take() {
self.runtime.block_on(client.disconnect());
}
info!("Disconnected from WAAPI (sync)");
}
}
impl Drop for WaapiClientSync {
fn drop(&mut self) {
if let Some(client) = self.client.take() {
if tokio::runtime::Handle::try_current().is_ok() {
warn!("WaapiClientSync dropped inside async context, offloading cleanup to a dedicated thread");
let runtime = Arc::clone(&self.runtime);
let _ = thread::Builder::new()
.name("waapi-sync-drop-cleanup".to_string())
.spawn(move || {
runtime.block_on(client.disconnect());
});
} else {
self.runtime.block_on(client.disconnect());
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mark_unsubscribed_is_idempotent() {
let mut is_unsubscribed = false;
assert!(mark_unsubscribed(&mut is_unsubscribed));
assert!(!mark_unsubscribed(&mut is_unsubscribed));
}
#[tokio::test]
async fn test_sync_client_drop_inside_async_context_is_safe() {
let runtime = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("failed to create runtime"),
);
let async_client = WaapiClient {
conn: None,
event_loop_handle: None,
subscription_ids: Arc::new(StdMutex::new(Vec::new())),
connected: Arc::new(AtomicBool::new(false)),
};
let sync_client = WaapiClientSync {
runtime,
client: Some(async_client),
};
drop(sync_client);
}
}