use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use futures_util::{SinkExt, StreamExt};
use prost::Message;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, Mutex};
use tokio::time::{interval, sleep, Interval};
use tokio_tungstenite::{
connect_async,
tungstenite::{Error as WsError, Message as WsMessage},
MaybeTlsStream, WebSocketStream as TungsteniteStream,
};
mod proto {
include!(concat!(env!("OUT_DIR"), "/yahoo.finance.rs"));
}
pub use proto::{yaticker, Yaticker};
const YAHOO_WS_URL: &str = "wss://streamer.finance.yahoo.com/";
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
const DEFAULT_INITIAL_RECONNECT_DELAY: Duration = Duration::from_secs(1);
const MAX_RECONNECT_DELAY: Duration = Duration::from_secs(60);
const DEFAULT_MAX_RECONNECT_ATTEMPTS: u32 = 0;
const DEFAULT_BACKPRESSURE_BUFFER: usize = 1000;
#[derive(Debug, Serialize, Deserialize)]
struct SubscriptionMessage {
subscribe: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct UnsubscriptionMessage {
unsubscribe: Vec<String>,
}
pub type MessageHandler = Arc<dyn Fn(Yaticker) -> Result<(), Box<dyn std::error::Error + Send + Sync>> + Send + Sync>;
#[derive(Clone)]
pub struct WebSocketConfig {
pub initial_reconnect_delay: Duration,
pub max_reconnect_attempts: u32,
pub heartbeat_interval: Duration,
pub backpressure_buffer_size: usize,
pub auto_reconnect: bool,
}
impl Default for WebSocketConfig {
fn default() -> Self {
Self::new()
}
}
impl WebSocketConfig {
pub fn new() -> Self {
Self {
initial_reconnect_delay: DEFAULT_INITIAL_RECONNECT_DELAY,
max_reconnect_attempts: DEFAULT_MAX_RECONNECT_ATTEMPTS,
heartbeat_interval: HEARTBEAT_INTERVAL,
backpressure_buffer_size: DEFAULT_BACKPRESSURE_BUFFER,
auto_reconnect: true,
}
}
pub fn with_initial_reconnect_delay(mut self, delay: Duration) -> Self {
self.initial_reconnect_delay = delay;
self
}
pub fn with_max_reconnect_attempts(mut self, attempts: u32) -> Self {
self.max_reconnect_attempts = attempts;
self
}
pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
self.heartbeat_interval = interval;
self
}
pub fn with_backpressure_buffer_size(mut self, size: usize) -> Self {
self.backpressure_buffer_size = size;
self
}
pub fn with_auto_reconnect(mut self, enabled: bool) -> Self {
self.auto_reconnect = enabled;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Connected,
Reconnecting,
Disconnected,
}
#[derive(Debug, Clone, Default)]
pub struct StreamStats {
pub messages_received: u64,
pub messages_dropped: u64,
pub reconnect_attempts: u32,
pub successful_reconnects: u32,
pub heartbeats_sent: u64,
}
pub struct WebSocketStream {
ws_stream: Option<TungsteniteStream<MaybeTlsStream<tokio::net::TcpStream>>>,
subscribed_symbols: Vec<String>,
heartbeat: Interval,
config: WebSocketConfig,
state: ConnectionState,
reconnect_attempts: u32,
current_reconnect_delay: Duration,
message_handlers: Vec<MessageHandler>,
backpressure_buffer: Option<mpsc::Receiver<Yaticker>>,
stats: Arc<Mutex<StreamStats>>,
}
impl WebSocketStream {
pub async fn connect() -> Result<Self, WsError> {
Self::connect_with_config(WebSocketConfig::new()).await
}
pub async fn connect_with_config(config: WebSocketConfig) -> Result<Self, WsError> {
let (ws_stream, _) = connect_async(YAHOO_WS_URL).await?;
let heartbeat = interval(config.heartbeat_interval);
Ok(Self {
ws_stream: Some(ws_stream),
subscribed_symbols: Vec::new(),
heartbeat,
config: config.clone(),
state: ConnectionState::Connected,
reconnect_attempts: 0,
current_reconnect_delay: config.initial_reconnect_delay,
message_handlers: Vec::new(),
backpressure_buffer: None,
stats: Arc::new(Mutex::new(StreamStats::default())),
})
}
async fn reconnect(&mut self) -> Result<(), WsError> {
if !self.config.auto_reconnect {
return Err(WsError::AlreadyClosed);
}
if self.config.max_reconnect_attempts > 0
&& self.reconnect_attempts >= self.config.max_reconnect_attempts {
log::error!(
"Max reconnection attempts ({}) exceeded",
self.config.max_reconnect_attempts
);
self.state = ConnectionState::Disconnected;
return Err(WsError::AlreadyClosed);
}
self.state = ConnectionState::Reconnecting;
self.reconnect_attempts += 1;
{
let mut stats = self.stats.lock().await;
stats.reconnect_attempts += 1;
}
log::info!(
"Attempting reconnection #{} after {:?}",
self.reconnect_attempts,
self.current_reconnect_delay
);
sleep(self.current_reconnect_delay).await;
self.current_reconnect_delay = std::cmp::min(
self.current_reconnect_delay * 2,
MAX_RECONNECT_DELAY
);
match connect_async(YAHOO_WS_URL).await {
Ok((ws_stream, _)) => {
self.ws_stream = Some(ws_stream);
self.state = ConnectionState::Connected;
self.current_reconnect_delay = self.config.initial_reconnect_delay;
{
let mut stats = self.stats.lock().await;
stats.successful_reconnects += 1;
}
log::info!("Reconnection successful");
if !self.subscribed_symbols.is_empty() {
log::info!("Re-subscribing to {} symbols", self.subscribed_symbols.len());
self.resubscribe_all().await?;
}
Ok(())
}
Err(e) => {
log::error!("Reconnection failed: {}", e);
Err(e)
}
}
}
async fn resubscribe_all(&mut self) -> Result<(), WsError> {
if self.subscribed_symbols.is_empty() {
return Ok(());
}
let symbols = self.subscribed_symbols.clone();
let message = SubscriptionMessage {
subscribe: symbols,
};
let json = serde_json::to_string(&message)
.map_err(|e| WsError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))?;
if let Some(ws) = &mut self.ws_stream {
ws.send(WsMessage::Text(json)).await?;
}
Ok(())
}
pub fn add_handler<F>(&mut self, handler: F)
where
F: Fn(Yaticker) -> Result<(), Box<dyn std::error::Error + Send + Sync>> + Send + Sync + 'static,
{
self.message_handlers.push(Arc::new(handler));
}
pub fn enable_backpressure(&mut self) {
let (tx, rx) = mpsc::channel(self.config.backpressure_buffer_size);
self.backpressure_buffer = Some(rx);
let stats = Arc::clone(&self.stats);
let handlers = self.message_handlers.clone();
tokio::spawn(async move {
});
}
pub fn state(&self) -> ConnectionState {
self.state
}
pub async fn stats(&self) -> StreamStats {
self.stats.lock().await.clone()
}
pub async fn reset_stats(&mut self) {
let mut stats = self.stats.lock().await;
*stats = StreamStats::default();
}
pub async fn subscribe(&mut self, symbols: &[&str]) -> Result<(), WsError> {
let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
let message = SubscriptionMessage {
subscribe: symbols.clone(),
};
let json = serde_json::to_string(&message)
.map_err(|e| WsError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))?;
if let Some(ws) = &mut self.ws_stream {
ws.send(WsMessage::Text(json)).await?;
self.subscribed_symbols.extend(symbols);
} else {
return Err(WsError::AlreadyClosed);
}
Ok(())
}
pub async fn unsubscribe(&mut self, symbols: &[&str]) -> Result<(), WsError> {
let symbols: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
let message = UnsubscriptionMessage {
unsubscribe: symbols.clone(),
};
let json = serde_json::to_string(&message)
.map_err(|e| WsError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))?;
if let Some(ws) = &mut self.ws_stream {
ws.send(WsMessage::Text(json)).await?;
self.subscribed_symbols.retain(|s| !symbols.contains(s));
} else {
return Err(WsError::AlreadyClosed);
}
Ok(())
}
pub async fn next(&mut self) -> Option<Result<Yaticker, WsError>> {
loop {
if self.backpressure_buffer.is_some() {
return self.next_buffered().await.map(Ok);
}
if self.ws_stream.is_none() && self.config.auto_reconnect {
match self.reconnect().await {
Ok(()) => continue,
Err(e) => return Some(Err(e)),
}
}
tokio::select! {
_ = self.heartbeat.tick() => {
if let Err(e) = self.send_heartbeat().await {
if self.config.auto_reconnect {
log::warn!("Heartbeat failed, attempting reconnection");
self.ws_stream = None;
continue;
}
return Some(Err(e));
}
}
msg = async {
if let Some(ws) = &mut self.ws_stream {
ws.next().await
} else {
None
}
} => {
match msg {
Some(Ok(WsMessage::Binary(data))) => {
{
let mut stats = self.stats.lock().await;
stats.messages_received += 1;
}
let decoded = match BASE64.decode(&data) {
Ok(d) => d,
Err(e) => {
log::error!("Base64 decode error: {}", e);
continue;
}
};
let ticker = match Yaticker::decode(&decoded[..]) {
Ok(ticker) => ticker,
Err(e) => {
log::error!("Protobuf decode error: {}", e);
continue;
}
};
for handler in &self.message_handlers {
if let Err(e) = handler(ticker.clone()) {
log::error!("Message handler error: {}", e);
}
}
return Some(Ok(ticker));
}
Some(Ok(WsMessage::Text(_))) => {
continue;
}
Some(Ok(WsMessage::Ping(data))) => {
if let Some(ws) = &mut self.ws_stream {
if let Err(e) = ws.send(WsMessage::Pong(data)).await {
return Some(Err(e));
}
}
continue;
}
Some(Ok(WsMessage::Pong(_))) => {
continue;
}
Some(Ok(WsMessage::Close(_))) => {
log::info!("WebSocket closed by server");
if self.config.auto_reconnect {
self.ws_stream = None;
continue;
}
return None;
}
Some(Ok(WsMessage::Frame(_))) => {
continue;
}
Some(Err(e)) => {
log::error!("WebSocket error: {}", e);
if self.config.auto_reconnect {
self.ws_stream = None;
continue;
}
return Some(Err(e));
}
None => {
log::info!("WebSocket connection closed");
if self.config.auto_reconnect {
self.ws_stream = None;
continue;
}
return None;
}
}
}
}
}
}
pub async fn next_buffered(&mut self) -> Option<Yaticker> {
if let Some(rx) = &mut self.backpressure_buffer {
rx.recv().await
} else {
None
}
}
async fn send_heartbeat(&mut self) -> Result<(), WsError> {
if self.subscribed_symbols.is_empty() {
return Ok(());
}
let message = SubscriptionMessage {
subscribe: self.subscribed_symbols.clone(),
};
let json = serde_json::to_string(&message)
.map_err(|e| WsError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))?;
if let Some(ws) = &mut self.ws_stream {
ws.send(WsMessage::Text(json)).await?;
{
let mut stats = self.stats.lock().await;
stats.heartbeats_sent += 1;
}
}
Ok(())
}
pub fn subscribed_symbols(&self) -> &[String] {
&self.subscribed_symbols
}
pub async fn close(mut self) -> Result<(), WsError> {
if let Some(mut ws) = self.ws_stream.take() {
ws.close(None).await
} else {
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subscription_message_serialization() {
let msg = SubscriptionMessage {
subscribe: vec!["AAPL".to_string(), "GOOGL".to_string()],
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"subscribe\""));
assert!(json.contains("\"AAPL\""));
assert!(json.contains("\"GOOGL\""));
}
#[test]
fn test_unsubscription_message_serialization() {
let msg = UnsubscriptionMessage {
unsubscribe: vec!["AAPL".to_string()],
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"unsubscribe\""));
assert!(json.contains("\"AAPL\""));
}
#[test]
fn test_config_builder() {
let config = WebSocketConfig::new()
.with_max_reconnect_attempts(5)
.with_backpressure_buffer_size(500)
.with_auto_reconnect(false);
assert_eq!(config.max_reconnect_attempts, 5);
assert_eq!(config.backpressure_buffer_size, 500);
assert!(!config.auto_reconnect);
}
#[test]
fn test_config_defaults() {
let config = WebSocketConfig::new();
assert_eq!(config.initial_reconnect_delay, DEFAULT_INITIAL_RECONNECT_DELAY);
assert_eq!(config.max_reconnect_attempts, DEFAULT_MAX_RECONNECT_ATTEMPTS);
assert_eq!(config.heartbeat_interval, HEARTBEAT_INTERVAL);
assert_eq!(config.backpressure_buffer_size, DEFAULT_BACKPRESSURE_BUFFER);
assert!(config.auto_reconnect);
}
#[test]
fn test_connection_state() {
assert_eq!(ConnectionState::Connected, ConnectionState::Connected);
assert_ne!(ConnectionState::Connected, ConnectionState::Disconnected);
assert_ne!(ConnectionState::Reconnecting, ConnectionState::Disconnected);
}
#[test]
fn test_stream_stats_default() {
let stats = StreamStats::default();
assert_eq!(stats.messages_received, 0);
assert_eq!(stats.messages_dropped, 0);
assert_eq!(stats.reconnect_attempts, 0);
assert_eq!(stats.successful_reconnects, 0);
assert_eq!(stats.heartbeats_sent, 0);
}
}