1use base64::{engine::general_purpose, Engine};
4use ethers_core::{
5 abi::AbiDecode,
6 types::{Bytes, U256},
7};
8use jsonwebtoken::{encode, errors::Error, get_current_timestamp, Algorithm, EncodingKey, Header};
9use serde::{
10 de::{self, MapAccess, Unexpected, Visitor},
11 Deserialize, Serialize,
12};
13use serde_json::{value::RawValue, Value};
14use std::fmt;
15use thiserror::Error;
16
17#[derive(Deserialize, Debug, Clone, Error)]
19pub struct JsonRpcError {
20 pub code: i64,
22 pub message: String,
24 pub data: Option<Value>,
26}
27
28fn spelunk_revert(value: &Value) -> Option<Bytes> {
33 match value {
34 Value::String(s) => s.parse().ok(),
35 Value::Object(o) => o.values().flat_map(spelunk_revert).next(),
36 _ => None,
37 }
38}
39
40impl JsonRpcError {
41 pub fn is_revert(&self) -> bool {
46 self.message.contains("revert")
48 }
49
50 pub fn as_revert_data(&self) -> Option<Bytes> {
62 self.is_revert().then(|| self.data.as_ref().and_then(spelunk_revert).unwrap_or_default())
63 }
64
65 pub fn decode_revert_data<E: AbiDecode>(&self) -> Option<E> {
67 E::decode(&self.as_revert_data()?).ok()
68 }
69}
70
71impl fmt::Display for JsonRpcError {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 write!(f, "(code: {}, message: {}, data: {:?})", self.code, self.message, self.data)
74 }
75}
76
77fn is_zst<T>(_t: &T) -> bool {
78 std::mem::size_of::<T>() == 0
79}
80
81#[derive(Serialize, Deserialize, Debug)]
82pub struct Request<'a, T> {
84 id: u64,
85 jsonrpc: &'a str,
86 method: &'a str,
87 #[serde(skip_serializing_if = "is_zst")]
88 params: T,
89}
90
91impl<'a, T> Request<'a, T> {
92 pub fn new(id: u64, method: &'a str, params: T) -> Self {
94 Self { id, jsonrpc: "2.0", method, params }
95 }
96}
97
98#[derive(Debug)]
100pub enum Response<'a> {
101 Success { id: u64, result: &'a RawValue },
102 Error { id: u64, error: JsonRpcError },
103 Notification { method: &'a str, params: Params<'a> },
104}
105
106#[derive(Deserialize, Debug)]
107pub struct Params<'a> {
108 pub subscription: U256,
109 #[serde(borrow)]
110 pub result: &'a RawValue,
111}
112
113impl<'de: 'a, 'a> Deserialize<'de> for Response<'a> {
116 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
117 where
118 D: serde::Deserializer<'de>,
119 {
120 #[allow(dead_code)]
121 struct ResponseVisitor<'a>(&'a ());
122 impl<'de: 'a, 'a> Visitor<'de> for ResponseVisitor<'a> {
123 type Value = Response<'a>;
124
125 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
126 formatter.write_str("a valid jsonrpc 2.0 response object")
127 }
128
129 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
130 where
131 A: MapAccess<'de>,
132 {
133 let mut jsonrpc = false;
134
135 let mut id = None;
137 let mut result = None;
139 let mut error = None;
141 let mut method = None;
143 let mut params = None;
144
145 while let Some(key) = map.next_key()? {
146 match key {
147 "jsonrpc" => {
148 if jsonrpc {
149 return Err(de::Error::duplicate_field("jsonrpc"))
150 }
151
152 let value = map.next_value()?;
153 if value != "2.0" {
154 return Err(de::Error::invalid_value(Unexpected::Str(value), &"2.0"))
155 }
156
157 jsonrpc = true;
158 }
159 "id" => {
160 if id.is_some() {
161 return Err(de::Error::duplicate_field("id"))
162 }
163
164 let value: u64 = map.next_value()?;
165 id = Some(value);
166 }
167 "result" => {
168 if result.is_some() {
169 return Err(de::Error::duplicate_field("result"))
170 }
171
172 let value: &RawValue = map.next_value()?;
173 result = Some(value);
174 }
175 "error" => {
176 if error.is_some() {
177 return Err(de::Error::duplicate_field("error"))
178 }
179
180 let value: JsonRpcError = map.next_value()?;
181 error = Some(value);
182 }
183 "method" => {
184 if method.is_some() {
185 return Err(de::Error::duplicate_field("method"))
186 }
187
188 let value: &str = map.next_value()?;
189 method = Some(value);
190 }
191 "params" => {
192 if params.is_some() {
193 return Err(de::Error::duplicate_field("params"))
194 }
195
196 let value: Params = map.next_value()?;
197 params = Some(value);
198 }
199 key => {
200 return Err(de::Error::unknown_field(
201 key,
202 &["id", "jsonrpc", "result", "error", "params", "method"],
203 ))
204 }
205 }
206 }
207
208 if !jsonrpc {
210 return Err(de::Error::missing_field("jsonrpc"))
211 }
212
213 match (id, result, error, method, params) {
214 (Some(id), Some(result), None, None, None) => {
215 Ok(Response::Success { id, result })
216 }
217 (Some(id), None, Some(error), None, None) => Ok(Response::Error { id, error }),
218 (None, None, None, Some(method), Some(params)) => {
219 Ok(Response::Notification { method, params })
220 }
221 _ => Err(de::Error::custom(
222 "response must be either a success/error or notification object",
223 )),
224 }
225 }
226 }
227
228 deserializer.deserialize_map(ResponseVisitor(&()))
229 }
230}
231
232#[derive(Clone, Debug)]
236pub enum Authorization {
237 Basic(String),
239 Bearer(String),
241 Raw(String),
243}
244
245impl Authorization {
246 pub fn basic(username: impl AsRef<str>, password: impl AsRef<str>) -> Self {
248 let username = username.as_ref();
249 let password = password.as_ref();
250 let auth_secret = general_purpose::STANDARD.encode(format!("{username}:{password}"));
251 Self::Basic(auth_secret)
252 }
253
254 pub fn bearer(token: impl Into<String>) -> Self {
256 Self::Bearer(token.into())
257 }
258
259 pub fn raw(token: impl Into<String>) -> Self {
261 Self::Raw(token.into())
262 }
263}
264
265impl fmt::Display for Authorization {
266 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267 match self {
268 Authorization::Basic(auth_secret) => write!(f, "Basic {auth_secret}"),
269 Authorization::Bearer(token) => write!(f, "Bearer {token}"),
270 Authorization::Raw(s) => write!(f, "{s}"),
271 }
272 }
273}
274
275const DEFAULT_ALGORITHM: Algorithm = Algorithm::HS256;
277
278pub const JWT_SECRET_LENGTH: usize = 32;
280
281pub struct JwtKey([u8; JWT_SECRET_LENGTH]);
283
284impl JwtKey {
285 pub fn from_slice(key: &[u8]) -> Result<Self, String> {
287 if key.len() != JWT_SECRET_LENGTH {
288 return Err(format!(
289 "Invalid key length. Expected {} got {}",
290 JWT_SECRET_LENGTH,
291 key.len()
292 ))
293 }
294 let mut res = [0; JWT_SECRET_LENGTH];
295 res.copy_from_slice(key);
296 Ok(Self(res))
297 }
298
299 pub fn from_hex(hex: &str) -> Result<Self, String> {
301 let bytes = hex::decode(hex).map_err(|e| format!("Invalid hex: {}", e))?;
302 Self::from_slice(&bytes)
303 }
304
305 pub fn as_bytes(&self) -> &[u8; JWT_SECRET_LENGTH] {
307 &self.0
308 }
309
310 pub fn into_bytes(self) -> [u8; JWT_SECRET_LENGTH] {
312 self.0
313 }
314}
315
316pub struct JwtAuth {
318 key: EncodingKey,
319 id: Option<String>,
320 clv: Option<String>,
321}
322
323impl JwtAuth {
324 pub fn new(secret: JwtKey, id: Option<String>, clv: Option<String>) -> Self {
326 Self { key: EncodingKey::from_secret(secret.as_bytes()), id, clv }
327 }
328
329 pub fn generate_token(&self) -> Result<String, Error> {
331 let claims = self.generate_claims_at_timestamp();
332 self.generate_token_with_claims(&claims)
333 }
334
335 fn generate_token_with_claims(&self, claims: &Claims) -> Result<String, Error> {
337 let header = Header::new(DEFAULT_ALGORITHM);
338 encode(&header, claims, &self.key)
339 }
340
341 fn generate_claims_at_timestamp(&self) -> Claims {
343 Claims { iat: get_current_timestamp(), id: self.id.clone(), clv: self.clv.clone() }
344 }
345
346 pub fn validate_token(
348 token: &str,
349 secret: &JwtKey,
350 ) -> Result<jsonwebtoken::TokenData<Claims>, Error> {
351 let mut validation = jsonwebtoken::Validation::new(DEFAULT_ALGORITHM);
352 validation.validate_exp = false;
353 validation.required_spec_claims.remove("exp");
354
355 jsonwebtoken::decode::<Claims>(
356 token,
357 &jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()),
358 &validation,
359 )
360 .map_err(Into::into)
361 }
362}
363
364#[derive(Debug, Serialize, Deserialize, PartialEq)]
366pub struct Claims {
367 iat: u64,
369 id: Option<String>,
371 clv: Option<String>,
373}
374
375#[cfg(test)]
376mod tests {
377 use ethers_core::types::U64;
378
379 use super::*;
380
381 #[test]
382 fn deser_response() {
383 let _ =
384 serde_json::from_str::<Response<'_>>(r#"{"jsonrpc":"2.0","result":19}"#).unwrap_err();
385 let _ = serde_json::from_str::<Response<'_>>(r#"{"jsonrpc":"3.0","result":19,"id":1}"#)
386 .unwrap_err();
387
388 let response: Response<'_> =
389 serde_json::from_str(r#"{"jsonrpc":"2.0","result":19,"id":1}"#).unwrap();
390
391 match response {
392 Response::Success { id, result } => {
393 assert_eq!(id, 1);
394 let result: u64 = serde_json::from_str(result.get()).unwrap();
395 assert_eq!(result, 19);
396 }
397 _ => panic!("expected `Success` response"),
398 }
399
400 let response: Response<'_> = serde_json::from_str(
401 r#"{"jsonrpc":"2.0","error":{"code":-32000,"message":"error occurred"},"id":2}"#,
402 )
403 .unwrap();
404
405 match response {
406 Response::Error { id, error } => {
407 assert_eq!(id, 2);
408 assert_eq!(error.code, -32000);
409 assert_eq!(error.message, "error occurred");
410 assert!(error.data.is_none());
411 }
412 _ => panic!("expected `Error` response"),
413 }
414
415 let response: Response<'_> =
416 serde_json::from_str(r#"{"jsonrpc":"2.0","result":"0xfa","id":0}"#).unwrap();
417
418 match response {
419 Response::Success { id, result } => {
420 assert_eq!(id, 0);
421 let result: U64 = serde_json::from_str(result.get()).unwrap();
422 assert_eq!(result.as_u64(), 250);
423 }
424 _ => panic!("expected `Success` response"),
425 }
426 }
427
428 #[test]
429 fn ser_request() {
430 let request: Request<()> = Request::new(0, "eth_chainId", ());
431 assert_eq!(
432 &serde_json::to_string(&request).unwrap(),
433 r#"{"id":0,"jsonrpc":"2.0","method":"eth_chainId"}"#
434 );
435
436 let request: Request<()> = Request::new(300, "method_name", ());
437 assert_eq!(
438 &serde_json::to_string(&request).unwrap(),
439 r#"{"id":300,"jsonrpc":"2.0","method":"method_name"}"#
440 );
441
442 let request: Request<u32> = Request::new(300, "method_name", 1);
443 assert_eq!(
444 &serde_json::to_string(&request).unwrap(),
445 r#"{"id":300,"jsonrpc":"2.0","method":"method_name","params":1}"#
446 );
447 }
448
449 #[test]
450 fn test_roundtrip() {
451 let jwt_secret = [42; 32];
452 let auth = JwtAuth::new(
453 JwtKey::from_slice(&jwt_secret).unwrap(),
454 Some("42".into()),
455 Some("Lighthouse".into()),
456 );
457 let claims = auth.generate_claims_at_timestamp();
458 let token = auth.generate_token_with_claims(&claims).unwrap();
459
460 assert_eq!(
461 JwtAuth::validate_token(&token, &JwtKey::from_slice(&jwt_secret).unwrap())
462 .unwrap()
463 .claims,
464 claims
465 );
466 }
467}