1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
pub mod client;
pub mod connection;
pub mod error;
pub mod model;

use crate::error::ClientError;
use crate::model::{Candle, Trade};
use async_trait::async_trait;
use serde::Serialize;
use tokio::net::TcpStream;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use tracing::debug;

// Todo: general:
//  - Increase test coverage significantly now you know the PoC design works
//  - Unsure all .unwrap()s have been exchanged for more robust handling
//  - Ensure proper error handling & swapping unwraps() for more robust handling
//     '-> ensure all methods are returning an appropriate Result which is handled by caller

// Todo: connection.rs:
//  - Improve method of confirming subscription request so test_subscribe unit test passed
//     '-> subscription succeeded even if it didn't, need to confirm first message arrives?
//     '-> ensure logging is aligned once this has been done
//  - manage() add in connection fixing, reconnections

/// Useful type alias for a [WebSocketStream] connection.
pub type WSStream = WebSocketStream<MaybeTlsStream<TcpStream>>;

/// Client trait defining the behaviour of all implementing ExchangeClients. All methods return
/// a stream of normalised data.
#[async_trait]
pub trait ExchangeClient {
    const EXCHANGE_NAME: &'static str;

    async fn consume_trades(
        &mut self,
        symbol: String,
    ) -> Result<UnboundedReceiverStream<Trade>, ClientError>;
    async fn consume_candles(
        &mut self,
        symbol: String,
        interval: &str,
    ) -> Result<UnboundedReceiverStream<Candle>, ClientError>;
}

/// Utilised to subscribe to an exchange's [WebSocketStream] via a ConnectionHandler (eg/ Trade stream).
pub trait Subscription {
    /// Constructs a new [Subscription] implementation.
    fn new(stream_name: String, ticker_pair: String) -> Self;
    /// Serializes the [Subscription] in a String data format.
    fn as_text(&self) -> Result<String, ClientError>
    where
        Self: Serialize,
    {
        Ok(serde_json::to_string(self)?)
    }
}

/// Returns a stream identifier that can be used to route messages from a [Subscription].
pub trait StreamIdentifier {
    fn get_stream_id(&self) -> Identifier;
}

/// Enum returned from [StreamIdentifier] representing if a struct has an identifiable stream Id.
pub enum Identifier {
    Yes(String),
    No,
}

/// Connect asynchronously to an exchange's server, returning a [WebSocketStream].
async fn connect(base_uri: &String) -> Result<WSStream, ClientError> {
    debug!("Establishing WebSocket connection to: {:?}", base_uri);
    connect_async(base_uri)
        .await
        .and_then(|(ws_stream, _)| Ok(ws_stream))
        .map_err(|err| ClientError::WebSocketConnect(err))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_connect() {
        struct TestCase {
            input_base_uri: String,
            expected_can_connect: bool,
        }

        let test_cases = vec![
            TestCase {
                // Test case 0: Not a valid WS base URI
                input_base_uri: "not a valid base uri".to_string(),
                expected_can_connect: false,
            },
            TestCase {
                // Test case 1: Valid Binance WS base URI
                input_base_uri: "wss://stream.binance.com:9443/ws".to_string(),
                expected_can_connect: true,
            },
            TestCase {
                // Test case 2: Valid Bitstamp WS base URI
                input_base_uri: "wss://ws.bitstamp.net/".to_string(),
                expected_can_connect: true,
            },
        ];

        for (index, test) in test_cases.into_iter().enumerate() {
            let actual_result = connect(&test.input_base_uri).await;
            assert_eq!(
                test.expected_can_connect,
                actual_result.is_ok(),
                "Test case: {:?}",
                index
            );
        }
    }
}