finance_query/streaming/
client.rs1use std::collections::HashSet;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11use futures::SinkExt;
12use futures::stream::Stream;
13use tokio::sync::{RwLock, broadcast, mpsc};
14use tokio::time::interval;
15use tokio_stream::wrappers::BroadcastStream;
16use tokio_tungstenite::{connect_async, tungstenite::Message};
17use tracing::{debug, error, info, warn};
18
19use super::pricing::{PriceUpdate, PricingData, PricingDecodeError};
20use crate::error::YahooError;
21
22pub type StreamResult<T> = std::result::Result<T, StreamError>;
24
25#[derive(Debug, Clone)]
27pub enum StreamError {
28 ConnectionFailed(String),
30 WebSocketError(String),
32 DecodeError(String),
34}
35
36impl std::fmt::Display for StreamError {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 StreamError::ConnectionFailed(e) => write!(f, "Connection failed: {}", e),
40 StreamError::WebSocketError(e) => write!(f, "WebSocket error: {}", e),
41 StreamError::DecodeError(e) => write!(f, "Decode error: {}", e),
42 }
43 }
44}
45
46impl std::error::Error for StreamError {}
47
48impl From<StreamError> for YahooError {
49 fn from(e: StreamError) -> Self {
50 YahooError::ResponseStructureError {
51 field: "streaming".to_string(),
52 context: e.to_string(),
53 }
54 }
55}
56
57const YAHOO_WS_URL: &str = "wss://streamer.finance.yahoo.com/?version=2";
59
60const HEARTBEAT_INTERVAL_SECS: u64 = 15;
62
63const RECONNECT_BACKOFF_SECS: u64 = 3;
65
66const CHANNEL_CAPACITY: usize = 1024;
68
69pub struct PriceStream {
95 inner: BroadcastStream<PriceUpdate>,
96 _handle: Arc<StreamHandle>,
97}
98
99struct StreamHandle {
101 command_tx: mpsc::Sender<StreamCommand>,
102 broadcast_tx: broadcast::Sender<PriceUpdate>,
103}
104
105enum StreamCommand {
107 Subscribe(Vec<String>),
108 Unsubscribe(Vec<String>),
109 Close,
110}
111
112impl PriceStream {
113 pub async fn subscribe(symbols: &[&str]) -> StreamResult<Self> {
130 let (broadcast_tx, broadcast_rx) = broadcast::channel(CHANNEL_CAPACITY);
131 let (command_tx, command_rx) = mpsc::channel(32);
132
133 let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
134 let initial_symbols = symbols.clone();
135
136 let tx_clone = broadcast_tx.clone();
137
138 tokio::spawn(async move {
140 if let Err(e) = run_websocket_loop(initial_symbols, broadcast_tx, command_rx).await {
141 error!("WebSocket loop error: {}", e);
142 }
143 });
144
145 let handle = Arc::new(StreamHandle {
146 command_tx,
147 broadcast_tx: tx_clone,
148 });
149
150 Ok(PriceStream {
151 inner: BroadcastStream::new(broadcast_rx),
152 _handle: handle,
153 })
154 }
155
156 pub fn resubscribe(&self) -> Self {
160 PriceStream {
161 inner: BroadcastStream::new(self._handle.broadcast_tx.subscribe()),
162 _handle: Arc::clone(&self._handle),
163 }
164 }
165
166 pub async fn add_symbols(&self, symbols: &[&str]) {
180 let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
181 let _ = self
182 ._handle
183 .command_tx
184 .send(StreamCommand::Subscribe(symbols))
185 .await;
186 }
187
188 pub async fn remove_symbols(&self, symbols: &[&str]) {
202 let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
203 let _ = self
204 ._handle
205 .command_tx
206 .send(StreamCommand::Unsubscribe(symbols))
207 .await;
208 }
209
210 pub async fn close(&self) {
212 let _ = self._handle.command_tx.send(StreamCommand::Close).await;
213 }
214}
215
216impl Stream for PriceStream {
217 type Item = PriceUpdate;
218
219 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
220 match Pin::new(&mut self.inner).poll_next(cx) {
221 Poll::Ready(Some(Ok(data))) => Poll::Ready(Some(data)),
222 Poll::Ready(Some(Err(e))) => {
223 warn!("Broadcast error: {:?}", e);
224 cx.waker().wake_by_ref();
226 Poll::Pending
227 }
228 Poll::Ready(None) => Poll::Ready(None),
229 Poll::Pending => Poll::Pending,
230 }
231 }
232}
233
234async fn run_websocket_loop(
236 initial_symbols: Vec<String>,
237 broadcast_tx: broadcast::Sender<PriceUpdate>,
238 mut command_rx: mpsc::Receiver<StreamCommand>,
239) -> StreamResult<()> {
240 let subscriptions = Arc::new(RwLock::new(HashSet::<String>::from_iter(initial_symbols)));
241
242 loop {
243 match connect_and_stream(&subscriptions, &broadcast_tx, &mut command_rx).await {
244 Ok(()) => {
245 info!("WebSocket connection closed gracefully");
246 break;
247 }
248 Err(e) => {
249 error!(
250 "WebSocket error: {}, reconnecting in {}s...",
251 e, RECONNECT_BACKOFF_SECS
252 );
253 tokio::time::sleep(Duration::from_secs(RECONNECT_BACKOFF_SECS)).await;
254 }
255 }
256 }
257
258 Ok(())
259}
260
261async fn connect_and_stream(
263 subscriptions: &Arc<RwLock<HashSet<String>>>,
264 broadcast_tx: &broadcast::Sender<PriceUpdate>,
265 command_rx: &mut mpsc::Receiver<StreamCommand>,
266) -> StreamResult<()> {
267 use futures::StreamExt;
268
269 info!("Connecting to Yahoo Finance WebSocket...");
270
271 let (ws_stream, _) = connect_async(YAHOO_WS_URL)
272 .await
273 .map_err(|e| StreamError::ConnectionFailed(e.to_string()))?;
274
275 info!("Connected to Yahoo Finance WebSocket");
276
277 let (mut write, mut read) = ws_stream.split();
278
279 {
281 let subs = subscriptions.read().await;
282 if !subs.is_empty() {
283 let symbols: Vec<&str> = subs.iter().map(|s| s.as_str()).collect();
284 let msg = serde_json::json!({ "subscribe": symbols });
285 write
286 .send(Message::Text(msg.to_string().into()))
287 .await
288 .map_err(|e| StreamError::WebSocketError(e.to_string()))?;
289 info!("Subscribed to {} symbols", symbols.len());
290 }
291 }
292
293 let heartbeat_subs = Arc::clone(subscriptions);
295 let (heartbeat_tx, mut heartbeat_rx) = mpsc::channel::<Message>(32);
296
297 tokio::spawn(async move {
298 let mut interval = interval(Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
299 loop {
300 interval.tick().await;
301 let subs = heartbeat_subs.read().await;
302 if !subs.is_empty() {
303 let symbols: Vec<&str> = subs.iter().map(|s| s.as_str()).collect();
304 let msg = serde_json::json!({ "subscribe": symbols });
305 if heartbeat_tx
306 .send(Message::Text(msg.to_string().into()))
307 .await
308 .is_err()
309 {
310 break;
311 }
312 debug!("Heartbeat subscription sent for {} symbols", symbols.len());
313 }
314 }
315 });
316
317 loop {
318 tokio::select! {
319 Some(msg) = read.next() => {
321 match msg {
322 Ok(Message::Text(text)) => {
323 if let Err(e) = handle_text_message(&text, broadcast_tx) {
324 warn!("Failed to handle message: {}", e);
325 }
326 }
327 Ok(Message::Binary(data)) => {
328 debug!("Received binary message: {} bytes", data.len());
329 }
330 Ok(Message::Close(_)) => {
331 info!("Received close frame");
332 break;
333 }
334 Ok(Message::Ping(data)) => {
335 let _ = write.send(Message::Pong(data)).await;
336 }
337 Ok(_) => {}
338 Err(e) => {
339 error!("WebSocket read error: {}", e);
340 return Err(StreamError::WebSocketError(e.to_string()));
341 }
342 }
343 }
344
345 Some(msg) = heartbeat_rx.recv() => {
347 if let Err(e) = write.send(msg).await {
348 error!("Failed to send heartbeat: {}", e);
349 return Err(StreamError::WebSocketError(e.to_string()));
350 }
351 }
352
353 Some(cmd) = command_rx.recv() => {
355 match cmd {
356 StreamCommand::Subscribe(symbols) => {
357 let mut newly_added = Vec::new();
358 {
359 let mut subs = subscriptions.write().await;
360 for s in &symbols {
361 if subs.insert(s.clone()) {
362 newly_added.push(s.clone());
363 }
364 }
365 }
366 if !newly_added.is_empty() {
367 let msg = serde_json::json!({ "subscribe": newly_added });
368 let _ = write.send(Message::Text(msg.to_string().into())).await;
369 info!("Added subscriptions: {:?}", newly_added);
370 }
371 }
372 StreamCommand::Unsubscribe(symbols) => {
373 let mut actually_removed = Vec::new();
374 {
375 let mut subs = subscriptions.write().await;
376 for s in &symbols {
377 if subs.remove(s) {
378 actually_removed.push(s.clone());
379 }
380 }
381 }
382 if !actually_removed.is_empty() {
383 let msg = serde_json::json!({ "unsubscribe": actually_removed });
384 let _ = write.send(Message::Text(msg.to_string().into())).await;
385 info!("Removed subscriptions: {:?}", actually_removed);
386 }
387 }
388 StreamCommand::Close => {
389 info!("Received close command");
390 let _ = write.send(Message::Close(None)).await;
391 return Ok(());
392 }
393 }
394 }
395
396 else => break,
397 }
398 }
399
400 Ok(())
401}
402
403fn handle_text_message(
405 text: &str,
406 broadcast_tx: &broadcast::Sender<PriceUpdate>,
407) -> std::result::Result<(), PricingDecodeError> {
408 let json: serde_json::Value =
410 serde_json::from_str(text).map_err(|e| PricingDecodeError::Base64(e.to_string()))?;
411
412 if let Some(encoded) = json.get("message").and_then(|v| v.as_str()) {
413 let pricing_data = PricingData::from_base64(encoded)?;
414 let price_update: PriceUpdate = pricing_data.into();
415
416 if broadcast_tx.receiver_count() > 0 {
418 let _ = broadcast_tx.send(price_update);
419 }
420 }
421
422 Ok(())
423}
424
425pub struct PriceStreamBuilder {
427 symbols: Vec<String>,
428 reconnect_delay: Duration,
429}
430
431impl PriceStreamBuilder {
432 pub fn new() -> Self {
434 Self {
435 symbols: Vec::new(),
436 reconnect_delay: Duration::from_secs(RECONNECT_BACKOFF_SECS),
437 }
438 }
439
440 pub fn symbols(mut self, symbols: &[&str]) -> Self {
442 self.symbols.extend(symbols.iter().map(|s| s.to_string()));
443 self
444 }
445
446 pub fn reconnect_delay(mut self, delay: Duration) -> Self {
448 self.reconnect_delay = delay;
449 self
450 }
451
452 pub async fn build(self) -> StreamResult<PriceStream> {
454 let symbol_refs: Vec<&str> = self.symbols.iter().map(|s| s.as_str()).collect();
455 PriceStream::subscribe(&symbol_refs).await
456 }
457}
458
459impl Default for PriceStreamBuilder {
460 fn default() -> Self {
461 Self::new()
462 }
463}