1use url::Url;
5
6use tokio::net::TcpStream;
7
8use tracing::debug;
9use tracing::span;
10use tracing::trace;
11use tracing::Level;
12use tracing_futures::Instrument;
13
14use tungstenite::connect_async;
15use tungstenite::MaybeTlsStream;
16use tungstenite::WebSocketStream;
17
18use websocket_util::wrap::Wrapper;
19
20use crate::Error;
21
22
23#[derive(Debug)]
26#[doc(hidden)]
27pub enum MessageResult<T, E> {
28 Ok(T),
30 Err(E),
32}
33
34impl<T, E> From<Result<T, E>> for MessageResult<T, E> {
35 #[inline]
36 fn from(result: Result<T, E>) -> Self {
37 match result {
38 Ok(t) => Self::Ok(t),
39 Err(e) => Self::Err(e),
40 }
41 }
42}
43
44
45async fn connect_internal(url: &Url) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, Error> {
47 let span = span!(Level::DEBUG, "stream");
48
49 async move {
50 debug!(message = "connecting", url = display(url));
51
52 let (stream, response) = connect_async(url).await?;
56 debug!("connection successful");
57 trace!(response = debug(&response));
58
59 Ok(stream)
60 }
61 .instrument(span)
62 .await
63}
64
65
66pub(crate) async fn connect(
68 url: &Url,
69) -> Result<Wrapper<WebSocketStream<MaybeTlsStream<TcpStream>>>, Error> {
70 connect_internal(url)
71 .await
72 .map(|stream| Wrapper::builder().build(stream))
73}
74
75
76#[cfg(test)]
77pub(crate) mod test {
78 use super::*;
79
80 use std::future::Future;
81
82 use websocket_util::test::mock_server;
83 use websocket_util::test::WebSocketStream;
84 use websocket_util::tungstenite::Error as WebSocketError;
85
86 use crate::subscribable::Subscribable;
87 use crate::ApiInfo;
88
89
90 pub(crate) const KEY_ID: &str = "USER12345678";
92 pub(crate) const SECRET: &str = "justletmein";
94
95
96 pub(crate) async fn mock_stream<S, F, R>(f: F) -> Result<(S::Stream, S::Subscription), Error>
100 where
101 S: Subscribable<Input = ApiInfo>,
102 F: FnOnce(WebSocketStream) -> R + Send + Sync + 'static,
103 R: Future<Output = Result<(), WebSocketError>> + Send + Sync + 'static,
104 {
105 let addr = mock_server(f).await;
106 let stream_url = Url::parse(&format!("ws://{addr}")).unwrap();
107
108 let api_info = ApiInfo {
112 api_base_url: Url::parse("http://example.com").unwrap(),
113 api_stream_url: stream_url.clone(),
114 data_base_url: Url::parse("http://example.com").unwrap(),
115 data_stream_base_url: stream_url.clone(),
116 key_id: KEY_ID.to_string(),
117 secret: SECRET.to_string(),
118 };
119
120 S::connect(&api_info).await
121 }
122}