forest/rpc/
segregation_layer.rs1use super::ApiPaths;
5use crate::{for_each_rpc_method, rpc::reflect::RpcMethod};
6use ahash::{HashMap, HashSet};
7use futures::future::Either;
8use itertools::Itertools as _;
9use jsonrpsee::MethodResponse;
10use jsonrpsee::core::middleware::{Batch, BatchEntry, BatchEntryErr, Notification};
11use jsonrpsee::server::middleware::rpc::RpcServiceT;
12use jsonrpsee::types::error::{METHOD_NOT_FOUND_CODE, METHOD_NOT_FOUND_MSG};
13use jsonrpsee::types::{ErrorObject, Id};
14use std::sync::LazyLock;
15use tower::Layer;
16
17static VERSION_METHODS_MAPPINGS: LazyLock<HashMap<ApiPaths, HashSet<&'static str>>> =
18 LazyLock::new(|| {
19 let mut map = HashMap::default();
20 for version in [ApiPaths::V0, ApiPaths::V1, ApiPaths::V2] {
21 let mut supported = HashSet::default();
22
23 macro_rules! insert {
24 ($ty:ty) => {
25 if <$ty>::API_PATHS.contains(version) {
26 supported.insert(<$ty>::NAME);
27 if let Some(alias) = <$ty>::NAME_ALIAS {
28 supported.insert(alias);
29 }
30 }
31 };
32 }
33
34 for_each_rpc_method!(insert);
35
36 supported.insert(crate::rpc::chain::CHAIN_NOTIFY);
37 supported.insert(crate::rpc::channel::CANCEL_METHOD_NAME);
38
39 map.insert(version, supported);
40 }
41
42 map
43 });
44
45#[derive(Clone, Default)]
47pub(super) struct SegregationLayer;
48
49impl<S> Layer<S> for SegregationLayer {
50 type Service = SegregationService<S>;
51
52 fn layer(&self, service: S) -> Self::Service {
53 SegregationService { service }
54 }
55}
56
57#[derive(Clone)]
58pub(super) struct SegregationService<S> {
59 service: S,
60}
61
62impl<S> SegregationService<S> {
63 fn check<'a>(&self, path: Option<&ApiPaths>, method_name: &str) -> Result<(), ErrorObject<'a>> {
64 let supported = path
65 .and_then(|p| VERSION_METHODS_MAPPINGS.get(p))
66 .map(|set| set.contains(method_name))
67 .unwrap_or(false);
68 if supported {
69 Ok(())
70 } else {
71 Err(ErrorObject::borrowed(
72 METHOD_NOT_FOUND_CODE,
73 METHOD_NOT_FOUND_MSG,
74 None,
75 ))
76 }
77 }
78}
79
80impl<S> RpcServiceT for SegregationService<S>
81where
82 S: RpcServiceT<
83 MethodResponse = MethodResponse,
84 NotificationResponse = MethodResponse,
85 BatchResponse = MethodResponse,
86 > + Send
87 + Sync
88 + Clone
89 + 'static,
90{
91 type MethodResponse = S::MethodResponse;
92 type NotificationResponse = S::NotificationResponse;
93 type BatchResponse = S::BatchResponse;
94
95 fn call<'a>(
96 &self,
97 req: jsonrpsee::types::Request<'a>,
98 ) -> impl Future<Output = Self::MethodResponse> + Send + 'a {
99 match self.check(req.extensions().get::<ApiPaths>(), req.method_name()) {
100 Ok(()) => Either::Left(self.service.call(req)),
101 Err(e) => Either::Right(async move { MethodResponse::error(req.id(), e) }),
102 }
103 }
104
105 fn batch<'a>(&self, batch: Batch<'a>) -> impl Future<Output = Self::BatchResponse> + Send + 'a {
106 let entries = batch
107 .into_iter()
108 .filter_map(|entry| match entry {
109 Ok(BatchEntry::Call(req)) => Some(
110 match self.check(req.extensions().get::<ApiPaths>(), req.method_name()) {
111 Ok(()) => Ok(BatchEntry::Call(req)),
112 Err(e) => Err(BatchEntryErr::new(req.id(), e)),
113 },
114 ),
115 Ok(BatchEntry::Notification(n)) => {
116 match self.check(n.extensions().get::<ApiPaths>(), n.method_name()) {
117 Ok(_) => Some(Ok(BatchEntry::Notification(n))),
118 Err(_) => None,
119 }
120 }
121 Err(err) => Some(Err(err)),
122 })
123 .collect_vec();
124 self.service.batch(Batch::from(entries))
125 }
126
127 fn notification<'a>(
128 &self,
129 n: Notification<'a>,
130 ) -> impl Future<Output = Self::NotificationResponse> + Send + 'a {
131 match self.check(n.extensions().get::<ApiPaths>(), n.method_name()) {
132 Ok(()) => Either::Left(self.service.notification(n)),
133 Err(e) => Either::Right(async move { MethodResponse::error(Id::Null, e) }),
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_version_methods_mappings() {
144 assert!(!VERSION_METHODS_MAPPINGS.is_empty());
145 }
146}