jsonrpsee_core/middleware/
mod.rs

1//! Middleware for the RPC service.
2
3pub mod layer;
4
5use std::borrow::Cow;
6use std::future::Future;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10use futures_util::future::Either;
11use jsonrpsee_types::{ErrorObject, Id};
12use pin_project::pin_project;
13use serde::Serialize;
14use serde::ser::SerializeSeq;
15use serde_json::value::RawValue;
16use tower::layer::LayerFn;
17use tower::layer::util::{Identity, Stack};
18
19/// Re-export types from `jsonrpsee_types` crate for convenience.
20pub type Notification<'a> = jsonrpsee_types::Notification<'a, Option<Cow<'a, RawValue>>>;
21/// Re-export types from `jsonrpsee_types` crate for convenience.
22pub use jsonrpsee_types::{Extensions, Request};
23
24/// Error response that can used to indicate an error in JSON-RPC request batch request.
25/// This is used in the [`Batch`] type to indicate an error in the batch entry.
26#[derive(Debug)]
27pub struct BatchEntryErr<'a>(jsonrpsee_types::Response<'a, ()>);
28
29impl<'a> BatchEntryErr<'a> {
30	/// Create a new error response.
31	pub fn new(id: Id<'a>, err: ErrorObject<'a>) -> Self {
32		let payload = jsonrpsee_types::ResponsePayload::Error(err);
33		let response = jsonrpsee_types::Response::new(payload, id);
34		Self(response)
35	}
36
37	/// Get the parts of the error response.q
38	pub fn into_parts(self) -> (ErrorObject<'a>, Id<'a>) {
39		let err = match self.0.payload {
40			jsonrpsee_types::ResponsePayload::Error(err) => err,
41			_ => unreachable!("BatchEntryErr can only be created from error payload; qed"),
42		};
43		(err, self.0.id)
44	}
45}
46
47/// A batch of JSON-RPC requests.
48#[derive(Debug, Default)]
49pub struct Batch<'a> {
50	inner: Vec<Result<BatchEntry<'a>, BatchEntryErr<'a>>>,
51	extensions: Option<Extensions>,
52}
53
54impl std::fmt::Display for Batch<'_> {
55	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56		let fmt = serde_json::to_string(self).map_err(|_| std::fmt::Error)?;
57		f.write_str(&fmt)
58	}
59}
60
61impl<'a> IntoIterator for Batch<'a> {
62	type Item = Result<BatchEntry<'a>, BatchEntryErr<'a>>;
63	type IntoIter = std::vec::IntoIter<Self::Item>;
64
65	fn into_iter(self) -> Self::IntoIter {
66		self.inner.into_iter()
67	}
68}
69
70impl<'a> Batch<'a> {
71	/// Create a new batch from a vector of batch entries.
72	pub fn from(entries: Vec<Result<BatchEntry<'a>, BatchEntryErr<'a>>>) -> Self {
73		Self { inner: entries, extensions: None }
74	}
75
76	/// Create a new empty batch.
77	pub fn new() -> Self {
78		Self { inner: Vec::new(), extensions: None }
79	}
80
81	/// Create a new empty batch with the at least capacity.
82	pub fn with_capacity(capacity: usize) -> Self {
83		Self { inner: Vec::with_capacity(capacity), extensions: None }
84	}
85
86	/// Push a new batch entry to the batch.
87	pub fn push(&mut self, req: Request<'a>) {
88		match self.extensions {
89			Some(ref mut ext) => {
90				ext.extend(req.extensions().clone());
91			}
92			None => {
93				self.extensions = Some(req.extensions().clone());
94			}
95		};
96
97		self.inner.push(Ok(BatchEntry::Call(req)));
98	}
99
100	/// Get the length of the batch.
101	pub fn len(&self) -> usize {
102		self.inner.len()
103	}
104
105	/// Returns whether the batch is empty.
106	pub fn is_empty(&self) -> bool {
107		self.inner.is_empty()
108	}
109
110	/// Get an iterator over the batch.
111	pub fn iter(&self) -> impl Iterator<Item = &Result<BatchEntry<'a>, BatchEntryErr<'a>>> {
112		self.inner.iter()
113	}
114
115	/// Get a mutable iterator over the batch.
116	pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Result<BatchEntry<'a>, BatchEntryErr<'a>>> {
117		self.inner.iter_mut()
118	}
119
120	/// Consume the batch and and return the parts.
121	pub fn into_extensions(self) -> Extensions {
122		match self.extensions {
123			Some(ext) => ext,
124			None => self.extensions_from_iter(),
125		}
126	}
127
128	/// Get a reference to the extensions of the batch.
129	pub fn extensions(&mut self) -> &Extensions {
130		if self.extensions.is_none() {
131			self.extensions = Some(self.extensions_from_iter());
132		}
133
134		self.extensions.as_ref().expect("Extensions inserted above; qed")
135	}
136
137	/// Get a mutable reference to the extensions of the batch.
138	pub fn extensions_mut(&mut self) -> &mut Extensions {
139		if self.extensions.is_none() {
140			self.extensions = Some(self.extensions_from_iter());
141		}
142
143		self.extensions.as_mut().expect("Extensions inserted above; qed")
144	}
145
146	fn extensions_from_iter(&self) -> Extensions {
147		let mut ext = Extensions::new();
148		for entry in self.inner.iter().flatten() {
149			ext.extend(entry.extensions().clone());
150		}
151		ext
152	}
153}
154
155impl Serialize for Batch<'_> {
156	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
157	where
158		S: serde::Serializer,
159	{
160		let mut seq = serializer.serialize_seq(Some(self.inner.len()))?;
161		for entry in &self.inner {
162			match entry {
163				Ok(entry) => seq.serialize_element(entry)?,
164				Err(err) => seq.serialize_element(&err.0)?,
165			}
166		}
167		seq.end()
168	}
169}
170
171#[derive(Debug, Clone)]
172/// A marker type to indicate that the request is a subscription for the [`RpcServiceT::call`] method.
173pub struct IsSubscription {
174	sub_id: Id<'static>,
175	unsub_id: Id<'static>,
176	unsub_method: String,
177}
178
179impl IsSubscription {
180	/// Create a new [`IsSubscription`] instance.
181	pub fn new(sub_id: Id<'static>, unsub_id: Id<'static>, unsub_method: String) -> Self {
182		Self { sub_id, unsub_id, unsub_method }
183	}
184
185	/// Get the request id of the subscription calls.
186	pub fn sub_req_id(&self) -> Id<'static> {
187		self.sub_id.clone()
188	}
189
190	/// Get the request id of the unsubscription call.
191	pub fn unsub_req_id(&self) -> Id<'static> {
192		self.unsub_id.clone()
193	}
194
195	/// Get the unsubscription method name.
196	pub fn unsubscribe_method(&self) -> &str {
197		&self.unsub_method
198	}
199}
200
201/// An extension type for the [`RpcServiceT::batch`] for the expected id range of the batch entries.
202#[derive(Debug, Clone)]
203pub struct IsBatch {
204	/// The range of ids for the batch entries.
205	pub id_range: std::ops::Range<u64>,
206}
207
208/// A batch entry specific for the [`RpcServiceT::batch`] method to support both
209/// method calls and notifications.
210#[derive(Debug, Clone, Serialize)]
211#[serde(untagged)]
212pub enum BatchEntry<'a> {
213	/// A regular JSON-RPC call.
214	Call(Request<'a>),
215	/// A JSON-RPC notification.
216	Notification(Notification<'a>),
217}
218
219impl<'a> BatchEntry<'a> {
220	/// Get a reference to extensions of the batch entry.
221	pub fn extensions(&self) -> &Extensions {
222		match self {
223			BatchEntry::Call(req) => req.extensions(),
224			BatchEntry::Notification(n) => n.extensions(),
225		}
226	}
227
228	/// Get a mut reference to extensions of the batch entry.
229	pub fn extensions_mut(&mut self) -> &mut Extensions {
230		match self {
231			BatchEntry::Call(req) => req.extensions_mut(),
232			BatchEntry::Notification(n) => n.extensions_mut(),
233		}
234	}
235
236	/// Get the method name of the batch entry.
237	pub fn method_name(&self) -> &str {
238		match self {
239			BatchEntry::Call(req) => req.method_name(),
240			BatchEntry::Notification(n) => n.method_name(),
241		}
242	}
243
244	/// Set the method name of the batch entry.
245	pub fn set_method_name(&mut self, name: impl Into<Cow<'a, str>>) {
246		match self {
247			BatchEntry::Call(req) => {
248				req.method = name.into();
249			}
250			BatchEntry::Notification(n) => {
251				n.method = name.into();
252			}
253		}
254	}
255
256	/// Get the params of the batch entry (may be empty).
257	pub fn params(&self) -> Option<&Cow<'a, RawValue>> {
258		match self {
259			BatchEntry::Call(req) => req.params.as_ref(),
260			BatchEntry::Notification(n) => n.params.as_ref(),
261		}
262	}
263
264	/// Set the params of the batch entry.
265	pub fn set_params(&mut self, params: Option<Box<RawValue>>) {
266		match self {
267			BatchEntry::Call(req) => {
268				req.params = params.map(Cow::Owned);
269			}
270			BatchEntry::Notification(n) => {
271				n.params = params.map(Cow::Owned);
272			}
273		}
274	}
275
276	/// Consume the batch entry and extract the extensions.
277	pub fn into_extensions(self) -> Extensions {
278		match self {
279			BatchEntry::Call(req) => req.extensions,
280			BatchEntry::Notification(n) => n.extensions,
281		}
282	}
283}
284
285/// Represent a JSON-RPC service that can process JSON-RPC calls, notifications, and batch requests.
286///
287/// This trait is similar to [`tower::Service`] but it's specialized for JSON-RPC operations.
288///
289/// The response type is a future that resolves to a `Result<R, E>` mainly because this trait is
290/// intended to by used by both client and server implementations.
291///
292/// In the server implementation, the error is infallible but in the client implementation, the error
293/// can occur due to I/O errors or JSON-RPC protocol errors.
294///
295/// Such that server implementations must use `std::convert::Infallible` as the error type because
296/// the underlying service requires that and otherwise one will get a compiler error that tries to
297/// explain that.
298pub trait RpcServiceT {
299	/// Response type for `RpcServiceT::call`.
300	type MethodResponse;
301	/// Response type for `RpcServiceT::notification`.
302	type NotificationResponse;
303	/// Response type for `RpcServiceT::batch`.
304	type BatchResponse;
305
306	/// Processes a single JSON-RPC call, which may be a subscription or regular call.
307	fn call<'a>(&self, request: Request<'a>) -> impl Future<Output = Self::MethodResponse> + Send + 'a;
308
309	/// Processes multiple JSON-RPC calls at once, similar to `RpcServiceT::call`.
310	///
311	/// This method wraps `RpcServiceT::call` and `RpcServiceT::notification`,
312	/// but the root RPC service does not inherently recognize custom implementations
313	/// of these methods.
314	///
315	/// As a result, if you have custom logic for individual calls or notifications,
316	/// you must duplicate that implementation in this method or no middleware will be applied
317	/// for calls inside the batch.
318	fn batch<'a>(&self, requests: Batch<'a>) -> impl Future<Output = Self::BatchResponse> + Send + 'a;
319
320	/// Similar to `RpcServiceT::call` but processes a JSON-RPC notification.
321	fn notification<'a>(&self, n: Notification<'a>) -> impl Future<Output = Self::NotificationResponse> + Send + 'a;
322}
323
324/// Similar to [`tower::ServiceBuilder`] but doesn't
325/// support any tower middleware implementations.
326#[derive(Debug, Clone)]
327pub struct RpcServiceBuilder<L>(tower::ServiceBuilder<L>);
328
329impl Default for RpcServiceBuilder<Identity> {
330	fn default() -> Self {
331		RpcServiceBuilder(tower::ServiceBuilder::new())
332	}
333}
334
335impl RpcServiceBuilder<Identity> {
336	/// Create a new [`RpcServiceBuilder`].
337	pub fn new() -> Self {
338		Self(tower::ServiceBuilder::new())
339	}
340}
341
342impl<L> RpcServiceBuilder<L> {
343	/// Optionally add a new layer `T` to the [`RpcServiceBuilder`].
344	///
345	/// See the documentation for [`tower::ServiceBuilder::option_layer`] for more details.
346	pub fn option_layer<T>(self, layer: Option<T>) -> RpcServiceBuilder<Stack<layer::Either<T, Identity>, L>> {
347		let layer =
348			if let Some(layer) = layer { layer::Either::Left(layer) } else { layer::Either::Right(Identity::new()) };
349		RpcServiceBuilder(self.0.layer(layer))
350	}
351
352	/// Add a new layer `T` to the [`RpcServiceBuilder`].
353	///
354	/// See the documentation for [`tower::ServiceBuilder::layer`] for more details.
355	pub fn layer<T>(self, layer: T) -> RpcServiceBuilder<Stack<T, L>> {
356		RpcServiceBuilder(self.0.layer(layer))
357	}
358
359	/// Add a [`tower::Layer`] built from a function that accepts a service and returns another service.
360	///
361	/// See the documentation for [`tower::ServiceBuilder::layer_fn`] for more details.
362	pub fn layer_fn<F>(self, f: F) -> RpcServiceBuilder<Stack<LayerFn<F>, L>> {
363		RpcServiceBuilder(self.0.layer_fn(f))
364	}
365
366	/// Add a logging layer to [`RpcServiceBuilder`]
367	///
368	/// This logs each request and response for every call.
369	///
370	pub fn rpc_logger(self, max_log_len: u32) -> RpcServiceBuilder<Stack<layer::RpcLoggerLayer, L>> {
371		RpcServiceBuilder(self.0.layer(layer::RpcLoggerLayer::new(max_log_len)))
372	}
373
374	/// Wrap the service `S` with the middleware.
375	pub fn service<S>(&self, service: S) -> L::Service
376	where
377		L: tower::Layer<S>,
378	{
379		self.0.service(service)
380	}
381}
382
383/// Response which may be ready or a future.
384#[derive(Debug)]
385#[pin_project]
386pub struct ResponseFuture<F, R>(#[pin] futures_util::future::Either<F, std::future::Ready<R>>);
387
388impl<F, R> ResponseFuture<F, R> {
389	/// Returns a future that resolves to a response.
390	pub fn future(f: F) -> ResponseFuture<F, R> {
391		ResponseFuture(Either::Left(f))
392	}
393
394	/// Return a response which is already computed.
395	pub fn ready(response: R) -> ResponseFuture<F, R> {
396		ResponseFuture(Either::Right(std::future::ready(response)))
397	}
398}
399
400impl<F, R> Future for ResponseFuture<F, R>
401where
402	F: Future<Output = R>,
403{
404	type Output = F::Output;
405
406	fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
407		match self.project().0.poll(cx) {
408			Poll::Ready(rp) => Poll::Ready(rp),
409			Poll::Pending => Poll::Pending,
410		}
411	}
412}
413
414#[cfg(test)]
415mod tests {
416	use jsonrpsee_types::{ErrorCode, ErrorObject};
417
418	#[test]
419	fn serialize_batch_entry() {
420		use super::{BatchEntry, Notification, Request};
421		use jsonrpsee_types::Id;
422
423		let req = Request::borrowed("say_hello", None, Id::Number(1));
424		let batch_entry = BatchEntry::Call(req.clone());
425		assert_eq!(
426			serde_json::to_string(&batch_entry).unwrap(),
427			"{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"say_hello\"}",
428		);
429
430		let notification = Notification::new("say_hello".into(), None);
431		let batch_entry = BatchEntry::Notification(notification.clone());
432		assert_eq!(
433			serde_json::to_string(&batch_entry).unwrap(),
434			"{\"jsonrpc\":\"2.0\",\"method\":\"say_hello\",\"params\":null}",
435		);
436	}
437
438	#[test]
439	fn serialize_batch_works() {
440		use super::{Batch, BatchEntry, BatchEntryErr, Notification, Request};
441		use jsonrpsee_types::Id;
442
443		let req = Request::borrowed("say_hello", None, Id::Number(1));
444		let notification = Notification::new("say_hello".into(), None);
445		let batch = Batch::from(vec![
446			Ok(BatchEntry::Call(req)),
447			Ok(BatchEntry::Notification(notification)),
448			Err(BatchEntryErr::new(Id::Number(2), ErrorObject::from(ErrorCode::InvalidRequest))),
449		]);
450		assert_eq!(
451			serde_json::to_string(&batch).unwrap(),
452			r#"[{"jsonrpc":"2.0","id":1,"method":"say_hello"},{"jsonrpc":"2.0","method":"say_hello","params":null},{"jsonrpc":"2.0","id":2,"error":{"code":-32600,"message":"Invalid request"}}]"#,
453		);
454	}
455}