Skip to main content

clasp_bridge/
websocket.rs

1//! WebSocket Bridge for CLASP
2//!
3//! Provides bidirectional WebSocket connectivity for CLASP.
4//! Supports both client and server modes.
5
6use crate::{Bridge, BridgeConfig, BridgeError, BridgeEvent, Result};
7use async_trait::async_trait;
8use clasp_core::{Message, PublishMessage, SetMessage, SignalType, Value};
9use futures::{SinkExt, StreamExt};
10use parking_lot::Mutex;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use tokio::net::{TcpListener, TcpStream};
16use tokio::sync::mpsc;
17use tokio_tungstenite::{
18    accept_async, connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream,
19    WebSocketStream,
20};
21use tracing::{debug, error, info, warn};
22
23/// WebSocket message format
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
25#[serde(rename_all = "lowercase")]
26pub enum WsMessageFormat {
27    /// JSON text messages
28    #[default]
29    Json,
30    /// MessagePack binary messages
31    MsgPack,
32    /// Raw binary/text passthrough
33    Raw,
34}
35
36/// WebSocket bridge mode
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38#[serde(rename_all = "lowercase")]
39pub enum WsMode {
40    /// Connect to a WebSocket server
41    #[default]
42    Client,
43    /// Act as a WebSocket server
44    Server,
45}
46
47/// WebSocket Bridge configuration
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct WebSocketBridgeConfig {
50    /// Mode: client or server
51    #[serde(default)]
52    pub mode: WsMode,
53    /// URL for client mode (ws://...) or bind address for server mode (0.0.0.0:8080)
54    pub url: String,
55    /// Path for server mode (e.g., "/ws")
56    #[serde(default)]
57    pub path: Option<String>,
58    /// Message format
59    #[serde(default)]
60    pub format: WsMessageFormat,
61    /// Ping interval in seconds (0 to disable)
62    #[serde(default = "default_ping_interval")]
63    pub ping_interval_secs: u32,
64    /// Auto-reconnect on disconnect (client mode only)
65    #[serde(default = "default_true")]
66    pub auto_reconnect: bool,
67    /// Reconnect delay in seconds
68    #[serde(default = "default_reconnect_delay")]
69    pub reconnect_delay_secs: u32,
70    /// Custom headers for client mode
71    #[serde(default)]
72    pub headers: HashMap<String, String>,
73    /// CLASP namespace prefix for incoming messages
74    #[serde(default = "default_namespace")]
75    pub namespace: String,
76}
77
78fn default_true() -> bool {
79    true
80}
81
82fn default_ping_interval() -> u32 {
83    30
84}
85
86fn default_reconnect_delay() -> u32 {
87    5
88}
89
90fn default_namespace() -> String {
91    "/ws".to_string()
92}
93
94impl Default for WebSocketBridgeConfig {
95    fn default() -> Self {
96        Self {
97            mode: WsMode::Client,
98            url: "ws://localhost:8080".to_string(),
99            path: None,
100            format: WsMessageFormat::Json,
101            ping_interval_secs: 30,
102            auto_reconnect: true,
103            reconnect_delay_secs: 5,
104            headers: HashMap::new(),
105            namespace: "/ws".to_string(),
106        }
107    }
108}
109
110/// WebSocket client connection
111type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
112
113/// WebSocket server connection (raw TCP)
114type WsServerStream = WebSocketStream<TcpStream>;
115
116/// Type alias for split sink
117type WsSink = futures::stream::SplitSink<WsServerStream, WsMessage>;
118
119/// WebSocket Bridge implementation
120pub struct WebSocketBridge {
121    config: BridgeConfig,
122    ws_config: WebSocketBridgeConfig,
123    running: Arc<Mutex<bool>>,
124    send_tx: Option<mpsc::Sender<WsMessage>>,
125    shutdown_tx: Option<mpsc::Sender<()>>,
126}
127
128impl WebSocketBridge {
129    /// Create a new WebSocket bridge
130    pub fn new(ws_config: WebSocketBridgeConfig) -> Self {
131        let config = BridgeConfig {
132            name: "WebSocket Bridge".to_string(),
133            protocol: "websocket".to_string(),
134            bidirectional: true,
135            ..Default::default()
136        };
137
138        Self {
139            config,
140            ws_config,
141            running: Arc::new(Mutex::new(false)),
142            send_tx: None,
143            shutdown_tx: None,
144        }
145    }
146
147    /// Parse incoming WebSocket message to CLASP
148    fn parse_message(msg: &WsMessage, format: WsMessageFormat, prefix: &str) -> Option<Message> {
149        match msg {
150            WsMessage::Text(text) => match format {
151                WsMessageFormat::Json | WsMessageFormat::Raw => {
152                    // Try to parse as JSON
153                    if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
154                        let address = json
155                            .get("address")
156                            .and_then(|v| v.as_str())
157                            .map(|s| s.to_string())
158                            .unwrap_or_else(|| format!("{}/message", prefix));
159
160                        let value = json
161                            .get("value")
162                            .map(|v| Self::json_to_value(v.clone()))
163                            .or_else(|| json.get("data").map(|v| Self::json_to_value(v.clone())))
164                            .unwrap_or_else(|| Self::json_to_value(json));
165
166                        Some(Message::Set(SetMessage {
167                            address,
168                            value,
169                            revision: None,
170                            lock: false,
171                            unlock: false,
172                        }))
173                    } else {
174                        // Plain text message
175                        Some(Message::Set(SetMessage {
176                            address: format!("{}/text", prefix),
177                            value: Value::String(text.clone()),
178                            revision: None,
179                            lock: false,
180                            unlock: false,
181                        }))
182                    }
183                }
184                WsMessageFormat::MsgPack => {
185                    // MsgPack format expects binary messages, but try to parse text as JSON
186                    // and convert to CLASP message for better interoperability
187                    if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
188                        // Check if it looks like a CLASP message format
189                        if let (Some(addr), Some(val)) = (
190                            json.get("address").and_then(|a| a.as_str()),
191                            json.get("value"),
192                        ) {
193                            Some(Message::Set(SetMessage {
194                                address: addr.to_string(),
195                                value: Self::json_to_value(val.clone()),
196                                revision: None,
197                                lock: false,
198                                unlock: false,
199                            }))
200                        } else {
201                            // Wrap text as a message for the namespace
202                            Some(Message::Set(SetMessage {
203                                address: format!("{}/text", prefix),
204                                value: Value::String(text.clone()),
205                                revision: None,
206                                lock: false,
207                                unlock: false,
208                            }))
209                        }
210                    } else {
211                        // Plain text, wrap as message
212                        Some(Message::Set(SetMessage {
213                            address: format!("{}/text", prefix),
214                            value: Value::String(text.clone()),
215                            revision: None,
216                            lock: false,
217                            unlock: false,
218                        }))
219                    }
220                }
221            },
222            WsMessage::Binary(data) => match format {
223                WsMessageFormat::MsgPack => {
224                    // Try to decode as CLASP message
225                    if let Ok((msg, _)) = clasp_core::codec::decode(data) {
226                        Some(msg)
227                    } else {
228                        // Return as bytes
229                        Some(Message::Set(SetMessage {
230                            address: format!("{}/binary", prefix),
231                            value: Value::Bytes(data.clone()),
232                            revision: None,
233                            lock: false,
234                            unlock: false,
235                        }))
236                    }
237                }
238                WsMessageFormat::Raw | WsMessageFormat::Json => Some(Message::Set(SetMessage {
239                    address: format!("{}/binary", prefix),
240                    value: Value::Bytes(data.clone()),
241                    revision: None,
242                    lock: false,
243                    unlock: false,
244                })),
245            },
246            _ => None,
247        }
248    }
249
250    /// Convert JSON value to CLASP Value
251    fn json_to_value(json: serde_json::Value) -> Value {
252        match json {
253            serde_json::Value::Null => Value::Null,
254            serde_json::Value::Bool(b) => Value::Bool(b),
255            serde_json::Value::Number(n) => {
256                if let Some(i) = n.as_i64() {
257                    Value::Int(i)
258                } else if let Some(f) = n.as_f64() {
259                    Value::Float(f)
260                } else {
261                    Value::Null
262                }
263            }
264            serde_json::Value::String(s) => Value::String(s),
265            serde_json::Value::Array(arr) => {
266                Value::Array(arr.into_iter().map(Self::json_to_value).collect())
267            }
268            serde_json::Value::Object(obj) => {
269                let map: HashMap<String, Value> = obj
270                    .into_iter()
271                    .map(|(k, v)| (k, Self::json_to_value(v)))
272                    .collect();
273                Value::Map(map)
274            }
275        }
276    }
277
278    /// Convert CLASP message to WebSocket message
279    fn message_to_ws(msg: &Message, format: WsMessageFormat) -> Option<WsMessage> {
280        let (address, value) = match msg {
281            Message::Set(set) => (Some(&set.address), Some(&set.value)),
282            Message::Publish(pub_msg) => (Some(&pub_msg.address), pub_msg.value.as_ref()),
283            _ => return None,
284        };
285
286        match format {
287            WsMessageFormat::Json => {
288                let json = serde_json::json!({
289                    "address": address,
290                    "value": value,
291                });
292                Some(WsMessage::Text(json.to_string()))
293            }
294            WsMessageFormat::MsgPack => {
295                if let Ok(encoded) = clasp_core::codec::encode(msg) {
296                    Some(WsMessage::Binary(encoded.to_vec()))
297                } else {
298                    None
299                }
300            }
301            WsMessageFormat::Raw => {
302                if let Some(val) = value {
303                    match val {
304                        Value::String(s) => Some(WsMessage::Text(s.clone())),
305                        Value::Bytes(b) => Some(WsMessage::Binary(b.clone())),
306                        _ => {
307                            let json = serde_json::to_string(val).ok()?;
308                            Some(WsMessage::Text(json))
309                        }
310                    }
311                } else {
312                    None
313                }
314            }
315        }
316    }
317
318    /// Run client mode
319    async fn run_client(
320        url: String,
321        format: WsMessageFormat,
322        namespace: String,
323        auto_reconnect: bool,
324        reconnect_delay: u32,
325        ping_interval_secs: u32,
326        event_tx: mpsc::Sender<BridgeEvent>,
327        mut send_rx: mpsc::Receiver<WsMessage>,
328        mut shutdown_rx: mpsc::Receiver<()>,
329        running: Arc<Mutex<bool>>,
330    ) {
331        loop {
332            info!("WebSocket connecting to {}", url);
333
334            match connect_async(&url).await {
335                Ok((ws_stream, _)) => {
336                    info!("WebSocket connected");
337                    *running.lock() = true;
338                    let _ = event_tx.send(BridgeEvent::Connected).await;
339
340                    let (mut write, mut read) = ws_stream.split();
341
342                    // Create ping interval if configured
343                    let ping_duration = if ping_interval_secs > 0 {
344                        Some(std::time::Duration::from_secs(ping_interval_secs as u64))
345                    } else {
346                        None
347                    };
348                    let mut ping_interval = ping_duration.map(tokio::time::interval);
349                    let mut awaiting_pong = false;
350
351                    loop {
352                        tokio::select! {
353                            // Handle incoming messages
354                            msg = read.next() => {
355                                match msg {
356                                    Some(Ok(ws_msg)) => {
357                                        match &ws_msg {
358                                            WsMessage::Pong(_) => {
359                                                awaiting_pong = false;
360                                                debug!("Received pong");
361                                            }
362                                            WsMessage::Ping(data) => {
363                                                // Respond with pong
364                                                if let Err(e) = write.send(WsMessage::Pong(data.clone())).await {
365                                                    error!("Failed to send pong: {}", e);
366                                                    break;
367                                                }
368                                            }
369                                            _ => {
370                                                if let Some(clasp_msg) = Self::parse_message(&ws_msg, format, &namespace) {
371                                                    let _ = event_tx.send(BridgeEvent::ToClasp(clasp_msg)).await;
372                                                }
373                                            }
374                                        }
375                                    }
376                                    Some(Err(e)) => {
377                                        error!("WebSocket error: {}", e);
378                                        break;
379                                    }
380                                    None => {
381                                        warn!("WebSocket connection closed");
382                                        break;
383                                    }
384                                }
385                            }
386                            // Handle outgoing messages
387                            msg = send_rx.recv() => {
388                                if let Some(ws_msg) = msg {
389                                    if let Err(e) = write.send(ws_msg).await {
390                                        error!("WebSocket send error: {}", e);
391                                        break;
392                                    }
393                                }
394                            }
395                            // Handle ping interval
396                            _ = async {
397                                if let Some(ref mut interval) = ping_interval {
398                                    interval.tick().await
399                                } else {
400                                    // Never fires if ping is disabled
401                                    std::future::pending::<tokio::time::Instant>().await
402                                }
403                            } => {
404                                if awaiting_pong {
405                                    warn!("Ping timeout - no pong received");
406                                    break;
407                                }
408                                if let Err(e) = write.send(WsMessage::Ping(vec![])).await {
409                                    error!("Failed to send ping: {}", e);
410                                    break;
411                                }
412                                awaiting_pong = true;
413                                debug!("Sent ping");
414                            }
415                            // Handle shutdown
416                            _ = shutdown_rx.recv() => {
417                                info!("WebSocket shutting down");
418                                let _ = write.close().await;
419                                *running.lock() = false;
420                                return;
421                            }
422                        }
423                    }
424
425                    *running.lock() = false;
426                    let _ = event_tx
427                        .send(BridgeEvent::Disconnected {
428                            reason: Some("Connection closed".to_string()),
429                        })
430                        .await;
431                }
432                Err(e) => {
433                    error!("WebSocket connection failed: {}", e);
434                    let _ = event_tx
435                        .send(BridgeEvent::Error(format!("Connection failed: {}", e)))
436                        .await;
437                }
438            }
439
440            if !auto_reconnect {
441                *running.lock() = false;
442                return;
443            }
444
445            info!("Reconnecting in {} seconds...", reconnect_delay);
446            tokio::time::sleep(std::time::Duration::from_secs(reconnect_delay as u64)).await;
447        }
448    }
449
450    /// Run server mode
451    async fn run_server(
452        addr: SocketAddr,
453        format: WsMessageFormat,
454        namespace: String,
455        ping_interval_secs: u32,
456        event_tx: mpsc::Sender<BridgeEvent>,
457        mut send_rx: mpsc::Receiver<WsMessage>,
458        mut shutdown_rx: mpsc::Receiver<()>,
459        running: Arc<Mutex<bool>>,
460    ) {
461        use std::sync::atomic::{AtomicU64, Ordering};
462        use tokio::sync::RwLock;
463
464        let listener = match TcpListener::bind(addr).await {
465            Ok(l) => l,
466            Err(e) => {
467                error!("Failed to bind WebSocket server: {}", e);
468                let _ = event_tx
469                    .send(BridgeEvent::Error(format!("Bind failed: {}", e)))
470                    .await;
471                return;
472            }
473        };
474
475        info!("WebSocket server listening on {}", addr);
476        *running.lock() = true;
477        let _ = event_tx.send(BridgeEvent::Connected).await;
478
479        // Track connected clients with their send channels
480        let clients: Arc<RwLock<HashMap<u64, mpsc::Sender<WsMessage>>>> =
481            Arc::new(RwLock::new(HashMap::new()));
482        let next_client_id = Arc::new(AtomicU64::new(0));
483
484        loop {
485            tokio::select! {
486                result = listener.accept() => {
487                    match result {
488                        Ok((stream, peer_addr)) => {
489                            let client_id = next_client_id.fetch_add(1, Ordering::SeqCst);
490                            info!("WebSocket client {} connected: {}", client_id, peer_addr);
491
492                            let format = format;
493                            let namespace = namespace.clone();
494                            let event_tx = event_tx.clone();
495                            let clients = clients.clone();
496                            let ping_interval = ping_interval_secs;
497
498                            // Create a channel for sending to this specific client
499                            let (client_tx, mut client_rx) = mpsc::channel::<WsMessage>(100);
500                            clients.write().await.insert(client_id, client_tx);
501
502                            tokio::spawn(async move {
503                                if let Ok(ws_stream) = accept_async(stream).await {
504                                    let (mut write, mut read) = ws_stream.split();
505
506                                    // Create ping interval if configured
507                                    let ping_duration = if ping_interval > 0 {
508                                        Some(std::time::Duration::from_secs(ping_interval as u64))
509                                    } else {
510                                        None
511                                    };
512                                    let mut ping_timer = ping_duration.map(tokio::time::interval);
513                                    let mut awaiting_pong = false;
514
515                                    loop {
516                                        tokio::select! {
517                                            // Handle incoming messages from client
518                                            msg = read.next() => {
519                                                match msg {
520                                                    Some(Ok(ws_msg)) => {
521                                                        match &ws_msg {
522                                                            WsMessage::Pong(_) => {
523                                                                awaiting_pong = false;
524                                                                debug!("Client {} pong received", client_id);
525                                                            }
526                                                            WsMessage::Ping(data) => {
527                                                                if let Err(e) = write.send(WsMessage::Pong(data.clone())).await {
528                                                                    debug!("Failed to send pong to client {}: {}", client_id, e);
529                                                                    break;
530                                                                }
531                                                            }
532                                                            _ => {
533                                                                if let Some(clasp_msg) = Self::parse_message(&ws_msg, format, &namespace) {
534                                                                    let _ = event_tx.send(BridgeEvent::ToClasp(clasp_msg)).await;
535                                                                }
536                                                            }
537                                                        }
538                                                    }
539                                                    Some(Err(e)) => {
540                                                        debug!("WebSocket client {} error: {}", client_id, e);
541                                                        break;
542                                                    }
543                                                    None => break,
544                                                }
545                                            }
546                                            // Handle outgoing messages to client
547                                            msg = client_rx.recv() => {
548                                                match msg {
549                                                    Some(ws_msg) => {
550                                                        if let Err(e) = write.send(ws_msg).await {
551                                                            debug!("Failed to send to client {}: {}", client_id, e);
552                                                            break;
553                                                        }
554                                                    }
555                                                    None => break,
556                                                }
557                                            }
558                                            // Handle ping interval
559                                            _ = async {
560                                                if let Some(ref mut timer) = ping_timer {
561                                                    timer.tick().await
562                                                } else {
563                                                    std::future::pending::<tokio::time::Instant>().await
564                                                }
565                                            } => {
566                                                if awaiting_pong {
567                                                    warn!("Client {} ping timeout", client_id);
568                                                    break;
569                                                }
570                                                if let Err(e) = write.send(WsMessage::Ping(vec![])).await {
571                                                    debug!("Failed to send ping to client {}: {}", client_id, e);
572                                                    break;
573                                                }
574                                                awaiting_pong = true;
575                                            }
576                                        }
577                                    }
578                                }
579
580                                // Clean up client on disconnect
581                                clients.write().await.remove(&client_id);
582                                info!("WebSocket client {} disconnected: {}", client_id, peer_addr);
583                            });
584                        }
585                        Err(e) => {
586                            error!("WebSocket accept error: {}", e);
587                        }
588                    }
589                }
590                // Handle messages to broadcast to all clients
591                msg = send_rx.recv() => {
592                    if let Some(ws_msg) = msg {
593                        // Broadcast to all connected clients
594                        let client_list: Vec<_> = clients.read().await.values().cloned().collect();
595                        for tx in client_list {
596                            let _ = tx.send(ws_msg.clone()).await;
597                        }
598                    }
599                }
600                _ = shutdown_rx.recv() => {
601                    info!("WebSocket server shutting down");
602                    break;
603                }
604            }
605        }
606
607        *running.lock() = false;
608        let _ = event_tx
609            .send(BridgeEvent::Disconnected {
610                reason: Some("Server stopped".to_string()),
611            })
612            .await;
613    }
614}
615
616#[async_trait]
617impl Bridge for WebSocketBridge {
618    fn config(&self) -> &BridgeConfig {
619        &self.config
620    }
621
622    async fn start(&mut self) -> Result<mpsc::Receiver<BridgeEvent>> {
623        if *self.running.lock() {
624            return Err(BridgeError::Other("Bridge already running".to_string()));
625        }
626
627        let (event_tx, event_rx) = mpsc::channel(100);
628        let (send_tx, send_rx) = mpsc::channel(100);
629        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
630
631        self.send_tx = Some(send_tx);
632        self.shutdown_tx = Some(shutdown_tx);
633
634        let running = self.running.clone();
635        let ws_config = self.ws_config.clone();
636
637        match ws_config.mode {
638            WsMode::Client => {
639                tokio::spawn(Self::run_client(
640                    ws_config.url,
641                    ws_config.format,
642                    ws_config.namespace,
643                    ws_config.auto_reconnect,
644                    ws_config.reconnect_delay_secs,
645                    ws_config.ping_interval_secs,
646                    event_tx,
647                    send_rx,
648                    shutdown_rx,
649                    running,
650                ));
651            }
652            WsMode::Server => {
653                let addr: SocketAddr = ws_config
654                    .url
655                    .parse()
656                    .map_err(|e| BridgeError::Other(format!("Invalid address: {}", e)))?;
657
658                tokio::spawn(Self::run_server(
659                    addr,
660                    ws_config.format,
661                    ws_config.namespace,
662                    ws_config.ping_interval_secs,
663                    event_tx,
664                    send_rx,
665                    shutdown_rx,
666                    running,
667                ));
668            }
669        }
670
671        info!("WebSocket bridge started in {:?} mode", self.ws_config.mode);
672        Ok(event_rx)
673    }
674
675    async fn stop(&mut self) -> Result<()> {
676        *self.running.lock() = false;
677        if let Some(tx) = self.shutdown_tx.take() {
678            let _ = tx.send(()).await;
679        }
680        self.send_tx = None;
681        info!("WebSocket bridge stopped");
682        Ok(())
683    }
684
685    async fn send(&self, msg: Message) -> Result<()> {
686        let send_tx = self
687            .send_tx
688            .as_ref()
689            .ok_or_else(|| BridgeError::Other("Not connected".to_string()))?;
690
691        if let Some(ws_msg) = Self::message_to_ws(&msg, self.ws_config.format) {
692            send_tx
693                .send(ws_msg)
694                .await
695                .map_err(|e| BridgeError::Other(format!("WebSocket send failed: {}", e)))?;
696        }
697
698        Ok(())
699    }
700
701    fn is_running(&self) -> bool {
702        *self.running.lock()
703    }
704
705    fn namespace(&self) -> &str {
706        &self.ws_config.namespace
707    }
708}
709
710#[cfg(test)]
711mod tests {
712    use super::*;
713
714    #[test]
715    fn test_config_default() {
716        let config = WebSocketBridgeConfig::default();
717        assert_eq!(config.mode, WsMode::Client);
718        assert_eq!(config.namespace, "/ws");
719    }
720
721    #[test]
722    fn test_message_formats() {
723        let prefix = "/ws";
724
725        // JSON text message
726        let ws_msg = WsMessage::Text(r#"{"address": "/test", "value": 42}"#.to_string());
727        let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Json, prefix);
728        assert!(clasp.is_some());
729
730        // Plain text
731        let ws_msg = WsMessage::Text("hello".to_string());
732        let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Json, prefix);
733        assert!(clasp.is_some());
734
735        // Binary
736        let ws_msg = WsMessage::Binary(vec![1, 2, 3]);
737        let clasp = WebSocketBridge::parse_message(&ws_msg, WsMessageFormat::Raw, prefix);
738        assert!(clasp.is_some());
739    }
740}