use std::fmt;
use base64::{engine::general_purpose, Engine};
use jsonwebtoken::{encode, errors::Error, get_current_timestamp, Algorithm, EncodingKey, Header};
use primitive_types::U256;
use serde::{
de::{self, MapAccess, Unexpected, Visitor},
Deserialize, Serialize,
};
use serde_json::{value::RawValue, Value};
use thiserror::Error;
use neo3::prelude::Bytes;
#[derive(Deserialize, Debug, Clone, Error, PartialEq)]
pub struct JsonRpcError {
pub code: i64,
pub message: String,
pub data: Option<Value>,
}
fn spelunk_revert(value: &Value) -> Option<Bytes> {
match value {
Value::String(s) => Some(s.as_bytes().to_vec()),
Value::Object(o) => o.values().flat_map(spelunk_revert).next(),
_ => None,
}
}
impl JsonRpcError {
pub fn is_revert(&self) -> bool {
self.message.contains("revert")
}
pub fn as_revert_data(&self) -> Option<Bytes> {
self.is_revert()
.then(|| self.data.as_ref().and_then(spelunk_revert).unwrap_or_default())
}
}
impl fmt::Display for JsonRpcError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "(code: {}, message: {}, data: {:?})", self.code, self.message, self.data)
}
}
fn is_zst<T>(_t: &T) -> bool {
std::mem::size_of::<T>() == 0
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Request<'a, T> {
id: u64,
jsonrpc: &'a str,
method: &'a str,
#[serde(skip_serializing_if = "is_zst")]
params: T,
}
impl<'a, T> Request<'a, T> {
pub fn new(id: u64, method: &'a str, params: T) -> Self {
Self { id, jsonrpc: "2.0", method, params }
}
}
#[derive(Debug)]
pub enum Response<'a> {
Success { id: u64, result: &'a RawValue },
Error { id: u64, error: JsonRpcError },
Notification { method: &'a str, params: Params<'a> },
}
#[derive(Deserialize, Debug)]
pub struct Params<'a> {
pub subscription: U256,
#[serde(borrow)]
pub result: &'a RawValue,
}
impl<'de: 'a, 'a> Deserialize<'de> for Response<'a> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct ResponseVisitor<'a>(std::marker::PhantomData<&'a ()>);
impl<'de: 'a, 'a> Visitor<'de> for ResponseVisitor<'a> {
type Value = Response<'a>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a valid jsonrpc 2.0 response object")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut jsonrpc = false;
let mut id = None;
let mut result = None;
let mut error = None;
let mut method = None;
let mut params = None;
while let Some(key) = map.next_key()? {
match key {
"jsonrpc" => {
if jsonrpc {
return Err(de::Error::duplicate_field("jsonrpc"));
}
let value = map.next_value()?;
if value != "2.0" {
return Err(de::Error::invalid_value(
Unexpected::Str(value),
&"2.0",
));
}
jsonrpc = true;
},
"id" => {
if id.is_some() {
return Err(de::Error::duplicate_field("id"));
}
let value: u64 = map.next_value()?;
id = Some(value);
},
"result" => {
if result.is_some() {
return Err(de::Error::duplicate_field("result"));
}
let value: &RawValue = map.next_value()?;
result = Some(value);
},
"error" => {
if error.is_some() {
return Err(de::Error::duplicate_field("error"));
}
let value: JsonRpcError = map.next_value()?;
error = Some(value);
},
"method" => {
if method.is_some() {
return Err(de::Error::duplicate_field("method"));
}
let value: &str = map.next_value()?;
method = Some(value);
},
"params" => {
if params.is_some() {
return Err(de::Error::duplicate_field("params"));
}
let value: Params = map.next_value()?;
params = Some(value);
},
key => {
return Err(de::Error::unknown_field(
key,
&["id", "jsonrpc", "result", "error", "params", "method"],
))
},
}
}
if !jsonrpc {
return Err(de::Error::missing_field("jsonrpc"));
}
match (id, result, error, method, params) {
(Some(id), Some(result), None, None, None) => {
Ok(Response::Success { id, result })
},
(Some(id), None, Some(error), None, None) => Ok(Response::Error { id, error }),
(None, None, None, Some(method), Some(params)) => {
Ok(Response::Notification { method, params })
},
_ => Err(de::Error::custom(
"response must be either a success/error or notification object",
)),
}
}
}
deserializer.deserialize_map(ResponseVisitor(std::marker::PhantomData))
}
}
#[derive(Clone)]
pub enum Authorization {
Basic(String),
Bearer(String),
Raw(String),
}
impl fmt::Debug for Authorization {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Authorization::Basic(_) => {
f.debug_tuple("Authorization::Basic").field(&"<redacted>").finish()
},
Authorization::Bearer(_) => {
f.debug_tuple("Authorization::Bearer").field(&"<redacted>").finish()
},
Authorization::Raw(_) => {
f.debug_tuple("Authorization::Raw").field(&"<redacted>").finish()
},
}
}
}
impl Authorization {
pub fn basic(username: impl AsRef<str>, password: impl AsRef<str>) -> Self {
let username = username.as_ref();
let password = password.as_ref();
let auth_secret = general_purpose::STANDARD.encode(format!("{username}:{password}"));
Self::Basic(auth_secret)
}
pub fn bearer(token: impl Into<String>) -> Self {
Self::Bearer(token.into())
}
pub fn raw(token: impl Into<String>) -> Self {
Self::Raw(token.into())
}
}
impl fmt::Display for Authorization {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Authorization::Basic(auth_secret) => write!(f, "Basic {auth_secret}"),
Authorization::Bearer(token) => write!(f, "Bearer {token}"),
Authorization::Raw(s) => write!(f, "{s}"),
}
}
}
const DEFAULT_ALGORITHM: Algorithm = Algorithm::HS256;
pub const JWT_SECRET_LENGTH: usize = 32;
pub struct JwtKey([u8; JWT_SECRET_LENGTH]);
impl JwtKey {
pub fn from_slice(key: &[u8]) -> Result<Self, String> {
if key.len() != JWT_SECRET_LENGTH {
return Err(format!(
"Invalid key length. Expected {} got {}",
JWT_SECRET_LENGTH,
key.len()
));
}
let mut res = [0; JWT_SECRET_LENGTH];
res.copy_from_slice(key);
Ok(Self(res))
}
pub fn from_hex(hex: &str) -> Result<Self, String> {
let bytes = hex::decode(hex).map_err(|e| format!("Invalid hex: {}", e))?;
Self::from_slice(&bytes)
}
pub fn as_bytes(&self) -> &[u8; JWT_SECRET_LENGTH] {
&self.0
}
pub fn into_bytes(self) -> [u8; JWT_SECRET_LENGTH] {
self.0
}
}
pub struct JwtAuth {
key: EncodingKey,
id: Option<String>,
clv: Option<String>,
}
impl JwtAuth {
pub fn new(secret: JwtKey, id: Option<String>, clv: Option<String>) -> Self {
Self { key: EncodingKey::from_secret(secret.as_bytes()), id, clv }
}
pub fn generate_token(&self) -> Result<String, Error> {
let claims = self.generate_claims_at_timestamp();
self.generate_token_with_claims(&claims)
}
fn generate_token_with_claims(&self, claims: &Claims) -> Result<String, Error> {
let header = Header::new(DEFAULT_ALGORITHM);
encode(&header, claims, &self.key)
}
fn generate_claims_at_timestamp(&self) -> Claims {
Claims { iat: get_current_timestamp(), id: self.id.clone(), clv: self.clv.clone() }
}
pub fn validate_token(
token: &str,
secret: &JwtKey,
) -> Result<jsonwebtoken::TokenData<Claims>, Error> {
let mut validation = jsonwebtoken::Validation::new(DEFAULT_ALGORITHM);
validation.validate_exp = false;
validation.required_spec_claims.remove("exp");
jsonwebtoken::decode::<Claims>(
token,
&jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()),
&validation,
)
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct Claims {
iat: u64,
id: Option<String>,
clv: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deser_response() {
let _ =
serde_json::from_str::<Response<'_>>(r#"{"jsonrpc":"2.0","result":19}"#).unwrap_err();
let _ = serde_json::from_str::<Response<'_>>(r#"{"jsonrpc":"3.0","result":19,"id":1}"#)
.unwrap_err();
let response: Response<'_> =
serde_json::from_str(r#"{"jsonrpc":"2.0","result":19,"id":1}"#).unwrap();
match response {
Response::Success { id, result } => {
assert_eq!(id, 1);
let result: u64 = serde_json::from_str(result.get()).unwrap();
assert_eq!(result, 19);
},
_ => panic!("Expected Success response but got: {:?}", response),
}
let response: Response<'_> = serde_json::from_str(
r#"{"jsonrpc":"2.0","error":{"code":-32000,"message":"error occurred"},"id":2}"#,
)
.unwrap();
match response {
Response::Error { id, error } => {
assert_eq!(id, 2);
assert_eq!(error.code, -32000);
assert_eq!(error.message, "error occurred");
assert!(error.data.is_none());
},
_ => panic!("Expected Error response but got: {:?}", response),
}
let response: Response<'_> =
serde_json::from_str(r#"{"jsonrpc":"2.0","result":"0xfa","id":0}"#).unwrap();
match response {
Response::Success { id, result } => {
assert_eq!(id, 0);
let result: String = serde_json::from_str(result.get()).unwrap();
assert_eq!(i64::from_str_radix(result.trim_start_matches("0x"), 16).unwrap(), 250);
},
_ => panic!("Expected Success response but got: {:?}", response),
}
}
#[test]
fn ser_request() {
let request: Request<()> = Request::new(0, "neo_chainId", ());
assert_eq!(
&serde_json::to_string(&request).unwrap(),
r#"{"id":0,"jsonrpc":"2.0","method":"neo_chainId"}"#
);
let request: Request<()> = Request::new(300, "method_name", ());
assert_eq!(
&serde_json::to_string(&request).unwrap(),
r#"{"id":300,"jsonrpc":"2.0","method":"method_name"}"#
);
let request: Request<u32> = Request::new(300, "method_name", 1);
assert_eq!(
&serde_json::to_string(&request).unwrap(),
r#"{"id":300,"jsonrpc":"2.0","method":"method_name","params":1}"#
);
}
#[test]
fn test_roundtrip() {
let jwt_secret = [42; 32];
let auth = JwtAuth::new(
JwtKey::from_slice(&jwt_secret).unwrap(),
Some("42".into()),
Some("Lighthouse".into()),
);
let claims = auth.generate_claims_at_timestamp();
let token = auth.generate_token_with_claims(&claims).unwrap();
assert_eq!(
JwtAuth::validate_token(&token, &JwtKey::from_slice(&jwt_secret).unwrap())
.unwrap()
.claims,
claims
);
}
}