1use std::sync::Arc;
4use std::time::Duration;
5
6use async_tungstenite::tokio::connect_async_with_tls_connector_and_config;
7use base64::Engine;
8use rustls::ClientConfig;
9use tokio_rustls::TlsConnector;
10use tracing::trace;
11use tungstenite::ClientRequestBuilder;
12use tungstenite::http::Uri;
13use tungstenite::http::uri::PathAndQuery;
14use tungstenite::protocol::WebSocketConfig;
15
16use crate::{Client, Error};
17
18const AUTH_TOKEN_PREFIX: &str = "base64url.bearer.phx.";
22
23const BASE_64: base64::engine::GeneralPurpose = base64::prelude::BASE64_URL_SAFE_NO_PAD;
24
25const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
26const DEFAULT_HEARTBEAT: Duration = Duration::from_secs(DEFAULT_TIMEOUT.as_secs() / 2);
27
28#[derive(Debug)]
30pub struct Builder {
31 client_req: ClientRequestBuilder,
32 ws_config: WebSocketConfig,
33 tls_config: Option<Arc<ClientConfig>>,
34 auth_token: Option<String>,
35 heartbeat: Duration,
36}
37
38impl Builder {
39 pub fn new(mut uri: Uri) -> Result<Self, Error> {
41 let has_vsn = uri
42 .query()
43 .is_some_and(|s| s.split('&').any(|s| s.starts_with("vsn=")));
44
45 if !has_vsn {
46 let pq = match uri.query() {
47 Some(query) if !query.is_empty() => {
48 PathAndQuery::try_from(format!("{}?{query}&vsn=2.0.0", uri.path()))
49 .map_err(Error::Uri)?
50 }
51 Some(_) | None => PathAndQuery::try_from(format!("{}?vsn=2.0.0", uri.path()))
52 .map_err(Error::Uri)?,
53 };
54
55 uri = tungstenite::http::uri::Builder::from(uri)
56 .path_and_query(pq)
57 .build()
58 .map_err(Error::UriBuild)?;
59 }
60
61 let client_req = ClientRequestBuilder::new(uri.clone());
62
63 Ok(Self {
64 client_req,
65 ws_config: WebSocketConfig::default(),
66 tls_config: None,
67 auth_token: None,
68 heartbeat: DEFAULT_HEARTBEAT,
70 })
71 }
72
73 #[must_use]
75 pub fn ws_config(mut self, ws_config: WebSocketConfig) -> Self {
76 self.ws_config = ws_config;
77
78 self
79 }
80
81 #[must_use]
83 pub fn add_header(mut self, key: String, value: String) -> Self {
84 self.client_req = self.client_req.with_header(key, value);
85
86 self
87 }
88
89 #[must_use]
91 pub fn add_sub_protocol(mut self, key: String, value: String) -> Self {
92 self.client_req = self.client_req.with_header(key, value);
93
94 self
95 }
96
97 #[must_use]
99 pub fn auth_token(mut self, token: &str) -> Self {
100 let encoded = BASE_64.encode(token);
101
102 self.auth_token = Some(format!("{AUTH_TOKEN_PREFIX}{encoded}"));
103
104 self.client_req = self.client_req.with_sub_protocol("phoenix");
105
106 self
107 }
108
109 #[must_use]
111 pub fn tls_config(mut self, tls_config: Arc<ClientConfig>) -> Self {
112 self.tls_config = Some(tls_config);
113
114 self
115 }
116
117 #[must_use]
119 pub fn heartbeat(mut self, heartbeat: Duration) -> Self {
120 self.heartbeat = heartbeat;
121
122 self
123 }
124
125 pub async fn connect(mut self) -> Result<Client, Error> {
127 if let Some(token) = self.auth_token {
128 self.client_req = self.client_req.with_sub_protocol(token);
129 }
130
131 let connector = self.tls_config.map(TlsConnector::from);
132
133 let (connection, resp) = connect_async_with_tls_connector_and_config(
134 self.client_req,
135 connector,
136 Some(self.ws_config),
137 )
138 .await
139 .map_err(Box::new)
140 .map_err(Error::Connect)?;
141
142 trace!(status = %resp.status(), headers = ?resp.headers());
143
144 Ok(Client::new(connection, self.heartbeat))
145 }
146}