use std::collections::HashSet;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use futures::SinkExt;
use futures::stream::Stream;
use tokio::sync::{RwLock, broadcast, mpsc};
use tokio::time::interval;
use tokio_stream::wrappers::BroadcastStream;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use tracing::{debug, error, info, warn};
use super::pricing::{PriceUpdate, PricingData, PricingDecodeError};
use crate::error::FinanceError;
pub type StreamResult<T> = std::result::Result<T, StreamError>;
#[derive(Debug, Clone)]
pub enum StreamError {
ConnectionFailed(String),
WebSocketError(String),
DecodeError(String),
}
impl std::fmt::Display for StreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StreamError::ConnectionFailed(e) => write!(f, "Connection failed: {}", e),
StreamError::WebSocketError(e) => write!(f, "WebSocket error: {}", e),
StreamError::DecodeError(e) => write!(f, "Decode error: {}", e),
}
}
}
impl std::error::Error for StreamError {}
impl From<StreamError> for FinanceError {
fn from(e: StreamError) -> Self {
FinanceError::ResponseStructureError {
field: "streaming".to_string(),
context: e.to_string(),
}
}
}
const YAHOO_WS_URL: &str = "wss://streamer.finance.yahoo.com/?version=2";
const HEARTBEAT_INTERVAL_SECS: u64 = 15;
const RECONNECT_BACKOFF_SECS: u64 = 3;
const CHANNEL_CAPACITY: usize = 1024;
pub struct PriceStream {
inner: BroadcastStream<PriceUpdate>,
_handle: Arc<StreamHandle>,
}
struct StreamHandle {
command_tx: mpsc::Sender<StreamCommand>,
broadcast_tx: broadcast::Sender<PriceUpdate>,
}
enum StreamCommand {
Subscribe(Vec<String>),
Unsubscribe(Vec<String>),
Close,
}
impl PriceStream {
pub async fn subscribe(symbols: &[&str]) -> StreamResult<Self> {
Self::subscribe_inner(symbols, Duration::from_secs(RECONNECT_BACKOFF_SECS)).await
}
async fn subscribe_inner(symbols: &[&str], retry_delay: Duration) -> StreamResult<Self> {
let (broadcast_tx, broadcast_rx) = broadcast::channel(CHANNEL_CAPACITY);
let (command_tx, command_rx) = mpsc::channel(32);
let initial_symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
let tx_clone = broadcast_tx.clone();
tokio::spawn(async move {
if let Err(e) =
run_websocket_loop(initial_symbols, broadcast_tx, command_rx, retry_delay).await
{
error!("WebSocket loop error: {}", e);
}
});
let handle = Arc::new(StreamHandle {
command_tx,
broadcast_tx: tx_clone,
});
Ok(PriceStream {
inner: BroadcastStream::new(broadcast_rx),
_handle: handle,
})
}
pub fn resubscribe(&self) -> Self {
PriceStream {
inner: BroadcastStream::new(self._handle.broadcast_tx.subscribe()),
_handle: Arc::clone(&self._handle),
}
}
pub async fn add_symbols(&self, symbols: &[&str]) {
let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
let _ = self
._handle
.command_tx
.send(StreamCommand::Subscribe(symbols))
.await;
}
pub async fn remove_symbols(&self, symbols: &[&str]) {
let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
let _ = self
._handle
.command_tx
.send(StreamCommand::Unsubscribe(symbols))
.await;
}
pub async fn close(&self) {
let _ = self._handle.command_tx.send(StreamCommand::Close).await;
}
}
impl Stream for PriceStream {
type Item = PriceUpdate;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(Ok(data))) => Poll::Ready(Some(data)),
Poll::Ready(Some(Err(e))) => {
warn!("Broadcast error: {:?}", e);
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
async fn run_websocket_loop(
initial_symbols: Vec<String>,
broadcast_tx: broadcast::Sender<PriceUpdate>,
mut command_rx: mpsc::Receiver<StreamCommand>,
retry_delay: Duration,
) -> StreamResult<()> {
let subscriptions = Arc::new(RwLock::new(HashSet::<String>::from_iter(initial_symbols)));
loop {
match connect_and_stream(&subscriptions, &broadcast_tx, &mut command_rx).await {
Ok(()) => {
info!("WebSocket connection closed gracefully");
break;
}
Err(e) => {
error!(
"WebSocket error: {}, reconnecting in {:.1}s...",
e,
retry_delay.as_secs_f32()
);
tokio::time::sleep(retry_delay).await;
}
}
}
Ok(())
}
async fn connect_and_stream(
subscriptions: &Arc<RwLock<HashSet<String>>>,
broadcast_tx: &broadcast::Sender<PriceUpdate>,
command_rx: &mut mpsc::Receiver<StreamCommand>,
) -> StreamResult<()> {
use futures::StreamExt;
info!("Connecting to Yahoo Finance WebSocket...");
let (ws_stream, _) = connect_async(YAHOO_WS_URL)
.await
.map_err(|e| StreamError::ConnectionFailed(e.to_string()))?;
info!("Connected to Yahoo Finance WebSocket");
let (mut write, mut read) = ws_stream.split();
{
let subs = subscriptions.read().await;
if !subs.is_empty() {
let symbols: Vec<&str> = subs.iter().map(|s| s.as_str()).collect();
let msg = serde_json::json!({ "subscribe": symbols });
write
.send(Message::Text(msg.to_string().into()))
.await
.map_err(|e| StreamError::WebSocketError(e.to_string()))?;
info!("Subscribed to {} symbols", symbols.len());
}
}
let heartbeat_subs = Arc::clone(subscriptions);
let (heartbeat_tx, mut heartbeat_rx) = mpsc::channel::<Message>(32);
tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
loop {
interval.tick().await;
let subs = heartbeat_subs.read().await;
if !subs.is_empty() {
let symbols: Vec<&str> = subs.iter().map(|s| s.as_str()).collect();
let msg = serde_json::json!({ "subscribe": symbols });
if heartbeat_tx
.send(Message::Text(msg.to_string().into()))
.await
.is_err()
{
break;
}
debug!("Heartbeat subscription sent for {} symbols", symbols.len());
}
}
});
loop {
tokio::select! {
Some(msg) = read.next() => {
match msg {
Ok(Message::Text(text)) => {
if let Err(e) = handle_text_message(&text, broadcast_tx) {
warn!("Failed to handle message: {}", e);
}
}
Ok(Message::Binary(data)) => {
debug!("Received binary message: {} bytes", data.len());
}
Ok(Message::Close(_)) => {
info!("Received close frame");
break;
}
Ok(Message::Ping(data)) => {
let _ = write.send(Message::Pong(data)).await;
}
Ok(_) => {}
Err(e) => {
error!("WebSocket read error: {}", e);
return Err(StreamError::WebSocketError(e.to_string()));
}
}
}
Some(msg) = heartbeat_rx.recv() => {
if let Err(e) = write.send(msg).await {
error!("Failed to send heartbeat: {}", e);
return Err(StreamError::WebSocketError(e.to_string()));
}
}
Some(cmd) = command_rx.recv() => {
match cmd {
StreamCommand::Subscribe(symbols) => {
let mut newly_added = Vec::new();
{
let mut subs = subscriptions.write().await;
for s in &symbols {
if subs.insert(s.clone()) {
newly_added.push(s.clone());
}
}
}
if !newly_added.is_empty() {
let msg = serde_json::json!({ "subscribe": newly_added });
let _ = write.send(Message::Text(msg.to_string().into())).await;
info!("Added subscriptions: {:?}", newly_added);
}
}
StreamCommand::Unsubscribe(symbols) => {
let mut actually_removed = Vec::new();
{
let mut subs = subscriptions.write().await;
for s in &symbols {
if subs.remove(s) {
actually_removed.push(s.clone());
}
}
}
if !actually_removed.is_empty() {
let msg = serde_json::json!({ "unsubscribe": actually_removed });
let _ = write.send(Message::Text(msg.to_string().into())).await;
info!("Removed subscriptions: {:?}", actually_removed);
}
}
StreamCommand::Close => {
info!("Received close command");
let _ = write.send(Message::Close(None)).await;
return Ok(());
}
}
}
else => break,
}
}
Ok(())
}
fn handle_text_message(
text: &str,
broadcast_tx: &broadcast::Sender<PriceUpdate>,
) -> std::result::Result<(), PricingDecodeError> {
let json: serde_json::Value =
serde_json::from_str(text).map_err(|e| PricingDecodeError::Base64(e.to_string()))?;
if let Some(encoded) = json.get("message").and_then(|v| v.as_str()) {
let pricing_data = PricingData::from_base64(encoded)?;
let price_update: PriceUpdate = pricing_data.into();
if broadcast_tx.receiver_count() > 0 {
let _ = broadcast_tx.send(price_update);
}
}
Ok(())
}
pub struct PriceStreamBuilder {
symbols: Vec<String>,
retry_delay: Duration,
}
impl PriceStreamBuilder {
pub fn new() -> Self {
Self {
symbols: Vec::new(),
retry_delay: Duration::from_secs(RECONNECT_BACKOFF_SECS),
}
}
pub fn symbols(mut self, symbols: &[&str]) -> Self {
self.symbols.extend(symbols.iter().map(|s| s.to_string()));
self
}
pub fn retry(mut self, delay: Duration) -> Self {
self.retry_delay = delay;
self
}
pub async fn build(self) -> StreamResult<PriceStream> {
let symbol_refs: Vec<&str> = self.symbols.iter().map(|s| s.as_str()).collect();
PriceStream::subscribe_inner(&symbol_refs, self.retry_delay).await
}
}
impl Default for PriceStreamBuilder {
fn default() -> Self {
Self::new()
}
}