async_jsonrpc_client/ws_client/
builder.rs

1use std::{fmt, time::Duration};
2
3use async_tungstenite::tungstenite::handshake::client::Request as HandShakeRequest;
4use futures::channel::mpsc;
5use http::header::{self, HeaderMap, HeaderName, HeaderValue};
6
7use crate::{
8    error::WsError,
9    ws_client::{task::WsTask, WsClient},
10};
11
12/// A `WsClientBuilder` can be used to create a `HttpClient` with  custom configuration.
13#[derive(Debug)]
14pub struct WsClientBuilder {
15    headers: HeaderMap,
16    timeout: Option<Duration>,
17    max_concurrent_request_capacity: usize,
18    max_capacity_per_subscription: usize,
19}
20
21impl Default for WsClientBuilder {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl WsClientBuilder {
28    /// Creates a new `WsClientBuilder`.
29    ///
30    /// This is the same as `WsClient::builder()`.
31    pub fn new() -> Self {
32        Self {
33            headers: HeaderMap::new(),
34            timeout: None,
35            max_concurrent_request_capacity: 256,
36            max_capacity_per_subscription: 64,
37        }
38    }
39
40    // ========================================================================
41    // HTTP header options
42    // ========================================================================
43
44    /// Enables basic authentication.
45    pub fn basic_auth<U, P>(self, username: U, password: Option<P>) -> Self
46    where
47        U: fmt::Display,
48        P: fmt::Display,
49    {
50        let mut basic_auth = "Basic ".to_string();
51        let auth = if let Some(password) = password {
52            base64::encode(format!("{}:{}", username, password))
53        } else {
54            base64::encode(format!("{}:", username))
55        };
56        basic_auth.push_str(&auth);
57        let value = HeaderValue::from_str(&basic_auth).expect("basic auth header value");
58        self.header(header::AUTHORIZATION, value)
59    }
60
61    /// Enables bearer authentication.
62    pub fn bearer_auth<T>(self, token: T) -> Self
63    where
64        T: fmt::Display,
65    {
66        let bearer_auth = format!("Bearer {}", token);
67        let value = HeaderValue::from_str(&bearer_auth).expect("bearer auth header value");
68        self.header(header::AUTHORIZATION, value)
69    }
70
71    /// Adds a `Header` for handshake request.
72    pub fn header(mut self, name: HeaderName, value: HeaderValue) -> Self {
73        self.headers.insert(name, value);
74        self
75    }
76
77    /// Adds `Header`s for handshake request.
78    pub fn headers(mut self, headers: HeaderMap) -> Self {
79        self.headers.extend(headers);
80        self
81    }
82
83    // ========================================================================
84    // Channel options
85    // ========================================================================
86
87    /// Sets the max channel capacity of sending request concurrently.
88    ///
89    /// Default is 256.
90    pub fn max_concurrent_request_capacity(mut self, capacity: usize) -> Self {
91        self.max_concurrent_request_capacity = capacity;
92        self
93    }
94
95    /// Sets the max channel capacity of every subscription stream.
96    ///
97    /// Default is 64.
98    pub fn max_capacity_per_subscription(mut self, capacity: usize) -> Self {
99        self.max_capacity_per_subscription = capacity;
100        self
101    }
102
103    // ========================================================================
104    // Timeout options
105    // ========================================================================
106
107    /// Enables a request timeout.
108    ///
109    /// The timeout is applied from when the request starts connecting until the
110    /// response body has finished.
111    ///
112    /// Default is no timeout.
113    pub fn timeout(mut self, timeout: Duration) -> Self {
114        self.timeout = Some(timeout);
115        self
116    }
117
118    // ========================================================================
119
120    /// Returns a `WsClient` that uses this `WsClientBuilder` configuration.
121    pub async fn build(self, url: impl Into<String>) -> Result<WsClient, WsError> {
122        let url = url.into();
123        let mut handshake_builder = HandShakeRequest::get(&url);
124        let headers = handshake_builder.headers_mut().expect("handshake request just created");
125        headers.extend(self.headers);
126        let handshake_req = handshake_builder.body(()).map_err(WsError::HttpFormat)?;
127
128        let (to_back, from_front) = mpsc::channel(self.max_concurrent_request_capacity);
129        log::debug!("Connecting '{}' ...", url);
130        let task = WsTask::handshake(handshake_req, self.max_capacity_per_subscription).await?;
131        log::debug!("Connect '{}' successfully", url);
132        #[cfg(feature = "ws-async-std")]
133        let _handle = async_std::task::spawn(task.into_task(from_front));
134        #[cfg(feature = "ws-tokio")]
135        let _handle = tokio::spawn(task.into_task(from_front));
136
137        Ok(WsClient {
138            to_back,
139            timeout: self.timeout,
140        })
141    }
142}