use std::collections::{HashMap, HashSet};
use futures_util::TryStreamExt;
use http::{Request, Response, StatusCode};
use jsonrpsee::server::logger::Body;
use crate::KoraError;
pub fn default_sig_verify() -> bool {
false
}
pub async fn extract_parts_and_body_bytes(
request: Request<Body>,
) -> (http::request::Parts, Vec<u8>) {
let (parts, body) = request.into_parts();
let body_bytes = body
.try_fold(Vec::new(), |mut acc, chunk| async move {
acc.extend_from_slice(&chunk);
Ok(acc)
})
.await
.unwrap_or_default();
(parts, body_bytes)
}
pub fn get_jsonrpc_method(body_bytes: &[u8]) -> Option<String> {
match serde_json::from_slice::<serde_json::Value>(body_bytes) {
Ok(val) => val.get("method").and_then(|m| m.as_str()).map(|s| s.to_string()),
Err(_) => None,
}
}
pub fn verify_jsonrpc_method(
body_bytes: &[u8],
allowed_methods: &HashSet<String>,
) -> Result<String, KoraError> {
let method = get_jsonrpc_method(body_bytes);
if let Some(method) = method {
if allowed_methods.contains(&method) {
return Ok(method);
}
}
Err(KoraError::InvalidRequest("Method not allowed".to_string()))
}
pub fn build_response_with_graceful_error(
headers: Option<HashMap<String, String>>,
status_code: StatusCode,
error_message: &str,
) -> Response<Body> {
let mut builder = Response::builder();
if let Some(headers) = headers {
for (key, value) in headers.iter() {
builder = builder.header(key, value);
}
}
builder.status(status_code).body(Body::from(error_message.to_string())).unwrap_or_else(|e| {
log::error!("Failed to build response, error: {e:?}");
let mut response = Response::new(Body::empty());
*response.status_mut() = status_code;
response
})
}
#[derive(Clone)]
pub struct MethodValidationLayer {
allowed_methods: HashSet<String>,
}
impl MethodValidationLayer {
pub fn new(allowed_methods: Vec<String>) -> Self {
Self { allowed_methods: allowed_methods.into_iter().collect() }
}
}
#[derive(Clone)]
pub struct MethodValidationService<S> {
inner: S,
allowed_methods: HashSet<String>,
}
impl<S> tower::Layer<S> for MethodValidationLayer {
type Service = MethodValidationService<S>;
fn layer(&self, inner: S) -> Self::Service {
MethodValidationService { inner, allowed_methods: self.allowed_methods.clone() }
}
}
impl<S> tower::Service<Request<Body>> for MethodValidationService<S>
where
S: tower::Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let allowed_methods = self.allowed_methods.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let (parts, body_bytes) = extract_parts_and_body_bytes(request).await;
match verify_jsonrpc_method(&body_bytes, &allowed_methods) {
Ok(_) => {}
Err(_) => {
return Ok(build_response_with_graceful_error(
None,
StatusCode::METHOD_NOT_ALLOWED,
"",
));
}
}
let new_body = Body::from(body_bytes);
let new_request = Request::from_parts(parts, new_body);
inner.call(new_request).await
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::Method;
use std::{
future::Ready,
task::{Context, Poll},
};
use tower::{Layer, Service, ServiceExt};
#[derive(Clone)]
struct MockService;
impl tower::Service<Request<Body>> for MockService {
type Response = Response<Body>;
type Error = std::convert::Infallible;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: Request<Body>) -> Self::Future {
std::future::ready(Ok(Response::builder().status(200).body(Body::empty()).unwrap()))
}
}
#[tokio::test]
async fn test_method_validation_disallowed_method() {
let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
let layer = MethodValidationLayer::new(allowed_methods);
let mut service = layer.layer(MockService);
let body = r#"{"jsonrpc":"2.0","method":"unknownMethod","id":1}"#;
let request =
Request::builder().method(Method::POST).uri("/test").body(Body::from(body)).unwrap();
let response = service.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
}
#[tokio::test]
async fn test_method_validation_malformed_json() {
let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
let layer = MethodValidationLayer::new(allowed_methods);
let mut service = layer.layer(MockService);
let body = r#"{"invalid json"#;
let request =
Request::builder().method(Method::POST).uri("/test").body(Body::from(body)).unwrap();
let response = service.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
}
#[tokio::test]
async fn test_method_validation_missing_method_field() {
let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
let layer = MethodValidationLayer::new(allowed_methods);
let mut service = layer.layer(MockService);
let body = r#"{"jsonrpc":"2.0","id":1}"#;
let request =
Request::builder().method(Method::POST).uri("/test").body(Body::from(body)).unwrap();
let response = service.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
}
#[tokio::test]
async fn test_method_validation_multiple_allowed_methods() {
let allowed_methods = vec![
"liveness".to_string(),
"getConfig".to_string(),
"signTransaction".to_string(),
"estimateTransactionFee".to_string(),
];
let layer = MethodValidationLayer::new(allowed_methods);
let mut service = layer.layer(MockService);
for method in &["liveness", "getConfig", "signTransaction", "estimateTransactionFee"] {
let body = format!(r#"{{"jsonrpc":"2.0","method":"{}","id":1}}"#, method);
let request = Request::builder()
.method(Method::POST)
.uri("/test")
.body(Body::from(body))
.unwrap();
let response = service.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK, "Method {} should be allowed", method);
}
}
}