#![allow(clippy::unwrap_used, clippy::expect_used)]
use futures::{SinkExt, StreamExt};
use rustrade_integration::{
Transformer,
error::SocketError,
protocol::websocket::{WebSocket, WebSocketSerdeParser, WsMessage},
stream::ExchangeStream,
};
use serde::{Deserialize, de};
use serde_json::json;
use std::{collections::VecDeque, str::FromStr};
use tokio_tungstenite::connect_async;
use tracing::debug;
type ExchangeWsStream<Exchange> = ExchangeStream<WebSocketSerdeParser, WebSocket, Exchange>;
type VolumeSum = f64;
#[derive(Deserialize)]
#[serde(untagged, rename_all = "camelCase")]
enum BinanceMessage {
SubResponse {
result: Option<Vec<String>>,
id: u32,
},
Trade {
#[serde(rename = "q", deserialize_with = "de_str")]
quantity: f64,
},
}
struct StatefulTransformer {
sum_of_volume: VolumeSum,
}
impl Transformer for StatefulTransformer {
type Error = SocketError;
type Input = BinanceMessage;
type Output = VolumeSum;
type OutputIter = Vec<Result<Self::Output, Self::Error>>;
fn transform(&mut self, input: Self::Input) -> Self::OutputIter {
match input {
BinanceMessage::SubResponse { result, id } => {
debug!("Received SubResponse for {}: {:?}", id, result);
}
BinanceMessage::Trade { quantity, .. } => {
self.sum_of_volume += quantity;
}
};
vec![Ok(self.sum_of_volume)]
}
}
#[tokio::main]
async fn main() {
let mut binance_conn = connect_async("wss://fstream.binance.com/ws")
.await
.map(|(ws_conn, _)| ws_conn)
.expect("failed to connect");
binance_conn
.send(WsMessage::text(
json!({"method": "SUBSCRIBE","params": ["btcusdt@aggTrade"],"id": 1}).to_string(),
))
.await
.expect("failed to send WsMessage over socket");
let transformer = StatefulTransformer { sum_of_volume: 0.0 };
let mut ws_stream = ExchangeWsStream::new(binance_conn, transformer, VecDeque::new());
while let Some(volume_result) = ws_stream.next().await {
match volume_result {
Ok(cumulative_volume) => {
println!("{cumulative_volume:?}");
}
Err(error) => {
eprintln!("{error}")
}
}
}
}
fn de_str<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: de::Deserializer<'de>,
T: FromStr,
T::Err: std::fmt::Display,
{
let data: String = Deserialize::deserialize(deserializer)?;
data.parse::<T>().map_err(de::Error::custom)
}