1use std::sync::Arc;
4use std::time::Duration;
5
6use async_tungstenite::tokio::connect_async_with_tls_connector_and_config;
7use base64::Engine;
8use rustls::ClientConfig;
9use tokio_rustls::TlsConnector;
10use tracing::{instrument, trace};
11use tungstenite::http::Uri;
12use tungstenite::protocol::WebSocketConfig;
13use tungstenite::ClientRequestBuilder;
14
15use crate::{Client, Error};
16
17const AUTH_TOKEN_PREFIX: &str = "base64url.bearer.phx.";
21
22const BASE_64: base64::engine::GeneralPurpose = base64::prelude::BASE64_URL_SAFE_NO_PAD;
23
24const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
25const DEFAULT_HEARTBEAT: Duration = Duration::from_secs(DEFAULT_TIMEOUT.as_secs() / 2);
26
27#[derive(Debug)]
29pub struct Builder {
30 uri: Uri,
31 client_req: ClientRequestBuilder,
32 ws_config: WebSocketConfig,
33 tls_config: Option<Arc<ClientConfig>>,
34 auth_token: Option<String>,
35 heartbeat: Duration,
36}
37
38impl Builder {
39 #[must_use]
41 pub fn new(uri: Uri) -> Self {
42 let client_req = ClientRequestBuilder::new(uri.clone());
43
44 Self {
45 uri,
46 client_req,
47 ws_config: WebSocketConfig::default(),
48 tls_config: None,
49 auth_token: None,
50 heartbeat: DEFAULT_HEARTBEAT,
52 }
53 }
54
55 #[must_use]
57 pub fn ws_config(mut self, ws_config: WebSocketConfig) -> Self {
58 self.ws_config = ws_config;
59
60 self
61 }
62
63 #[must_use]
65 pub fn add_header(mut self, key: String, value: String) -> Self {
66 self.client_req = self.client_req.with_header(key, value);
67
68 self
69 }
70
71 #[must_use]
73 pub fn add_sub_protocol(mut self, key: String, value: String) -> Self {
74 self.client_req = self.client_req.with_header(key, value);
75
76 self
77 }
78
79 #[must_use]
81 pub fn auth_token(mut self, token: &str) -> Self {
82 let encoded = BASE_64.encode(token);
83
84 self.auth_token = Some(format!("{AUTH_TOKEN_PREFIX}{encoded}"));
85
86 self.client_req = self.client_req.with_sub_protocol("phoenix");
87
88 self
89 }
90
91 #[must_use]
93 pub fn tls_config(mut self, tls_config: Arc<ClientConfig>) -> Self {
94 self.tls_config = Some(tls_config);
95
96 self
97 }
98
99 #[must_use]
101 pub fn heartbeat(mut self, heartbeat: Duration) -> Self {
102 self.heartbeat = heartbeat;
103
104 self
105 }
106
107 #[must_use]
109 #[instrument(skip(self), fields(uri = %self.uri))]
110 pub async fn connect(mut self) -> Result<Client, Error> {
111 if let Some(token) = self.auth_token {
112 self.client_req = self.client_req.with_sub_protocol(token);
113 }
114
115 let connector = self.tls_config.map(TlsConnector::from);
116
117 let (connection, resp) = connect_async_with_tls_connector_and_config(
118 self.client_req,
119 connector,
120 Some(self.ws_config),
121 )
122 .await
123 .map_err(Error::Connect)?;
124
125 trace!(status = %resp.status(), headers = ?resp.headers());
126
127 Ok(Client::new(connection, self.heartbeat))
128 }
129}