use std::sync::Arc;
use futures_util::{StreamExt, SinkExt};
use serde_json::{json, Value};
use tokio::sync::{broadcast, Mutex};
use tokio_tungstenite::{connect_async, tungstenite::Message, WebSocketStream, MaybeTlsStream};
use crate::core::{
Credentials,
ExchangeError, ExchangeResult,
ConnectionStatus, StreamEvent,
};
use super::endpoints::FinnhubUrls;
use super::auth::FinnhubAuth;
use super::parser::FinnhubParser;
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FinnhubChannel {
Trades(String),
News(String),
PressReleases(String),
}
impl FinnhubChannel {
pub fn subscribe_type(&self) -> &'static str {
match self {
FinnhubChannel::Trades(_) => "subscribe",
FinnhubChannel::News(_) => "subscribe-news",
FinnhubChannel::PressReleases(_) => "subscribe-pr",
}
}
pub fn unsubscribe_type(&self) -> &'static str {
match self {
FinnhubChannel::Trades(_) => "unsubscribe",
FinnhubChannel::News(_) => "unsubscribe-news",
FinnhubChannel::PressReleases(_) => "unsubscribe-pr",
}
}
pub fn symbol(&self) -> &str {
match self {
FinnhubChannel::Trades(s)
| FinnhubChannel::News(s)
| FinnhubChannel::PressReleases(s) => s.as_str(),
}
}
pub fn subscribe_message(&self) -> serde_json::Value {
serde_json::json!({
"type": self.subscribe_type(),
"symbol": self.symbol().to_uppercase()
})
}
pub fn unsubscribe_message(&self) -> serde_json::Value {
serde_json::json!({
"type": self.unsubscribe_type(),
"symbol": self.symbol().to_uppercase()
})
}
}
pub struct FinnhubWebSocket {
auth: FinnhubAuth,
urls: FinnhubUrls,
ws_stream: Arc<Mutex<Option<WsStream>>>,
event_tx: broadcast::Sender<StreamEvent>,
status: Arc<Mutex<ConnectionStatus>>,
}
impl FinnhubWebSocket {
pub async fn new(credentials: Credentials) -> ExchangeResult<Self> {
let auth = FinnhubAuth::new(&credentials)?;
let urls = FinnhubUrls::MAINNET;
let (event_tx, _) = broadcast::channel(1000);
Ok(Self {
auth,
urls,
ws_stream: Arc::new(Mutex::new(None)),
event_tx,
status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
})
}
pub async fn connect(&self) -> ExchangeResult<()> {
let url = self.auth.ws_url_with_auth(self.urls.websocket_url());
let (ws_stream, _) = connect_async(&url)
.await
.map_err(|e| ExchangeError::Network(format!("WebSocket connection failed: {}", e)))?;
*self.ws_stream.lock().await = Some(ws_stream);
*self.status.lock().await = ConnectionStatus::Connected;
Ok(())
}
pub async fn subscribe_ticker(&self, symbol: &str) -> ExchangeResult<()> {
let sub_msg = json!({
"type": "subscribe",
"symbol": symbol.to_uppercase()
});
if let Some(ref mut ws) = *self.ws_stream.lock().await {
ws.send(Message::Text(sub_msg.to_string()))
.await
.map_err(|e| ExchangeError::Network(format!("Subscribe failed: {}", e)))?;
}
Ok(())
}
pub async fn subscribe_news(&self, symbol: &str) -> ExchangeResult<()> {
let sub_msg = json!({
"type": "subscribe-news",
"symbol": symbol.to_uppercase()
});
if let Some(ref mut ws) = *self.ws_stream.lock().await {
ws.send(Message::Text(sub_msg.to_string()))
.await
.map_err(|e| ExchangeError::Network(format!("Subscribe news failed: {}", e)))?;
}
Ok(())
}
pub async fn subscribe_press_releases(&self, symbol: &str) -> ExchangeResult<()> {
let sub_msg = json!({
"type": "subscribe-pr",
"symbol": symbol.to_uppercase()
});
if let Some(ref mut ws) = *self.ws_stream.lock().await {
ws.send(Message::Text(sub_msg.to_string()))
.await
.map_err(|e| ExchangeError::Network(format!("Subscribe press releases failed: {}", e)))?;
}
Ok(())
}
pub async fn unsubscribe_ticker(&self, symbol: &str) -> ExchangeResult<()> {
let unsub_msg = json!({
"type": "unsubscribe",
"symbol": symbol.to_uppercase()
});
if let Some(ref mut ws) = *self.ws_stream.lock().await {
ws.send(Message::Text(unsub_msg.to_string()))
.await
.map_err(|e| ExchangeError::Network(format!("Unsubscribe failed: {}", e)))?;
}
Ok(())
}
pub async fn unsubscribe_news(&self, symbol: &str) -> ExchangeResult<()> {
let unsub_msg = json!({
"type": "unsubscribe-news",
"symbol": symbol.to_uppercase()
});
if let Some(ref mut ws) = *self.ws_stream.lock().await {
ws.send(Message::Text(unsub_msg.to_string()))
.await
.map_err(|e| ExchangeError::Network(format!("Unsubscribe news failed: {}", e)))?;
}
Ok(())
}
pub async fn subscribe_channel(&self, channel: &FinnhubChannel) -> ExchangeResult<()> {
let msg = channel.subscribe_message();
if let Some(ref mut ws) = *self.ws_stream.lock().await {
ws.send(Message::Text(msg.to_string()))
.await
.map_err(|e| ExchangeError::Network(format!("Subscribe channel failed: {}", e)))?;
}
Ok(())
}
pub async fn unsubscribe_channel(&self, channel: &FinnhubChannel) -> ExchangeResult<()> {
let msg = channel.unsubscribe_message();
if let Some(ref mut ws) = *self.ws_stream.lock().await {
ws.send(Message::Text(msg.to_string()))
.await
.map_err(|e| {
ExchangeError::Network(format!("Unsubscribe channel failed: {}", e))
})?;
}
Ok(())
}
pub async fn subscribe_channels(&self, channels: &[FinnhubChannel]) -> ExchangeResult<()> {
for channel in channels {
self.subscribe_channel(channel).await?;
}
Ok(())
}
pub async fn unsubscribe_channels(&self, channels: &[FinnhubChannel]) -> ExchangeResult<()> {
for channel in channels {
self.unsubscribe_channel(channel).await?;
}
Ok(())
}
pub async fn disconnect(&self) -> ExchangeResult<()> {
if let Some(mut ws) = self.ws_stream.lock().await.take() {
ws.close(None)
.await
.map_err(|e| ExchangeError::Network(format!("Disconnect failed: {}", e)))?;
}
*self.status.lock().await = ConnectionStatus::Disconnected;
Ok(())
}
pub async fn status(&self) -> ConnectionStatus {
*self.status.lock().await
}
pub fn event_stream(&self) -> broadcast::Receiver<StreamEvent> {
self.event_tx.subscribe()
}
pub async fn start_receiving(&self) -> ExchangeResult<()> {
loop {
let msg = {
let mut ws_guard = self.ws_stream.lock().await;
if let Some(ref mut ws) = *ws_guard {
match ws.next().await {
Some(Ok(msg)) => msg,
Some(Err(e)) => {
*self.status.lock().await = ConnectionStatus::Disconnected;
return Err(ExchangeError::Network(format!("WebSocket error: {}", e)));
}
None => {
*self.status.lock().await = ConnectionStatus::Disconnected;
return Err(ExchangeError::Network("WebSocket closed".to_string()));
}
}
} else {
return Err(ExchangeError::Network("WebSocket not connected".to_string()));
}
};
match msg {
Message::Text(text) => {
match serde_json::from_str::<Value>(&text) {
Ok(json_msg) => {
match FinnhubParser::parse_ws_message(&json_msg) {
Ok(events) => {
for event in events {
let _ = self.event_tx.send(event);
}
}
Err(e) => {
eprintln!("Failed to parse WebSocket message: {}", e);
}
}
}
Err(e) => {
eprintln!("Failed to parse JSON: {}", e);
}
}
}
Message::Ping(data) => {
let mut ws_guard = self.ws_stream.lock().await;
if let Some(ref mut ws) = *ws_guard {
let _ = ws.send(Message::Pong(data)).await;
}
}
Message::Close(_) => {
*self.status.lock().await = ConnectionStatus::Disconnected;
return Ok(());
}
_ => {}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_websocket_creation() {
let credentials = Credentials::new("test_api_key", "");
let ws = FinnhubWebSocket::new(credentials).await;
assert!(ws.is_ok());
}
#[test]
fn test_subscribe_message_format() {
let msg = json!({
"type": "subscribe",
"symbol": "AAPL"
});
assert_eq!(msg["type"], "subscribe");
assert_eq!(msg["symbol"], "AAPL");
}
#[test]
fn test_unsubscribe_message_format() {
let msg = json!({
"type": "unsubscribe",
"symbol": "AAPL"
});
assert_eq!(msg["type"], "unsubscribe");
assert_eq!(msg["symbol"], "AAPL");
}
#[test]
fn test_channel_trades_subscribe() {
let ch = FinnhubChannel::Trades("AAPL".into());
let msg = ch.subscribe_message();
assert_eq!(msg["type"], "subscribe");
assert_eq!(msg["symbol"], "AAPL");
}
#[test]
fn test_channel_trades_unsubscribe() {
let ch = FinnhubChannel::Trades("AAPL".into());
let msg = ch.unsubscribe_message();
assert_eq!(msg["type"], "unsubscribe");
assert_eq!(msg["symbol"], "AAPL");
}
#[test]
fn test_channel_news_subscribe() {
let ch = FinnhubChannel::News("TSLA".into());
let msg = ch.subscribe_message();
assert_eq!(msg["type"], "subscribe-news");
assert_eq!(msg["symbol"], "TSLA");
}
#[test]
fn test_channel_press_releases_subscribe() {
let ch = FinnhubChannel::PressReleases("MSFT".into());
let msg = ch.subscribe_message();
assert_eq!(msg["type"], "subscribe-pr");
assert_eq!(msg["symbol"], "MSFT");
}
#[test]
fn test_channel_symbol() {
assert_eq!(FinnhubChannel::Trades("AAPL".into()).symbol(), "AAPL");
assert_eq!(FinnhubChannel::News("TSLA".into()).symbol(), "TSLA");
}
}