1use std::borrow::Cow as StdCow;
28use std::fmt;
29use std::sync::Arc;
30use std::time::Duration;
31
32use crate::rpc_service::RpcService;
33use crate::transport::{self, Error as TransportError, HttpBackend, HttpTransportClientBuilder};
34use crate::{HttpRequest, HttpResponse};
35use hyper::body::Bytes;
36use hyper::http::{Extensions, HeaderMap};
37use jsonrpsee_core::client::{
38 BatchResponse, ClientT, Error, IdKind, MiddlewareBatchResponse, MiddlewareMethodResponse, MiddlewareNotifResponse,
39 RequestIdManager, Subscription, SubscriptionClientT, generate_batch_id_range,
40};
41use jsonrpsee_core::middleware::layer::{RpcLogger, RpcLoggerLayer};
42use jsonrpsee_core::middleware::{Batch, RpcServiceBuilder, RpcServiceT};
43use jsonrpsee_core::params::BatchRequestBuilder;
44use jsonrpsee_core::traits::ToRpcParams;
45use jsonrpsee_core::{BoxError, TEN_MB_SIZE_BYTES};
46use jsonrpsee_types::{ErrorObject, InvalidRequestId, Notification, Request, ResponseSuccess, TwoPointZero};
47use serde::de::DeserializeOwned;
48use tokio::sync::Semaphore;
49use tower::layer::util::Identity;
50use tower::{Layer, Service};
51
52#[cfg(feature = "tls")]
53use crate::{CertificateStore, CustomCertStore};
54
55type Logger = tower::layer::util::Stack<RpcLoggerLayer, tower::layer::util::Identity>;
56
57#[derive(Clone, Debug)]
81pub struct HttpClientBuilder<HttpMiddleware = Identity, RpcMiddleware = Logger> {
82 max_request_size: u32,
83 max_response_size: u32,
84 request_timeout: Duration,
85 #[cfg(feature = "tls")]
86 certificate_store: CertificateStore,
87 id_kind: IdKind,
88 headers: HeaderMap,
89 service_builder: tower::ServiceBuilder<HttpMiddleware>,
90 rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
91 tcp_no_delay: bool,
92 max_concurrent_requests: Option<usize>,
93 keep_alive_duration: Option<Duration>,
94 keep_alive_interval: Option<Duration>,
95 keep_alive_retries: Option<u32>,
96}
97
98impl<HttpMiddleware, RpcMiddleware> HttpClientBuilder<HttpMiddleware, RpcMiddleware> {
99 pub fn max_request_size(mut self, size: u32) -> Self {
101 self.max_request_size = size;
102 self
103 }
104
105 pub fn max_response_size(mut self, size: u32) -> Self {
107 self.max_response_size = size;
108 self
109 }
110
111 pub fn request_timeout(mut self, timeout: Duration) -> Self {
113 self.request_timeout = timeout;
114 self
115 }
116
117 pub fn max_concurrent_requests(mut self, max_concurrent_requests: usize) -> Self {
119 self.max_concurrent_requests = Some(max_concurrent_requests);
120 self
121 }
122
123 #[cfg(feature = "tls")]
189 pub fn with_custom_cert_store(mut self, cfg: CustomCertStore) -> Self {
190 self.certificate_store = CertificateStore::Custom(cfg);
191 self
192 }
193
194 pub fn id_format(mut self, id_kind: IdKind) -> Self {
196 self.id_kind = id_kind;
197 self
198 }
199
200 pub fn set_headers(mut self, headers: HeaderMap) -> Self {
204 self.headers = headers;
205 self
206 }
207
208 pub fn set_tcp_no_delay(mut self, no_delay: bool) -> Self {
212 self.tcp_no_delay = no_delay;
213 self
214 }
215
216 pub fn set_keep_alive(mut self, duration: Option<Duration>) -> Self {
218 self.keep_alive_duration = duration;
219 self
220 }
221
222 pub fn set_keep_alive_interval(mut self, interval: Option<Duration>) -> Self {
224 self.keep_alive_interval = interval;
225 self
226 }
227
228 pub fn set_keep_alive_retries(mut self, retries: Option<u32>) -> Self {
230 self.keep_alive_retries = retries;
231 self
232 }
233
234 pub fn set_rpc_middleware<T>(self, rpc_builder: RpcServiceBuilder<T>) -> HttpClientBuilder<HttpMiddleware, T> {
236 HttpClientBuilder {
237 #[cfg(feature = "tls")]
238 certificate_store: self.certificate_store,
239 id_kind: self.id_kind,
240 headers: self.headers,
241 max_request_size: self.max_request_size,
242 max_response_size: self.max_response_size,
243 service_builder: self.service_builder,
244 rpc_middleware: rpc_builder,
245 request_timeout: self.request_timeout,
246 tcp_no_delay: self.tcp_no_delay,
247 max_concurrent_requests: self.max_concurrent_requests,
248 keep_alive_duration: self.keep_alive_duration,
249 keep_alive_interval: self.keep_alive_interval,
250 keep_alive_retries: self.keep_alive_retries,
251 }
252 }
253
254 pub fn set_http_middleware<T>(
256 self,
257 service_builder: tower::ServiceBuilder<T>,
258 ) -> HttpClientBuilder<T, RpcMiddleware> {
259 HttpClientBuilder {
260 #[cfg(feature = "tls")]
261 certificate_store: self.certificate_store,
262 id_kind: self.id_kind,
263 headers: self.headers,
264 max_request_size: self.max_request_size,
265 max_response_size: self.max_response_size,
266 service_builder,
267 rpc_middleware: self.rpc_middleware,
268 request_timeout: self.request_timeout,
269 tcp_no_delay: self.tcp_no_delay,
270 max_concurrent_requests: self.max_concurrent_requests,
271 keep_alive_duration: self.keep_alive_duration,
272 keep_alive_retries: self.keep_alive_retries,
273 keep_alive_interval: self.keep_alive_interval,
274 }
275 }
276}
277
278impl<B, S, S2, HttpMiddleware, RpcMiddleware> HttpClientBuilder<HttpMiddleware, RpcMiddleware>
279where
280 RpcMiddleware: Layer<RpcService<S>, Service = S2>,
281 <RpcMiddleware as Layer<RpcService<S>>>::Service: RpcServiceT,
282 HttpMiddleware: Layer<transport::HttpBackend, Service = S>,
283 S: Service<HttpRequest, Response = HttpResponse<B>, Error = TransportError> + Clone,
284 B: http_body::Body<Data = Bytes> + Send + Unpin + 'static,
285 B::Data: Send,
286 B::Error: Into<BoxError>,
287{
288 pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient<S2>, Error> {
290 let Self {
291 max_request_size,
292 max_response_size,
293 request_timeout,
294 #[cfg(feature = "tls")]
295 certificate_store,
296 id_kind,
297 headers,
298 service_builder,
299 tcp_no_delay,
300 rpc_middleware,
301 keep_alive_duration,
302 keep_alive_interval,
303 keep_alive_retries,
304 ..
305 } = self;
306
307 let http = HttpTransportClientBuilder {
308 max_request_size,
309 max_response_size,
310 headers,
311 tcp_no_delay,
312 service_builder,
313 keep_alive_duration,
314 keep_alive_interval,
315 keep_alive_retries,
316 #[cfg(feature = "tls")]
317 certificate_store,
318 }
319 .build(target)
320 .map_err(|e| Error::Transport(e.into()))?;
321
322 let request_guard = self
323 .max_concurrent_requests
324 .map(|max_concurrent_requests| Arc::new(Semaphore::new(max_concurrent_requests)));
325
326 Ok(HttpClient {
327 service: rpc_middleware.service(RpcService::new(http)),
328 id_manager: Arc::new(RequestIdManager::new(id_kind)),
329 request_guard,
330 request_timeout,
331 })
332 }
333}
334
335impl Default for HttpClientBuilder {
336 fn default() -> Self {
337 Self {
338 max_request_size: TEN_MB_SIZE_BYTES,
339 max_response_size: TEN_MB_SIZE_BYTES,
340 request_timeout: Duration::from_secs(60),
341 #[cfg(feature = "tls")]
342 certificate_store: CertificateStore::Native,
343 id_kind: IdKind::Number,
344 headers: HeaderMap::new(),
345 service_builder: tower::ServiceBuilder::new(),
346 rpc_middleware: RpcServiceBuilder::default().rpc_logger(1024),
347 tcp_no_delay: true,
348 max_concurrent_requests: None,
349 keep_alive_duration: None,
350 keep_alive_interval: None,
351 keep_alive_retries: None,
352 }
353 }
354}
355
356impl HttpClientBuilder {
357 pub fn new() -> HttpClientBuilder<Identity, Logger> {
359 HttpClientBuilder::default()
360 }
361}
362
363#[derive(Debug, Clone)]
365pub struct HttpClient<S = RpcLogger<RpcService<HttpBackend>>> {
366 service: S,
368 id_manager: Arc<RequestIdManager>,
370 request_guard: Option<Arc<Semaphore>>,
372 request_timeout: Duration,
374}
375
376impl HttpClient<HttpBackend> {
377 pub fn builder() -> HttpClientBuilder {
379 HttpClientBuilder::new()
380 }
381
382 pub fn request_timeout(&self) -> Duration {
384 self.request_timeout
385 }
386}
387
388impl<S> ClientT for HttpClient<S>
389where
390 S: RpcServiceT<
391 MethodResponse = Result<MiddlewareMethodResponse, Error>,
392 BatchResponse = Result<MiddlewareBatchResponse, Error>,
393 NotificationResponse = Result<MiddlewareNotifResponse, Error>,
394 > + Send
395 + Sync,
396{
397 fn notification<Params>(&self, method: &str, params: Params) -> impl Future<Output = Result<(), Error>> + Send
398 where
399 Params: ToRpcParams + Send,
400 {
401 async {
402 let _permit = match self.request_guard.as_ref() {
403 Some(permit) => permit.acquire().await.ok(),
404 None => None,
405 };
406 let params = params.to_rpc_params()?.map(StdCow::Owned);
407 let fut = self.service.notification(Notification::new(method.into(), params));
408
409 run_future_until_timeout(fut, self.request_timeout).await.map_err(|e| Error::Transport(e.into()))?;
410 Ok(())
411 }
412 }
413
414 fn request<R, Params>(&self, method: &str, params: Params) -> impl Future<Output = Result<R, Error>> + Send
415 where
416 R: DeserializeOwned,
417 Params: ToRpcParams + Send,
418 {
419 async {
420 let _permit = match self.request_guard.as_ref() {
421 Some(permit) => permit.acquire().await.ok(),
422 None => None,
423 };
424 let id = self.id_manager.next_request_id();
425 let params = params.to_rpc_params()?;
426
427 let method_response = run_future_until_timeout(
428 self.service.call(Request::borrowed(method, params.as_deref(), id.clone())),
429 self.request_timeout,
430 )
431 .await?
432 .into_response();
433
434 let rp = ResponseSuccess::try_from(method_response.into_inner())?;
435
436 let result = serde_json::from_str(rp.result.get()).map_err(Error::ParseError)?;
437 if rp.id == id { Ok(result) } else { Err(InvalidRequestId::NotPendingRequest(rp.id.to_string()).into()) }
438 }
439 }
440
441 fn batch_request<'a, R>(
442 &self,
443 batch: BatchRequestBuilder<'a>,
444 ) -> impl Future<Output = Result<jsonrpsee_core::client::BatchResponse<'a, R>, Error>> + Send
445 where
446 R: DeserializeOwned + fmt::Debug + 'a,
447 {
448 async {
449 let _permit = match self.request_guard.as_ref() {
450 Some(permit) => permit.acquire().await.ok(),
451 None => None,
452 };
453 let batch = batch.build()?;
454 let id = self.id_manager.next_request_id();
455 let id_range = generate_batch_id_range(id, batch.len() as u64)?;
456
457 let mut batch_request = Batch::with_capacity(batch.len());
458 for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
459 let id = self.id_manager.as_id_kind().into_id(id);
460 let req = Request {
461 jsonrpc: TwoPointZero,
462 method: method.into(),
463 params: params.map(StdCow::Owned),
464 id,
465 extensions: Extensions::new(),
466 };
467 batch_request.push(req);
468 }
469
470 let rps = run_future_until_timeout(self.service.batch(batch_request), self.request_timeout).await?;
471
472 let mut batch_response = Vec::new();
473 let mut success = 0;
474 let mut failed = 0;
475
476 for _ in 0..rps.len() {
478 batch_response.push(Err(ErrorObject::borrowed(0, "", None)));
479 }
480
481 for rp in rps.into_iter() {
482 let id = rp.id().try_parse_inner_as_number()?;
483
484 let res = match ResponseSuccess::try_from(rp.into_inner()) {
485 Ok(r) => {
486 let v = serde_json::from_str(r.result.get()).map_err(Error::ParseError)?;
487 success += 1;
488 Ok(v)
489 }
490 Err(err) => {
491 failed += 1;
492 Err(err)
493 }
494 };
495
496 let maybe_elem = id
497 .checked_sub(id_range.start)
498 .and_then(|p| p.try_into().ok())
499 .and_then(|p: usize| batch_response.get_mut(p));
500
501 if let Some(elem) = maybe_elem {
502 *elem = res;
503 } else {
504 return Err(InvalidRequestId::NotPendingRequest(id.to_string()).into());
505 }
506 }
507
508 Ok(BatchResponse::new(success, batch_response, failed))
509 }
510 }
511}
512
513impl<S> SubscriptionClientT for HttpClient<S>
514where
515 S: RpcServiceT<
516 MethodResponse = Result<MiddlewareMethodResponse, Error>,
517 BatchResponse = Result<MiddlewareBatchResponse, Error>,
518 NotificationResponse = Result<MiddlewareNotifResponse, Error>,
519 > + Send
520 + Sync,
521{
522 fn subscribe<'a, N, Params>(
525 &self,
526 _subscribe_method: &'a str,
527 _params: Params,
528 _unsubscribe_method: &'a str,
529 ) -> impl Future<Output = Result<Subscription<N>, Error>>
530 where
531 Params: ToRpcParams + Send,
532 N: DeserializeOwned,
533 {
534 async { Err(Error::HttpNotImplemented) }
535 }
536
537 fn subscribe_to_method<N>(&self, _method: &str) -> impl Future<Output = Result<Subscription<N>, Error>>
539 where
540 N: DeserializeOwned,
541 {
542 async { Err(Error::HttpNotImplemented) }
543 }
544}
545
546async fn run_future_until_timeout<F, T>(fut: F, timeout: Duration) -> Result<T, Error>
547where
548 F: std::future::Future<Output = Result<T, Error>>,
549{
550 match tokio::time::timeout(timeout, fut).await {
551 Ok(Ok(r)) => Ok(r),
552 Err(_) => Err(Error::RequestTimeout),
553 Ok(Err(e)) => Err(e),
554 }
555}