Skip to main content

clasp_bridge/
websocket.rs

1//! WebSocket Bridge for CLASP
2//!
3//! Provides bidirectional WebSocket connectivity for CLASP.
4//! Supports both client and server modes.
5
6use crate::{Bridge, BridgeConfig, BridgeError, BridgeEvent, Result};
7use async_trait::async_trait;
8use clasp_core::{Message, SetMessage, Value};
9use futures::{SinkExt, StreamExt};
10use parking_lot::Mutex;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use tokio::net::{TcpListener, TcpStream};
16use tokio::sync::mpsc;
17use tokio_tungstenite::{
18    accept_async, connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream,
19    WebSocketStream,
20};
21use tracing::{debug, error, info, warn};
22
23/// WebSocket message format
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
25#[serde(rename_all = "lowercase")]
26pub enum WsMessageFormat {
27    /// JSON text messages
28    #[default]
29    Json,
30    /// MessagePack binary messages
31    MsgPack,
32    /// Raw binary/text passthrough
33    Raw,
34}
35
36/// WebSocket bridge mode
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38#[serde(rename_all = "lowercase")]
39pub enum WsMode {
40    /// Connect to a WebSocket server
41    #[default]
42    Client,
43    /// Act as a WebSocket server
44    Server,
45}
46
47/// WebSocket Bridge configuration
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct WebSocketBridgeConfig {
50    /// Mode: client or server
51    #[serde(default)]
52    pub mode: WsMode,
53    /// URL for client mode (ws://...) or bind address for server mode (0.0.0.0:8080)
54    pub url: String,
55    /// Path for server mode (e.g., "/ws")
56    #[serde(default)]
57    pub path: Option<String>,
58    /// Message format
59    #[serde(default)]
60    pub format: WsMessageFormat,
61    /// Ping interval in seconds (0 to disable)
62    #[serde(default = "default_ping_interval")]
63    pub ping_interval_secs: u32,
64    /// Auto-reconnect on disconnect (client mode only)
65    #[serde(default = "default_true")]
66    pub auto_reconnect: bool,
67    /// Reconnect delay in seconds
68    #[serde(default = "default_reconnect_delay")]
69    pub reconnect_delay_secs: u32,
70    /// Custom headers for client mode
71    #[serde(default)]
72    pub headers: HashMap<String, String>,
73    /// CLASP namespace prefix for incoming messages
74    #[serde(default = "default_namespace")]
75    pub namespace: String,
76}
77
78fn default_true() -> bool {
79    true
80}
81
82fn default_ping_interval() -> u32 {
83    30
84}
85
86fn default_reconnect_delay() -> u32 {
87    5
88}
89
90fn default_namespace() -> String {
91    "/ws".to_string()
92}
93
94impl Default for WebSocketBridgeConfig {
95    fn default() -> Self {
96        Self {
97            mode: WsMode::Client,
98            url: "ws://localhost:8080".to_string(),
99            path: None,
100            format: WsMessageFormat::Json,
101            ping_interval_secs: 30,
102            auto_reconnect: true,
103            reconnect_delay_secs: 5,
104            headers: HashMap::new(),
105            namespace: "/ws".to_string(),
106        }
107    }
108}
109
110/// WebSocket client connection
111#[allow(dead_code)]
112type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
113
114/// WebSocket server connection (raw TCP)
115#[allow(dead_code)]
116type WsServerStream = WebSocketStream<TcpStream>;
117
118/// Type alias for split sink
119#[allow(dead_code)]
120type WsSink = futures::stream::SplitSink<WsServerStream, WsMessage>;
121
122/// WebSocket Bridge implementation
123pub struct WebSocketBridge {
124    config: BridgeConfig,
125    ws_config: WebSocketBridgeConfig,
126    running: Arc<Mutex<bool>>,
127    send_tx: Option<mpsc::Sender<WsMessage>>,
128    shutdown_tx: Option<mpsc::Sender<()>>,
129}
130
131impl WebSocketBridge {
132    /// Create a new WebSocket bridge
133    pub fn new(ws_config: WebSocketBridgeConfig) -> Self {
134        let config = BridgeConfig {
135            name: "WebSocket Bridge".to_string(),
136            protocol: "websocket".to_string(),
137            bidirectional: true,
138            ..Default::default()
139        };
140
141        Self {
142            config,
143            ws_config,
144            running: Arc::new(Mutex::new(false)),
145            send_tx: None,
146            shutdown_tx: None,
147        }
148    }
149
150    /// Parse incoming WebSocket message to CLASP
151    fn parse_message(msg: &WsMessage, format: WsMessageFormat, prefix: &str) -> Option<Message> {
152        match msg {
153            WsMessage::Text(text) => match format {
154                WsMessageFormat::Json | WsMessageFormat::Raw => {
155                    // Try to parse as JSON
156                    if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
157                        let address = json
158                            .get("address")
159                            .and_then(|v| v.as_str())
160                            .map(|s| s.to_string())
161                            .unwrap_or_else(|| format!("{}/message", prefix));
162
163                        let value = json
164                            .get("value")
165                            .map(|v| Self::json_to_value(v.clone()))
166                            .or_else(|| json.get("data").map(|v| Self::json_to_value(v.clone())))
167                            .unwrap_or_else(|| Self::json_to_value(json));
168
169                        Some(Message::Set(SetMessage {
170                            address,
171                            value,
172                            revision: None,
173                            lock: false,
174                            unlock: false,
175                        }))
176                    } else {
177                        // Plain text message
178                        Some(Message::Set(SetMessage {
179                            address: format!("{}/text", prefix),
180                            value: Value::String(text.clone()),
181                            revision: None,
182                            lock: false,
183                            unlock: false,
184                        }))
185                    }
186                }
187                WsMessageFormat::MsgPack => {
188                    // MsgPack format expects binary messages, but try to parse text as JSON
189                    // and convert to CLASP message for better interoperability
190                    if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
191                        // Check if it looks like a CLASP message format
192                        if let (Some(addr), Some(val)) = (
193                            json.get("address").and_then(|a| a.as_str()),
194                            json.get("value"),
195                        ) {
196                            Some(Message::Set(SetMessage {
197                                address: addr.to_string(),
198                                value: Self::json_to_value(val.clone()),
199                                revision: None,
200                                lock: false,
201                                unlock: false,
202                            }))
203                        } else {
204                            // Wrap text as a message for the namespace
205                            Some(Message::Set(SetMessage {
206                                address: format!("{}/text", prefix),
207                                value: Value::String(text.clone()),
208                                revision: None,
209                                lock: false,
210                                unlock: false,
211                            }))
212                        }
213                    } else {
214                        // Plain text, wrap as message
215                        Some(Message::Set(SetMessage {
216                            address: format!("{}/text", prefix),
217                            value: Value::String(text.clone()),
218                            revision: None,
219                            lock: false,
220                            unlock: false,
221                        }))
222                    }
223                }
224            },
225            WsMessage::Binary(data) => match format {
226                WsMessageFormat::MsgPack => {
227                    // Try to decode as CLASP message
228                    if let Ok((msg, _)) = clasp_core::codec::decode(data) {
229                        Some(msg)
230                    } else {
231                        // Return as bytes
232                        Some(Message::Set(SetMessage {
233                            address: format!("{}/binary", prefix),
234                            value: Value::Bytes(data.clone()),
235                            revision: None,
236                            lock: false,
237                            unlock: false,
238                        }))
239                    }
240                }
241                WsMessageFormat::Raw | WsMessageFormat::Json => Some(Message::Set(SetMessage {
242                    address: format!("{}/binary", prefix),
243                    value: Value::Bytes(data.clone()),
244                    revision: None,
245                    lock: false,
246                    unlock: false,
247                })),
248            },
249            _ => None,
250        }
251    }
252
253    /// Convert JSON value to CLASP Value
254    fn json_to_value(json: serde_json::Value) -> Value {
255        match json {
256            serde_json::Value::Null => Value::Null,
257            serde_json::Value::Bool(b) => Value::Bool(b),
258            serde_json::Value::Number(n) => {
259                if let Some(i) = n.as_i64() {
260                    Value::Int(i)
261                } else if let Some(f) = n.as_f64() {
262                    Value::Float(f)
263                } else {
264                    Value::Null
265                }
266            }
267            serde_json::Value::String(s) => Value::String(s),
268            serde_json::Value::Array(arr) => {
269                Value::Array(arr.into_iter().map(Self::json_to_value).collect())
270            }
271            serde_json::Value::Object(obj) => {
272                let map: HashMap<String, Value> = obj
273                    .into_iter()
274                    .map(|(k, v)| (k, Self::json_to_value(v)))
275                    .collect();
276                Value::Map(map)
277            }
278        }
279    }
280
281    /// Convert CLASP message to WebSocket message
282    fn message_to_ws(msg: &Message, format: WsMessageFormat) -> Option<WsMessage> {
283        let (address, value) = match msg {
284            Message::Set(set) => (Some(&set.address), Some(&set.value)),
285            Message::Publish(pub_msg) => (Some(&pub_msg.address), pub_msg.value.as_ref()),
286            _ => return None,
287        };
288
289        match format {
290            WsMessageFormat::Json => {
291                let json = serde_json::json!({
292                    "address": address,
293                    "value": value,
294                });
295                Some(WsMessage::Text(json.to_string()))
296            }
297            WsMessageFormat::MsgPack => {
298                if let Ok(encoded) = clasp_core::codec::encode(msg) {
299                    Some(WsMessage::Binary(encoded.to_vec()))
300                } else {
301                    None
302                }
303            }
304            WsMessageFormat::Raw => {
305                if let Some(val) = value {
306                    match val {
307                        Value::String(s) => Some(WsMessage::Text(s.clone())),
308                        Value::Bytes(b) => Some(WsMessage::Binary(b.clone())),
309                        _ => {
310                            let json = serde_json::to_string(val).ok()?;
311                            Some(WsMessage::Text(json))
312                        }
313                    }
314                } else {
315                    None
316                }
317            }
318        }
319    }
320
321    /// Run client mode
322    #[allow(clippy::too_many_arguments)]
323    async fn run_client(
324        url: String,
325        format: WsMessageFormat,
326        namespace: String,
327        auto_reconnect: bool,
328        reconnect_delay: u32,
329        ping_interval_secs: u32,
330        event_tx: mpsc::Sender<BridgeEvent>,
331        mut send_rx: mpsc::Receiver<WsMessage>,
332        mut shutdown_rx: mpsc::Receiver<()>,
333        running: Arc<Mutex<bool>>,
334    ) {
335        loop {
336            info!("WebSocket connecting to {}", url);
337
338            match connect_async(&url).await {
339                Ok((ws_stream, _)) => {
340                    info!("WebSocket connected");
341                    *running.lock() = true;
342                    let _ = event_tx.send(BridgeEvent::Connected).await;
343
344                    let (mut write, mut read) = ws_stream.split();
345
346                    // Create ping interval if configured
347                    let ping_duration = if ping_interval_secs > 0 {
348                        Some(std::time::Duration::from_secs(ping_interval_secs as u64))
349                    } else {
350                        None
351                    };
352                    let mut ping_interval = ping_duration.map(tokio::time::interval);
353                    let mut awaiting_pong = false;
354
355                    loop {
356                        tokio::select! {
357                            // Handle incoming messages
358                            msg = read.next() => {
359                                match msg {
360                                    Some(Ok(ws_msg)) => {
361                                        match &ws_msg {
362                                            WsMessage::Pong(_) => {
363                                                awaiting_pong = false;
364                                                debug!("Received pong");
365                                            }
366                                            WsMessage::Ping(data) => {
367                                                // Respond with pong
368                                                if let Err(e) = write.send(WsMessage::Pong(data.clone())).await {
369                                                    error!("Failed to send pong: {}", e);
370                                                    break;
371                                                }
372                                            }
373                                            _ => {
374                                                if let Some(clasp_msg) = Self::parse_message(&ws_msg, format, &namespace) {
375                                                    let _ = event_tx.send(BridgeEvent::ToClasp(Box::new(clasp_msg))).await;
376                                                }
377                                            }
378                                        }
379                                    }
380                                    Some(Err(e)) => {
381                                        error!("WebSocket error: {}", e);
382                                        break;
383                                    }
384                                    None => {
385                                        warn!("WebSocket connection closed");
386                                        break;
387                                    }
388                                }
389                            }
390                            // Handle outgoing messages
391                            msg = send_rx.recv() => {
392                                if let Some(ws_msg) = msg {
393                                    if let Err(e) = write.send(ws_msg).await {
394                                        error!("WebSocket send error: {}", e);
395                                        break;
396                                    }
397                                }
398                            }
399                            // Handle ping interval
400                            _ = async {
401                                if let Some(ref mut interval) = ping_interval {
402                                    interval.tick().await
403                                } else {
404                                    // Never fires if ping is disabled
405                                    std::future::pending::<tokio::time::Instant>().await
406                                }
407                            } => {
408                                if awaiting_pong {
409                                    warn!("Ping timeout - no pong received");
410                                    break;
411                                }
412                                if let Err(e) = write.send(WsMessage::Ping(vec![])).await {
413                                    error!("Failed to send ping: {}", e);
414                                    break;
415                                }
416                                awaiting_pong = true;
417                                debug!("Sent ping");
418                            }
419                            // Handle shutdown
420                            _ = shutdown_rx.recv() => {
421                                info!("WebSocket shutting down");
422                                let _ = write.close().await;
423                                *running.lock() = false;
424                                return;
425                            }
426                        }
427                    }
428
429                    *running.lock() = false;
430                    let _ = event_tx
431                        .send(BridgeEvent::Disconnected {
432                            reason: Some("Connection closed".to_string()),
433                        })
434                        .await;
435                }
436                Err(e) => {
437                    error!("WebSocket connection failed: {}", e);
438                    let _ = event_tx
439                        .send(BridgeEvent::Error(format!("Connection failed: {}", e)))
440                        .await;
441                }
442            }
443
444            if !auto_reconnect {
445                *running.lock() = false;
446                return;
447            }
448
449            info!("Reconnecting in {} seconds...", reconnect_delay);
450            tokio::time::sleep(std::time::Duration::from_secs(reconnect_delay as u64)).await;
451        }
452    }
453
454    /// Run server mode
455    #[allow(clippy::too_many_arguments)]
456    async fn run_server(
457        addr: SocketAddr,
458        format: WsMessageFormat,
459        namespace: String,
460        ping_interval_secs: u32,
461        event_tx: mpsc::Sender<BridgeEvent>,
462        mut send_rx: mpsc::Receiver<WsMessage>,
463        mut shutdown_rx: mpsc::Receiver<()>,
464        running: Arc<Mutex<bool>>,
465    ) {
466        use std::sync::atomic::{AtomicU64, Ordering};
467        use tokio::sync::RwLock;
468
469        let listener = match TcpListener::bind(addr).await {
470            Ok(l) => l,
471            Err(e) => {
472                error!("Failed to bind WebSocket server: {}", e);
473                let _ = event_tx
474                    .send(BridgeEvent::Error(format!("Bind failed: {}", e)))
475                    .await;
476                return;
477            }
478        };
479
480        info!("WebSocket server listening on {}", addr);
481        *running.lock() = true;
482        let _ = event_tx.send(BridgeEvent::Connected).await;
483
484        // Track connected clients with their send channels
485        let clients: Arc<RwLock<HashMap<u64, mpsc::Sender<WsMessage>>>> =
486            Arc::new(RwLock::new(HashMap::new()));
487        let next_client_id = Arc::new(AtomicU64::new(0));
488
489        loop {
490            tokio::select! {
491                result = listener.accept() => {
492                    match result {
493                        Ok((stream, peer_addr)) => {
494                            let client_id = next_client_id.fetch_add(1, Ordering::SeqCst);
495                            info!("WebSocket client {} connected: {}", client_id, peer_addr);
496
497                            let namespace = namespace.clone();
498                            let event_tx = event_tx.clone();
499                            let clients = clients.clone();
500                            let ping_interval = ping_interval_secs;
501
502                            // Create a channel for sending to this specific client
503                            let (client_tx, mut client_rx) = mpsc::channel::<WsMessage>(100);
504                            clients.write().await.insert(client_id, client_tx);
505
506                            tokio::spawn(async move {
507                                if let Ok(ws_stream) = accept_async(stream).await {
508                                    let (mut write, mut read) = ws_stream.split();
509
510                                    // Create ping interval if configured
511                                    let ping_duration = if ping_interval > 0 {
512                                        Some(std::time::Duration::from_secs(ping_interval as u64))
513                                    } else {
514                                        None
515                                    };
516                                    let mut ping_timer = ping_duration.map(tokio::time::interval);
517                                    let mut awaiting_pong = false;
518
519                                    loop {
520                                        tokio::select! {
521                                            // Handle incoming messages from client
522                                            msg = read.next() => {
523                                                match msg {
524                                                    Some(Ok(ws_msg)) => {
525                                                        match &ws_msg {
526                                                            WsMessage::Pong(_) => {
527                                                                awaiting_pong = false;
528                                                                debug!("Client {} pong received", client_id);
529                                                            }
530                                                            WsMessage::Ping(data) => {
531                                                                if let Err(e) = write.send(WsMessage::Pong(data.clone())).await {
532                                                                    debug!("Failed to send pong to client {}: {}", client_id, e);
533                                                                    break;
534                                                                }
535                                                            }
536                                                            _ => {
537                                                                if let Some(clasp_msg) = Self::parse_message(&ws_msg, format, &namespace) {
538                                                                    let _ = event_tx.send(BridgeEvent::ToClasp(Box::new(clasp_msg))).await;
539                                                                }
540                                                            }
541                                                        }
542                                                    }
543                                                    Some(Err(e)) => {
544                                                        debug!("WebSocket client {} error: {}", client_id, e);
545                                                        break;
546                                                    }
547                                                    None => break,
548                                                }
549                                            }
550                                            // Handle outgoing messages to client
551                                            msg = client_rx.recv() => {
552                                                match msg {
553                                                    Some(ws_msg) => {
554                                                        if let Err(e) = write.send(ws_msg).await {
555                                                            debug!("Failed to send to client {}: {}", client_id, e);
556                                                            break;
557                                                        }
558                                                    }
559                                                    None => break,
560                                                }
561                                            }
562                                            // Handle ping interval
563                                            _ = async {
564                                                if let Some(ref mut timer) = ping_timer {
565                                                    timer.tick().await
566                                                } else {
567                                                    std::future::pending::<tokio::time::Instant>().await
568                                                }
569                                            } => {
570                                                if awaiting_pong {
571                                                    warn!("Client {} ping timeout", client_id);
572                                                    break;
573                                                }
574                                                if let Err(e) = write.send(WsMessage::Ping(vec![])).await {
575                                                    debug!("Failed to send ping to client {}: {}", client_id, e);
576                                                    break;
577                                                }
578                                                awaiting_pong = true;
579                                            }
580                                        }
581                                    }
582                                }
583
584                                // Clean up client on disconnect
585                                clients.write().await.remove(&client_id);
586                                info!("WebSocket client {} disconnected: {}", client_id, peer_addr);
587                            });
588                        }
589                        Err(e) => {
590                            error!("WebSocket accept error: {}", e);
591                        }
592                    }
593                }
594                // Handle messages to broadcast to all clients
595                msg = send_rx.recv() => {
596                    if let Some(ws_msg) = msg {
597                        // Broadcast to all connected clients
598                        let client_list: Vec<_> = clients.read().await.values().cloned().collect();
599                        for tx in client_list {
600                            let _ = tx.send(ws_msg.clone()).await;
601                        }
602                    }
603                }
604                _ = shutdown_rx.recv() => {
605                    info!("WebSocket server shutting down");
606                    break;
607                }
608            }
609        }
610
611        *running.lock() = false;
612        let _ = event_tx
613            .send(BridgeEvent::Disconnected {
614                reason: Some("Server stopped".to_string()),
615            })
616            .await;
617    }
618}
619
620#[async_trait]
621impl Bridge for WebSocketBridge {
622    fn config(&self) -> &BridgeConfig {
623        &self.config
624    }
625
626    async fn start(&mut self) -> Result<mpsc::Receiver<BridgeEvent>> {
627        if *self.running.lock() {
628            return Err(BridgeError::Other("Bridge already running".to_string()));
629        }
630
631        let (event_tx, event_rx) = mpsc::channel(100);
632        let (send_tx, send_rx) = mpsc::channel(100);
633        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
634
635        self.send_tx = Some(send_tx);
636        self.shutdown_tx = Some(shutdown_tx);
637
638        let running = self.running.clone();
639        let ws_config = self.ws_config.clone();
640
641        match ws_config.mode {
642            WsMode::Client => {
643                tokio::spawn(Self::run_client(
644                    ws_config.url,
645                    ws_config.format,
646                    ws_config.namespace,
647                    ws_config.auto_reconnect,
648                    ws_config.reconnect_delay_secs,
649                    ws_config.ping_interval_secs,
650                    event_tx,
651                    send_rx,
652                    shutdown_rx,
653                    running,
654                ));
655            }
656            WsMode::Server => {
657                let addr: SocketAddr = ws_config
658                    .url
659                    .parse()
660                    .map_err(|e| BridgeError::Other(format!("Invalid address: {}", e)))?;
661
662                tokio::spawn(Self::run_server(
663                    addr,
664                    ws_config.format,
665                    ws_config.namespace,
666                    ws_config.ping_interval_secs,
667                    event_tx,
668                    send_rx,
669                    shutdown_rx,
670                    running,
671                ));
672            }
673        }
674
675        info!("WebSocket bridge started in {:?} mode", self.ws_config.mode);
676        Ok(event_rx)
677    }
678
679    async fn stop(&mut self) -> Result<()> {
680        *self.running.lock() = false;
681        if let Some(tx) = self.shutdown_tx.take() {
682            let _ = tx.send(()).await;
683        }
684        self.send_tx = None;
685        info!("WebSocket bridge stopped");
686        Ok(())
687    }
688
689    async fn send(&self, msg: Message) -> Result<()> {
690        let send_tx = self
691            .send_tx
692            .as_ref()
693            .ok_or_else(|| BridgeError::Other("Not connected".to_string()))?;
694
695        if let Some(ws_msg) = Self::message_to_ws(&msg, self.ws_config.format) {
696            send_tx
697                .send(ws_msg)
698                .await
699                .map_err(|e| BridgeError::Other(format!("WebSocket send failed: {}", e)))?;
700        }
701
702        Ok(())
703    }
704
705    fn is_running(&self) -> bool {
706        *self.running.lock()
707    }
708
709    fn namespace(&self) -> &str {
710        &self.ws_config.namespace
711    }
712}
713
714#[cfg(test)]
715mod tests {
716    use super::*;
717
718    #[test]
719    fn test_config_default() {
720        let config = WebSocketBridgeConfig::default();
721        assert_eq!(config.mode, WsMode::Client);
722        assert_eq!(config.namespace, "/ws");
723    }
724
725    #[test]
726    fn test_message_formats() {
727        let prefix = "/ws";
728
729        // JSON text message
730        let ws_msg = WsMessage::Text(r#"{"address": "/test", "value": 42}"#.to_string());
731        let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Json, prefix);
732        assert!(clasp.is_some());
733
734        // Plain text
735        let ws_msg = WsMessage::Text("hello".to_string());
736        let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Json, prefix);
737        assert!(clasp.is_some());
738
739        // Binary
740        let ws_msg = WsMessage::Binary(vec![1, 2, 3]);
741        let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Raw, prefix);
742        assert!(clasp.is_some());
743    }
744}