finance_query/streaming/
client.rs

1//! WebSocket client for Yahoo Finance real-time streaming
2//!
3//! Provides a Stream-based API for receiving real-time price updates.
4
5use std::collections::HashSet;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11use futures::SinkExt;
12use futures::stream::Stream;
13use tokio::sync::{RwLock, broadcast, mpsc};
14use tokio::time::interval;
15use tokio_stream::wrappers::BroadcastStream;
16use tokio_tungstenite::{connect_async, tungstenite::Message};
17use tracing::{debug, error, info, warn};
18
19use super::pricing::{PriceUpdate, PricingData, PricingDecodeError};
20use crate::error::YahooError;
21
22/// Result type for streaming operations
23pub type StreamResult<T> = std::result::Result<T, StreamError>;
24
25/// Errors that can occur during streaming
26#[derive(Debug, Clone)]
27pub enum StreamError {
28    /// WebSocket connection failed
29    ConnectionFailed(String),
30    /// WebSocket send/receive error
31    WebSocketError(String),
32    /// Failed to decode message
33    DecodeError(String),
34}
35
36impl std::fmt::Display for StreamError {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            StreamError::ConnectionFailed(e) => write!(f, "Connection failed: {}", e),
40            StreamError::WebSocketError(e) => write!(f, "WebSocket error: {}", e),
41            StreamError::DecodeError(e) => write!(f, "Decode error: {}", e),
42        }
43    }
44}
45
46impl std::error::Error for StreamError {}
47
48impl From<StreamError> for YahooError {
49    fn from(e: StreamError) -> Self {
50        YahooError::ResponseStructureError {
51            field: "streaming".to_string(),
52            context: e.to_string(),
53        }
54    }
55}
56
57/// Yahoo Finance WebSocket URL
58const YAHOO_WS_URL: &str = "wss://streamer.finance.yahoo.com/?version=2";
59
60/// Heartbeat interval for subscription refresh
61const HEARTBEAT_INTERVAL_SECS: u64 = 15;
62
63/// Reconnection backoff duration
64const RECONNECT_BACKOFF_SECS: u64 = 3;
65
66/// Channel capacity for price updates
67const CHANNEL_CAPACITY: usize = 1024;
68
69/// A streaming price subscription that yields real-time price updates.
70///
71/// This provides a Flow-like API for receiving real-time price data from Yahoo Finance.
72///
73/// # Example
74///
75/// ```no_run
76/// use finance_query::streaming::PriceStream;
77/// use futures::StreamExt;
78///
79/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
80/// // Subscribe to multiple symbols
81/// let mut stream = PriceStream::subscribe(&["AAPL", "NVDA", "TSLA"]).await?;
82///
83/// // Receive price updates
84/// while let Some(price) = stream.next().await {
85///     println!("{}: ${:.2} ({:+.2}%)",
86///         price.id,
87///         price.price,
88///         price.change_percent
89///     );
90/// }
91/// # Ok(())
92/// # }
93/// ```
94pub struct PriceStream {
95    inner: BroadcastStream<PriceUpdate>,
96    _handle: Arc<StreamHandle>,
97}
98
99/// Handle to manage the WebSocket connection
100struct StreamHandle {
101    command_tx: mpsc::Sender<StreamCommand>,
102    broadcast_tx: broadcast::Sender<PriceUpdate>,
103}
104
105/// Commands sent to the WebSocket task
106enum StreamCommand {
107    Subscribe(Vec<String>),
108    Unsubscribe(Vec<String>),
109    Close,
110}
111
112impl PriceStream {
113    /// Subscribe to real-time price updates for the given symbols.
114    ///
115    /// # Arguments
116    ///
117    /// * `symbols` - Ticker symbols to subscribe to (e.g., `["AAPL", "NVDA"]`)
118    ///
119    /// # Example
120    ///
121    /// ```no_run
122    /// use finance_query::streaming::PriceStream;
123    ///
124    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
125    /// let stream = PriceStream::subscribe(&["AAPL", "GOOGL"]).await?;
126    /// # Ok(())
127    /// # }
128    /// ```
129    pub async fn subscribe(symbols: &[&str]) -> StreamResult<Self> {
130        let (broadcast_tx, broadcast_rx) = broadcast::channel(CHANNEL_CAPACITY);
131        let (command_tx, command_rx) = mpsc::channel(32);
132
133        let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
134        let initial_symbols = symbols.clone();
135
136        let tx_clone = broadcast_tx.clone();
137
138        // Spawn the WebSocket task
139        tokio::spawn(async move {
140            if let Err(e) = run_websocket_loop(initial_symbols, broadcast_tx, command_rx).await {
141                error!("WebSocket loop error: {}", e);
142            }
143        });
144
145        let handle = Arc::new(StreamHandle {
146            command_tx,
147            broadcast_tx: tx_clone,
148        });
149
150        Ok(PriceStream {
151            inner: BroadcastStream::new(broadcast_rx),
152            _handle: handle,
153        })
154    }
155
156    /// Create a new receiver for this stream.
157    ///
158    /// Useful when you need multiple consumers of the same price data.
159    pub fn resubscribe(&self) -> Self {
160        PriceStream {
161            inner: BroadcastStream::new(self._handle.broadcast_tx.subscribe()),
162            _handle: Arc::clone(&self._handle),
163        }
164    }
165
166    /// Add more symbols to the subscription.
167    ///
168    /// # Example
169    ///
170    /// ```no_run
171    /// use finance_query::streaming::PriceStream;
172    ///
173    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
174    /// let stream = PriceStream::subscribe(&["AAPL"]).await?;
175    /// stream.add_symbols(&["NVDA", "TSLA"]).await;
176    /// # Ok(())
177    /// # }
178    /// ```
179    pub async fn add_symbols(&self, symbols: &[&str]) {
180        let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
181        let _ = self
182            ._handle
183            .command_tx
184            .send(StreamCommand::Subscribe(symbols))
185            .await;
186    }
187
188    /// Remove symbols from the subscription.
189    ///
190    /// # Example
191    ///
192    /// ```no_run
193    /// use finance_query::streaming::PriceStream;
194    ///
195    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
196    /// let stream = PriceStream::subscribe(&["AAPL", "NVDA"]).await?;
197    /// stream.remove_symbols(&["NVDA"]).await;
198    /// # Ok(())
199    /// # }
200    /// ```
201    pub async fn remove_symbols(&self, symbols: &[&str]) {
202        let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
203        let _ = self
204            ._handle
205            .command_tx
206            .send(StreamCommand::Unsubscribe(symbols))
207            .await;
208    }
209
210    /// Close the stream and disconnect from the WebSocket.
211    pub async fn close(&self) {
212        let _ = self._handle.command_tx.send(StreamCommand::Close).await;
213    }
214}
215
216impl Stream for PriceStream {
217    type Item = PriceUpdate;
218
219    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
220        match Pin::new(&mut self.inner).poll_next(cx) {
221            Poll::Ready(Some(Ok(data))) => Poll::Ready(Some(data)),
222            Poll::Ready(Some(Err(e))) => {
223                warn!("Broadcast error: {:?}", e);
224                // Try again on lag
225                cx.waker().wake_by_ref();
226                Poll::Pending
227            }
228            Poll::Ready(None) => Poll::Ready(None),
229            Poll::Pending => Poll::Pending,
230        }
231    }
232}
233
234/// Run the WebSocket connection loop with automatic reconnection
235async fn run_websocket_loop(
236    initial_symbols: Vec<String>,
237    broadcast_tx: broadcast::Sender<PriceUpdate>,
238    mut command_rx: mpsc::Receiver<StreamCommand>,
239) -> StreamResult<()> {
240    let subscriptions = Arc::new(RwLock::new(HashSet::<String>::from_iter(initial_symbols)));
241
242    loop {
243        match connect_and_stream(&subscriptions, &broadcast_tx, &mut command_rx).await {
244            Ok(()) => {
245                info!("WebSocket connection closed gracefully");
246                break;
247            }
248            Err(e) => {
249                error!(
250                    "WebSocket error: {}, reconnecting in {}s...",
251                    e, RECONNECT_BACKOFF_SECS
252                );
253                tokio::time::sleep(Duration::from_secs(RECONNECT_BACKOFF_SECS)).await;
254            }
255        }
256    }
257
258    Ok(())
259}
260
261/// Connect to Yahoo WebSocket and stream data
262async fn connect_and_stream(
263    subscriptions: &Arc<RwLock<HashSet<String>>>,
264    broadcast_tx: &broadcast::Sender<PriceUpdate>,
265    command_rx: &mut mpsc::Receiver<StreamCommand>,
266) -> StreamResult<()> {
267    use futures::StreamExt;
268
269    info!("Connecting to Yahoo Finance WebSocket...");
270
271    let (ws_stream, _) = connect_async(YAHOO_WS_URL)
272        .await
273        .map_err(|e| StreamError::ConnectionFailed(e.to_string()))?;
274
275    info!("Connected to Yahoo Finance WebSocket");
276
277    let (mut write, mut read) = ws_stream.split();
278
279    // Send initial subscriptions
280    {
281        let subs = subscriptions.read().await;
282        if !subs.is_empty() {
283            let symbols: Vec<&str> = subs.iter().map(|s| s.as_str()).collect();
284            let msg = serde_json::json!({ "subscribe": symbols });
285            write
286                .send(Message::Text(msg.to_string().into()))
287                .await
288                .map_err(|e| StreamError::WebSocketError(e.to_string()))?;
289            info!("Subscribed to {} symbols", symbols.len());
290        }
291    }
292
293    // Heartbeat task - sends subscription refresh every 15 seconds
294    let heartbeat_subs = Arc::clone(subscriptions);
295    let (heartbeat_tx, mut heartbeat_rx) = mpsc::channel::<Message>(32);
296
297    tokio::spawn(async move {
298        let mut interval = interval(Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
299        loop {
300            interval.tick().await;
301            let subs = heartbeat_subs.read().await;
302            if !subs.is_empty() {
303                let symbols: Vec<&str> = subs.iter().map(|s| s.as_str()).collect();
304                let msg = serde_json::json!({ "subscribe": symbols });
305                if heartbeat_tx
306                    .send(Message::Text(msg.to_string().into()))
307                    .await
308                    .is_err()
309                {
310                    break;
311                }
312                debug!("Heartbeat subscription sent for {} symbols", symbols.len());
313            }
314        }
315    });
316
317    loop {
318        tokio::select! {
319            // Handle incoming WebSocket messages
320            Some(msg) = read.next() => {
321                match msg {
322                    Ok(Message::Text(text)) => {
323                        if let Err(e) = handle_text_message(&text, broadcast_tx) {
324                            warn!("Failed to handle message: {}", e);
325                        }
326                    }
327                    Ok(Message::Binary(data)) => {
328                        debug!("Received binary message: {} bytes", data.len());
329                    }
330                    Ok(Message::Close(_)) => {
331                        info!("Received close frame");
332                        break;
333                    }
334                    Ok(Message::Ping(data)) => {
335                        let _ = write.send(Message::Pong(data)).await;
336                    }
337                    Ok(_) => {}
338                    Err(e) => {
339                        error!("WebSocket read error: {}", e);
340                        return Err(StreamError::WebSocketError(e.to_string()));
341                    }
342                }
343            }
344
345            // Handle heartbeat messages
346            Some(msg) = heartbeat_rx.recv() => {
347                if let Err(e) = write.send(msg).await {
348                    error!("Failed to send heartbeat: {}", e);
349                    return Err(StreamError::WebSocketError(e.to_string()));
350                }
351            }
352
353            // Handle commands (subscribe/unsubscribe)
354            Some(cmd) = command_rx.recv() => {
355                match cmd {
356                    StreamCommand::Subscribe(symbols) => {
357                        let mut newly_added = Vec::new();
358                        {
359                            let mut subs = subscriptions.write().await;
360                            for s in &symbols {
361                                if subs.insert(s.clone()) {
362                                    newly_added.push(s.clone());
363                                }
364                            }
365                        }
366                        if !newly_added.is_empty() {
367                            let msg = serde_json::json!({ "subscribe": newly_added });
368                            let _ = write.send(Message::Text(msg.to_string().into())).await;
369                            info!("Added subscriptions: {:?}", newly_added);
370                        }
371                    }
372                    StreamCommand::Unsubscribe(symbols) => {
373                        let mut actually_removed = Vec::new();
374                        {
375                            let mut subs = subscriptions.write().await;
376                            for s in &symbols {
377                                if subs.remove(s) {
378                                    actually_removed.push(s.clone());
379                                }
380                            }
381                        }
382                        if !actually_removed.is_empty() {
383                            let msg = serde_json::json!({ "unsubscribe": actually_removed });
384                            let _ = write.send(Message::Text(msg.to_string().into())).await;
385                            info!("Removed subscriptions: {:?}", actually_removed);
386                        }
387                    }
388                    StreamCommand::Close => {
389                        info!("Received close command");
390                        let _ = write.send(Message::Close(None)).await;
391                        return Ok(());
392                    }
393                }
394            }
395
396            else => break,
397        }
398    }
399
400    Ok(())
401}
402
403/// Handle incoming text message from Yahoo WebSocket
404fn handle_text_message(
405    text: &str,
406    broadcast_tx: &broadcast::Sender<PriceUpdate>,
407) -> std::result::Result<(), PricingDecodeError> {
408    // Yahoo sends JSON with base64-encoded protobuf in "message" field
409    let json: serde_json::Value =
410        serde_json::from_str(text).map_err(|e| PricingDecodeError::Base64(e.to_string()))?;
411
412    if let Some(encoded) = json.get("message").and_then(|v| v.as_str()) {
413        let pricing_data = PricingData::from_base64(encoded)?;
414        let price_update: PriceUpdate = pricing_data.into();
415
416        // Broadcast to all receivers
417        if broadcast_tx.receiver_count() > 0 {
418            let _ = broadcast_tx.send(price_update);
419        }
420    }
421
422    Ok(())
423}
424
425/// Builder for creating price streams with custom configuration
426pub struct PriceStreamBuilder {
427    symbols: Vec<String>,
428    reconnect_delay: Duration,
429}
430
431impl PriceStreamBuilder {
432    /// Create a new builder
433    pub fn new() -> Self {
434        Self {
435            symbols: Vec::new(),
436            reconnect_delay: Duration::from_secs(RECONNECT_BACKOFF_SECS),
437        }
438    }
439
440    /// Add symbols to subscribe to
441    pub fn symbols(mut self, symbols: &[&str]) -> Self {
442        self.symbols.extend(symbols.iter().map(|s| s.to_string()));
443        self
444    }
445
446    /// Set reconnection delay
447    pub fn reconnect_delay(mut self, delay: Duration) -> Self {
448        self.reconnect_delay = delay;
449        self
450    }
451
452    /// Build and start the price stream
453    pub async fn build(self) -> StreamResult<PriceStream> {
454        let symbol_refs: Vec<&str> = self.symbols.iter().map(|s| s.as_str()).collect();
455        PriceStream::subscribe(&symbol_refs).await
456    }
457}
458
459impl Default for PriceStreamBuilder {
460    fn default() -> Self {
461        Self::new()
462    }
463}