reown_relay_client/
lib.rs

1pub use {
2    crate::error::{ClientError, RequestBuildError},
3    reown_relay_rpc as rpc,
4};
5use {
6    ::http::HeaderMap,
7    reown_relay_rpc::{
8        auth::{SerializedAuthToken, RELAY_WEBSOCKET_ADDRESS},
9        domain::{MessageId, ProjectId, SubscriptionId},
10        rpc::{SubscriptionError, SubscriptionResult},
11        user_agent::UserAgent,
12    },
13    serde::Serialize,
14    std::sync::{
15        atomic::{AtomicU8, Ordering},
16        Arc,
17    },
18    url::Url,
19};
20pub mod error;
21pub mod websocket;
22
23pub type HttpRequest<T> = ::http::Request<T>;
24
25/// Relay authorization method. A wrapper around [`SerializedAuthToken`].
26#[derive(Debug, Clone)]
27pub enum Authorization {
28    /// Uses query string to pass the auth token, e.g. `?auth=<token>`.
29    Query(SerializedAuthToken),
30
31    /// Uses the `Authorization: Bearer <token>` HTTP header.
32    Header(SerializedAuthToken),
33}
34
35/// Relay connection options.
36#[derive(Debug, Clone)]
37pub struct ConnectionOptions {
38    /// The Relay websocket address. The default address is
39    /// `wss://relay.walletconnect.com`.
40    pub address: String,
41
42    /// The project-specific secret key. Can be generated in the Cloud Dashboard
43    /// at the following URL: <https://cloud.walletconnect.com/app>
44    pub project_id: ProjectId,
45
46    /// The authorization method and auth token to use.
47    pub auth: Authorization,
48
49    /// Optional origin of the request. Subject to allow-list validation.
50    pub origin: Option<String>,
51
52    /// Optional user agent parameters.
53    pub user_agent: Option<UserAgent>,
54}
55
56impl ConnectionOptions {
57    pub fn new(project_id: impl Into<ProjectId>, auth: SerializedAuthToken) -> Self {
58        Self {
59            address: RELAY_WEBSOCKET_ADDRESS.into(),
60            project_id: project_id.into(),
61            auth: Authorization::Query(auth),
62            origin: None,
63            user_agent: None,
64        }
65    }
66
67    pub fn with_address(mut self, address: impl Into<String>) -> Self {
68        self.address = address.into();
69        self
70    }
71
72    pub fn with_origin(mut self, origin: impl Into<Option<String>>) -> Self {
73        self.origin = origin.into();
74        self
75    }
76
77    pub fn with_user_agent(mut self, user_agent: impl Into<Option<UserAgent>>) -> Self {
78        self.user_agent = user_agent.into();
79        self
80    }
81
82    pub fn as_url(&self) -> Result<Url, RequestBuildError> {
83        #[derive(Serialize)]
84        #[serde(rename_all = "camelCase")]
85        struct QueryParams<'a> {
86            project_id: &'a ProjectId,
87            auth: Option<&'a SerializedAuthToken>,
88            ua: Option<&'a UserAgent>,
89        }
90
91        let query = serde_qs::to_string(&QueryParams {
92            project_id: &self.project_id,
93            auth: if let Authorization::Query(auth) = &self.auth {
94                Some(auth)
95            } else {
96                None
97            },
98            ua: self.user_agent.as_ref(),
99        })
100        .map_err(RequestBuildError::Query)?;
101
102        let mut url = Url::parse(&self.address).map_err(RequestBuildError::Url)?;
103        url.set_query(Some(&query));
104
105        Ok(url)
106    }
107
108    #[cfg(not(target_arch = "wasm32"))]
109    fn as_ws_request(&self) -> Result<HttpRequest<()>, RequestBuildError> {
110        use {
111            crate::websocket::WebsocketClientError,
112            tokio_tungstenite::tungstenite::client::IntoClientRequest,
113        };
114
115        let url = self.as_url()?;
116
117        let mut request = url
118            .into_client_request()
119            .map_err(WebsocketClientError::Transport)?;
120
121        self.update_request_headers(request.headers_mut())?;
122
123        Ok(request)
124    }
125
126    #[cfg(target_arch = "wasm32")]
127    fn as_ws_request(&self) -> Result<HttpRequest<()>, RequestBuildError> {
128        use crate::websocket::WebsocketClientError;
129
130        let url = self.as_url()?;
131        let mut request = HttpRequest::builder()
132            .uri(format!("{}", url))
133            .body(())
134            .map_err(WebsocketClientError::HttpErr)?;
135
136        self.update_request_headers(request.headers_mut())?;
137        Ok(request)
138    }
139
140    fn update_request_headers(&self, headers: &mut HeaderMap) -> Result<(), RequestBuildError> {
141        if let Authorization::Header(token) = &self.auth {
142            let value = format!("Bearer {token}")
143                .parse()
144                .map_err(|_| RequestBuildError::Headers)?;
145
146            headers.append("Authorization", value);
147        }
148
149        if let Some(origin) = &self.origin {
150            let value = origin.parse().map_err(|_| RequestBuildError::Headers)?;
151
152            headers.append("Origin", value);
153        }
154
155        Ok(())
156    }
157}
158
159/// Generates unique message IDs for use in RPC requests. Uses 56 bits for the
160/// timestamp with millisecond precision, with the last 8 bits from a monotonic
161/// counter. Capable of producing up to `256000` unique values per second.
162#[derive(Debug, Clone)]
163pub struct MessageIdGenerator {
164    next: Arc<AtomicU8>,
165}
166
167impl MessageIdGenerator {
168    pub fn new() -> Self {
169        Self::default()
170    }
171
172    /// Generates a [`MessageId`].
173    pub fn next(&self) -> MessageId {
174        let next = self.next.fetch_add(1, Ordering::Relaxed) as u64;
175        let timestamp = chrono::Utc::now().timestamp_millis() as u64;
176        let id = timestamp << 8 | next;
177
178        MessageId::new(id)
179    }
180}
181
182impl Default for MessageIdGenerator {
183    fn default() -> Self {
184        Self {
185            next: Arc::new(AtomicU8::new(0)),
186        }
187    }
188}
189
190#[inline]
191fn convert_subscription_result(
192    res: SubscriptionResult,
193) -> Result<SubscriptionId, error::Error<SubscriptionError>> {
194    match res {
195        SubscriptionResult::Id(id) => Ok(id),
196        SubscriptionResult::Error(err) => Err(ClientError::from(err).into()),
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use {
203        super::*,
204        std::{collections::HashSet, hash::Hash},
205    };
206
207    fn elements_unique<T>(iter: T) -> bool
208    where
209        T: IntoIterator,
210        T::Item: Eq + Hash,
211    {
212        let mut set = HashSet::new();
213        iter.into_iter().all(move |x| set.insert(x))
214    }
215
216    #[test]
217    fn unique_message_ids() {
218        let gen = MessageIdGenerator::new();
219        // N.B. We can produce up to 256 unique values within 1ms.
220        let values = (0..256).map(move |_| gen.next()).collect::<Vec<_>>();
221        assert!(elements_unique(values));
222    }
223}