Skip to main content

phoenix_chan/
builder.rs

1//! Configures a [`Client`]
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use async_tungstenite::tokio::connect_async_with_tls_connector_and_config;
7use base64::Engine;
8use rustls::ClientConfig;
9use tokio_rustls::TlsConnector;
10use tracing::trace;
11use tungstenite::ClientRequestBuilder;
12use tungstenite::http::Uri;
13use tungstenite::http::uri::PathAndQuery;
14use tungstenite::protocol::WebSocketConfig;
15
16use crate::{Client, Error};
17
18/// Authentication token prefix
19///
20/// See <https://github.com/phoenixframework/phoenix/blob/ad1a7ee2c9c29ff102b94242fdbb9cb14dd0dd4b/assets/js/phoenix/constants.js#L30>
21const AUTH_TOKEN_PREFIX: &str = "base64url.bearer.phx.";
22
23const BASE_64: base64::engine::GeneralPurpose = base64::prelude::BASE64_URL_SAFE_NO_PAD;
24
25const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
26const DEFAULT_HEARTBEAT: Duration = Duration::from_secs(DEFAULT_TIMEOUT.as_secs() / 2);
27
28/// Builder to configure a [`Client`]
29#[derive(Debug)]
30pub struct Builder {
31    client_req: ClientRequestBuilder,
32    ws_config: WebSocketConfig,
33    tls_config: Option<Arc<ClientConfig>>,
34    auth_token: Option<String>,
35    heartbeat: Duration,
36}
37
38impl Builder {
39    /// Returns a new instance with defaults set.
40    pub fn new(mut uri: Uri) -> Result<Self, Error> {
41        let has_vsn = uri
42            .query()
43            .is_some_and(|s| s.split('&').any(|s| s.starts_with("vsn=")));
44
45        if !has_vsn {
46            let pq = match uri.query() {
47                Some(query) if !query.is_empty() => {
48                    PathAndQuery::try_from(format!("{}?{query}&vsn=2.0.0", uri.path()))
49                        .map_err(Error::Uri)?
50                }
51                Some(_) | None => PathAndQuery::try_from(format!("{}?vsn=2.0.0", uri.path()))
52                    .map_err(Error::Uri)?,
53            };
54
55            uri = tungstenite::http::uri::Builder::from(uri)
56                .path_and_query(pq)
57                .build()
58                .map_err(Error::UriBuild)?;
59        }
60
61        let client_req = ClientRequestBuilder::new(uri.clone());
62
63        Ok(Self {
64            client_req,
65            ws_config: WebSocketConfig::default(),
66            tls_config: None,
67            auth_token: None,
68            // https://github.com/phoenixframework/phoenix/blob/ad1a7ee2c9c29ff102b94242fdbb9cb14dd0dd4b/assets/js/phoenix/constants.js#L6
69            heartbeat: DEFAULT_HEARTBEAT,
70        })
71    }
72
73    /// Configure the [`WebSocketConfig`]
74    #[must_use]
75    pub fn ws_config(mut self, ws_config: WebSocketConfig) -> Self {
76        self.ws_config = ws_config;
77
78        self
79    }
80
81    /// Add headers to the client connection request.
82    #[must_use]
83    pub fn add_header(mut self, key: String, value: String) -> Self {
84        self.client_req = self.client_req.with_header(key, value);
85
86        self
87    }
88
89    /// Add a sub-protocol header to the WebSocket connection.
90    #[must_use]
91    pub fn add_sub_protocol(mut self, key: String, value: String) -> Self {
92        self.client_req = self.client_req.with_header(key, value);
93
94        self
95    }
96
97    /// Set the authentication token to pass to the server.
98    #[must_use]
99    pub fn auth_token(mut self, token: &str) -> Self {
100        let encoded = BASE_64.encode(token);
101
102        self.auth_token = Some(format!("{AUTH_TOKEN_PREFIX}{encoded}"));
103
104        self.client_req = self.client_req.with_sub_protocol("phoenix");
105
106        self
107    }
108
109    /// Configure the [`WebSocketConfig`]
110    #[must_use]
111    pub fn tls_config(mut self, tls_config: Arc<ClientConfig>) -> Self {
112        self.tls_config = Some(tls_config);
113
114        self
115    }
116
117    /// Set the heart-bit interval duration.
118    #[must_use]
119    pub fn heartbeat(mut self, heartbeat: Duration) -> Self {
120        self.heartbeat = heartbeat;
121
122        self
123    }
124
125    /// Returns a configured client.
126    pub async fn connect(mut self) -> Result<Client, Error> {
127        if let Some(token) = self.auth_token {
128            self.client_req = self.client_req.with_sub_protocol(token);
129        }
130
131        let connector = self.tls_config.map(TlsConnector::from);
132
133        let (connection, resp) = connect_async_with_tls_connector_and_config(
134            self.client_req,
135            connector,
136            Some(self.ws_config),
137        )
138        .await
139        .map_err(Box::new)
140        .map_err(Error::Connect)?;
141
142        trace!(status = %resp.status(), headers = ?resp.headers());
143
144        Ok(Client::new(connection, self.heartbeat))
145    }
146}