forest/rpc/
segregation_layer.rs

1// Copyright 2019-2025 ChainSafe Systems
2// SPDX-License-Identifier: Apache-2.0, MIT
3
4use super::ApiPaths;
5use crate::{for_each_rpc_method, rpc::reflect::RpcMethod};
6use ahash::{HashMap, HashSet};
7use futures::future::Either;
8use itertools::Itertools as _;
9use jsonrpsee::MethodResponse;
10use jsonrpsee::core::middleware::{Batch, BatchEntry, BatchEntryErr, Notification};
11use jsonrpsee::server::middleware::rpc::RpcServiceT;
12use jsonrpsee::types::error::{METHOD_NOT_FOUND_CODE, METHOD_NOT_FOUND_MSG};
13use jsonrpsee::types::{ErrorObject, Id};
14use std::sync::LazyLock;
15use tower::Layer;
16
17static VERSION_METHODS_MAPPINGS: LazyLock<HashMap<ApiPaths, HashSet<&'static str>>> =
18    LazyLock::new(|| {
19        let mut map = HashMap::default();
20        for version in [ApiPaths::V0, ApiPaths::V1, ApiPaths::V2] {
21            let mut supported = HashSet::default();
22
23            macro_rules! insert {
24                ($ty:ty) => {
25                    if <$ty>::API_PATHS.contains(version) {
26                        supported.insert(<$ty>::NAME);
27                        if let Some(alias) = <$ty>::NAME_ALIAS {
28                            supported.insert(alias);
29                        }
30                    }
31                };
32            }
33
34            for_each_rpc_method!(insert);
35
36            supported.insert(crate::rpc::chain::CHAIN_NOTIFY);
37            supported.insert(crate::rpc::channel::CANCEL_METHOD_NAME);
38
39            map.insert(version, supported);
40        }
41
42        map
43    });
44
45/// JSON-RPC middleware layer for segregating RPC methods by the versions they support.
46#[derive(Clone, Default)]
47pub(super) struct SegregationLayer;
48
49impl<S> Layer<S> for SegregationLayer {
50    type Service = SegregationService<S>;
51
52    fn layer(&self, service: S) -> Self::Service {
53        SegregationService { service }
54    }
55}
56
57#[derive(Clone)]
58pub(super) struct SegregationService<S> {
59    service: S,
60}
61
62impl<S> SegregationService<S> {
63    fn check<'a>(&self, path: Option<&ApiPaths>, method_name: &str) -> Result<(), ErrorObject<'a>> {
64        let supported = path
65            .and_then(|p| VERSION_METHODS_MAPPINGS.get(p))
66            .map(|set| set.contains(method_name))
67            .unwrap_or(false);
68        if supported {
69            Ok(())
70        } else {
71            Err(ErrorObject::borrowed(
72                METHOD_NOT_FOUND_CODE,
73                METHOD_NOT_FOUND_MSG,
74                None,
75            ))
76        }
77    }
78}
79
80impl<S> RpcServiceT for SegregationService<S>
81where
82    S: RpcServiceT<
83            MethodResponse = MethodResponse,
84            NotificationResponse = MethodResponse,
85            BatchResponse = MethodResponse,
86        > + Send
87        + Sync
88        + Clone
89        + 'static,
90{
91    type MethodResponse = S::MethodResponse;
92    type NotificationResponse = S::NotificationResponse;
93    type BatchResponse = S::BatchResponse;
94
95    fn call<'a>(
96        &self,
97        req: jsonrpsee::types::Request<'a>,
98    ) -> impl Future<Output = Self::MethodResponse> + Send + 'a {
99        match self.check(req.extensions().get::<ApiPaths>(), req.method_name()) {
100            Ok(()) => Either::Left(self.service.call(req)),
101            Err(e) => Either::Right(async move { MethodResponse::error(req.id(), e) }),
102        }
103    }
104
105    fn batch<'a>(&self, batch: Batch<'a>) -> impl Future<Output = Self::BatchResponse> + Send + 'a {
106        let entries = batch
107            .into_iter()
108            .filter_map(|entry| match entry {
109                Ok(BatchEntry::Call(req)) => Some(
110                    match self.check(req.extensions().get::<ApiPaths>(), req.method_name()) {
111                        Ok(()) => Ok(BatchEntry::Call(req)),
112                        Err(e) => Err(BatchEntryErr::new(req.id(), e)),
113                    },
114                ),
115                Ok(BatchEntry::Notification(n)) => {
116                    match self.check(n.extensions().get::<ApiPaths>(), n.method_name()) {
117                        Ok(_) => Some(Ok(BatchEntry::Notification(n))),
118                        Err(_) => None,
119                    }
120                }
121                Err(err) => Some(Err(err)),
122            })
123            .collect_vec();
124        self.service.batch(Batch::from(entries))
125    }
126
127    fn notification<'a>(
128        &self,
129        n: Notification<'a>,
130    ) -> impl Future<Output = Self::NotificationResponse> + Send + 'a {
131        match self.check(n.extensions().get::<ApiPaths>(), n.method_name()) {
132            Ok(()) => Either::Left(self.service.notification(n)),
133            Err(e) => Either::Right(async move { MethodResponse::error(Id::Null, e) }),
134        }
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn test_version_methods_mappings() {
144        assert!(!VERSION_METHODS_MAPPINGS.is_empty());
145    }
146}