1use std::collections::HashSet;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11use futures::SinkExt;
12use futures::stream::Stream;
13use tokio::sync::{RwLock, broadcast, mpsc};
14use tokio::time::interval;
15use tokio_stream::wrappers::BroadcastStream;
16use tokio_tungstenite::{connect_async, tungstenite::Message};
17use tracing::{debug, error, info, warn};
18
19use super::pricing::{PriceUpdate, PricingData, PricingDecodeError};
20use crate::error::FinanceError;
21
22pub type StreamResult<T> = std::result::Result<T, StreamError>;
24
25#[derive(Debug, Clone)]
27pub enum StreamError {
28 ConnectionFailed(String),
30 WebSocketError(String),
32 DecodeError(String),
34}
35
36impl std::fmt::Display for StreamError {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 StreamError::ConnectionFailed(e) => write!(f, "Connection failed: {}", e),
40 StreamError::WebSocketError(e) => write!(f, "WebSocket error: {}", e),
41 StreamError::DecodeError(e) => write!(f, "Decode error: {}", e),
42 }
43 }
44}
45
46impl std::error::Error for StreamError {}
47
48impl From<StreamError> for FinanceError {
49 fn from(e: StreamError) -> Self {
50 FinanceError::ResponseStructureError {
51 field: "streaming".to_string(),
52 context: e.to_string(),
53 }
54 }
55}
56
57const YAHOO_WS_URL: &str = "wss://streamer.finance.yahoo.com/?version=2";
59
60const HEARTBEAT_INTERVAL_SECS: u64 = 15;
62
63const RECONNECT_BACKOFF_SECS: u64 = 3;
65
66const CHANNEL_CAPACITY: usize = 1024;
68
69pub struct PriceStream {
95 inner: BroadcastStream<PriceUpdate>,
96 _handle: Arc<StreamHandle>,
97}
98
99struct StreamHandle {
101 command_tx: mpsc::Sender<StreamCommand>,
102 broadcast_tx: broadcast::Sender<PriceUpdate>,
103}
104
105enum StreamCommand {
107 Subscribe(Vec<String>),
108 Unsubscribe(Vec<String>),
109 Close,
110}
111
112impl PriceStream {
113 pub async fn subscribe(symbols: &[&str]) -> StreamResult<Self> {
130 Self::subscribe_inner(symbols, Duration::from_secs(RECONNECT_BACKOFF_SECS)).await
131 }
132
133 async fn subscribe_inner(symbols: &[&str], retry_delay: Duration) -> StreamResult<Self> {
134 let (broadcast_tx, broadcast_rx) = broadcast::channel(CHANNEL_CAPACITY);
135 let (command_tx, command_rx) = mpsc::channel(32);
136
137 let initial_symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
138
139 let tx_clone = broadcast_tx.clone();
140
141 tokio::spawn(async move {
143 if let Err(e) =
144 run_websocket_loop(initial_symbols, broadcast_tx, command_rx, retry_delay).await
145 {
146 error!("WebSocket loop error: {}", e);
147 }
148 });
149
150 let handle = Arc::new(StreamHandle {
151 command_tx,
152 broadcast_tx: tx_clone,
153 });
154
155 Ok(PriceStream {
156 inner: BroadcastStream::new(broadcast_rx),
157 _handle: handle,
158 })
159 }
160
161 pub fn resubscribe(&self) -> Self {
165 PriceStream {
166 inner: BroadcastStream::new(self._handle.broadcast_tx.subscribe()),
167 _handle: Arc::clone(&self._handle),
168 }
169 }
170
171 pub async fn add_symbols(&self, symbols: &[&str]) {
185 let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
186 let _ = self
187 ._handle
188 .command_tx
189 .send(StreamCommand::Subscribe(symbols))
190 .await;
191 }
192
193 pub async fn remove_symbols(&self, symbols: &[&str]) {
207 let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
208 let _ = self
209 ._handle
210 .command_tx
211 .send(StreamCommand::Unsubscribe(symbols))
212 .await;
213 }
214
215 pub async fn close(&self) {
217 let _ = self._handle.command_tx.send(StreamCommand::Close).await;
218 }
219}
220
221impl Stream for PriceStream {
222 type Item = PriceUpdate;
223
224 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
225 match Pin::new(&mut self.inner).poll_next(cx) {
226 Poll::Ready(Some(Ok(data))) => Poll::Ready(Some(data)),
227 Poll::Ready(Some(Err(e))) => {
228 warn!("Broadcast error: {:?}", e);
229 cx.waker().wake_by_ref();
231 Poll::Pending
232 }
233 Poll::Ready(None) => Poll::Ready(None),
234 Poll::Pending => Poll::Pending,
235 }
236 }
237}
238
239async fn run_websocket_loop(
241 initial_symbols: Vec<String>,
242 broadcast_tx: broadcast::Sender<PriceUpdate>,
243 mut command_rx: mpsc::Receiver<StreamCommand>,
244 retry_delay: Duration,
245) -> StreamResult<()> {
246 let subscriptions = Arc::new(RwLock::new(HashSet::<String>::from_iter(initial_symbols)));
247
248 loop {
249 match connect_and_stream(&subscriptions, &broadcast_tx, &mut command_rx).await {
250 Ok(()) => {
251 info!("WebSocket connection closed gracefully");
252 break;
253 }
254 Err(e) => {
255 error!(
256 "WebSocket error: {}, reconnecting in {:.1}s...",
257 e,
258 retry_delay.as_secs_f32()
259 );
260 tokio::time::sleep(retry_delay).await;
261 }
262 }
263 }
264
265 Ok(())
266}
267
268async fn connect_and_stream(
270 subscriptions: &Arc<RwLock<HashSet<String>>>,
271 broadcast_tx: &broadcast::Sender<PriceUpdate>,
272 command_rx: &mut mpsc::Receiver<StreamCommand>,
273) -> StreamResult<()> {
274 use futures::StreamExt;
275
276 info!("Connecting to Yahoo Finance WebSocket...");
277
278 let (ws_stream, _) = connect_async(YAHOO_WS_URL)
279 .await
280 .map_err(|e| StreamError::ConnectionFailed(e.to_string()))?;
281
282 info!("Connected to Yahoo Finance WebSocket");
283
284 let (mut write, mut read) = ws_stream.split();
285
286 {
288 let subs = subscriptions.read().await;
289 if !subs.is_empty() {
290 let symbols: Vec<&str> = subs.iter().map(|s| s.as_str()).collect();
291 let msg = serde_json::json!({ "subscribe": symbols });
292 write
293 .send(Message::Text(msg.to_string().into()))
294 .await
295 .map_err(|e| StreamError::WebSocketError(e.to_string()))?;
296 info!("Subscribed to {} symbols", symbols.len());
297 }
298 }
299
300 let heartbeat_subs = Arc::clone(subscriptions);
302 let (heartbeat_tx, mut heartbeat_rx) = mpsc::channel::<Message>(32);
303
304 tokio::spawn(async move {
305 let mut interval = interval(Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
306 loop {
307 interval.tick().await;
308 let subs = heartbeat_subs.read().await;
309 if !subs.is_empty() {
310 let symbols: Vec<&str> = subs.iter().map(|s| s.as_str()).collect();
311 let msg = serde_json::json!({ "subscribe": symbols });
312 if heartbeat_tx
313 .send(Message::Text(msg.to_string().into()))
314 .await
315 .is_err()
316 {
317 break;
318 }
319 debug!("Heartbeat subscription sent for {} symbols", symbols.len());
320 }
321 }
322 });
323
324 loop {
325 tokio::select! {
326 Some(msg) = read.next() => {
328 match msg {
329 Ok(Message::Text(text)) => {
330 if let Err(e) = handle_text_message(&text, broadcast_tx) {
331 warn!("Failed to handle message: {}", e);
332 }
333 }
334 Ok(Message::Binary(data)) => {
335 debug!("Received binary message: {} bytes", data.len());
336 }
337 Ok(Message::Close(_)) => {
338 info!("Received close frame");
339 break;
340 }
341 Ok(Message::Ping(data)) => {
342 let _ = write.send(Message::Pong(data)).await;
343 }
344 Ok(_) => {}
345 Err(e) => {
346 error!("WebSocket read error: {}", e);
347 return Err(StreamError::WebSocketError(e.to_string()));
348 }
349 }
350 }
351
352 Some(msg) = heartbeat_rx.recv() => {
354 if let Err(e) = write.send(msg).await {
355 error!("Failed to send heartbeat: {}", e);
356 return Err(StreamError::WebSocketError(e.to_string()));
357 }
358 }
359
360 Some(cmd) = command_rx.recv() => {
362 match cmd {
363 StreamCommand::Subscribe(symbols) => {
364 let mut newly_added = Vec::new();
365 {
366 let mut subs = subscriptions.write().await;
367 for s in &symbols {
368 if subs.insert(s.clone()) {
369 newly_added.push(s.clone());
370 }
371 }
372 }
373 if !newly_added.is_empty() {
374 let msg = serde_json::json!({ "subscribe": newly_added });
375 let _ = write.send(Message::Text(msg.to_string().into())).await;
376 info!("Added subscriptions: {:?}", newly_added);
377 }
378 }
379 StreamCommand::Unsubscribe(symbols) => {
380 let mut actually_removed = Vec::new();
381 {
382 let mut subs = subscriptions.write().await;
383 for s in &symbols {
384 if subs.remove(s) {
385 actually_removed.push(s.clone());
386 }
387 }
388 }
389 if !actually_removed.is_empty() {
390 let msg = serde_json::json!({ "unsubscribe": actually_removed });
391 let _ = write.send(Message::Text(msg.to_string().into())).await;
392 info!("Removed subscriptions: {:?}", actually_removed);
393 }
394 }
395 StreamCommand::Close => {
396 info!("Received close command");
397 let _ = write.send(Message::Close(None)).await;
398 return Ok(());
399 }
400 }
401 }
402
403 else => break,
404 }
405 }
406
407 Ok(())
408}
409
410fn handle_text_message(
412 text: &str,
413 broadcast_tx: &broadcast::Sender<PriceUpdate>,
414) -> std::result::Result<(), PricingDecodeError> {
415 let json: serde_json::Value =
417 serde_json::from_str(text).map_err(|e| PricingDecodeError::Base64(e.to_string()))?;
418
419 if let Some(encoded) = json.get("message").and_then(|v| v.as_str()) {
420 let pricing_data = PricingData::from_base64(encoded)?;
421 let price_update: PriceUpdate = pricing_data.into();
422
423 if broadcast_tx.receiver_count() > 0 {
425 let _ = broadcast_tx.send(price_update);
426 }
427 }
428
429 Ok(())
430}
431
432pub struct PriceStreamBuilder {
434 symbols: Vec<String>,
435 retry_delay: Duration,
436}
437
438impl PriceStreamBuilder {
439 pub fn new() -> Self {
441 Self {
442 symbols: Vec::new(),
443 retry_delay: Duration::from_secs(RECONNECT_BACKOFF_SECS),
444 }
445 }
446
447 pub fn symbols(mut self, symbols: &[&str]) -> Self {
449 self.symbols.extend(symbols.iter().map(|s| s.to_string()));
450 self
451 }
452
453 pub fn retry(mut self, delay: Duration) -> Self {
455 self.retry_delay = delay;
456 self
457 }
458
459 pub async fn build(self) -> StreamResult<PriceStream> {
461 let symbol_refs: Vec<&str> = self.symbols.iter().map(|s| s.as_str()).collect();
462 PriceStream::subscribe_inner(&symbol_refs, self.retry_delay).await
463 }
464}
465
466impl Default for PriceStreamBuilder {
467 fn default() -> Self {
468 Self::new()
469 }
470}