Skip to main content

clasp_bridge/
websocket.rs

1//! WebSocket Bridge for CLASP
2//!
3//! Provides bidirectional WebSocket connectivity for CLASP.
4//! Supports both client and server modes.
5
6use crate::{Bridge, BridgeConfig, BridgeError, BridgeEvent, Result};
7use async_trait::async_trait;
8use clasp_core::{Message, SetMessage, Value};
9use futures::{SinkExt, StreamExt};
10use parking_lot::Mutex;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use tokio::net::{TcpListener, TcpStream};
16use tokio::sync::mpsc;
17use tokio_tungstenite::{
18    accept_async, connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream,
19    WebSocketStream,
20};
21use tracing::{debug, error, info, warn};
22
23/// WebSocket message format
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
25#[serde(rename_all = "lowercase")]
26pub enum WsMessageFormat {
27    /// JSON text messages
28    #[default]
29    Json,
30    /// MessagePack binary messages
31    MsgPack,
32    /// Raw binary/text passthrough
33    Raw,
34}
35
36/// WebSocket bridge mode
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38#[serde(rename_all = "lowercase")]
39pub enum WsMode {
40    /// Connect to a WebSocket server
41    #[default]
42    Client,
43    /// Act as a WebSocket server
44    Server,
45}
46
47/// WebSocket Bridge configuration
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct WebSocketBridgeConfig {
50    /// Mode: client or server
51    #[serde(default)]
52    pub mode: WsMode,
53    /// URL for client mode (ws://...) or bind address for server mode (0.0.0.0:8080)
54    pub url: String,
55    /// Path for server mode (e.g., "/ws")
56    #[serde(default)]
57    pub path: Option<String>,
58    /// Message format
59    #[serde(default)]
60    pub format: WsMessageFormat,
61    /// Ping interval in seconds (0 to disable)
62    #[serde(default = "default_ping_interval")]
63    pub ping_interval_secs: u32,
64    /// Auto-reconnect on disconnect (client mode only)
65    #[serde(default = "default_true")]
66    pub auto_reconnect: bool,
67    /// Reconnect delay in seconds
68    #[serde(default = "default_reconnect_delay")]
69    pub reconnect_delay_secs: u32,
70    /// Custom headers for client mode
71    #[serde(default)]
72    pub headers: HashMap<String, String>,
73    /// CLASP namespace prefix for incoming messages
74    #[serde(default = "default_namespace")]
75    pub namespace: String,
76}
77
78fn default_true() -> bool {
79    true
80}
81
82fn default_ping_interval() -> u32 {
83    30
84}
85
86fn default_reconnect_delay() -> u32 {
87    5
88}
89
90fn default_namespace() -> String {
91    "/ws".to_string()
92}
93
94impl Default for WebSocketBridgeConfig {
95    fn default() -> Self {
96        Self {
97            mode: WsMode::Client,
98            url: "ws://localhost:8080".to_string(),
99            path: None,
100            format: WsMessageFormat::Json,
101            ping_interval_secs: 30,
102            auto_reconnect: true,
103            reconnect_delay_secs: 5,
104            headers: HashMap::new(),
105            namespace: "/ws".to_string(),
106        }
107    }
108}
109
110/// WebSocket client connection
111#[allow(dead_code)]
112type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
113
114/// WebSocket server connection (raw TCP)
115#[allow(dead_code)]
116type WsServerStream = WebSocketStream<TcpStream>;
117
118/// Type alias for split sink
119#[allow(dead_code)]
120type WsSink = futures::stream::SplitSink<WsServerStream, WsMessage>;
121
122/// WebSocket Bridge implementation
123pub struct WebSocketBridge {
124    config: BridgeConfig,
125    ws_config: WebSocketBridgeConfig,
126    running: Arc<Mutex<bool>>,
127    send_tx: Option<mpsc::Sender<WsMessage>>,
128    shutdown_tx: Option<mpsc::Sender<()>>,
129}
130
131impl WebSocketBridge {
132    /// Create a new WebSocket bridge
133    pub fn new(ws_config: WebSocketBridgeConfig) -> Self {
134        let config = BridgeConfig {
135            name: "WebSocket Bridge".to_string(),
136            protocol: "websocket".to_string(),
137            bidirectional: true,
138            ..Default::default()
139        };
140
141        Self {
142            config,
143            ws_config,
144            running: Arc::new(Mutex::new(false)),
145            send_tx: None,
146            shutdown_tx: None,
147        }
148    }
149
150    /// Parse incoming WebSocket message to CLASP
151    fn parse_message(msg: &WsMessage, format: WsMessageFormat, prefix: &str) -> Option<Message> {
152        match msg {
153            WsMessage::Text(text) => match format {
154                WsMessageFormat::Json | WsMessageFormat::Raw => {
155                    // Try to parse as JSON
156                    if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
157                        let address = json
158                            .get("address")
159                            .and_then(|v| v.as_str())
160                            .map(|s| s.to_string())
161                            .unwrap_or_else(|| format!("{}/message", prefix));
162
163                        let value = json
164                            .get("value")
165                            .map(|v| Self::json_to_value(v.clone()))
166                            .or_else(|| json.get("data").map(|v| Self::json_to_value(v.clone())))
167                            .unwrap_or_else(|| Self::json_to_value(json));
168
169                        Some(Message::Set(SetMessage {
170                            address,
171                            value,
172                            revision: None,
173                            lock: false,
174                            unlock: false,
175                            ttl: None,
176                        }))
177                    } else {
178                        // Plain text message
179                        Some(Message::Set(SetMessage {
180                            address: format!("{}/text", prefix),
181                            value: Value::String(text.clone()),
182                            revision: None,
183                            lock: false,
184                            unlock: false,
185                            ttl: None,
186                        }))
187                    }
188                }
189                WsMessageFormat::MsgPack => {
190                    // MsgPack format expects binary messages, but try to parse text as JSON
191                    // and convert to CLASP message for better interoperability
192                    if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
193                        // Check if it looks like a CLASP message format
194                        if let (Some(addr), Some(val)) = (
195                            json.get("address").and_then(|a| a.as_str()),
196                            json.get("value"),
197                        ) {
198                            Some(Message::Set(SetMessage {
199                                address: addr.to_string(),
200                                value: Self::json_to_value(val.clone()),
201                                revision: None,
202                                lock: false,
203                                unlock: false,
204                                ttl: None,
205                            }))
206                        } else {
207                            // Wrap text as a message for the namespace
208                            Some(Message::Set(SetMessage {
209                                address: format!("{}/text", prefix),
210                                value: Value::String(text.clone()),
211                                revision: None,
212                                lock: false,
213                                unlock: false,
214                                ttl: None,
215                            }))
216                        }
217                    } else {
218                        // Plain text, wrap as message
219                        Some(Message::Set(SetMessage {
220                            address: format!("{}/text", prefix),
221                            value: Value::String(text.clone()),
222                            revision: None,
223                            lock: false,
224                            unlock: false,
225                            ttl: None,
226                        }))
227                    }
228                }
229            },
230            WsMessage::Binary(data) => match format {
231                WsMessageFormat::MsgPack => {
232                    // Try to decode as CLASP message
233                    if let Ok((msg, _)) = clasp_core::codec::decode(data) {
234                        Some(msg)
235                    } else {
236                        // Return as bytes
237                        Some(Message::Set(SetMessage {
238                            address: format!("{}/binary", prefix),
239                            value: Value::Bytes(data.clone()),
240                            revision: None,
241                            lock: false,
242                            unlock: false,
243                            ttl: None,
244                        }))
245                    }
246                }
247                WsMessageFormat::Raw | WsMessageFormat::Json => Some(Message::Set(SetMessage {
248                    address: format!("{}/binary", prefix),
249                    value: Value::Bytes(data.clone()),
250                    revision: None,
251                    lock: false,
252                    unlock: false,
253                    ttl: None,
254                })),
255            },
256            _ => None,
257        }
258    }
259
260    /// Convert JSON value to CLASP Value
261    fn json_to_value(json: serde_json::Value) -> Value {
262        match json {
263            serde_json::Value::Null => Value::Null,
264            serde_json::Value::Bool(b) => Value::Bool(b),
265            serde_json::Value::Number(n) => {
266                if let Some(i) = n.as_i64() {
267                    Value::Int(i)
268                } else if let Some(f) = n.as_f64() {
269                    Value::Float(f)
270                } else {
271                    Value::Null
272                }
273            }
274            serde_json::Value::String(s) => Value::String(s),
275            serde_json::Value::Array(arr) => {
276                Value::Array(arr.into_iter().map(Self::json_to_value).collect())
277            }
278            serde_json::Value::Object(obj) => {
279                let map: HashMap<String, Value> = obj
280                    .into_iter()
281                    .map(|(k, v)| (k, Self::json_to_value(v)))
282                    .collect();
283                Value::Map(map)
284            }
285        }
286    }
287
288    /// Convert CLASP message to WebSocket message
289    fn message_to_ws(msg: &Message, format: WsMessageFormat) -> Option<WsMessage> {
290        let (address, value) = match msg {
291            Message::Set(set) => (Some(&set.address), Some(&set.value)),
292            Message::Publish(pub_msg) => (Some(&pub_msg.address), pub_msg.value.as_ref()),
293            _ => return None,
294        };
295
296        match format {
297            WsMessageFormat::Json => {
298                let json = serde_json::json!({
299                    "address": address,
300                    "value": value,
301                });
302                Some(WsMessage::Text(json.to_string()))
303            }
304            WsMessageFormat::MsgPack => {
305                if let Ok(encoded) = clasp_core::codec::encode(msg) {
306                    Some(WsMessage::Binary(encoded.to_vec()))
307                } else {
308                    None
309                }
310            }
311            WsMessageFormat::Raw => {
312                if let Some(val) = value {
313                    match val {
314                        Value::String(s) => Some(WsMessage::Text(s.clone())),
315                        Value::Bytes(b) => Some(WsMessage::Binary(b.clone())),
316                        _ => {
317                            let json = serde_json::to_string(val).ok()?;
318                            Some(WsMessage::Text(json))
319                        }
320                    }
321                } else {
322                    None
323                }
324            }
325        }
326    }
327
328    /// Run client mode
329    #[allow(clippy::too_many_arguments)]
330    async fn run_client(
331        url: String,
332        format: WsMessageFormat,
333        namespace: String,
334        auto_reconnect: bool,
335        reconnect_delay: u32,
336        ping_interval_secs: u32,
337        event_tx: mpsc::Sender<BridgeEvent>,
338        mut send_rx: mpsc::Receiver<WsMessage>,
339        mut shutdown_rx: mpsc::Receiver<()>,
340        running: Arc<Mutex<bool>>,
341    ) {
342        loop {
343            info!("WebSocket connecting to {}", url);
344
345            match connect_async(&url).await {
346                Ok((ws_stream, _)) => {
347                    info!("WebSocket connected");
348                    *running.lock() = true;
349                    let _ = event_tx.send(BridgeEvent::Connected).await;
350
351                    let (mut write, mut read) = ws_stream.split();
352
353                    // Create ping interval if configured
354                    let ping_duration = if ping_interval_secs > 0 {
355                        Some(std::time::Duration::from_secs(ping_interval_secs as u64))
356                    } else {
357                        None
358                    };
359                    let mut ping_interval = ping_duration.map(tokio::time::interval);
360                    let mut awaiting_pong = false;
361
362                    loop {
363                        tokio::select! {
364                            // Handle incoming messages
365                            msg = read.next() => {
366                                match msg {
367                                    Some(Ok(ws_msg)) => {
368                                        match &ws_msg {
369                                            WsMessage::Pong(_) => {
370                                                awaiting_pong = false;
371                                                debug!("Received pong");
372                                            }
373                                            WsMessage::Ping(data) => {
374                                                // Respond with pong
375                                                if let Err(e) = write.send(WsMessage::Pong(data.clone())).await {
376                                                    error!("Failed to send pong: {}", e);
377                                                    break;
378                                                }
379                                            }
380                                            _ => {
381                                                if let Some(clasp_msg) = Self::parse_message(&ws_msg, format, &namespace) {
382                                                    let _ = event_tx.send(BridgeEvent::ToClasp(Box::new(clasp_msg))).await;
383                                                }
384                                            }
385                                        }
386                                    }
387                                    Some(Err(e)) => {
388                                        error!("WebSocket error: {}", e);
389                                        break;
390                                    }
391                                    None => {
392                                        warn!("WebSocket connection closed");
393                                        break;
394                                    }
395                                }
396                            }
397                            // Handle outgoing messages
398                            msg = send_rx.recv() => {
399                                if let Some(ws_msg) = msg {
400                                    if let Err(e) = write.send(ws_msg).await {
401                                        error!("WebSocket send error: {}", e);
402                                        break;
403                                    }
404                                }
405                            }
406                            // Handle ping interval
407                            _ = async {
408                                if let Some(ref mut interval) = ping_interval {
409                                    interval.tick().await
410                                } else {
411                                    // Never fires if ping is disabled
412                                    std::future::pending::<tokio::time::Instant>().await
413                                }
414                            } => {
415                                if awaiting_pong {
416                                    warn!("Ping timeout - no pong received");
417                                    break;
418                                }
419                                if let Err(e) = write.send(WsMessage::Ping(vec![])).await {
420                                    error!("Failed to send ping: {}", e);
421                                    break;
422                                }
423                                awaiting_pong = true;
424                                debug!("Sent ping");
425                            }
426                            // Handle shutdown
427                            _ = shutdown_rx.recv() => {
428                                info!("WebSocket shutting down");
429                                let _ = write.close().await;
430                                *running.lock() = false;
431                                return;
432                            }
433                        }
434                    }
435
436                    *running.lock() = false;
437                    let _ = event_tx
438                        .send(BridgeEvent::Disconnected {
439                            reason: Some("Connection closed".to_string()),
440                        })
441                        .await;
442                }
443                Err(e) => {
444                    error!("WebSocket connection failed: {}", e);
445                    let _ = event_tx
446                        .send(BridgeEvent::Error(format!("Connection failed: {}", e)))
447                        .await;
448                }
449            }
450
451            if !auto_reconnect {
452                *running.lock() = false;
453                return;
454            }
455
456            info!("Reconnecting in {} seconds...", reconnect_delay);
457            tokio::time::sleep(std::time::Duration::from_secs(reconnect_delay as u64)).await;
458        }
459    }
460
461    /// Run server mode
462    #[allow(clippy::too_many_arguments)]
463    async fn run_server(
464        addr: SocketAddr,
465        format: WsMessageFormat,
466        namespace: String,
467        ping_interval_secs: u32,
468        event_tx: mpsc::Sender<BridgeEvent>,
469        mut send_rx: mpsc::Receiver<WsMessage>,
470        mut shutdown_rx: mpsc::Receiver<()>,
471        running: Arc<Mutex<bool>>,
472    ) {
473        use std::sync::atomic::{AtomicU64, Ordering};
474        use tokio::sync::RwLock;
475
476        let listener = match TcpListener::bind(addr).await {
477            Ok(l) => l,
478            Err(e) => {
479                error!("Failed to bind WebSocket server: {}", e);
480                let _ = event_tx
481                    .send(BridgeEvent::Error(format!("Bind failed: {}", e)))
482                    .await;
483                return;
484            }
485        };
486
487        info!("WebSocket server listening on {}", addr);
488        *running.lock() = true;
489        let _ = event_tx.send(BridgeEvent::Connected).await;
490
491        // Track connected clients with their send channels
492        let clients: Arc<RwLock<HashMap<u64, mpsc::Sender<WsMessage>>>> =
493            Arc::new(RwLock::new(HashMap::new()));
494        let next_client_id = Arc::new(AtomicU64::new(0));
495
496        loop {
497            tokio::select! {
498                result = listener.accept() => {
499                    match result {
500                        Ok((stream, peer_addr)) => {
501                            let client_id = next_client_id.fetch_add(1, Ordering::SeqCst);
502                            info!("WebSocket client {} connected: {}", client_id, peer_addr);
503
504                            let namespace = namespace.clone();
505                            let event_tx = event_tx.clone();
506                            let clients = clients.clone();
507                            let ping_interval = ping_interval_secs;
508
509                            // Create a channel for sending to this specific client
510                            let (client_tx, mut client_rx) = mpsc::channel::<WsMessage>(100);
511                            clients.write().await.insert(client_id, client_tx);
512
513                            tokio::spawn(async move {
514                                if let Ok(ws_stream) = accept_async(stream).await {
515                                    let (mut write, mut read) = ws_stream.split();
516
517                                    // Create ping interval if configured
518                                    let ping_duration = if ping_interval > 0 {
519                                        Some(std::time::Duration::from_secs(ping_interval as u64))
520                                    } else {
521                                        None
522                                    };
523                                    let mut ping_timer = ping_duration.map(tokio::time::interval);
524                                    let mut awaiting_pong = false;
525
526                                    loop {
527                                        tokio::select! {
528                                            // Handle incoming messages from client
529                                            msg = read.next() => {
530                                                match msg {
531                                                    Some(Ok(ws_msg)) => {
532                                                        match &ws_msg {
533                                                            WsMessage::Pong(_) => {
534                                                                awaiting_pong = false;
535                                                                debug!("Client {} pong received", client_id);
536                                                            }
537                                                            WsMessage::Ping(data) => {
538                                                                if let Err(e) = write.send(WsMessage::Pong(data.clone())).await {
539                                                                    debug!("Failed to send pong to client {}: {}", client_id, e);
540                                                                    break;
541                                                                }
542                                                            }
543                                                            _ => {
544                                                                if let Some(clasp_msg) = Self::parse_message(&ws_msg, format, &namespace) {
545                                                                    let _ = event_tx.send(BridgeEvent::ToClasp(Box::new(clasp_msg))).await;
546                                                                }
547                                                            }
548                                                        }
549                                                    }
550                                                    Some(Err(e)) => {
551                                                        debug!("WebSocket client {} error: {}", client_id, e);
552                                                        break;
553                                                    }
554                                                    None => break,
555                                                }
556                                            }
557                                            // Handle outgoing messages to client
558                                            msg = client_rx.recv() => {
559                                                match msg {
560                                                    Some(ws_msg) => {
561                                                        if let Err(e) = write.send(ws_msg).await {
562                                                            debug!("Failed to send to client {}: {}", client_id, e);
563                                                            break;
564                                                        }
565                                                    }
566                                                    None => break,
567                                                }
568                                            }
569                                            // Handle ping interval
570                                            _ = async {
571                                                if let Some(ref mut timer) = ping_timer {
572                                                    timer.tick().await
573                                                } else {
574                                                    std::future::pending::<tokio::time::Instant>().await
575                                                }
576                                            } => {
577                                                if awaiting_pong {
578                                                    warn!("Client {} ping timeout", client_id);
579                                                    break;
580                                                }
581                                                if let Err(e) = write.send(WsMessage::Ping(vec![])).await {
582                                                    debug!("Failed to send ping to client {}: {}", client_id, e);
583                                                    break;
584                                                }
585                                                awaiting_pong = true;
586                                            }
587                                        }
588                                    }
589                                }
590
591                                // Clean up client on disconnect
592                                clients.write().await.remove(&client_id);
593                                info!("WebSocket client {} disconnected: {}", client_id, peer_addr);
594                            });
595                        }
596                        Err(e) => {
597                            error!("WebSocket accept error: {}", e);
598                        }
599                    }
600                }
601                // Handle messages to broadcast to all clients
602                msg = send_rx.recv() => {
603                    if let Some(ws_msg) = msg {
604                        // Broadcast to all connected clients
605                        let client_list: Vec<_> = clients.read().await.values().cloned().collect();
606                        for tx in client_list {
607                            let _ = tx.send(ws_msg.clone()).await;
608                        }
609                    }
610                }
611                _ = shutdown_rx.recv() => {
612                    info!("WebSocket server shutting down");
613                    break;
614                }
615            }
616        }
617
618        *running.lock() = false;
619        let _ = event_tx
620            .send(BridgeEvent::Disconnected {
621                reason: Some("Server stopped".to_string()),
622            })
623            .await;
624    }
625}
626
627#[async_trait]
628impl Bridge for WebSocketBridge {
629    fn config(&self) -> &BridgeConfig {
630        &self.config
631    }
632
633    async fn start(&mut self) -> Result<mpsc::Receiver<BridgeEvent>> {
634        if *self.running.lock() {
635            return Err(BridgeError::Other("Bridge already running".to_string()));
636        }
637
638        let (event_tx, event_rx) = mpsc::channel(100);
639        let (send_tx, send_rx) = mpsc::channel(100);
640        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
641
642        self.send_tx = Some(send_tx);
643        self.shutdown_tx = Some(shutdown_tx);
644
645        let running = self.running.clone();
646        let ws_config = self.ws_config.clone();
647
648        match ws_config.mode {
649            WsMode::Client => {
650                tokio::spawn(Self::run_client(
651                    ws_config.url,
652                    ws_config.format,
653                    ws_config.namespace,
654                    ws_config.auto_reconnect,
655                    ws_config.reconnect_delay_secs,
656                    ws_config.ping_interval_secs,
657                    event_tx,
658                    send_rx,
659                    shutdown_rx,
660                    running,
661                ));
662            }
663            WsMode::Server => {
664                let addr: SocketAddr = ws_config
665                    .url
666                    .parse()
667                    .map_err(|e| BridgeError::Other(format!("Invalid address: {}", e)))?;
668
669                tokio::spawn(Self::run_server(
670                    addr,
671                    ws_config.format,
672                    ws_config.namespace,
673                    ws_config.ping_interval_secs,
674                    event_tx,
675                    send_rx,
676                    shutdown_rx,
677                    running,
678                ));
679            }
680        }
681
682        info!("WebSocket bridge started in {:?} mode", self.ws_config.mode);
683        Ok(event_rx)
684    }
685
686    async fn stop(&mut self) -> Result<()> {
687        *self.running.lock() = false;
688        if let Some(tx) = self.shutdown_tx.take() {
689            let _ = tx.send(()).await;
690        }
691        self.send_tx = None;
692        info!("WebSocket bridge stopped");
693        Ok(())
694    }
695
696    async fn send(&self, msg: Message) -> Result<()> {
697        let send_tx = self
698            .send_tx
699            .as_ref()
700            .ok_or_else(|| BridgeError::Other("Not connected".to_string()))?;
701
702        if let Some(ws_msg) = Self::message_to_ws(&msg, self.ws_config.format) {
703            send_tx
704                .send(ws_msg)
705                .await
706                .map_err(|e| BridgeError::Other(format!("WebSocket send failed: {}", e)))?;
707        }
708
709        Ok(())
710    }
711
712    fn is_running(&self) -> bool {
713        *self.running.lock()
714    }
715
716    fn namespace(&self) -> &str {
717        &self.ws_config.namespace
718    }
719}
720
721#[cfg(test)]
722mod tests {
723    use super::*;
724
725    #[test]
726    fn test_config_default() {
727        let config = WebSocketBridgeConfig::default();
728        assert_eq!(config.mode, WsMode::Client);
729        assert_eq!(config.namespace, "/ws");
730    }
731
732    #[test]
733    fn test_message_formats() {
734        let prefix = "/ws";
735
736        // JSON text message
737        let ws_msg = WsMessage::Text(r#"{"address": "/test", "value": 42}"#.to_string());
738        let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Json, prefix);
739        assert!(clasp.is_some());
740
741        // Plain text
742        let ws_msg = WsMessage::Text("hello".to_string());
743        let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Json, prefix);
744        assert!(clasp.is_some());
745
746        // Binary
747        let ws_msg = WsMessage::Binary(vec![1, 2, 3]);
748        let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Raw, prefix);
749        assert!(clasp.is_some());
750    }
751}