use super::router::{JsonRpcRequestOrBatch, JsonRpcRouter};
use crate::handler_trait::RequestData;
use crate::server::request_extraction::extract_headers;
use axum::{
body::Body,
extract::State,
http::{HeaderMap, Request, StatusCode, header},
response::{IntoResponse, Response as AxumResponse},
};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone)]
pub struct JsonRpcState {
pub router: Arc<JsonRpcRouter>,
}
pub async fn handle_jsonrpc(
State(state): State<Arc<JsonRpcState>>,
headers: HeaderMap,
uri: axum::http::Uri,
body: String,
) -> AxumResponse {
if !validate_content_type(&headers) {
return create_error_response(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
"Content-Type must be application/json",
);
}
let request = match JsonRpcRouter::parse_request(&body) {
Ok(req) => req,
Err(error_response) => {
let json = serde_json::to_string(&error_response).expect("Error serialization should never fail");
return create_jsonrpc_response(json);
}
};
let request_data = create_jsonrpc_request_data(&headers, &uri);
let http_request = Request::builder()
.method("POST")
.uri(uri.clone())
.body(Body::empty())
.unwrap_or_else(|_| Request::builder().method("POST").uri("/").body(Body::empty()).unwrap());
let response = match request {
JsonRpcRequestOrBatch::Single(req) => {
let response = state.router.route_single(req, http_request, &request_data).await;
serde_json::to_string(&response).expect("Response serialization should never fail")
}
JsonRpcRequestOrBatch::Batch(batch) => {
let http_request = Request::builder()
.method("POST")
.uri(uri.clone())
.body(Body::empty())
.unwrap_or_else(|_| Request::builder().method("POST").uri("/").body(Body::empty()).unwrap());
match state.router.route_batch(batch, http_request, &request_data).await {
Ok(responses) => {
serde_json::to_string(&responses).expect("Batch response serialization should never fail")
}
Err(error_response) => {
serde_json::to_string(&error_response).expect("Error serialization should never fail")
}
}
}
};
create_jsonrpc_response(response)
}
fn create_jsonrpc_request_data(headers: &HeaderMap, uri: &axum::http::Uri) -> RequestData {
RequestData {
path_params: Arc::new(HashMap::new()),
query_params: Arc::new(serde_json::json!({})),
validated_params: None,
raw_query_params: Arc::new(HashMap::new()),
body: Arc::new(serde_json::json!({})),
raw_body: None,
headers: Arc::new(extract_headers(headers)),
cookies: Arc::new(HashMap::new()),
method: "POST".to_string(),
path: uri.path().to_string(),
#[cfg(feature = "di")]
dependencies: None,
}
}
fn validate_content_type(headers: &HeaderMap) -> bool {
match headers.get(header::CONTENT_TYPE) {
None => true,
Some(value) => {
if let Ok(ct) = value.to_str() {
ct.to_lowercase().contains("application/json")
} else {
false
}
}
}
}
fn create_jsonrpc_response(json: String) -> AxumResponse {
(StatusCode::OK, [(header::CONTENT_TYPE, "application/json")], json).into_response()
}
fn create_error_response(status: StatusCode, message: &str) -> AxumResponse {
(
status,
[(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
message.to_string(),
)
.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::jsonrpc::{method_registry::JsonRpcMethodRegistry, router::JsonRpcRouter};
use serde_json::json;
fn create_test_state() -> Arc<JsonRpcState> {
let registry = Arc::new(JsonRpcMethodRegistry::new());
let router = Arc::new(JsonRpcRouter::new(registry, true, 100));
Arc::new(JsonRpcState { router })
}
fn create_json_headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
headers
}
fn create_wrong_content_type_headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "text/plain".parse().unwrap());
headers
}
fn create_empty_headers() -> HeaderMap {
HeaderMap::new()
}
fn create_test_uri() -> axum::http::Uri {
axum::http::Uri::from_static("/rpc")
}
#[tokio::test]
async fn test_handle_jsonrpc_single_request() {
let state = create_test_state();
let headers = create_json_headers();
let uri = create_test_uri();
let body = r#"{"jsonrpc":"2.0","method":"test.method","params":{},"id":1}"#.to_string();
let response = handle_jsonrpc(State(state), headers, uri, body).await;
assert_eq!(response.status(), StatusCode::OK);
let content_type = response.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap();
assert!(content_type.contains("application/json"));
}
#[tokio::test]
async fn test_handle_jsonrpc_batch_request() {
let state = create_test_state();
let headers = create_json_headers();
let uri = create_test_uri();
let body = r#"[
{"jsonrpc":"2.0","method":"test.method","params":{},"id":1},
{"jsonrpc":"2.0","method":"test.method","params":{},"id":2}
]"#
.to_string();
let response = handle_jsonrpc(State(state), headers, uri, body).await;
assert_eq!(response.status(), StatusCode::OK);
let content_type = response.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap();
assert!(content_type.contains("application/json"));
}
#[tokio::test]
async fn test_invalid_content_type() {
let state = create_test_state();
let headers = create_wrong_content_type_headers();
let uri = create_test_uri();
let body = r#"{"jsonrpc":"2.0","method":"test","id":1}"#.to_string();
let response = handle_jsonrpc(State(state), headers, uri, body).await;
assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
let content_type = response.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap();
assert!(content_type.contains("text/plain"));
}
#[tokio::test]
async fn test_missing_content_type_defaults_to_json() {
let state = create_test_state();
let headers = create_empty_headers();
let uri = create_test_uri();
let body = r#"{"jsonrpc":"2.0","method":"test.method","params":{},"id":1}"#.to_string();
let response = handle_jsonrpc(State(state), headers, uri, body).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_invalid_json_parse_error() {
let state = create_test_state();
let headers = create_json_headers();
let uri = create_test_uri();
let body = r#"{"invalid json"}"#.to_string();
let response = handle_jsonrpc(State(state), headers, uri, body).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_notification_in_batch() {
let state = create_test_state();
let headers = create_json_headers();
let uri = create_test_uri();
let body = r#"[
{"jsonrpc":"2.0","method":"test","params":{},"id":1},
{"jsonrpc":"2.0","method":"test","params":{}},
{"jsonrpc":"2.0","method":"test","params":{},"id":2}
]"#
.to_string();
let response = handle_jsonrpc(State(state), headers, uri, body).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[test]
fn test_validate_content_type_valid() {
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
assert!(validate_content_type(&headers));
}
#[test]
fn test_validate_content_type_valid_with_charset() {
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "application/json; charset=utf-8".parse().unwrap());
assert!(validate_content_type(&headers));
}
#[test]
fn test_validate_content_type_invalid() {
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "text/plain".parse().unwrap());
assert!(!validate_content_type(&headers));
}
#[test]
fn test_validate_content_type_missing() {
let headers = HeaderMap::new();
assert!(validate_content_type(&headers));
}
#[test]
fn test_create_jsonrpc_response() {
let json = r#"{"jsonrpc":"2.0","result":42,"id":1}"#.to_string();
let response = create_jsonrpc_response(json);
assert_eq!(response.status(), StatusCode::OK);
let content_type = response.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap();
assert_eq!(content_type, "application/json");
}
#[test]
fn test_create_error_response() {
let response = create_error_response(StatusCode::UNSUPPORTED_MEDIA_TYPE, "Invalid content type");
assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
let content_type = response.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap();
assert!(content_type.contains("text/plain"));
}
#[tokio::test]
async fn test_method_not_found_in_single_request() {
let state = create_test_state();
let headers = create_json_headers();
let uri = create_test_uri();
let body = r#"{"jsonrpc":"2.0","method":"nonexistent.method","params":{},"id":1}"#.to_string();
let response = handle_jsonrpc(State(state), headers, uri, body).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_batch_disabled() {
let registry = Arc::new(JsonRpcMethodRegistry::new());
let router = Arc::new(JsonRpcRouter::new(registry, false, 100));
let state = Arc::new(JsonRpcState { router });
let headers = create_json_headers();
let uri = create_test_uri();
let body = r#"[
{"jsonrpc":"2.0","method":"test","params":{},"id":1}
]"#
.to_string();
let response = handle_jsonrpc(State(state), headers, uri, body).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_batch_size_exceeded() {
let registry = Arc::new(JsonRpcMethodRegistry::new());
let router = Arc::new(JsonRpcRouter::new(registry, true, 2));
let state = Arc::new(JsonRpcState { router });
let headers = create_json_headers();
let uri = create_test_uri();
let body = r#"[
{"jsonrpc":"2.0","method":"test","params":{},"id":1},
{"jsonrpc":"2.0","method":"test","params":{},"id":2},
{"jsonrpc":"2.0","method":"test","params":{},"id":3}
]"#
.to_string();
let response = handle_jsonrpc(State(state), headers, uri, body).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_empty_batch() {
let state = create_test_state();
let headers = create_json_headers();
let uri = create_test_uri();
let body = r#"[]"#.to_string();
let response = handle_jsonrpc(State(state), headers, uri, body).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_response_with_params() {
let state = create_test_state();
let headers = create_json_headers();
let uri = create_test_uri();
let params = json!({"key": "value", "number": 42});
let body = serde_json::to_string(&json!({
"jsonrpc": "2.0",
"method": "test.method",
"params": params,
"id": 1
}))
.unwrap();
let response = handle_jsonrpc(State(state), headers, uri, body).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_content_type_case_insensitive() {
let state = create_test_state();
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "Application/JSON".parse().unwrap());
let uri = create_test_uri();
let body = r#"{"jsonrpc":"2.0","method":"test","id":1}"#.to_string();
let response = handle_jsonrpc(State(state), headers, uri, body).await;
assert_eq!(response.status(), StatusCode::OK);
}
}