Skip to main content

finance_query/streaming/
client.rs

1//! WebSocket client for Yahoo Finance real-time streaming
2//!
3//! Provides a Stream-based API for receiving real-time price updates.
4
5use std::collections::HashSet;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11use futures::SinkExt;
12use futures::stream::Stream;
13use tokio::sync::{RwLock, broadcast, mpsc};
14use tokio::time::interval;
15use tokio_stream::wrappers::BroadcastStream;
16use tokio_tungstenite::{connect_async, tungstenite::Message};
17use tracing::{debug, error, info, warn};
18
19use super::pricing::{PriceUpdate, PricingData, PricingDecodeError};
20use crate::error::FinanceError;
21
22/// Result type for streaming operations
23pub type StreamResult<T> = std::result::Result<T, StreamError>;
24
25/// Errors that can occur during streaming
26#[derive(Debug, Clone)]
27pub enum StreamError {
28    /// WebSocket connection failed
29    ConnectionFailed(String),
30    /// WebSocket send/receive error
31    WebSocketError(String),
32    /// Failed to decode message
33    DecodeError(String),
34}
35
36impl std::fmt::Display for StreamError {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            StreamError::ConnectionFailed(e) => write!(f, "Connection failed: {}", e),
40            StreamError::WebSocketError(e) => write!(f, "WebSocket error: {}", e),
41            StreamError::DecodeError(e) => write!(f, "Decode error: {}", e),
42        }
43    }
44}
45
46impl std::error::Error for StreamError {}
47
48impl From<StreamError> for FinanceError {
49    fn from(e: StreamError) -> Self {
50        FinanceError::ResponseStructureError {
51            field: "streaming".to_string(),
52            context: e.to_string(),
53        }
54    }
55}
56
57/// Yahoo Finance WebSocket URL
58const YAHOO_WS_URL: &str = "wss://streamer.finance.yahoo.com/?version=2";
59
60/// Heartbeat interval for subscription refresh
61const HEARTBEAT_INTERVAL_SECS: u64 = 15;
62
63/// Reconnection backoff duration
64const RECONNECT_BACKOFF_SECS: u64 = 3;
65
66/// Channel capacity for price updates
67const CHANNEL_CAPACITY: usize = 1024;
68
69/// A streaming price subscription that yields real-time price updates.
70///
71/// This provides a Flow-like API for receiving real-time price data from Yahoo Finance.
72///
73/// # Example
74///
75/// ```no_run
76/// use finance_query::streaming::PriceStream;
77/// use futures::StreamExt;
78///
79/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
80/// // Subscribe to multiple symbols
81/// let mut stream = PriceStream::subscribe(&["AAPL", "NVDA", "TSLA"]).await?;
82///
83/// // Receive price updates
84/// while let Some(price) = stream.next().await {
85///     println!("{}: ${:.2} ({:+.2}%)",
86///         price.id,
87///         price.price,
88///         price.change_percent
89///     );
90/// }
91/// # Ok(())
92/// # }
93/// ```
94pub struct PriceStream {
95    inner: BroadcastStream<PriceUpdate>,
96    _handle: Arc<StreamHandle>,
97}
98
99/// Handle to manage the WebSocket connection
100struct StreamHandle {
101    command_tx: mpsc::Sender<StreamCommand>,
102    broadcast_tx: broadcast::Sender<PriceUpdate>,
103}
104
105/// Commands sent to the WebSocket task
106enum StreamCommand {
107    Subscribe(Vec<String>),
108    Unsubscribe(Vec<String>),
109    Close,
110}
111
112impl PriceStream {
113    /// Subscribe to real-time price updates for the given symbols.
114    ///
115    /// # Arguments
116    ///
117    /// * `symbols` - Ticker symbols to subscribe to (e.g., `["AAPL", "NVDA"]`)
118    ///
119    /// # Example
120    ///
121    /// ```no_run
122    /// use finance_query::streaming::PriceStream;
123    ///
124    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
125    /// let stream = PriceStream::subscribe(&["AAPL", "GOOGL"]).await?;
126    /// # Ok(())
127    /// # }
128    /// ```
129    pub async fn subscribe(symbols: &[&str]) -> StreamResult<Self> {
130        Self::subscribe_inner(symbols, Duration::from_secs(RECONNECT_BACKOFF_SECS)).await
131    }
132
133    async fn subscribe_inner(symbols: &[&str], retry_delay: Duration) -> StreamResult<Self> {
134        let (broadcast_tx, broadcast_rx) = broadcast::channel(CHANNEL_CAPACITY);
135        let (command_tx, command_rx) = mpsc::channel(32);
136
137        let initial_symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
138
139        let tx_clone = broadcast_tx.clone();
140
141        // Spawn the WebSocket task
142        tokio::spawn(async move {
143            if let Err(e) =
144                run_websocket_loop(initial_symbols, broadcast_tx, command_rx, retry_delay).await
145            {
146                error!("WebSocket loop error: {}", e);
147            }
148        });
149
150        let handle = Arc::new(StreamHandle {
151            command_tx,
152            broadcast_tx: tx_clone,
153        });
154
155        Ok(PriceStream {
156            inner: BroadcastStream::new(broadcast_rx),
157            _handle: handle,
158        })
159    }
160
161    /// Create a new receiver for this stream.
162    ///
163    /// Useful when you need multiple consumers of the same price data.
164    pub fn resubscribe(&self) -> Self {
165        PriceStream {
166            inner: BroadcastStream::new(self._handle.broadcast_tx.subscribe()),
167            _handle: Arc::clone(&self._handle),
168        }
169    }
170
171    /// Add more symbols to the subscription.
172    ///
173    /// # Example
174    ///
175    /// ```no_run
176    /// use finance_query::streaming::PriceStream;
177    ///
178    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
179    /// let stream = PriceStream::subscribe(&["AAPL"]).await?;
180    /// stream.add_symbols(&["NVDA", "TSLA"]).await;
181    /// # Ok(())
182    /// # }
183    /// ```
184    pub async fn add_symbols(&self, symbols: &[&str]) {
185        let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
186        let _ = self
187            ._handle
188            .command_tx
189            .send(StreamCommand::Subscribe(symbols))
190            .await;
191    }
192
193    /// Remove symbols from the subscription.
194    ///
195    /// # Example
196    ///
197    /// ```no_run
198    /// use finance_query::streaming::PriceStream;
199    ///
200    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
201    /// let stream = PriceStream::subscribe(&["AAPL", "NVDA"]).await?;
202    /// stream.remove_symbols(&["NVDA"]).await;
203    /// # Ok(())
204    /// # }
205    /// ```
206    pub async fn remove_symbols(&self, symbols: &[&str]) {
207        let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
208        let _ = self
209            ._handle
210            .command_tx
211            .send(StreamCommand::Unsubscribe(symbols))
212            .await;
213    }
214
215    /// Close the stream and disconnect from the WebSocket.
216    pub async fn close(&self) {
217        let _ = self._handle.command_tx.send(StreamCommand::Close).await;
218    }
219}
220
221impl Stream for PriceStream {
222    type Item = PriceUpdate;
223
224    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
225        match Pin::new(&mut self.inner).poll_next(cx) {
226            Poll::Ready(Some(Ok(data))) => Poll::Ready(Some(data)),
227            Poll::Ready(Some(Err(e))) => {
228                warn!("Broadcast error: {:?}", e);
229                // Try again on lag
230                cx.waker().wake_by_ref();
231                Poll::Pending
232            }
233            Poll::Ready(None) => Poll::Ready(None),
234            Poll::Pending => Poll::Pending,
235        }
236    }
237}
238
239/// Run the WebSocket connection loop with automatic reconnection
240async fn run_websocket_loop(
241    initial_symbols: Vec<String>,
242    broadcast_tx: broadcast::Sender<PriceUpdate>,
243    mut command_rx: mpsc::Receiver<StreamCommand>,
244    retry_delay: Duration,
245) -> StreamResult<()> {
246    let subscriptions = Arc::new(RwLock::new(HashSet::<String>::from_iter(initial_symbols)));
247
248    loop {
249        match connect_and_stream(&subscriptions, &broadcast_tx, &mut command_rx).await {
250            Ok(()) => {
251                info!("WebSocket connection closed gracefully");
252                break;
253            }
254            Err(e) => {
255                error!(
256                    "WebSocket error: {}, reconnecting in {:.1}s...",
257                    e,
258                    retry_delay.as_secs_f32()
259                );
260                tokio::time::sleep(retry_delay).await;
261            }
262        }
263    }
264
265    Ok(())
266}
267
268/// Connect to Yahoo WebSocket and stream data
269async fn connect_and_stream(
270    subscriptions: &Arc<RwLock<HashSet<String>>>,
271    broadcast_tx: &broadcast::Sender<PriceUpdate>,
272    command_rx: &mut mpsc::Receiver<StreamCommand>,
273) -> StreamResult<()> {
274    use futures::StreamExt;
275
276    info!("Connecting to Yahoo Finance WebSocket...");
277
278    let (ws_stream, _) = connect_async(YAHOO_WS_URL)
279        .await
280        .map_err(|e| StreamError::ConnectionFailed(e.to_string()))?;
281
282    info!("Connected to Yahoo Finance WebSocket");
283
284    let (mut write, mut read) = ws_stream.split();
285
286    // Send initial subscriptions
287    {
288        let subs = subscriptions.read().await;
289        if !subs.is_empty() {
290            let symbols: Vec<&str> = subs.iter().map(|s| s.as_str()).collect();
291            let msg = serde_json::json!({ "subscribe": symbols });
292            write
293                .send(Message::Text(msg.to_string().into()))
294                .await
295                .map_err(|e| StreamError::WebSocketError(e.to_string()))?;
296            info!("Subscribed to {} symbols", symbols.len());
297        }
298    }
299
300    // Heartbeat task - sends subscription refresh every 15 seconds
301    let heartbeat_subs = Arc::clone(subscriptions);
302    let (heartbeat_tx, mut heartbeat_rx) = mpsc::channel::<Message>(32);
303
304    tokio::spawn(async move {
305        let mut interval = interval(Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
306        loop {
307            interval.tick().await;
308            let subs = heartbeat_subs.read().await;
309            if !subs.is_empty() {
310                let symbols: Vec<&str> = subs.iter().map(|s| s.as_str()).collect();
311                let msg = serde_json::json!({ "subscribe": symbols });
312                if heartbeat_tx
313                    .send(Message::Text(msg.to_string().into()))
314                    .await
315                    .is_err()
316                {
317                    break;
318                }
319                debug!("Heartbeat subscription sent for {} symbols", symbols.len());
320            }
321        }
322    });
323
324    loop {
325        tokio::select! {
326            // Handle incoming WebSocket messages
327            Some(msg) = read.next() => {
328                match msg {
329                    Ok(Message::Text(text)) => {
330                        if let Err(e) = handle_text_message(&text, broadcast_tx) {
331                            warn!("Failed to handle message: {}", e);
332                        }
333                    }
334                    Ok(Message::Binary(data)) => {
335                        debug!("Received binary message: {} bytes", data.len());
336                    }
337                    Ok(Message::Close(_)) => {
338                        info!("Received close frame");
339                        break;
340                    }
341                    Ok(Message::Ping(data)) => {
342                        let _ = write.send(Message::Pong(data)).await;
343                    }
344                    Ok(_) => {}
345                    Err(e) => {
346                        error!("WebSocket read error: {}", e);
347                        return Err(StreamError::WebSocketError(e.to_string()));
348                    }
349                }
350            }
351
352            // Handle heartbeat messages
353            Some(msg) = heartbeat_rx.recv() => {
354                if let Err(e) = write.send(msg).await {
355                    error!("Failed to send heartbeat: {}", e);
356                    return Err(StreamError::WebSocketError(e.to_string()));
357                }
358            }
359
360            // Handle commands (subscribe/unsubscribe)
361            Some(cmd) = command_rx.recv() => {
362                match cmd {
363                    StreamCommand::Subscribe(symbols) => {
364                        let mut newly_added = Vec::new();
365                        {
366                            let mut subs = subscriptions.write().await;
367                            for s in &symbols {
368                                if subs.insert(s.clone()) {
369                                    newly_added.push(s.clone());
370                                }
371                            }
372                        }
373                        if !newly_added.is_empty() {
374                            let msg = serde_json::json!({ "subscribe": newly_added });
375                            let _ = write.send(Message::Text(msg.to_string().into())).await;
376                            info!("Added subscriptions: {:?}", newly_added);
377                        }
378                    }
379                    StreamCommand::Unsubscribe(symbols) => {
380                        let mut actually_removed = Vec::new();
381                        {
382                            let mut subs = subscriptions.write().await;
383                            for s in &symbols {
384                                if subs.remove(s) {
385                                    actually_removed.push(s.clone());
386                                }
387                            }
388                        }
389                        if !actually_removed.is_empty() {
390                            let msg = serde_json::json!({ "unsubscribe": actually_removed });
391                            let _ = write.send(Message::Text(msg.to_string().into())).await;
392                            info!("Removed subscriptions: {:?}", actually_removed);
393                        }
394                    }
395                    StreamCommand::Close => {
396                        info!("Received close command");
397                        let _ = write.send(Message::Close(None)).await;
398                        return Ok(());
399                    }
400                }
401            }
402
403            else => break,
404        }
405    }
406
407    Ok(())
408}
409
410/// Handle incoming text message from Yahoo WebSocket
411fn handle_text_message(
412    text: &str,
413    broadcast_tx: &broadcast::Sender<PriceUpdate>,
414) -> std::result::Result<(), PricingDecodeError> {
415    // Yahoo sends JSON with base64-encoded protobuf in "message" field
416    let json: serde_json::Value =
417        serde_json::from_str(text).map_err(|e| PricingDecodeError::Base64(e.to_string()))?;
418
419    if let Some(encoded) = json.get("message").and_then(|v| v.as_str()) {
420        let pricing_data = PricingData::from_base64(encoded)?;
421        let price_update: PriceUpdate = pricing_data.into();
422
423        // Broadcast to all receivers
424        if broadcast_tx.receiver_count() > 0 {
425            let _ = broadcast_tx.send(price_update);
426        }
427    }
428
429    Ok(())
430}
431
432/// Builder for creating price streams with custom configuration
433pub struct PriceStreamBuilder {
434    symbols: Vec<String>,
435    retry_delay: Duration,
436}
437
438impl PriceStreamBuilder {
439    /// Create a new builder
440    pub fn new() -> Self {
441        Self {
442            symbols: Vec::new(),
443            retry_delay: Duration::from_secs(RECONNECT_BACKOFF_SECS),
444        }
445    }
446
447    /// Add symbols to subscribe to
448    pub fn symbols(mut self, symbols: &[&str]) -> Self {
449        self.symbols.extend(symbols.iter().map(|s| s.to_string()));
450        self
451    }
452
453    /// Set the delay between reconnection attempts (default: 3s)
454    pub fn retry(mut self, delay: Duration) -> Self {
455        self.retry_delay = delay;
456        self
457    }
458
459    /// Build and start the price stream
460    pub async fn build(self) -> StreamResult<PriceStream> {
461        let symbol_refs: Vec<&str> = self.symbols.iter().map(|s| s.as_str()).collect();
462        PriceStream::subscribe_inner(&symbol_refs, self.retry_delay).await
463    }
464}
465
466impl Default for PriceStreamBuilder {
467    fn default() -> Self {
468        Self::new()
469    }
470}