1use async_trait::async_trait;
20use futures::stream::StreamExt;
21use serde_json::json;
22use std::time::Duration;
23use tokio::time::timeout;
24use tokio_tungstenite::tungstenite::Message;
25use tokio_tungstenite::tungstenite::client::IntoClientRequest;
26
27use crate::domain::error::{Result, ServiceError, StygianError};
28use crate::ports::stream_source::{StreamEvent, StreamSourcePort};
29use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
30
31#[derive(Debug, Clone)]
48pub struct WebSocketConfig {
49 pub subscribe_message: Option<String>,
51 pub bearer_token: Option<String>,
53 pub timeout_secs: u64,
55 pub max_reconnect_attempts: u32,
57}
58
59impl Default for WebSocketConfig {
60 fn default() -> Self {
61 Self {
62 subscribe_message: None,
63 bearer_token: None,
64 timeout_secs: 30,
65 max_reconnect_attempts: 3,
66 }
67 }
68}
69
70#[derive(Default)]
77pub struct WebSocketSource {
78 config: WebSocketConfig,
79}
80
81impl WebSocketSource {
82 pub fn new(config: WebSocketConfig) -> Self {
84 Self { config }
85 }
86
87 fn config_from_params(&self, params: &serde_json::Value) -> WebSocketConfig {
89 let mut cfg = self.config.clone();
90 if let Some(msg) = params.get("subscribe_message").and_then(|v| v.as_str()) {
91 cfg.subscribe_message = Some(msg.to_string());
92 }
93 if let Some(token) = params.get("bearer_token").and_then(|v| v.as_str()) {
94 cfg.bearer_token = Some(token.to_string());
95 }
96 if let Some(t) = params.get("timeout_secs").and_then(|v| v.as_u64()) {
97 cfg.timeout_secs = t;
98 }
99 if let Some(r) = params
100 .get("max_reconnect_attempts")
101 .and_then(|v| v.as_u64())
102 {
103 cfg.max_reconnect_attempts = r as u32;
104 }
105 cfg
106 }
107
108 async fn collect_events(
110 &self,
111 url: &str,
112 max_events: Option<usize>,
113 cfg: &WebSocketConfig,
114 ) -> Result<Vec<StreamEvent>> {
115 let mut request = url.into_client_request().map_err(|e| {
116 StygianError::Service(ServiceError::Unavailable(format!(
117 "invalid WebSocket URL: {e}"
118 )))
119 })?;
120
121 if let Some(token) = &cfg.bearer_token {
123 request.headers_mut().insert(
124 reqwest::header::AUTHORIZATION,
125 format!("Bearer {token}").parse().map_err(|e| {
126 StygianError::Service(ServiceError::Unavailable(format!(
127 "invalid auth header: {e}"
128 )))
129 })?,
130 );
131 }
132
133 let connect_timeout = Duration::from_secs(cfg.timeout_secs);
134 let (ws_stream, _) = timeout(connect_timeout, tokio_tungstenite::connect_async(request))
135 .await
136 .map_err(|_| {
137 StygianError::Service(ServiceError::Unavailable(
138 "WebSocket connection timed out".into(),
139 ))
140 })?
141 .map_err(|e| {
142 StygianError::Service(ServiceError::Unavailable(format!(
143 "WebSocket connection failed: {e}"
144 )))
145 })?;
146
147 let (mut write, mut read) = ws_stream.split();
148
149 if let Some(ref sub_msg) = cfg.subscribe_message {
151 use futures::SinkExt;
152 write
153 .send(Message::Text(sub_msg.clone().into()))
154 .await
155 .map_err(|e| {
156 StygianError::Service(ServiceError::Unavailable(format!(
157 "failed to send subscribe message: {e}"
158 )))
159 })?;
160 }
161
162 let mut events = Vec::new();
163 let mut frame_idx: u64 = 0;
164
165 while let Some(msg_result) = timeout(Duration::from_secs(cfg.timeout_secs), read.next())
166 .await
167 .ok()
168 .flatten()
169 {
170 match msg_result {
171 Ok(msg) => {
172 if let Some(event) = map_message_to_event(msg, frame_idx) {
173 events.push(event);
174 frame_idx += 1;
175
176 if let Some(max) = max_events
177 && events.len() >= max
178 {
179 break;
180 }
181 }
182 }
183 Err(e) => {
184 tracing::warn!("WebSocket receive error: {e}");
185 break;
186 }
187 }
188 }
189
190 Ok(events)
191 }
192}
193
194fn map_message_to_event(msg: Message, frame_idx: u64) -> Option<StreamEvent> {
198 match msg {
199 Message::Text(text) => Some(StreamEvent {
200 id: Some(frame_idx.to_string()),
201 event_type: Some("text".into()),
202 data: text.to_string(),
203 }),
204 Message::Binary(data) => {
205 use base64::Engine;
206 let encoded = base64::engine::general_purpose::STANDARD.encode(&data);
207 Some(StreamEvent {
208 id: Some(frame_idx.to_string()),
209 event_type: Some("binary".into()),
210 data: encoded,
211 })
212 }
213 Message::Ping(data) => Some(StreamEvent {
214 id: Some(frame_idx.to_string()),
215 event_type: Some("ping".into()),
216 data: String::from_utf8_lossy(&data).to_string(),
217 }),
218 Message::Pong(_) | Message::Close(_) | Message::Frame(_) => None,
220 }
221}
222
223#[async_trait]
226impl StreamSourcePort for WebSocketSource {
227 async fn subscribe(&self, url: &str, max_events: Option<usize>) -> Result<Vec<StreamEvent>> {
228 let cfg = self.config.clone();
229 let mut last_err = None;
230
231 for attempt in 0..=cfg.max_reconnect_attempts {
232 match self.collect_events(url, max_events, &cfg).await {
233 Ok(events) => return Ok(events),
234 Err(e) => {
235 tracing::warn!(
236 "WebSocket attempt {}/{} failed: {e}",
237 attempt + 1,
238 cfg.max_reconnect_attempts + 1
239 );
240 last_err = Some(e);
241
242 if attempt < cfg.max_reconnect_attempts {
243 let backoff = Duration::from_secs(1 << attempt);
245 tokio::time::sleep(backoff).await;
246 }
247 }
248 }
249 }
250
251 Err(last_err.unwrap_or_else(|| {
252 StygianError::Service(ServiceError::Unavailable(
253 "WebSocket connection failed after all retries".into(),
254 ))
255 }))
256 }
257
258 fn source_name(&self) -> &str {
259 "websocket"
260 }
261}
262
263#[async_trait]
266impl ScrapingService for WebSocketSource {
267 async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
276 let cfg = self.config_from_params(&input.params);
277 let max_events = input
278 .params
279 .get("max_events")
280 .and_then(|v| v.as_u64())
281 .map(|n| n as usize);
282
283 let events = self.collect_events(&input.url, max_events, &cfg).await?;
284 let count = events.len();
285
286 let data = serde_json::to_string(&events).map_err(|e| {
287 StygianError::Service(ServiceError::InvalidResponse(format!(
288 "websocket serialization failed: {e}"
289 )))
290 })?;
291
292 Ok(ServiceOutput {
293 data,
294 metadata: json!({
295 "source": "websocket",
296 "event_count": count,
297 "source_url": input.url,
298 }),
299 })
300 }
301
302 fn name(&self) -> &'static str {
303 "websocket"
304 }
305}
306
307#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn map_text_frame() {
315 let msg = Message::Text(r#"{"price": 42.5}"#.into());
316 let event = map_message_to_event(msg, 0).expect("should map");
317 assert_eq!(event.id.as_deref(), Some("0"));
318 assert_eq!(event.event_type.as_deref(), Some("text"));
319 assert_eq!(event.data, r#"{"price": 42.5}"#);
320 }
321
322 #[test]
323 fn map_binary_frame_to_base64() {
324 let data = vec![0xDE, 0xAD, 0xBE, 0xEF];
325 let msg = Message::Binary(data.into());
326 let event = map_message_to_event(msg, 1).expect("should map");
327 assert_eq!(event.event_type.as_deref(), Some("binary"));
328 use base64::Engine;
330 let decoded = base64::engine::general_purpose::STANDARD
331 .decode(&event.data)
332 .expect("valid base64");
333 assert_eq!(decoded, vec![0xDE, 0xAD, 0xBE, 0xEF]);
334 }
335
336 #[test]
337 fn map_ping_frame() {
338 let msg = Message::Ping(vec![1, 2, 3].into());
339 let event = map_message_to_event(msg, 2).expect("should map");
340 assert_eq!(event.event_type.as_deref(), Some("ping"));
341 }
342
343 #[test]
344 fn pong_frame_is_skipped() {
345 let msg = Message::Pong(vec![].into());
346 assert!(map_message_to_event(msg, 0).is_none());
347 }
348
349 #[test]
350 fn close_frame_is_skipped() {
351 let msg = Message::Close(None);
352 assert!(map_message_to_event(msg, 0).is_none());
353 }
354
355 #[test]
356 fn default_config() {
357 let cfg = WebSocketConfig::default();
358 assert_eq!(cfg.timeout_secs, 30);
359 assert_eq!(cfg.max_reconnect_attempts, 3);
360 assert!(cfg.subscribe_message.is_none());
361 assert!(cfg.bearer_token.is_none());
362 }
363
364 #[test]
365 fn config_from_params_overrides() {
366 let source = WebSocketSource::default();
367 let params = json!({
368 "subscribe_message": "{\"action\":\"sub\"}",
369 "bearer_token": "tok123",
370 "timeout_secs": 60,
371 "max_reconnect_attempts": 5
372 });
373 let cfg = source.config_from_params(¶ms);
374 assert_eq!(
375 cfg.subscribe_message.as_deref(),
376 Some("{\"action\":\"sub\"}")
377 );
378 assert_eq!(cfg.bearer_token.as_deref(), Some("tok123"));
379 assert_eq!(cfg.timeout_secs, 60);
380 assert_eq!(cfg.max_reconnect_attempts, 5);
381 }
382
383 #[test]
384 fn frame_index_increments() {
385 let msgs = vec![
386 Message::Text("a".into()),
387 Message::Pong(vec![].into()), Message::Text("b".into()),
389 ];
390
391 let mut idx: u64 = 0;
392 let mut events = Vec::new();
393 for msg in msgs {
394 if let Some(event) = map_message_to_event(msg, idx) {
395 events.push(event);
396 idx += 1;
397 }
398 }
399
400 assert_eq!(events.len(), 2);
401 assert_eq!(events[0].id.as_deref(), Some("0"));
402 assert_eq!(events[1].id.as_deref(), Some("1"));
403 }
404
405 #[tokio::test]
407 #[ignore = "requires WebSocket echo server"]
408 async fn connect_to_echo_server() {
409 let source = WebSocketSource::new(WebSocketConfig {
410 subscribe_message: Some("hello".into()),
411 timeout_secs: 5,
412 ..WebSocketConfig::default()
413 });
414 let events = source
415 .subscribe("ws://127.0.0.1:9001/echo", Some(1))
416 .await
417 .expect("connect");
418 assert!(!events.is_empty());
419 }
420}