betfair_stream_api/
tls_sream.rs

1use alloc::sync::Arc;
2use core::net::SocketAddr;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5
6use betfair_adapter::BetfairUrl;
7use betfair_stream_types::request::RequestMessage;
8use betfair_stream_types::response::ResponseMessage;
9use futures::Sink;
10use futures_util::sink::SinkExt;
11use futures_util::{Stream, StreamExt};
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio::net::TcpStream;
14use tokio_util::bytes;
15use tokio_util::codec::{Decoder, Encoder, Framed};
16
17use crate::StreamError;
18
19pub(crate) async fn connect(
20    url: BetfairUrl<betfair_adapter::Stream>,
21) -> Result<RawStreamApiConnection, StreamError> {
22    let url = url.url();
23    tracing::debug!(?url, "connecting to stream");
24
25    let host = url.host_str().ok_or(StreamError::HostStringNotPresent)?;
26    let port = url.port().unwrap_or(443);
27    let socket_addr = tokio::net::lookup_host((host, port))
28        .await
29        .map_err(|err| {
30            tracing::error!(?err, "unable to look up host");
31            StreamError::UnableToLookUpHost {
32                host: host.to_owned(),
33                port,
34            }
35        })?
36        .next();
37    let domain = url.domain();
38    match (domain, socket_addr) {
39        (Some(domain), Some(socket_addr)) => {
40            let connection = connect_tls(domain, socket_addr).await?;
41            tracing::debug!("connecting to Stream API");
42            Ok(connection)
43        }
44        #[cfg(feature = "integration-test")]
45        (None, Some(socket_addr)) => {
46            let connection = connect_tls("localhost", socket_addr).await?;
47            tracing::debug!("connecting to Stream API");
48            Ok(connection)
49        }
50        params => {
51            tracing::error!(?params, "unable to connect to Stream API");
52
53            Err(StreamError::MisconfiguredStreamURL)
54        }
55    }
56}
57
58#[tracing::instrument(err)]
59async fn connect_tls(
60    domain: &str,
61    socket_addr: SocketAddr,
62) -> Result<RawStreamApiConnection, StreamError> {
63    let domain = rustls::pki_types::ServerName::try_from(domain.to_owned()).map_err(|err| {
64        tracing::error!(?err, "unable to convert domain to server name");
65        StreamError::UnableConvertDomainToServerName
66    })?;
67    let stream = TcpStream::connect(&socket_addr).await?;
68    let connector = tls_connector()?;
69    let stream = connector.connect(domain, stream).await.map_err(|err| {
70        tracing::error!(?err, "unable to connect to TLS stream");
71        StreamError::UnableConnectToTlsStream
72    })?;
73    let framed = Framed::new(stream, StreamAPIClientCodec);
74    Ok(internal::RawStreamApiConnection { io: framed })
75}
76
77pub(crate) type RawStreamApiConnection =
78    internal::RawStreamApiConnection<tokio_rustls::client::TlsStream<TcpStream>>;
79
80mod internal {
81    use super::*;
82
83    pub(crate) struct RawStreamApiConnection<
84        IO: AsyncRead + AsyncWrite + core::fmt::Debug + Send + Unpin,
85    > {
86        pub(super) io: Framed<IO, StreamAPIClientCodec>,
87    }
88
89    impl<IO: AsyncRead + AsyncWrite + core::fmt::Debug + Send + Unpin> Stream
90        for RawStreamApiConnection<IO>
91    {
92        type Item = Result<ResponseMessage, CodecError>;
93
94        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95            self.io.poll_next_unpin(cx)
96        }
97    }
98
99    impl<IO: AsyncRead + AsyncWrite + core::fmt::Debug + Send + Unpin> Sink<RequestMessage>
100        for RawStreamApiConnection<IO>
101    {
102        type Error = CodecError;
103
104        fn poll_ready(
105            mut self: Pin<&mut Self>,
106            cx: &mut Context<'_>,
107        ) -> Poll<Result<(), Self::Error>> {
108            self.io.poll_ready_unpin(cx)
109        }
110
111        fn start_send(mut self: Pin<&mut Self>, item: RequestMessage) -> Result<(), Self::Error> {
112            self.io.start_send_unpin(item)
113        }
114
115        fn poll_flush(
116            mut self: Pin<&mut Self>,
117            cx: &mut Context<'_>,
118        ) -> Poll<Result<(), Self::Error>> {
119            self.io.poll_flush_unpin(cx)
120        }
121
122        fn poll_close(
123            mut self: Pin<&mut Self>,
124            cx: &mut Context<'_>,
125        ) -> Poll<Result<(), Self::Error>> {
126            self.io.poll_close_unpin(cx)
127        }
128    }
129}
130
131/// Defines the encoding and decoding of Betfair stream api data structures using tokio
132pub struct StreamAPIClientCodec;
133
134/// Errors that can arise while decoding or encoding betfair stream api data
135#[derive(Debug, thiserror::Error)]
136pub enum CodecError {
137    /// serde error
138    #[error("Serde error: {0}")]
139    Serde(#[from] serde_json::Error),
140    /// io error
141    #[error("IO Error {0}")]
142    IoError(#[from] std::io::Error),
143}
144
145impl Decoder for StreamAPIClientCodec {
146    type Item = ResponseMessage;
147    type Error = CodecError;
148
149    fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
150        // Find position of `\n` first
151        if let Some(pos) = src.iter().position(|&byte| byte == b'\n') {
152            // Check if the preceding byte is `\r`
153            let delimiter_size = if pos > 0 && src[pos - 1] == b'\r' {
154                2
155            } else {
156                1
157            };
158
159            // Extract up to and including the delimiter
160            let line = src.split_to(pos + 1);
161
162            // Separate out the delimiter bytes
163            let (json_part, _) = line.split_at(line.len().saturating_sub(delimiter_size));
164
165            // Now we can parse it as JSON
166            let data = serde_json::from_slice::<Self::Item>(json_part)?;
167            return Ok(Some(data));
168        }
169        Ok(None)
170    }
171}
172
173impl Encoder<RequestMessage> for StreamAPIClientCodec {
174    type Error = CodecError;
175
176    fn encode(
177        &mut self,
178        item: RequestMessage,
179        dst: &mut bytes::BytesMut,
180    ) -> Result<(), Self::Error> {
181        // Serialize the item to a JSON string
182        let json = serde_json::to_string(&item)?;
183        // Write the JSON string to the buffer, followed by a newline
184        dst.extend_from_slice(json.as_bytes());
185        dst.extend_from_slice(b"\r\n");
186        Ok(())
187    }
188}
189
190#[tracing::instrument(err)]
191fn tls_connector() -> Result<tokio_rustls::TlsConnector, StreamError> {
192    use tokio_rustls::TlsConnector;
193
194    let mut roots = rustls::RootCertStore::empty();
195    let native_certs = rustls_native_certs::load_native_certs();
196    for cert in native_certs.certs {
197        roots.add(cert).map_err(|err| {
198            tracing::error!(?err, "Cannot set native certificate");
199            StreamError::CannotSetNativeCertificate
200        })?;
201    }
202
203    #[cfg(feature = "integration-test")]
204    {
205        use crate::CERTIFICATE;
206
207        if let Some(cert) = CERTIFICATE.clone().take() {
208            let mut cert = cert.as_bytes();
209            let cert = rustls_pemfile::certs(&mut cert)
210                .next()
211                .ok_or(StreamError::InvalidCustomCertificate)?
212                .map_err(|_| StreamError::InvalidCustomCertificate)?;
213            roots.add(cert).map_err(|err| {
214                tracing::error!(?err, "Cannot set native certificate");
215                StreamError::CustomCertificateNotSet
216            })?;
217        }
218    };
219
220    let config = rustls::ClientConfig::builder()
221        .with_root_certificates(roots)
222        .with_no_client_auth();
223    Ok(TlsConnector::from(Arc::new(config)))
224}
225
226#[cfg(test)]
227mod tests {
228
229    use core::fmt::Write;
230
231    use super::*;
232
233    #[tokio::test]
234    async fn can_resolve_host_ipv4() {
235        let url = url::Url::parse("tcptls://stream-api.betfair.com:443").unwrap();
236        let host = url.host_str().unwrap();
237        let port = url
238            .port()
239            .unwrap_or_else(|| if url.scheme() == "https" { 443 } else { 80 });
240        let socket_addr = tokio::net::lookup_host((host, port))
241            .await
242            .unwrap()
243            .next()
244            .unwrap();
245        assert!(socket_addr.ip().is_ipv4());
246        assert_eq!(socket_addr.port(), 443);
247    }
248
249    #[test]
250    fn can_decode_single_message() {
251        let msg = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
252        let separator = "\r\n";
253        let data = format!("{msg}{separator}");
254
255        let mut codec = StreamAPIClientCodec;
256        let mut buf = bytes::BytesMut::from(data.as_bytes());
257        let msg = codec.decode(&mut buf).unwrap().unwrap();
258
259        assert!(matches!(msg, ResponseMessage::Connection(_)));
260    }
261
262    #[test]
263    fn can_decode_multiple_messages() {
264        // contains two messages
265        let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
266        let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
267        let separator = "\r\n";
268        let data = format!("{msg_one}{separator}{msg_two}{separator}");
269
270        let mut codec = StreamAPIClientCodec;
271        let mut buf = bytes::BytesMut::from(data.as_bytes());
272        let msg_one = codec.decode(&mut buf).unwrap().unwrap();
273        let msg_two = codec.decode(&mut buf).unwrap().unwrap();
274
275        assert!(matches!(msg_one, ResponseMessage::Connection(_)));
276        assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
277    }
278
279    #[test]
280    fn can_decode_multiple_partial_messages() {
281        // contains two messages
282        let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
283        let msg_two_pt_one = r#"{"op":"ocm","id":3,"clk""#;
284        let msg_two_pt_two = r#":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
285        let separator = "\r\n";
286        let data = format!("{msg_one}{separator}{msg_two_pt_one}");
287
288        let mut codec = StreamAPIClientCodec;
289        let mut buf = bytes::BytesMut::from(data.as_bytes());
290        let msg_one = codec.decode(&mut buf).unwrap().unwrap();
291        let msg_two_attempt = codec.decode(&mut buf).unwrap();
292        assert!(msg_two_attempt.is_none());
293        buf.write_str(msg_two_pt_two).unwrap();
294        buf.write_str(separator).unwrap();
295        let msg_two = codec.decode(&mut buf).unwrap().unwrap();
296
297        assert!(matches!(msg_one, ResponseMessage::Connection(_)));
298        assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
299    }
300
301    #[test]
302    fn can_decode_subsequent_messages() {
303        // contains two messages
304        let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
305        let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
306        let separator = "\r\n";
307        let data = format!("{msg_one}{separator}");
308
309        let mut codec = StreamAPIClientCodec;
310        let mut buf = bytes::BytesMut::from(data.as_bytes());
311        let msg_one = codec.decode(&mut buf).unwrap().unwrap();
312        let msg_two_attempt = codec.decode(&mut buf).unwrap();
313        assert!(msg_two_attempt.is_none());
314        let data = format!("{msg_two}{separator}");
315        buf.write_str(data.as_str()).unwrap();
316        let msg_two = codec.decode(&mut buf).unwrap().unwrap();
317
318        assert!(matches!(msg_one, ResponseMessage::Connection(_)));
319        assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
320    }
321
322    #[test]
323    fn can_encode_message() {
324        let msg = RequestMessage::Authentication(
325            betfair_stream_types::request::authentication_message::AuthenticationMessage {
326                id: Some(1),
327                session: "sss".to_owned(),
328                app_key: "aaaa".to_owned(),
329            },
330        );
331        let mut codec = StreamAPIClientCodec;
332        let mut buf = bytes::BytesMut::new();
333        codec.encode(msg, &mut buf).unwrap();
334
335        let data = buf.freeze();
336        let data = core::str::from_utf8(&data).unwrap();
337
338        // assert that we have the suffix \r\n
339        assert!(data.ends_with("\r\n"));
340        // assert that we have the prefix {"op":"authentication"
341        assert!(data.starts_with("{\"op\":\"authentication\""));
342    }
343}