use std::sync::{Arc, Mutex};
use agentlink_core::mqtt::{
MqttClient, MqttConfig, MqttConnectionState, MqttEvent, MqttMessage, MqttQoS,
};
use agentlink_core::error::SdkResult;
use async_trait::async_trait;
use wasm_bindgen::prelude::*;
use wasm_bindgen::closure::Closure;
use web_sys::{BinaryType, CloseEvent, ErrorEvent, MessageEvent, WebSocket};
pub struct WasmMqttClient {
websocket: Arc<Mutex<Option<WebSocket>>>,
state: Arc<Mutex<MqttConnectionState>>,
event_callback: Arc<Mutex<Option<Box<dyn Fn(MqttEvent)>>>>,
on_message_closure: Arc<Mutex<Option<Closure<dyn FnMut(MessageEvent)>>>>,
on_close_closure: Arc<Mutex<Option<Closure<dyn FnMut(CloseEvent)>>>>,
on_error_closure: Arc<Mutex<Option<Closure<dyn FnMut(ErrorEvent)>>>>,
}
impl WasmMqttClient {
pub fn new() -> Self {
Self {
websocket: Arc::new(Mutex::new(None)),
state: Arc::new(Mutex::new(MqttConnectionState::Disconnected)),
event_callback: Arc::new(Mutex::new(None)),
on_message_closure: Arc::new(Mutex::new(None)),
on_close_closure: Arc::new(Mutex::new(None)),
on_error_closure: Arc::new(Mutex::new(None)),
}
}
fn convert_mqtt_url_to_ws(url: &str) -> SdkResult<String> {
if url.starts_with("mqtts://") {
Ok(url.replace("mqtts://", "wss://") + "/mqtt")
} else if url.starts_with("mqtt://") {
Ok(url.replace("mqtt://", "ws://") + "/mqtt")
} else if url.starts_with("wss://") || url.starts_with("ws://") {
Ok(url.to_string())
} else {
Err(agentlink_core::error::SdkError::Config(
format!("Invalid MQTT URL: {}", url)
))
}
}
}
#[async_trait(?Send)]
impl MqttClient for WasmMqttClient {
async fn connect(&self, config: MqttConfig) -> SdkResult<()> {
let ws_url = Self::convert_mqtt_url_to_ws(&config.broker_url)?;
let ws = WebSocket::new(&ws_url).map_err(|e| {
agentlink_core::error::SdkError::Mqtt(format!("Failed to create WebSocket: {:?}", e))
})?;
ws.set_binary_type(BinaryType::Arraybuffer);
*self.state.lock().unwrap() = MqttConnectionState::Connecting;
let state_open = self.state.clone();
let state_close = self.state.clone();
let state_error = self.state.clone();
let callback = self.event_callback.clone();
let onopen_callback = callback.clone();
let onopen = Closure::wrap(Box::new(move |_event: web_sys::Event| {
*state_open.lock().unwrap() = MqttConnectionState::Connected;
if let Some(cb) = onopen_callback.lock().unwrap().as_ref() {
cb(MqttEvent::Connected);
}
}) as Box<dyn FnMut(_)>);
ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
onopen.forget();
let onmessage_callback = callback.clone();
let onmessage = Closure::wrap(Box::new(move |event: MessageEvent| {
if let Ok(data) = event.data().dyn_into::<js_sys::ArrayBuffer>() {
let array = js_sys::Uint8Array::new(&data);
let mut payload = vec![0u8; array.length() as usize];
array.copy_to(&mut payload);
let msg = MqttMessage::new("topic", payload);
if let Some(cb) = onmessage_callback.lock().unwrap().as_ref() {
cb(MqttEvent::MessageReceived(msg));
}
}
}) as Box<dyn FnMut(_)>);
ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
*self.on_message_closure.lock().unwrap() = Some(onmessage);
let onclose_callback = callback.clone();
let onclose = Closure::wrap(Box::new(move |_event: CloseEvent| {
*state_close.lock().unwrap() = MqttConnectionState::Disconnected;
if let Some(cb) = onclose_callback.lock().unwrap().as_ref() {
cb(MqttEvent::Disconnected);
}
}) as Box<dyn FnMut(_)>);
ws.set_onclose(Some(onclose.as_ref().unchecked_ref()));
*self.on_close_closure.lock().unwrap() = Some(onclose);
let onerror_callback = callback.clone();
let onerror = Closure::wrap(Box::new(move |event: ErrorEvent| {
*state_error.lock().unwrap() = MqttConnectionState::Failed;
if let Some(cb) = onerror_callback.lock().unwrap().as_ref() {
cb(MqttEvent::Error {
error: event.message(),
});
}
}) as Box<dyn FnMut(_)>);
ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
*self.on_error_closure.lock().unwrap() = Some(onerror);
*self.websocket.lock().unwrap() = Some(ws);
Ok(())
}
async fn disconnect(&self) -> SdkResult<()> {
if let Some(ws) = self.websocket.lock().unwrap().take() {
let _ = ws.close();
}
*self.state.lock().unwrap() = MqttConnectionState::Disconnected;
Ok(())
}
async fn subscribe(&self, topic: &str, _qos: MqttQoS) -> SdkResult<()> {
if let Some(ref ws) = *self.websocket.lock().unwrap() {
let subscribe_packet = format!("SUBSCRIBE {}", topic);
ws.send_with_str(&subscribe_packet).map_err(|e| {
agentlink_core::error::SdkError::Mqtt(format!("Subscribe failed: {:?}", e))
})?;
if let Some(cb) = self.event_callback.lock().unwrap().as_ref() {
cb(MqttEvent::Subscribed { topic: topic.to_string() });
}
}
Ok(())
}
async fn unsubscribe(&self, topic: &str) -> SdkResult<()> {
if let Some(ref ws) = *self.websocket.lock().unwrap() {
let unsubscribe_packet = format!("UNSUBSCRIBE {}", topic);
ws.send_with_str(&unsubscribe_packet).map_err(|e| {
agentlink_core::error::SdkError::Mqtt(format!("Unsubscribe failed: {:?}", e))
})?;
if let Some(cb) = self.event_callback.lock().unwrap().as_ref() {
cb(MqttEvent::Unsubscribed { topic: topic.to_string() });
}
}
Ok(())
}
async fn publish(&self, message: MqttMessage) -> SdkResult<()> {
if let Some(ref ws) = *self.websocket.lock().unwrap() {
ws.send_with_u8_array(&message.payload).map_err(|e| {
agentlink_core::error::SdkError::Mqtt(format!("Publish failed: {:?}", e))
})?;
if let Some(cb) = self.event_callback.lock().unwrap().as_ref() {
cb(MqttEvent::Published { topic: message.topic });
}
}
Ok(())
}
fn connection_state(&self) -> MqttConnectionState {
*self.state.lock().unwrap()
}
}
impl Clone for WasmMqttClient {
fn clone(&self) -> Self {
Self {
websocket: Arc::new(Mutex::new(self.websocket.lock().unwrap().clone())),
state: Arc::new(Mutex::new(*self.state.lock().unwrap())),
event_callback: Arc::new(Mutex::new(None)),
on_message_closure: Arc::new(Mutex::new(None)),
on_close_closure: Arc::new(Mutex::new(None)),
on_error_closure: Arc::new(Mutex::new(None)),
}
}
}
impl WasmMqttClient {
pub fn on_event<F>(&self, callback: F)
where
F: Fn(MqttEvent) + 'static,
{
*self.event_callback.lock().unwrap() = Some(Box::new(callback));
}
}