rustic_jsonrpc/
registry.rs

1use std::any::type_name;
2use std::collections::HashMap;
3use std::error::Error as StdError;
4use std::fmt::{Display, Formatter};
5use std::future::Future;
6use std::ops::{Deref, DerefMut};
7use std::pin::Pin;
8
9use serde::Serialize;
10use serde_json::{from_slice, Value};
11
12use crate::{Container, Error, Id, Request, Response, METHOD_NOT_FOUND, PARSE_ERROR, SERVER_ERROR};
13
14pub type BoxError = Box<dyn StdError + Send + Sync>;
15
16/// Type alias for an RPC handler function.
17pub type Handler = for<'a> fn(
18    &'a Container,
19    &'a str,
20)
21    -> Pin<Box<dyn Future<Output = Result<Value, BoxError>> + Send + 'a>>;
22
23/// The `Registry` struct maintains a collection of RPC methods and a dependency injection container.
24pub struct Registry {
25    container: Container,
26    methods: HashMap<&'static str, Handler>,
27    post_call: Option<
28        Box<
29            dyn for<'a> Fn(
30                    &'a Request<'a>,
31                    &'a Result<Value, BoxError>,
32                ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
33                + Send
34                + Sync,
35        >,
36    >,
37}
38
39impl Registry {
40    /// Creates a new, empty `Registry`.
41    pub fn new() -> Self {
42        Self {
43            container: Container::new(),
44            methods: HashMap::new(),
45            post_call: None,
46        }
47    }
48
49    /// Registers a list of methods with the `Registry`.
50    pub fn register(&mut self, methods: &[Method]) {
51        for method in methods {
52            assert!(
53                self.methods.insert(method.name, method.handler).is_none(),
54                "method `{}` exists",
55                method.name,
56            );
57        }
58    }
59
60    /// Sets a function to be called after each RPC call.
61    pub fn post_call(
62        &mut self,
63        func: impl for<'a> Fn(
64                &'a Request<'a>,
65                &'a Result<Value, BoxError>,
66            ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
67            + Send
68            + Sync
69            + 'static,
70    ) {
71        self.post_call = Some(Box::new(func))
72    }
73
74    /// Handles an incoming RPC request.
75    pub async fn handle<'a>(&self, request: &'a [u8]) -> Option<Amount<Response<'a>>> {
76        if is_object(request) {
77            let response = match from_slice::<Request>(request) {
78                Ok(v) => self.invoke(&v).await?,
79                Err(e) => Response::error(Error::new(PARSE_ERROR, e, None), Id::Null),
80            };
81            return Some(Amount::One(response));
82        }
83
84        match from_slice::<Vec<Request>>(request) {
85            Ok(batch) => {
86                let mut response = Vec::with_capacity(batch.len());
87                for r in &batch {
88                    if let Some(v) = self.invoke(r).await {
89                        response.push(v);
90                    }
91                }
92                (!response.is_empty()).then_some(Amount::Batch(response))
93            }
94            Err(e) => Some(Amount::One(Response::error(
95                Error::new(PARSE_ERROR, e, None),
96                Id::Null,
97            ))),
98        }
99    }
100
101    async fn invoke<'a>(&self, req: &Request<'a>) -> Option<Response<'a>> {
102        let handler = match self.methods.get(req.method) {
103            Some(handler) => handler,
104            None if matches!(req.id, Id::None) => return None,
105            None => {
106                let err = Error::new(
107                    METHOD_NOT_FOUND,
108                    format!("method `{}` not found", req.method),
109                    None,
110                );
111                return Some(Response::error(err, req.id));
112            }
113        };
114
115        let params = req.params.map(|v| v.get()).unwrap_or("{}");
116        let result = handler(&self.container, params).await;
117        if let Some(ref f) = self.post_call {
118            f(req, &result).await;
119        }
120
121        if matches!(req.id, Id::None) {
122            return None;
123        }
124
125        match result {
126            Ok(v) => Some(Response::result(v, req.id)),
127            Err(e) => match Error::cast(&*e) {
128                Some(e) => Some(Response::error(e.clone(), req.id)),
129                None => {
130                    let e = Error::new(SERVER_ERROR, "server error", None);
131                    Some(Response::error(e, req.id))
132                }
133            },
134        }
135    }
136
137    /// Returns a list of registered method names.
138    pub fn methods(&self) -> Vec<&'static str> {
139        let mut methods = self.methods.keys().map(|v| *v).collect::<Vec<_>>();
140        methods.sort();
141        methods
142    }
143}
144
145impl Deref for Registry {
146    type Target = Container;
147
148    fn deref(&self) -> &Self::Target {
149        &self.container
150    }
151}
152
153impl DerefMut for Registry {
154    fn deref_mut(&mut self) -> &mut Self::Target {
155        &mut self.container
156    }
157}
158
159/// Checks if the provided byte slice represents a JSON object.
160fn is_object(s: &[u8]) -> bool {
161    for v in s {
162        if v.is_ascii_whitespace() {
163            continue;
164        }
165        return *v == b'{';
166    }
167    false
168}
169
170/// Represents a response that can be a single response or a batch of responses.
171#[derive(Debug, Serialize)]
172#[serde(untagged)]
173pub enum Amount<T> {
174    One(T),
175    Batch(Vec<T>),
176}
177
178/// Represents an RPC method.
179pub struct Method {
180    name: &'static str,
181    handler: Handler,
182}
183
184impl Method {
185    /// Creates a new `Method` with the given name and handler.
186    pub const fn new(name: &'static str, handler: Handler) -> Self {
187        Self { name, handler }
188    }
189}
190
191/// Represents an error encountered while injecting dependencies.
192#[derive(Debug)]
193pub struct InjectError {
194    name: &'static str,
195    ty: &'static str,
196}
197
198impl InjectError {
199    /// Creates a new `InjectError` for the specified argument name and type.
200    pub fn new<T>(name: &'static str) -> Self {
201        Self {
202            name,
203            ty: type_name::<T>(),
204        }
205    }
206}
207
208impl Display for InjectError {
209    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
210        write!(f, "error inject argument `{}: {}`", self.name, self.ty)
211    }
212}
213
214impl StdError for InjectError {}
215
216mod sealed {
217    pub trait Sealed {}
218}
219
220/// Trait representing the result of an RPC method.
221pub trait MethodResult: sealed::Sealed {
222    const ASSERT: () = ();
223}
224
225impl<T, E> sealed::Sealed for Result<T, E>
226where
227    T: Serialize,
228    E: Into<BoxError>,
229{
230}
231
232impl<T, E> MethodResult for Result<T, E>
233where
234    T: Serialize,
235    E: Into<BoxError>,
236{
237}
238
239/// Trait for converting an argument from one type to another.
240#[allow(async_fn_in_trait)]
241pub trait FromArg<T>: Sized {
242    type Error: StdError;
243
244    async fn from_arg(container: &Container, arg: T) -> Result<Self, Self::Error>;
245}