1use std::borrow::Cow as StdCow;
28use std::fmt;
29use std::sync::Arc;
30use std::time::Duration;
31
32use crate::transport::{self, Error as TransportError, HttpBackend, HttpTransportClient, HttpTransportClientBuilder};
33use crate::types::{NotificationSer, RequestSer, Response};
34use crate::{HttpRequest, HttpResponse};
35use async_trait::async_trait;
36use hyper::body::Bytes;
37use hyper::http::HeaderMap;
38use jsonrpsee_core::client::{
39	generate_batch_id_range, BatchResponse, ClientT, Error, IdKind, RequestIdManager, Subscription, SubscriptionClientT,
40};
41use jsonrpsee_core::params::BatchRequestBuilder;
42use jsonrpsee_core::traits::ToRpcParams;
43use jsonrpsee_core::{BoxError, JsonRawValue, TEN_MB_SIZE_BYTES};
44use jsonrpsee_types::{ErrorObject, InvalidRequestId, ResponseSuccess, TwoPointZero};
45use serde::de::DeserializeOwned;
46use tokio::sync::Semaphore;
47use tower::layer::util::Identity;
48use tower::{Layer, Service};
49use tracing::instrument;
50
51#[cfg(feature = "tls")]
52use crate::{CertificateStore, CustomCertStore};
53
54#[derive(Debug)]
78pub struct HttpClientBuilder<L = Identity> {
79	max_request_size: u32,
80	max_response_size: u32,
81	request_timeout: Duration,
82	#[cfg(feature = "tls")]
83	certificate_store: CertificateStore,
84	id_kind: IdKind,
85	max_log_length: u32,
86	headers: HeaderMap,
87	service_builder: tower::ServiceBuilder<L>,
88	tcp_no_delay: bool,
89	max_concurrent_requests: Option<usize>,
90}
91
92impl<L> HttpClientBuilder<L> {
93	pub fn max_request_size(mut self, size: u32) -> Self {
95		self.max_request_size = size;
96		self
97	}
98
99	pub fn max_response_size(mut self, size: u32) -> Self {
101		self.max_response_size = size;
102		self
103	}
104
105	pub fn request_timeout(mut self, timeout: Duration) -> Self {
107		self.request_timeout = timeout;
108		self
109	}
110
111	pub fn max_concurrent_requests(mut self, max_concurrent_requests: usize) -> Self {
113		self.max_concurrent_requests = Some(max_concurrent_requests);
114		self
115	}
116
117	#[cfg(feature = "tls")]
183	pub fn with_custom_cert_store(mut self, cfg: CustomCertStore) -> Self {
184		self.certificate_store = CertificateStore::Custom(cfg);
185		self
186	}
187
188	pub fn id_format(mut self, id_kind: IdKind) -> Self {
190		self.id_kind = id_kind;
191		self
192	}
193
194	pub fn set_max_logging_length(mut self, max: u32) -> Self {
198		self.max_log_length = max;
199		self
200	}
201
202	pub fn set_headers(mut self, headers: HeaderMap) -> Self {
206		self.headers = headers;
207		self
208	}
209
210	pub fn set_tcp_no_delay(mut self, no_delay: bool) -> Self {
214		self.tcp_no_delay = no_delay;
215		self
216	}
217
218	pub fn set_http_middleware<T>(self, service_builder: tower::ServiceBuilder<T>) -> HttpClientBuilder<T> {
220		HttpClientBuilder {
221			#[cfg(feature = "tls")]
222			certificate_store: self.certificate_store,
223			id_kind: self.id_kind,
224			headers: self.headers,
225			max_log_length: self.max_log_length,
226			max_request_size: self.max_request_size,
227			max_response_size: self.max_response_size,
228			service_builder,
229			request_timeout: self.request_timeout,
230			tcp_no_delay: self.tcp_no_delay,
231			max_concurrent_requests: self.max_concurrent_requests,
232		}
233	}
234}
235
236impl<B, S, L> HttpClientBuilder<L>
237where
238	L: Layer<transport::HttpBackend, Service = S>,
239	S: Service<HttpRequest, Response = HttpResponse<B>, Error = TransportError> + Clone,
240	B: http_body::Body<Data = Bytes> + Send + Unpin + 'static,
241	B::Data: Send,
242	B::Error: Into<BoxError>,
243{
244	pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient<S>, Error> {
246		let Self {
247			max_request_size,
248			max_response_size,
249			request_timeout,
250			#[cfg(feature = "tls")]
251			certificate_store,
252			id_kind,
253			headers,
254			max_log_length,
255			service_builder,
256			tcp_no_delay,
257			..
258		} = self;
259
260		let transport = HttpTransportClientBuilder {
261			max_request_size,
262			max_response_size,
263			headers,
264			max_log_length,
265			tcp_no_delay,
266			service_builder,
267			#[cfg(feature = "tls")]
268			certificate_store,
269		}
270		.build(target)
271		.map_err(|e| Error::Transport(e.into()))?;
272
273		let request_guard = self
274			.max_concurrent_requests
275			.map(|max_concurrent_requests| Arc::new(Semaphore::new(max_concurrent_requests)));
276
277		Ok(HttpClient {
278			transport,
279			id_manager: Arc::new(RequestIdManager::new(id_kind)),
280			request_timeout,
281			request_guard,
282		})
283	}
284}
285
286impl Default for HttpClientBuilder<Identity> {
287	fn default() -> Self {
288		Self {
289			max_request_size: TEN_MB_SIZE_BYTES,
290			max_response_size: TEN_MB_SIZE_BYTES,
291			request_timeout: Duration::from_secs(60),
292			#[cfg(feature = "tls")]
293			certificate_store: CertificateStore::Native,
294			id_kind: IdKind::Number,
295			max_log_length: 4096,
296			headers: HeaderMap::new(),
297			service_builder: tower::ServiceBuilder::new(),
298			tcp_no_delay: true,
299			max_concurrent_requests: None,
300		}
301	}
302}
303
304impl HttpClientBuilder<Identity> {
305	pub fn new() -> HttpClientBuilder<Identity> {
307		HttpClientBuilder::default()
308	}
309}
310
311#[derive(Debug, Clone)]
313pub struct HttpClient<S = HttpBackend> {
314	transport: HttpTransportClient<S>,
316	request_timeout: Duration,
318	id_manager: Arc<RequestIdManager>,
320	request_guard: Option<Arc<Semaphore>>,
322}
323
324impl HttpClient<HttpBackend> {
325	pub fn builder() -> HttpClientBuilder<Identity> {
327		HttpClientBuilder::new()
328	}
329}
330
331#[async_trait]
332impl<B, S> ClientT for HttpClient<S>
333where
334	S: Service<HttpRequest, Response = HttpResponse<B>, Error = TransportError> + Send + Sync + Clone,
335	<S as Service<HttpRequest>>::Future: Send,
336	B: http_body::Body<Data = Bytes> + Send + Unpin + 'static,
337	B::Error: Into<BoxError>,
338	B::Data: Send,
339{
340	#[instrument(name = "notification", skip(self, params), level = "trace")]
341	async fn notification<Params>(&self, method: &str, params: Params) -> Result<(), Error>
342	where
343		Params: ToRpcParams + Send,
344	{
345		let _permit = match self.request_guard.as_ref() {
346			Some(permit) => permit.acquire().await.ok(),
347			None => None,
348		};
349		let params = params.to_rpc_params()?;
350		let notif =
351			serde_json::to_string(&NotificationSer::borrowed(&method, params.as_deref())).map_err(Error::ParseError)?;
352
353		let fut = self.transport.send(notif);
354
355		match tokio::time::timeout(self.request_timeout, fut).await {
356			Ok(Ok(ok)) => Ok(ok),
357			Err(_) => Err(Error::RequestTimeout),
358			Ok(Err(e)) => Err(Error::Transport(e.into())),
359		}
360	}
361
362	#[instrument(name = "method_call", skip(self, params), level = "trace")]
363	async fn request<R, Params>(&self, method: &str, params: Params) -> Result<R, Error>
364	where
365		R: DeserializeOwned,
366		Params: ToRpcParams + Send,
367	{
368		let _permit = match self.request_guard.as_ref() {
369			Some(permit) => permit.acquire().await.ok(),
370			None => None,
371		};
372		let id = self.id_manager.next_request_id();
373		let params = params.to_rpc_params()?;
374
375		let request = RequestSer::borrowed(&id, &method, params.as_deref());
376		let raw = serde_json::to_string(&request).map_err(Error::ParseError)?;
377
378		let fut = self.transport.send_and_read_body(raw);
379		let body = match tokio::time::timeout(self.request_timeout, fut).await {
380			Ok(Ok(body)) => body,
381			Err(_e) => {
382				return Err(Error::RequestTimeout);
383			}
384			Ok(Err(e)) => {
385				return Err(Error::Transport(e.into()));
386			}
387		};
388
389		let response = ResponseSuccess::try_from(serde_json::from_slice::<Response<&JsonRawValue>>(&body)?)?;
392
393		let result = serde_json::from_str(response.result.get()).map_err(Error::ParseError)?;
394
395		if response.id == id {
396			Ok(result)
397		} else {
398			Err(InvalidRequestId::NotPendingRequest(response.id.to_string()).into())
399		}
400	}
401
402	#[instrument(name = "batch", skip(self, batch), level = "trace")]
403	async fn batch_request<'a, R>(&self, batch: BatchRequestBuilder<'a>) -> Result<BatchResponse<'a, R>, Error>
404	where
405		R: DeserializeOwned + fmt::Debug + 'a,
406	{
407		let _permit = match self.request_guard.as_ref() {
408			Some(permit) => permit.acquire().await.ok(),
409			None => None,
410		};
411		let batch = batch.build()?;
412		let id = self.id_manager.next_request_id();
413		let id_range = generate_batch_id_range(id, batch.len() as u64)?;
414
415		let mut batch_request = Vec::with_capacity(batch.len());
416		for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
417			let id = self.id_manager.as_id_kind().into_id(id);
418			batch_request.push(RequestSer {
419				jsonrpc: TwoPointZero,
420				id,
421				method: method.into(),
422				params: params.map(StdCow::Owned),
423			});
424		}
425
426		let fut = self.transport.send_and_read_body(serde_json::to_string(&batch_request).map_err(Error::ParseError)?);
427
428		let body = match tokio::time::timeout(self.request_timeout, fut).await {
429			Ok(Ok(body)) => body,
430			Err(_e) => return Err(Error::RequestTimeout),
431			Ok(Err(e)) => return Err(Error::Transport(e.into())),
432		};
433
434		let json_rps: Vec<Response<&JsonRawValue>> = serde_json::from_slice(&body).map_err(Error::ParseError)?;
435
436		let mut responses = Vec::with_capacity(json_rps.len());
437		let mut successful_calls = 0;
438		let mut failed_calls = 0;
439
440		for _ in 0..json_rps.len() {
441			responses.push(Err(ErrorObject::borrowed(0, "", None)));
442		}
443
444		for rp in json_rps {
445			let id = rp.id.try_parse_inner_as_number()?;
446
447			let res = match ResponseSuccess::try_from(rp) {
448				Ok(r) => {
449					let result = serde_json::from_str(r.result.get())?;
450					successful_calls += 1;
451					Ok(result)
452				}
453				Err(err) => {
454					failed_calls += 1;
455					Err(err)
456				}
457			};
458
459			let maybe_elem = id
460				.checked_sub(id_range.start)
461				.and_then(|p| p.try_into().ok())
462				.and_then(|p: usize| responses.get_mut(p));
463
464			if let Some(elem) = maybe_elem {
465				*elem = res;
466			} else {
467				return Err(InvalidRequestId::NotPendingRequest(id.to_string()).into());
468			}
469		}
470
471		Ok(BatchResponse::new(successful_calls, responses, failed_calls))
472	}
473}
474
475#[async_trait]
476impl<B, S> SubscriptionClientT for HttpClient<S>
477where
478	S: Service<HttpRequest, Response = HttpResponse<B>, Error = TransportError> + Send + Sync + Clone,
479	<S as Service<HttpRequest>>::Future: Send,
480	B: http_body::Body<Data = Bytes> + Send + Unpin + 'static,
481	B::Data: Send,
482	B::Error: Into<BoxError>,
483{
484	#[instrument(name = "subscription", fields(method = _subscribe_method), skip(self, _params, _subscribe_method, _unsubscribe_method), level = "trace")]
487	async fn subscribe<'a, N, Params>(
488		&self,
489		_subscribe_method: &'a str,
490		_params: Params,
491		_unsubscribe_method: &'a str,
492	) -> Result<Subscription<N>, Error>
493	where
494		Params: ToRpcParams + Send,
495		N: DeserializeOwned,
496	{
497		Err(Error::HttpNotImplemented)
498	}
499
500	#[instrument(name = "subscribe_method", fields(method = _method), skip(self, _method), level = "trace")]
502	async fn subscribe_to_method<'a, N>(&self, _method: &'a str) -> Result<Subscription<N>, Error>
503	where
504		N: DeserializeOwned,
505	{
506		Err(Error::HttpNotImplemented)
507	}
508}