use std::collections::hash_map;
use std::collections::HashMap;
use super::super::protocol::*;
use super::super::utils::*;
use futures_util::pin_mut;
use futures_util::stream::SplitStream;
use tokio::sync::mpsc;
use tokio_tungstenite::{
connect_async, tungstenite::client::IntoClientRequest, tungstenite::Message, MaybeTlsStream,
WebSocketStream,
};
use futures_util::{stream::SplitSink, SinkExt, StreamExt};
use tokio::net::TcpStream;
const CONNECTION: &str = "wss://data.tradingview.com/socket.io/websocket";
#[allow(dead_code)]
#[derive(Debug, PartialEq)]
enum FieldTypes {
All,
Price,
}
#[derive(Debug, Clone)]
pub struct SymbolData {
pub symbol: String,
pub price: f64,
pub technical_analysis: f64,
}
const FIELDS: [&str; 48] = [
"base-currency-logoid",
"ch",
"chp",
"currency-logoid",
"currency_code",
"current_session",
"description",
"exchange",
"format",
"fractional",
"is_tradable",
"language",
"local_description",
"logoid",
"lp",
"lp_time",
"minmov",
"minmove2",
"original_name",
"pricescale",
"pro_name",
"short_name",
"type",
"update_mode",
"volume",
"ask",
"bid",
"fundamentals",
"high_price",
"low_price",
"open_price",
"prev_close_price",
"rch",
"rchp",
"rtc",
"rtc_time",
"status",
"industry",
"basic_eps_net_income",
"beta_1_year",
"market_cap_basic",
"earnings_per_share_basic_ttm",
"price_earnings_ttm",
"sector",
"dividends_yield",
"timezone",
"country_code",
"provider_id",
];
pub struct Session {
session_id: String,
pub tx_to_send: mpsc::Sender<String>,
data: HashMap<String, (f64, f64)>,
rx_to_send: Option<mpsc::Receiver<String>>,
pub read: Option<SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
processors: Vec<MessageProcessor>,
}
impl Session {
pub async fn start(&self) {
self.tx_to_send
.send(format_ws_packet(WSPacket {
m: "quote_create_session".to_string(),
p: vec![(self.session_id).to_owned()],
}))
.await
.unwrap();
self.tx_to_send
.send(format_ws_packet(WSPacket {
m: "quote_set_fields".to_string(),
p: [
vec![(self.session_id).to_owned()],
get_quote_fields(FieldTypes::Price),
]
.concat(),
}))
.await
.unwrap();
}
pub async fn connect(&mut self) {
let mut request = CONNECTION.into_client_request().unwrap();
request.headers_mut().append(
http::header::ORIGIN,
"https://s.tradingview.com".parse().unwrap(),
);
let (ws_stream, _) = connect_async(request).await.expect("Failed to connect");
let (write, read) = ws_stream.split();
self.read = Some(read);
let rx_to_send = self.rx_to_send.take().expect("rx_to_send is None");
tokio::spawn(send_message(rx_to_send, write));
self.tx_to_send
.send(format_ws_packet(WSPacket {
m: "set_auth_token".to_owned(),
p: vec!["unauthorized_user_token".to_owned()],
}))
.await
.unwrap();
}
pub async fn add_symbol(&self, to_add: &str) {
if !self.data.keys().any(|i| i == to_add) {
self.tx_to_send
.send(format_ws_packet(WSPacket {
m: "quote_add_symbols".to_string(),
p: vec![self.session_id.to_owned(), to_add.to_owned()],
}))
.await
.unwrap();
}
}
pub fn get_data(&self, symbol: &str) -> (f64, f64) {
match self.data.get(symbol) {
Some(internal_data) => internal_data.to_owned().to_owned(),
None => (0.0, 0.0),
}
}
pub fn set_data_price(&mut self, symbol: &str, data: f64) {
self.data
.entry(symbol.to_owned())
.and_modify(|x| *x = (data, x.1))
.or_insert((data, 0.0));
}
pub fn set_data_ta(&mut self, symbol: &str, data: f64) {
self.data
.entry(symbol.to_owned())
.and_modify(|x| *x = (x.0, data))
.or_insert((0.0, data));
}
pub fn keys(&self) -> hash_map::IntoKeys<std::string::String, (f64, f64)> {
self.data.clone().into_keys()
}
pub async fn process_stream(&mut self) {
let ws_to_stream = {
self.read.take().expect("rx_to_send is None").for_each(
|message: Result<Message, tokio_tungstenite::tungstenite::Error>| {
let tx_to_send = self.tx_to_send.clone();
let processors = self.processors.clone();
async move {
let data = message
.expect("Message is an invalid format")
.into_text()
.expect("Could not turn into text");
let parsed_data = parse_ws_packet(&data);
println!("\x1b[91m🠳\x1b[0m {:#?}", data);
for d in parsed_data {
for processor in processors.iter() {
{
let d = d.clone();
let tx_to_send = tx_to_send.clone();
let processor = *processor;
tokio::task::spawn_blocking(move || {
processor(&d, tx_to_send.clone());
})
.await
.expect("Task panicked")
}
}
}
}
},
)
};
pin_mut!(ws_to_stream);
ws_to_stream.await;
}
pub fn add_processor(&mut self, processor: MessageProcessor) {
self.processors.push(processor);
}
}
async fn send_message(
mut rx: mpsc::Receiver<String>,
mut interface: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
) {
loop {
match rx.recv().await {
Some(data) => {
println!("\x1b[92m🠱\x1b[0m {:#?}", &data);
let message = Message::from(data);
interface.send(message).await.unwrap();
}
None => {
continue;
}
}
}
}
pub async fn constructor() -> Session {
let session_id = generate_session_id(None);
let (tx_to_send, rx_to_send) = mpsc::channel::<String>(20);
let current_session = Session {
session_id,
tx_to_send,
data: HashMap::new(),
rx_to_send: Some(rx_to_send),
read: None,
processors: vec![process_heartbeat],
};
current_session.start().await;
current_session
}
fn get_quote_fields(field: FieldTypes) -> Vec<String> {
match field {
FieldTypes::All => FIELDS.map(|x| x.to_owned()).to_vec(),
FieldTypes::Price => vec![
"lp".to_owned(),
"high_price".to_owned(),
"low_price".to_owned(),
"price_52_week_high".to_owned(),
"price_52_week_low".to_owned(),
],
}
}
pub type MessageProcessor = fn(&str, mpsc::Sender<String>);
pub fn process_heartbeat(message: &str, tx_to_send: mpsc::Sender<String>) {
if message.contains("~h~") {
let ping = format_ws_ping(message.replace("~h~", "").parse().unwrap());
tx_to_send.blocking_send(ping).unwrap();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_quote_fields() {
let quote_price = get_quote_fields(FieldTypes::Price);
assert_eq!(
quote_price,
vec![
"lp",
"high_price",
"low_price",
"price_52_week_high",
"price_52_week_low"
],
"The quote fields should include only 5 fields"
);
let quote_all = get_quote_fields(FieldTypes::All);
assert_eq!(
quote_all,
FIELDS.to_vec(),
"The quote fields should include all the fields"
);
}
#[test]
fn test_field_types() {
assert_eq!(
FieldTypes::All,
FieldTypes::All,
"The `All` variant should be equal to itself"
);
assert_ne!(
FieldTypes::All,
FieldTypes::Price,
"The `All` variant should not be equal to the `Price` variant"
);
assert_eq!(
FieldTypes::Price,
FieldTypes::Price,
"The `Price` variant should be equal to itself"
);
assert_ne!(
FieldTypes::Price,
FieldTypes::All,
"The `Price` variant should not be equal to the `All` variant"
);
}
}