use crate::types::{
Bar, StreamData, StreamMessage, Subscription, Subscriptions
};
use crate::STREAMING_URL;
use anyhow::Result;
use futures_util::StreamExt;
use std::time::{Duration, UNIX_EPOCH};
use tokio::sync::mpsc;
pub async fn start_streaming(subscriptions: Subscriptions) -> Result<()> {
let client = reqwest::Client::new();
let mut retries = 3;
let mut delay = Duration::from_secs(3);
loop {
let resp = client.get(STREAMING_URL).send().await;
if let Ok(response) = resp {
retries = 3;
delay = Duration::from_secs(3);
println!("[stream] Connected to streaming endpoint");
let mut stream = response.bytes_stream();
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
let s = String::from_utf8_lossy(&chunk);
for line in s.lines() {
let trimmed = line.trim();
if !trimmed.is_empty() {
process_stream_line(trimmed, &subscriptions);
}
}
}
Err(e) => {
eprintln!("[stream] Error reading from stream: {:?}", e);
break; }
}
}
println!("[stream] Stream ended, will attempt to reconnect");
} else {
eprintln!("[stream] Error fetching from the streaming endpoint: {:?}", resp.err());
}
if retries > 0 {
eprintln!("[stream] Attempting to reconnect in {:?}...", delay);
tokio::time::sleep(delay).await;
retries -= 1;
delay *= 2;
} else {
eprintln!("[stream] Maximum reconnection attempts reached. Resetting retry counter");
retries = 3;
delay = Duration::from_secs(30);
}
}
}
fn process_stream_line(line: &str, subscriptions: &Subscriptions) {
match serde_json::from_str::<StreamMessage>(line) {
Ok(StreamMessage::Data(json_data)) => {
handle_streaming_data(json_data, subscriptions);
}
Ok(StreamMessage::HeartbeatOrId(id)) => {
if id % 100 == 0 {
println!("[stream] Received heartbeat/id: {}", id);
}
}
Ok(StreamMessage::Other(value)) => {
println!("[stream] Received other message format: {:?}", value);
}
Err(e) => {
if line == "{}" {
return;
} else if line.starts_with("data: ") {
let data_content = &line[6..];
if !data_content.is_empty() {
process_stream_line(data_content, subscriptions);
}
} else {
eprintln!("[stream] Error parsing JSON: {:?}", e);
eprintln!("[stream] Raw data: {}", line);
}
}
}
}
pub async fn subscribe_on_stream(
symbol: String,
resolution: String,
subscriber_uid: String,
last_daily_bar: Bar,
subscriptions: Subscriptions,
callback: mpsc::UnboundedSender<Bar>,
) {
if let Ok(mut map) = subscriptions.lock() {
let handler = crate::types::Handler {
id: subscriber_uid.clone(),
callback,
};
let sub = Subscription {
subscriber_uid,
resolution,
last_daily_bar,
handlers: vec![handler],
};
map.insert(symbol.clone(), sub);
println!("[subscribeBars]: Subscribe to streaming. Channel: {}", symbol);
}
}
pub fn unsubscribe_from_stream(subscriber_uid: &str, subscriptions: &Subscriptions) {
if let Ok(mut map) = subscriptions.lock() {
let keys: Vec<_> = map.keys().cloned().collect();
for channel_string in keys {
if let Some(sub_item) = map.get(&channel_string) {
if sub_item.handlers.iter().any(|h| h.id == subscriber_uid) {
println!("[unsubscribeBars]: Unsubscribe from streaming. Channel: {}", channel_string);
map.remove(&channel_string);
break;
}
}
}
}
}
fn handle_streaming_data(data: StreamData, subscriptions: &Subscriptions) {
if let Ok(mut map) = subscriptions.lock() {
let channel_string = &data.id;
if let Some(sub_item) = map.get_mut(channel_string) {
let trade_price = data.p;
let last_bar = &sub_item.last_daily_bar;
let next_daily_bar_time = get_next_daily_bar_time(last_bar.time);
let bar = if data.t >= next_daily_bar_time {
Bar {
time: next_daily_bar_time,
open: trade_price,
high: trade_price,
low: trade_price,
close: trade_price,
}
} else {
Bar {
time: last_bar.time,
open: last_bar.open,
high: last_bar.high.max(trade_price),
low: last_bar.low.min(trade_price),
close: trade_price,
}
};
for handler in &sub_item.handlers {
if let Err(e) = handler.callback.send(bar.clone()) {
eprintln!("[stream] Failed to send bar update: {:?}", e);
}
}
sub_item.last_daily_bar = bar;
}
}
}
fn get_next_daily_bar_time(bar_time: u64) -> u64 {
let seconds = bar_time / 1000;
let date = UNIX_EPOCH + Duration::from_secs(seconds);
let next_date = date + Duration::from_secs(24 * 60 * 60);
next_date
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() * 1000 }
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[test]
fn test_get_next_daily_bar_time() {
let timestamp = 1_717_000_000_000; let next = get_next_daily_bar_time(timestamp);
assert_eq!(next - timestamp, 24 * 60 * 60 * 1000);
}
#[test]
fn test_process_stream_line_empty() {
let subscriptions = Arc::new(Mutex::new(HashMap::new()));
process_stream_line("{}", &subscriptions);
}
}