1use alloc::sync::Arc;
2use core::net::SocketAddr;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5
6use betfair_adapter::BetfairUrl;
7use betfair_stream_types::request::RequestMessage;
8use betfair_stream_types::response::ResponseMessage;
9use futures::Sink;
10use futures_util::sink::SinkExt;
11use futures_util::{Stream, StreamExt};
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio::net::TcpStream;
14use tokio_util::bytes;
15use tokio_util::codec::{Decoder, Encoder, Framed};
16
17use crate::StreamError;
18
19pub(crate) async fn connect(
20 url: BetfairUrl<betfair_adapter::Stream>,
21) -> Result<RawStreamApiConnection, StreamError> {
22 let url = url.url();
23 tracing::debug!(?url, "connecting to stream");
24
25 let host = url.host_str().ok_or(StreamError::HostStringNotPresent)?;
26 let port = url.port().unwrap_or(443);
27 let socket_addr = tokio::net::lookup_host((host, port))
28 .await
29 .map_err(|err| {
30 tracing::error!(?err, "unable to look up host");
31 StreamError::UnableToLookUpHost {
32 host: host.to_owned(),
33 port,
34 }
35 })?
36 .next();
37 let domain = url.domain();
38 match (domain, socket_addr) {
39 (Some(domain), Some(socket_addr)) => {
40 let connection = connect_tls(domain, socket_addr).await?;
41 tracing::debug!("connecting to Stream API");
42 Ok(connection)
43 }
44 #[cfg(feature = "integration-test")]
45 (None, Some(socket_addr)) => {
46 let connection = connect_tls("localhost", socket_addr).await?;
47 tracing::debug!("connecting to Stream API");
48 Ok(connection)
49 }
50 params => {
51 tracing::error!(?params, "unable to connect to Stream API");
52
53 Err(StreamError::MisconfiguredStreamURL)
54 }
55 }
56}
57
58#[tracing::instrument(err)]
59async fn connect_tls(
60 domain: &str,
61 socket_addr: SocketAddr,
62) -> Result<RawStreamApiConnection, StreamError> {
63 let domain = rustls::pki_types::ServerName::try_from(domain.to_owned()).map_err(|err| {
64 tracing::error!(?err, "unable to convert domain to server name");
65 StreamError::UnableConvertDomainToServerName
66 })?;
67 let stream = TcpStream::connect(&socket_addr).await?;
68 let connector = tls_connector()?;
69 let stream = connector.connect(domain, stream).await.map_err(|err| {
70 tracing::error!(?err, "unable to connect to TLS stream");
71 StreamError::UnableConnectToTlsStream
72 })?;
73 let framed = Framed::new(stream, StreamAPIClientCodec);
74 Ok(internal::RawStreamApiConnection { io: framed })
75}
76
77pub(crate) type RawStreamApiConnection =
78 internal::RawStreamApiConnection<tokio_rustls::client::TlsStream<TcpStream>>;
79
80mod internal {
81 use super::*;
82
83 pub(crate) struct RawStreamApiConnection<
84 IO: AsyncRead + AsyncWrite + core::fmt::Debug + Send + Unpin,
85 > {
86 pub(super) io: Framed<IO, StreamAPIClientCodec>,
87 }
88
89 impl<IO: AsyncRead + AsyncWrite + core::fmt::Debug + Send + Unpin> Stream
90 for RawStreamApiConnection<IO>
91 {
92 type Item = Result<ResponseMessage, CodecError>;
93
94 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95 self.io.poll_next_unpin(cx)
96 }
97 }
98
99 impl<IO: AsyncRead + AsyncWrite + core::fmt::Debug + Send + Unpin> Sink<RequestMessage>
100 for RawStreamApiConnection<IO>
101 {
102 type Error = CodecError;
103
104 fn poll_ready(
105 mut self: Pin<&mut Self>,
106 cx: &mut Context<'_>,
107 ) -> Poll<Result<(), Self::Error>> {
108 self.io.poll_ready_unpin(cx)
109 }
110
111 fn start_send(mut self: Pin<&mut Self>, item: RequestMessage) -> Result<(), Self::Error> {
112 self.io.start_send_unpin(item)
113 }
114
115 fn poll_flush(
116 mut self: Pin<&mut Self>,
117 cx: &mut Context<'_>,
118 ) -> Poll<Result<(), Self::Error>> {
119 self.io.poll_flush_unpin(cx)
120 }
121
122 fn poll_close(
123 mut self: Pin<&mut Self>,
124 cx: &mut Context<'_>,
125 ) -> Poll<Result<(), Self::Error>> {
126 self.io.poll_close_unpin(cx)
127 }
128 }
129}
130
131pub struct StreamAPIClientCodec;
133
134#[derive(Debug, thiserror::Error)]
136pub enum CodecError {
137 #[error("Serde error: {0}")]
139 Serde(#[from] serde_json::Error),
140 #[error("IO Error {0}")]
142 IoError(#[from] std::io::Error),
143}
144
145impl Decoder for StreamAPIClientCodec {
146 type Item = ResponseMessage;
147 type Error = CodecError;
148
149 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
150 if let Some(pos) = src.iter().position(|&byte| byte == b'\n') {
152 let delimiter_size = if pos > 0 && src[pos - 1] == b'\r' {
154 2
155 } else {
156 1
157 };
158
159 let line = src.split_to(pos + 1);
161
162 let (json_part, _) = line.split_at(line.len().saturating_sub(delimiter_size));
164
165 let data = serde_json::from_slice::<Self::Item>(json_part)?;
167 return Ok(Some(data));
168 }
169 Ok(None)
170 }
171}
172
173impl Encoder<RequestMessage> for StreamAPIClientCodec {
174 type Error = CodecError;
175
176 fn encode(
177 &mut self,
178 item: RequestMessage,
179 dst: &mut bytes::BytesMut,
180 ) -> Result<(), Self::Error> {
181 let json = serde_json::to_string(&item)?;
183 dst.extend_from_slice(json.as_bytes());
185 dst.extend_from_slice(b"\r\n");
186 Ok(())
187 }
188}
189
190#[tracing::instrument(err)]
191fn tls_connector() -> Result<tokio_rustls::TlsConnector, StreamError> {
192 use tokio_rustls::TlsConnector;
193
194 let mut roots = rustls::RootCertStore::empty();
195 let native_certs = rustls_native_certs::load_native_certs();
196 for cert in native_certs.certs {
197 roots.add(cert).map_err(|err| {
198 tracing::error!(?err, "Cannot set native certificate");
199 StreamError::CannotSetNativeCertificate
200 })?;
201 }
202
203 #[cfg(feature = "integration-test")]
204 {
205 use crate::CERTIFICATE;
206
207 if let Some(cert) = CERTIFICATE.clone().take() {
208 let mut cert = cert.as_bytes();
209 let cert = rustls_pemfile::certs(&mut cert)
210 .next()
211 .ok_or(StreamError::InvalidCustomCertificate)?
212 .map_err(|_| StreamError::InvalidCustomCertificate)?;
213 roots.add(cert).map_err(|err| {
214 tracing::error!(?err, "Cannot set native certificate");
215 StreamError::CustomCertificateNotSet
216 })?;
217 }
218 };
219
220 let config = rustls::ClientConfig::builder()
221 .with_root_certificates(roots)
222 .with_no_client_auth();
223 Ok(TlsConnector::from(Arc::new(config)))
224}
225
226#[cfg(test)]
227mod tests {
228
229 use core::fmt::Write;
230
231 use super::*;
232
233 #[tokio::test]
234 async fn can_resolve_host_ipv4() {
235 let url = url::Url::parse("tcptls://stream-api.betfair.com:443").unwrap();
236 let host = url.host_str().unwrap();
237 let port = url
238 .port()
239 .unwrap_or_else(|| if url.scheme() == "https" { 443 } else { 80 });
240 let socket_addr = tokio::net::lookup_host((host, port))
241 .await
242 .unwrap()
243 .next()
244 .unwrap();
245 assert!(socket_addr.ip().is_ipv4());
246 assert_eq!(socket_addr.port(), 443);
247 }
248
249 #[test]
250 fn can_decode_single_message() {
251 let msg = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
252 let separator = "\r\n";
253 let data = format!("{msg}{separator}");
254
255 let mut codec = StreamAPIClientCodec;
256 let mut buf = bytes::BytesMut::from(data.as_bytes());
257 let msg = codec.decode(&mut buf).unwrap().unwrap();
258
259 assert!(matches!(msg, ResponseMessage::Connection(_)));
260 }
261
262 #[test]
263 fn can_decode_multiple_messages() {
264 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
266 let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
267 let separator = "\r\n";
268 let data = format!("{msg_one}{separator}{msg_two}{separator}");
269
270 let mut codec = StreamAPIClientCodec;
271 let mut buf = bytes::BytesMut::from(data.as_bytes());
272 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
273 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
274
275 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
276 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
277 }
278
279 #[test]
280 fn can_decode_multiple_partial_messages() {
281 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
283 let msg_two_pt_one = r#"{"op":"ocm","id":3,"clk""#;
284 let msg_two_pt_two = r#":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
285 let separator = "\r\n";
286 let data = format!("{msg_one}{separator}{msg_two_pt_one}");
287
288 let mut codec = StreamAPIClientCodec;
289 let mut buf = bytes::BytesMut::from(data.as_bytes());
290 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
291 let msg_two_attempt = codec.decode(&mut buf).unwrap();
292 assert!(msg_two_attempt.is_none());
293 buf.write_str(msg_two_pt_two).unwrap();
294 buf.write_str(separator).unwrap();
295 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
296
297 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
298 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
299 }
300
301 #[test]
302 fn can_decode_subsequent_messages() {
303 let msg_one = r#"{"op":"connection","connectionId":"002-051134157842-432409"}"#;
305 let msg_two = r#"{"op":"ocm","id":3,"clk":"AAAAAAAA","status":503,"pt":1498137379766,"ct":"HEARTBEAT"}"#;
306 let separator = "\r\n";
307 let data = format!("{msg_one}{separator}");
308
309 let mut codec = StreamAPIClientCodec;
310 let mut buf = bytes::BytesMut::from(data.as_bytes());
311 let msg_one = codec.decode(&mut buf).unwrap().unwrap();
312 let msg_two_attempt = codec.decode(&mut buf).unwrap();
313 assert!(msg_two_attempt.is_none());
314 let data = format!("{msg_two}{separator}");
315 buf.write_str(data.as_str()).unwrap();
316 let msg_two = codec.decode(&mut buf).unwrap().unwrap();
317
318 assert!(matches!(msg_one, ResponseMessage::Connection(_)));
319 assert!(matches!(msg_two, ResponseMessage::OrderChange(_)));
320 }
321
322 #[test]
323 fn can_encode_message() {
324 let msg = RequestMessage::Authentication(
325 betfair_stream_types::request::authentication_message::AuthenticationMessage {
326 id: Some(1),
327 session: "sss".to_owned(),
328 app_key: "aaaa".to_owned(),
329 },
330 );
331 let mut codec = StreamAPIClientCodec;
332 let mut buf = bytes::BytesMut::new();
333 codec.encode(msg, &mut buf).unwrap();
334
335 let data = buf.freeze();
336 let data = core::str::from_utf8(&data).unwrap();
337
338 assert!(data.ends_with("\r\n"));
340 assert!(data.starts_with("{\"op\":\"authentication\""));
342 }
343}