use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tower::{Layer, Service};
#[derive(Clone)]
pub struct JsonRpcLayer {
validate_version: bool,
require_id: bool,
}
impl JsonRpcLayer {
#[must_use]
pub fn new() -> Self {
Self {
validate_version: true,
require_id: false,
}
}
#[must_use]
pub fn validate_version(mut self, validate: bool) -> Self {
self.validate_version = validate;
self
}
#[must_use]
pub fn require_id(mut self, require: bool) -> Self {
self.require_id = require;
self
}
}
impl Default for JsonRpcLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> Layer<S> for JsonRpcLayer {
type Service = JsonRpcMiddleware<S>;
fn layer(&self, service: S) -> Self::Service {
JsonRpcMiddleware {
inner: service,
validate_version: self.validate_version,
require_id: self.require_id,
}
}
}
#[derive(Clone)]
pub struct JsonRpcMiddleware<S> {
inner: S,
validate_version: bool,
require_id: bool,
}
impl<S> JsonRpcMiddleware<S> {
pub fn new(service: S) -> Self {
Self {
inner: service,
validate_version: true,
require_id: false,
}
}
pub fn with_validation(service: S, validate_version: bool, require_id: bool) -> Self {
Self {
inner: service,
validate_version,
require_id,
}
}
}
impl<S, Request> Service<Request> for JsonRpcMiddleware<S>
where
S: Service<Request> + Clone + Send + 'static,
S::Future: Send,
Request: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
let mut inner = self.inner.clone();
let validate_version = self.validate_version;
let require_id = self.require_id;
Box::pin(async move {
let _ = validate_version;
let _ = require_id;
inner.call(req).await
})
}
}
pub struct JsonRpcMiddlewareBuilder {
validate_version: bool,
require_id: bool,
}
impl JsonRpcMiddlewareBuilder {
#[must_use]
pub fn new() -> Self {
Self {
validate_version: true,
require_id: false,
}
}
#[must_use]
pub fn validate_version(mut self, validate: bool) -> Self {
self.validate_version = validate;
self
}
#[must_use]
pub fn require_id(mut self, require: bool) -> Self {
self.require_id = require;
self
}
#[must_use]
pub fn build(self) -> JsonRpcLayer {
JsonRpcLayer {
validate_version: self.validate_version,
require_id: self.require_id,
}
}
}
impl Default for JsonRpcMiddlewareBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::task::Poll;
#[derive(Clone)]
struct MockService;
impl<Request> Service<Request> for MockService {
type Response = String;
type Error = std::io::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request) -> Self::Future {
Box::pin(async { Ok("mock response".to_string()) })
}
}
#[test]
fn test_jsonrpc_layer_new() {
let layer = JsonRpcLayer::new();
assert!(layer.validate_version);
assert!(!layer.require_id);
}
#[test]
fn test_jsonrpc_layer_default() {
let layer = JsonRpcLayer::default();
assert!(layer.validate_version);
assert!(!layer.require_id);
}
#[test]
fn test_jsonrpc_layer_validate_version() {
let layer = JsonRpcLayer::new().validate_version(false);
assert!(!layer.validate_version);
assert!(!layer.require_id);
}
#[test]
fn test_jsonrpc_layer_require_id() {
let layer = JsonRpcLayer::new().require_id(true);
assert!(layer.validate_version);
assert!(layer.require_id);
}
#[test]
fn test_jsonrpc_layer_builder_chain() {
let layer = JsonRpcLayer::new().validate_version(false).require_id(true);
assert!(!layer.validate_version);
assert!(layer.require_id);
}
#[test]
fn test_jsonrpc_layer_applies_to_service() {
let layer = JsonRpcLayer::new();
let service = MockService;
let _middleware = layer.layer(service);
}
#[test]
fn test_jsonrpc_middleware_new() {
let service = MockService;
let middleware = JsonRpcMiddleware::new(service);
assert!(middleware.validate_version);
assert!(!middleware.require_id);
}
#[test]
fn test_jsonrpc_middleware_with_validation() {
let service = MockService;
let middleware = JsonRpcMiddleware::with_validation(service, false, true);
assert!(!middleware.validate_version);
assert!(middleware.require_id);
}
#[tokio::test]
async fn test_jsonrpc_middleware_call() {
let service = MockService;
let mut middleware = JsonRpcMiddleware::new(service);
let response = middleware.call("test request").await.unwrap();
assert_eq!(response, "mock response");
}
#[tokio::test]
async fn test_jsonrpc_middleware_ready() {
let service = MockService;
let middleware = JsonRpcMiddleware::new(service);
let _ = middleware;
}
#[test]
fn test_jsonrpc_middleware_builder_new() {
let builder = JsonRpcMiddlewareBuilder::new();
assert!(builder.validate_version);
assert!(!builder.require_id);
}
#[test]
fn test_jsonrpc_middleware_builder_default() {
let builder = JsonRpcMiddlewareBuilder::default();
assert!(builder.validate_version);
assert!(!builder.require_id);
}
#[test]
fn test_jsonrpc_middleware_builder_validate_version() {
let builder = JsonRpcMiddlewareBuilder::new().validate_version(false);
assert!(!builder.validate_version);
}
#[test]
fn test_jsonrpc_middleware_builder_require_id() {
let builder = JsonRpcMiddlewareBuilder::new().require_id(true);
assert!(builder.require_id);
}
#[test]
fn test_jsonrpc_middleware_builder_chain() {
let builder = JsonRpcMiddlewareBuilder::new()
.validate_version(false)
.require_id(true);
assert!(!builder.validate_version);
assert!(builder.require_id);
}
#[test]
fn test_jsonrpc_middleware_builder_build() {
let builder = JsonRpcMiddlewareBuilder::new()
.validate_version(false)
.require_id(true);
let layer = builder.build();
assert!(!layer.validate_version);
assert!(layer.require_id);
}
#[test]
fn test_layer_and_middleware_integration() {
let layer = JsonRpcLayer::new().validate_version(false).require_id(true);
let service = MockService;
let middleware = layer.layer(service);
assert!(!middleware.validate_version);
assert!(middleware.require_id);
}
#[tokio::test]
async fn test_middleware_with_layer() {
let layer = JsonRpcLayer::new();
let service = MockService;
let mut middleware = layer.layer(service);
let response = middleware.call("test").await.unwrap();
assert_eq!(response, "mock response");
}
#[test]
fn test_middleware_clone() {
let service = MockService;
let middleware1 = JsonRpcMiddleware::new(service);
let middleware2 = middleware1.clone();
assert_eq!(middleware1.validate_version, middleware2.validate_version);
assert_eq!(middleware1.require_id, middleware2.require_id);
}
#[test]
fn test_layer_clone() {
let layer1 = JsonRpcLayer::new().validate_version(false);
let layer2 = layer1.clone();
assert_eq!(layer1.validate_version, layer2.validate_version);
assert_eq!(layer1.require_id, layer2.require_id);
}
}