Skip to main content

stygian_graph/adapters/
websocket.rs

1//! WebSocket stream source adapter.
2//!
3//! Implements [`StreamSourcePort`](crate::ports::stream_source::StreamSourcePort) and [`ScrapingService`](crate::ports::ScrapingService) for consuming
4//! WebSocket feeds.  Uses `tokio-tungstenite` for the underlying connection.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use stygian_graph::adapters::websocket::WebSocketSource;
10//! use stygian_graph::ports::stream_source::StreamSourcePort;
11//!
12//! # async fn example() {
13//! let source = WebSocketSource::default();
14//! let events = source.subscribe("wss://api.example.com/ws", Some(10)).await.unwrap();
15//! println!("received {} events", events.len());
16//! # }
17//! ```
18
19use async_trait::async_trait;
20use futures::stream::StreamExt;
21use serde_json::json;
22use std::time::Duration;
23use tokio::time::timeout;
24use tokio_tungstenite::tungstenite::Message;
25use tokio_tungstenite::tungstenite::client::IntoClientRequest;
26
27use crate::domain::error::{Result, ServiceError, StygianError};
28use crate::ports::stream_source::{StreamEvent, StreamSourcePort};
29use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
30
31// ─── Configuration ────────────────────────────────────────────────────────────
32
33/// Configuration for a WebSocket connection.
34///
35/// # Example
36///
37/// ```
38/// use stygian_graph::adapters::websocket::WebSocketConfig;
39///
40/// let config = WebSocketConfig {
41///     subscribe_message: Some(r#"{"type":"subscribe","channel":"prices"}"#.into()),
42///     bearer_token: None,
43///     timeout_secs: 30,
44///     max_reconnect_attempts: 3,
45/// };
46/// ```
47#[derive(Debug, Clone)]
48pub struct WebSocketConfig {
49    /// Optional message to send immediately after connecting (e.g. subscribe).
50    pub subscribe_message: Option<String>,
51    /// Optional Bearer token for Authorization header on the upgrade request.
52    pub bearer_token: Option<String>,
53    /// Connection timeout in seconds.
54    pub timeout_secs: u64,
55    /// Maximum reconnection attempts on connection drop.
56    pub max_reconnect_attempts: u32,
57}
58
59impl Default for WebSocketConfig {
60    fn default() -> Self {
61        Self {
62            subscribe_message: None,
63            bearer_token: None,
64            timeout_secs: 30,
65            max_reconnect_attempts: 3,
66        }
67    }
68}
69
70// ─── Adapter ──────────────────────────────────────────────────────────────────
71
72/// WebSocket stream source adapter.
73///
74/// Connects to a WebSocket endpoint and collects messages until `max_events`
75/// is reached, the stream closes, or a connection timeout occurs.
76#[derive(Default)]
77pub struct WebSocketSource {
78    config: WebSocketConfig,
79}
80
81impl WebSocketSource {
82    /// Create a new WebSocket source with custom configuration.
83    pub const fn new(config: WebSocketConfig) -> Self {
84        Self { config }
85    }
86
87    /// Extract configuration from `ServiceInput.params` overrides.
88    fn config_from_params(&self, params: &serde_json::Value) -> WebSocketConfig {
89        let mut cfg = self.config.clone();
90        if let Some(msg) = params.get("subscribe_message").and_then(|v| v.as_str()) {
91            cfg.subscribe_message = Some(msg.to_string());
92        }
93        if let Some(token) = params.get("bearer_token").and_then(|v| v.as_str()) {
94            cfg.bearer_token = Some(token.to_string());
95        }
96        if let Some(t) = params
97            .get("timeout_secs")
98            .and_then(serde_json::Value::as_u64)
99        {
100            cfg.timeout_secs = t;
101        }
102        if let Some(r) = params
103            .get("max_reconnect_attempts")
104            .and_then(serde_json::Value::as_u64)
105        {
106            cfg.max_reconnect_attempts = u32::try_from(r).unwrap_or(u32::MAX);
107        }
108        cfg
109    }
110
111    /// Connect and collect events from a WebSocket endpoint.
112    async fn collect_events(
113        &self,
114        url: &str,
115        max_events: Option<usize>,
116        cfg: &WebSocketConfig,
117    ) -> Result<Vec<StreamEvent>> {
118        let mut request = url.into_client_request().map_err(|e| {
119            StygianError::Service(ServiceError::Unavailable(format!(
120                "invalid WebSocket URL: {e}"
121            )))
122        })?;
123
124        // Inject auth header if configured
125        if let Some(token) = &cfg.bearer_token {
126            request.headers_mut().insert(
127                reqwest::header::AUTHORIZATION,
128                format!("Bearer {token}").parse().map_err(|e| {
129                    StygianError::Service(ServiceError::Unavailable(format!(
130                        "invalid auth header: {e}"
131                    )))
132                })?,
133            );
134        }
135
136        let connect_timeout = Duration::from_secs(cfg.timeout_secs);
137        let (ws_stream, _) = timeout(connect_timeout, tokio_tungstenite::connect_async(request))
138            .await
139            .map_err(|_| {
140                StygianError::Service(ServiceError::Unavailable(
141                    "WebSocket connection timed out".into(),
142                ))
143            })?
144            .map_err(|e| {
145                StygianError::Service(ServiceError::Unavailable(format!(
146                    "WebSocket connection failed: {e}"
147                )))
148            })?;
149
150        let (mut write, mut read) = ws_stream.split();
151
152        // Send subscribe message if configured
153        if let Some(ref sub_msg) = cfg.subscribe_message {
154            use futures::SinkExt;
155            write
156                .send(Message::Text(sub_msg.clone().into()))
157                .await
158                .map_err(|e| {
159                    StygianError::Service(ServiceError::Unavailable(format!(
160                        "failed to send subscribe message: {e}"
161                    )))
162                })?;
163        }
164
165        let mut events = Vec::new();
166        let mut frame_idx: u64 = 0;
167
168        while let Some(msg_result) = timeout(Duration::from_secs(cfg.timeout_secs), read.next())
169            .await
170            .ok()
171            .flatten()
172        {
173            match msg_result {
174                Ok(msg) => {
175                    if let Some(event) = map_message_to_event(msg, frame_idx) {
176                        events.push(event);
177                        frame_idx += 1;
178
179                        if let Some(max) = max_events
180                            && events.len() >= max
181                        {
182                            break;
183                        }
184                    }
185                }
186                Err(e) => {
187                    tracing::warn!("WebSocket receive error: {e}");
188                    break;
189                }
190            }
191        }
192
193        Ok(events)
194    }
195}
196
197/// Map a WebSocket message to a [`StreamEvent`].
198///
199/// Returns `None` for internal frames (Pong, Close, Frame).
200fn map_message_to_event(msg: Message, frame_idx: u64) -> Option<StreamEvent> {
201    match msg {
202        Message::Text(text) => Some(StreamEvent {
203            id: Some(frame_idx.to_string()),
204            event_type: Some("text".into()),
205            data: text.to_string(),
206        }),
207        Message::Binary(data) => {
208            use base64::Engine;
209            let encoded = base64::engine::general_purpose::STANDARD.encode(&data);
210            Some(StreamEvent {
211                id: Some(frame_idx.to_string()),
212                event_type: Some("binary".into()),
213                data: encoded,
214            })
215        }
216        Message::Ping(data) => Some(StreamEvent {
217            id: Some(frame_idx.to_string()),
218            event_type: Some("ping".into()),
219            data: String::from_utf8_lossy(&data).to_string(),
220        }),
221        // Pong, Close, and Frame are internal — skip
222        Message::Pong(_) | Message::Close(_) | Message::Frame(_) => None,
223    }
224}
225
226// ─── StreamSourcePort ─────────────────────────────────────────────────────────
227
228#[async_trait]
229impl StreamSourcePort for WebSocketSource {
230    async fn subscribe(&self, url: &str, max_events: Option<usize>) -> Result<Vec<StreamEvent>> {
231        let cfg = self.config.clone();
232        let mut last_err = None;
233
234        for attempt in 0..=cfg.max_reconnect_attempts {
235            match self.collect_events(url, max_events, &cfg).await {
236                Ok(events) => return Ok(events),
237                Err(e) => {
238                    tracing::warn!(
239                        "WebSocket attempt {}/{} failed: {e}",
240                        attempt + 1,
241                        cfg.max_reconnect_attempts + 1
242                    );
243                    last_err = Some(e);
244
245                    if attempt < cfg.max_reconnect_attempts {
246                        // Exponential backoff: 1s, 2s, 4s ...
247                        let backoff = Duration::from_secs(1 << attempt);
248                        tokio::time::sleep(backoff).await;
249                    }
250                }
251            }
252        }
253
254        Err(last_err.unwrap_or_else(|| {
255            StygianError::Service(ServiceError::Unavailable(
256                "WebSocket connection failed after all retries".into(),
257            ))
258        }))
259    }
260
261    fn source_name(&self) -> &'static str {
262        "websocket"
263    }
264}
265
266// ─── ScrapingService ──────────────────────────────────────────────────────────
267
268#[async_trait]
269impl ScrapingService for WebSocketSource {
270    /// Collect messages from a WebSocket and return as JSON array.
271    ///
272    /// # Params (optional)
273    ///
274    /// * `max_events` — integer; maximum messages to collect.
275    /// * `subscribe_message` — string; message to send on connect.
276    /// * `bearer_token` — string; Bearer token for auth header.
277    /// * `timeout_secs` — integer; connection/read timeout.
278    async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
279        let cfg = self.config_from_params(&input.params);
280        let max_events = input
281            .params
282            .get("max_events")
283            .and_then(serde_json::Value::as_u64)
284            .map(|n| usize::try_from(n).unwrap_or(usize::MAX));
285
286        let events = self.collect_events(&input.url, max_events, &cfg).await?;
287        let count = events.len();
288
289        let data = serde_json::to_string(&events).map_err(|e| {
290            StygianError::Service(ServiceError::InvalidResponse(format!(
291                "websocket serialization failed: {e}"
292            )))
293        })?;
294
295        Ok(ServiceOutput {
296            data,
297            metadata: json!({
298                "source": "websocket",
299                "event_count": count,
300                "source_url": input.url,
301            }),
302        })
303    }
304
305    fn name(&self) -> &'static str {
306        "websocket"
307    }
308}
309
310// ─── Tests ────────────────────────────────────────────────────────────────────
311
312#[cfg(test)]
313mod tests {
314    use base64::Engine;
315
316    use super::*;
317
318    #[test]
319    fn map_text_frame() -> std::result::Result<(), Box<dyn std::error::Error>> {
320        let msg = Message::Text(r#"{"price": 42.5}"#.into());
321        let event =
322            map_message_to_event(msg, 0).ok_or_else(|| std::io::Error::other("should map"))?;
323        assert_eq!(event.id.as_deref(), Some("0"));
324        assert_eq!(event.event_type.as_deref(), Some("text"));
325        assert_eq!(event.data, r#"{"price": 42.5}"#);
326        Ok(())
327    }
328
329    #[test]
330    fn map_binary_frame_to_base64() -> std::result::Result<(), Box<dyn std::error::Error>> {
331        let data = vec![0xDE, 0xAD, 0xBE, 0xEF];
332        let msg = Message::Binary(data.into());
333        let event =
334            map_message_to_event(msg, 1).ok_or_else(|| std::io::Error::other("should map"))?;
335        assert_eq!(event.event_type.as_deref(), Some("binary"));
336        // Verify it's valid base64
337        let decoded = base64::engine::general_purpose::STANDARD.decode(&event.data)?;
338        assert_eq!(decoded, vec![0xDE, 0xAD, 0xBE, 0xEF]);
339        Ok(())
340    }
341
342    #[test]
343    fn map_ping_frame() -> std::result::Result<(), Box<dyn std::error::Error>> {
344        let msg = Message::Ping(vec![1, 2, 3].into());
345        let event =
346            map_message_to_event(msg, 2).ok_or_else(|| std::io::Error::other("should map"))?;
347        assert_eq!(event.event_type.as_deref(), Some("ping"));
348        Ok(())
349    }
350
351    #[test]
352    fn pong_frame_is_skipped() {
353        let msg = Message::Pong(vec![].into());
354        assert!(map_message_to_event(msg, 0).is_none());
355    }
356
357    #[test]
358    fn close_frame_is_skipped() {
359        let msg = Message::Close(None);
360        assert!(map_message_to_event(msg, 0).is_none());
361    }
362
363    #[test]
364    fn default_config() {
365        let cfg = WebSocketConfig::default();
366        assert_eq!(cfg.timeout_secs, 30);
367        assert_eq!(cfg.max_reconnect_attempts, 3);
368        assert!(cfg.subscribe_message.is_none());
369        assert!(cfg.bearer_token.is_none());
370    }
371
372    #[test]
373    fn config_from_params_overrides() {
374        let source = WebSocketSource::default();
375        let params = json!({
376            "subscribe_message": "{\"action\":\"sub\"}",
377            "bearer_token": "tok123",
378            "timeout_secs": 60,
379            "max_reconnect_attempts": 5
380        });
381        let cfg = source.config_from_params(&params);
382        assert_eq!(
383            cfg.subscribe_message.as_deref(),
384            Some("{\"action\":\"sub\"}")
385        );
386        assert_eq!(cfg.bearer_token.as_deref(), Some("tok123"));
387        assert_eq!(cfg.timeout_secs, 60);
388        assert_eq!(cfg.max_reconnect_attempts, 5);
389    }
390
391    #[test]
392    fn frame_index_increments() {
393        let msgs = vec![
394            Message::Text("a".into()),
395            Message::Pong(vec![].into()), // skipped
396            Message::Text("b".into()),
397        ];
398
399        let mut idx: u64 = 0;
400        let mut events = Vec::new();
401        for msg in msgs {
402            if let Some(event) = map_message_to_event(msg, idx) {
403                events.push(event);
404                idx += 1;
405            }
406        }
407
408        assert_eq!(events.len(), 2);
409        assert_eq!(events.first().and_then(|e| e.id.as_deref()), Some("0"));
410        assert_eq!(events.get(1).and_then(|e| e.id.as_deref()), Some("1"));
411    }
412
413    // Integration tests require a running WebSocket server — marked #[ignore]
414    #[tokio::test]
415    #[ignore = "requires WebSocket echo server"]
416    async fn connect_to_echo_server() -> std::result::Result<(), Box<dyn std::error::Error>> {
417        let source = WebSocketSource::new(WebSocketConfig {
418            subscribe_message: Some("hello".into()),
419            timeout_secs: 5,
420            ..WebSocketConfig::default()
421        });
422        let events = source
423            .subscribe("ws://127.0.0.1:9001/echo", Some(1))
424            .await?;
425        assert!(!events.is_empty());
426        Ok(())
427    }
428}