use std::pin::Pin;
use std::task::{Context, Poll};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use tokio_tungstenite::tungstenite::Message;
use crate::error::{FinanceError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Cluster {
Stocks,
Options,
Forex,
Crypto,
Futures,
Indices,
}
impl Cluster {
fn as_str(&self) -> &'static str {
match self {
Self::Stocks => "stocks",
Self::Options => "options",
Self::Forex => "forex",
Self::Crypto => "crypto",
Self::Futures => "futures",
Self::Indices => "indices",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct StreamTrade {
pub ev: Option<String>,
pub sym: Option<String>,
pub p: Option<f64>,
pub s: Option<f64>,
pub x: Option<i32>,
pub c: Option<Vec<i32>>,
pub t: Option<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct StreamQuote {
pub ev: Option<String>,
pub sym: Option<String>,
pub bp: Option<f64>,
pub bs: Option<f64>,
pub ap: Option<f64>,
#[serde(rename = "as")]
pub ask_size: Option<f64>,
pub bx: Option<i32>,
pub ax: Option<i32>,
pub c: Option<Vec<i32>>,
pub t: Option<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct StreamAggregate {
pub ev: Option<String>,
pub sym: Option<String>,
pub o: Option<f64>,
pub h: Option<f64>,
pub l: Option<f64>,
pub c: Option<f64>,
pub v: Option<f64>,
pub vw: Option<f64>,
pub s: Option<i64>,
pub e: Option<i64>,
pub z: Option<u64>,
}
#[derive(Debug, Clone)]
pub enum PolygonMessage {
Trade(StreamTrade),
Quote(StreamQuote),
Aggregate(StreamAggregate),
Status(serde_json::Value),
Unknown(String),
}
pub struct PolygonStreamBuilder {
api_key: String,
cluster: Cluster,
subscriptions: Vec<String>,
}
impl PolygonStreamBuilder {
pub fn cluster(mut self, cluster: Cluster) -> Self {
self.cluster = cluster;
self
}
pub fn subscribe(mut self, channels: &[&str]) -> Self {
self.subscriptions
.extend(channels.iter().map(|s| s.to_string()));
self
}
pub async fn build(self) -> Result<PolygonStream> {
let url = format!("wss://socket.polygon.io/{}", self.cluster.as_str());
let (ws_stream, _) = tokio_tungstenite::connect_async(&url)
.await
.map_err(|e| FinanceError::ApiError(format!("Polygon WebSocket connect error: {e}")))?;
let (write, read) = futures::StreamExt::split(ws_stream);
let write = std::sync::Arc::new(tokio::sync::Mutex::new(write));
{
use futures::SinkExt;
let auth_msg = serde_json::json!({
"action": "auth",
"params": self.api_key
});
write
.lock()
.await
.send(Message::Text(auth_msg.to_string().into()))
.await
.map_err(|e| {
FinanceError::ApiError(format!("Polygon WebSocket auth error: {e}"))
})?;
}
if !self.subscriptions.is_empty() {
use futures::SinkExt;
let sub_msg = serde_json::json!({
"action": "subscribe",
"params": self.subscriptions.join(",")
});
write
.lock()
.await
.send(Message::Text(sub_msg.to_string().into()))
.await
.map_err(|e| {
FinanceError::ApiError(format!("Polygon WebSocket subscribe error: {e}"))
})?;
}
Ok(PolygonStream {
read: Box::pin(read),
_write: write,
})
}
}
pub struct PolygonStream {
read: Pin<
Box<
dyn Stream<Item = std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>
+ Send,
>,
>,
_write: std::sync::Arc<
tokio::sync::Mutex<
futures::stream::SplitSink<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
Message,
>,
>,
>,
}
impl PolygonStream {
pub fn from_singleton() -> Result<PolygonStreamBuilder> {
Ok(PolygonStreamBuilder {
api_key: super::api_key()?,
cluster: Cluster::Stocks,
subscriptions: Vec::new(),
})
}
}
impl Stream for PolygonStream {
type Item = PolygonMessage;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match self.read.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(Message::Text(text)))) => {
return Poll::Ready(Some(parse_message(&text)));
}
Poll::Ready(Some(Ok(Message::Close(_)))) | Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Ready(Some(Ok(_))) => continue, Poll::Ready(Some(Err(_))) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
}
}
fn parse_message(text: &str) -> PolygonMessage {
let events: Vec<serde_json::Value> = match serde_json::from_str(text) {
Ok(v) => v,
Err(_) => return PolygonMessage::Unknown(text.to_string()),
};
for event in events {
let ev = event.get("ev").and_then(|v| v.as_str()).unwrap_or("");
match ev {
"T" | "XT" => {
if let Ok(trade) = serde_json::from_value(event) {
return PolygonMessage::Trade(trade);
}
}
"Q" | "XQ" => {
if let Ok(quote) = serde_json::from_value(event) {
return PolygonMessage::Quote(quote);
}
}
"A" | "AM" | "XA" | "XAM" => {
if let Ok(agg) = serde_json::from_value(event) {
return PolygonMessage::Aggregate(agg);
}
}
"status" => {
return PolygonMessage::Status(event);
}
_ => {}
}
}
PolygonMessage::Unknown(text.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_trade_message() {
let msg =
r#"[{"ev":"T","sym":"AAPL","p":186.19,"s":100,"x":4,"c":[12,37],"t":1705363200000}]"#;
match parse_message(msg) {
PolygonMessage::Trade(t) => {
assert_eq!(t.sym.as_deref(), Some("AAPL"));
assert!((t.p.unwrap() - 186.19).abs() < 0.01);
assert_eq!(t.s.unwrap() as u64, 100);
}
other => panic!("Expected Trade, got {:?}", other),
}
}
#[test]
fn test_parse_quote_message() {
let msg = r#"[{"ev":"Q","sym":"AAPL","bp":186.18,"bs":2,"ap":186.25,"as":3,"bx":19,"ax":11,"t":1705363200000}]"#;
match parse_message(msg) {
PolygonMessage::Quote(q) => {
assert_eq!(q.sym.as_deref(), Some("AAPL"));
assert!((q.bp.unwrap() - 186.18).abs() < 0.01);
assert!((q.ap.unwrap() - 186.25).abs() < 0.01);
}
other => panic!("Expected Quote, got {:?}", other),
}
}
#[test]
fn test_parse_aggregate_message() {
let msg = r#"[{"ev":"AM","sym":"AAPL","o":186.0,"h":186.25,"l":185.90,"c":186.19,"v":1500000,"vw":186.05,"s":1705363200000,"e":1705363260000,"z":823}]"#;
match parse_message(msg) {
PolygonMessage::Aggregate(a) => {
assert_eq!(a.sym.as_deref(), Some("AAPL"));
assert!((a.c.unwrap() - 186.19).abs() < 0.01);
assert_eq!(a.ev.as_deref(), Some("AM"));
}
other => panic!("Expected Aggregate, got {:?}", other),
}
}
#[test]
fn test_parse_status_message() {
let msg = r#"[{"ev":"status","status":"auth_success","message":"authenticated"}]"#;
match parse_message(msg) {
PolygonMessage::Status(v) => {
assert_eq!(v.get("status").unwrap().as_str().unwrap(), "auth_success");
}
other => panic!("Expected Status, got {:?}", other),
}
}
#[test]
fn test_parse_unknown_message() {
let msg = "not json at all";
assert!(matches!(parse_message(msg), PolygonMessage::Unknown(_)));
}
#[test]
fn test_cluster_as_str() {
assert_eq!(Cluster::Stocks.as_str(), "stocks");
assert_eq!(Cluster::Options.as_str(), "options");
assert_eq!(Cluster::Crypto.as_str(), "crypto");
assert_eq!(Cluster::Futures.as_str(), "futures");
assert_eq!(Cluster::Indices.as_str(), "indices");
}
}