1use std::borrow::Cow as StdCow;
28use std::fmt;
29use std::sync::Arc;
30use std::time::Duration;
31
32use crate::rpc_service::RpcService;
33use crate::transport::{self, Error as TransportError, HttpBackend, HttpTransportClientBuilder};
34use crate::{HttpRequest, HttpResponse};
35use hyper::body::Bytes;
36use hyper::http::{Extensions, HeaderMap};
37use jsonrpsee_core::client::{
38 BatchResponse, ClientT, Error, IdKind, MiddlewareBatchResponse, MiddlewareMethodResponse, MiddlewareNotifResponse,
39 RequestIdManager, Subscription, SubscriptionClientT, generate_batch_id_range,
40};
41use jsonrpsee_core::middleware::layer::{RpcLogger, RpcLoggerLayer};
42use jsonrpsee_core::middleware::{Batch, RpcServiceBuilder, RpcServiceT};
43use jsonrpsee_core::params::BatchRequestBuilder;
44use jsonrpsee_core::traits::ToRpcParams;
45use jsonrpsee_core::{BoxError, TEN_MB_SIZE_BYTES};
46use jsonrpsee_types::{ErrorObject, InvalidRequestId, Notification, Request, ResponseSuccess, TwoPointZero};
47use serde::de::DeserializeOwned;
48use tokio::sync::Semaphore;
49use tower::layer::util::Identity;
50use tower::{Layer, Service};
51
52#[cfg(feature = "tls")]
53use crate::{CertificateStore, CustomCertStore};
54
55type Logger = tower::layer::util::Stack<RpcLoggerLayer, tower::layer::util::Identity>;
56
57#[derive(Clone, Debug)]
81pub struct HttpClientBuilder<HttpMiddleware = Identity, RpcMiddleware = Logger> {
82 max_request_size: u32,
83 max_response_size: u32,
84 request_timeout: Duration,
85 #[cfg(feature = "tls")]
86 certificate_store: CertificateStore,
87 id_kind: IdKind,
88 headers: HeaderMap,
89 service_builder: tower::ServiceBuilder<HttpMiddleware>,
90 rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
91 tcp_no_delay: bool,
92 max_concurrent_requests: Option<usize>,
93}
94
95impl<HttpMiddleware, RpcMiddleware> HttpClientBuilder<HttpMiddleware, RpcMiddleware> {
96 pub fn max_request_size(mut self, size: u32) -> Self {
98 self.max_request_size = size;
99 self
100 }
101
102 pub fn max_response_size(mut self, size: u32) -> Self {
104 self.max_response_size = size;
105 self
106 }
107
108 pub fn request_timeout(mut self, timeout: Duration) -> Self {
110 self.request_timeout = timeout;
111 self
112 }
113
114 pub fn max_concurrent_requests(mut self, max_concurrent_requests: usize) -> Self {
116 self.max_concurrent_requests = Some(max_concurrent_requests);
117 self
118 }
119
120 #[cfg(feature = "tls")]
186 pub fn with_custom_cert_store(mut self, cfg: CustomCertStore) -> Self {
187 self.certificate_store = CertificateStore::Custom(cfg);
188 self
189 }
190
191 pub fn id_format(mut self, id_kind: IdKind) -> Self {
193 self.id_kind = id_kind;
194 self
195 }
196
197 pub fn set_headers(mut self, headers: HeaderMap) -> Self {
201 self.headers = headers;
202 self
203 }
204
205 pub fn set_tcp_no_delay(mut self, no_delay: bool) -> Self {
209 self.tcp_no_delay = no_delay;
210 self
211 }
212
213 pub fn set_rpc_middleware<T>(self, rpc_builder: RpcServiceBuilder<T>) -> HttpClientBuilder<HttpMiddleware, T> {
215 HttpClientBuilder {
216 #[cfg(feature = "tls")]
217 certificate_store: self.certificate_store,
218 id_kind: self.id_kind,
219 headers: self.headers,
220 max_request_size: self.max_request_size,
221 max_response_size: self.max_response_size,
222 service_builder: self.service_builder,
223 rpc_middleware: rpc_builder,
224 request_timeout: self.request_timeout,
225 tcp_no_delay: self.tcp_no_delay,
226 max_concurrent_requests: self.max_concurrent_requests,
227 }
228 }
229
230 pub fn set_http_middleware<T>(
232 self,
233 service_builder: tower::ServiceBuilder<T>,
234 ) -> HttpClientBuilder<T, RpcMiddleware> {
235 HttpClientBuilder {
236 #[cfg(feature = "tls")]
237 certificate_store: self.certificate_store,
238 id_kind: self.id_kind,
239 headers: self.headers,
240 max_request_size: self.max_request_size,
241 max_response_size: self.max_response_size,
242 service_builder,
243 rpc_middleware: self.rpc_middleware,
244 request_timeout: self.request_timeout,
245 tcp_no_delay: self.tcp_no_delay,
246 max_concurrent_requests: self.max_concurrent_requests,
247 }
248 }
249}
250
251impl<B, S, S2, HttpMiddleware, RpcMiddleware> HttpClientBuilder<HttpMiddleware, RpcMiddleware>
252where
253 RpcMiddleware: Layer<RpcService<S>, Service = S2>,
254 <RpcMiddleware as Layer<RpcService<S>>>::Service: RpcServiceT,
255 HttpMiddleware: Layer<transport::HttpBackend, Service = S>,
256 S: Service<HttpRequest, Response = HttpResponse<B>, Error = TransportError> + Clone,
257 B: http_body::Body<Data = Bytes> + Send + Unpin + 'static,
258 B::Data: Send,
259 B::Error: Into<BoxError>,
260{
261 pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient<S2>, Error> {
263 let Self {
264 max_request_size,
265 max_response_size,
266 request_timeout,
267 #[cfg(feature = "tls")]
268 certificate_store,
269 id_kind,
270 headers,
271 service_builder,
272 tcp_no_delay,
273 rpc_middleware,
274 ..
275 } = self;
276
277 let http = HttpTransportClientBuilder {
278 max_request_size,
279 max_response_size,
280 headers,
281 tcp_no_delay,
282 service_builder,
283 #[cfg(feature = "tls")]
284 certificate_store,
285 }
286 .build(target)
287 .map_err(|e| Error::Transport(e.into()))?;
288
289 let request_guard = self
290 .max_concurrent_requests
291 .map(|max_concurrent_requests| Arc::new(Semaphore::new(max_concurrent_requests)));
292
293 Ok(HttpClient {
294 service: rpc_middleware.service(RpcService::new(http)),
295 id_manager: Arc::new(RequestIdManager::new(id_kind)),
296 request_guard,
297 request_timeout,
298 })
299 }
300}
301
302impl Default for HttpClientBuilder {
303 fn default() -> Self {
304 Self {
305 max_request_size: TEN_MB_SIZE_BYTES,
306 max_response_size: TEN_MB_SIZE_BYTES,
307 request_timeout: Duration::from_secs(60),
308 #[cfg(feature = "tls")]
309 certificate_store: CertificateStore::Native,
310 id_kind: IdKind::Number,
311 headers: HeaderMap::new(),
312 service_builder: tower::ServiceBuilder::new(),
313 rpc_middleware: RpcServiceBuilder::default().rpc_logger(1024),
314 tcp_no_delay: true,
315 max_concurrent_requests: None,
316 }
317 }
318}
319
320impl HttpClientBuilder {
321 pub fn new() -> HttpClientBuilder<Identity, Logger> {
323 HttpClientBuilder::default()
324 }
325}
326
327#[derive(Debug, Clone)]
329pub struct HttpClient<S = RpcLogger<RpcService<HttpBackend>>> {
330 service: S,
332 id_manager: Arc<RequestIdManager>,
334 request_guard: Option<Arc<Semaphore>>,
336 request_timeout: Duration,
338}
339
340impl HttpClient<HttpBackend> {
341 pub fn builder() -> HttpClientBuilder {
343 HttpClientBuilder::new()
344 }
345
346 pub fn request_timeout(&self) -> Duration {
348 self.request_timeout
349 }
350}
351
352impl<S> ClientT for HttpClient<S>
353where
354 S: RpcServiceT<
355 MethodResponse = Result<MiddlewareMethodResponse, Error>,
356 BatchResponse = Result<MiddlewareBatchResponse, Error>,
357 NotificationResponse = Result<MiddlewareNotifResponse, Error>,
358 > + Send
359 + Sync,
360{
361 fn notification<Params>(&self, method: &str, params: Params) -> impl Future<Output = Result<(), Error>> + Send
362 where
363 Params: ToRpcParams + Send,
364 {
365 async {
366 let _permit = match self.request_guard.as_ref() {
367 Some(permit) => permit.acquire().await.ok(),
368 None => None,
369 };
370 let params = params.to_rpc_params()?.map(StdCow::Owned);
371 let fut = self.service.notification(Notification::new(method.into(), params));
372
373 run_future_until_timeout(fut, self.request_timeout).await.map_err(|e| Error::Transport(e.into()))?;
374 Ok(())
375 }
376 }
377
378 fn request<R, Params>(&self, method: &str, params: Params) -> impl Future<Output = Result<R, Error>> + Send
379 where
380 R: DeserializeOwned,
381 Params: ToRpcParams + Send,
382 {
383 async {
384 let _permit = match self.request_guard.as_ref() {
385 Some(permit) => permit.acquire().await.ok(),
386 None => None,
387 };
388 let id = self.id_manager.next_request_id();
389 let params = params.to_rpc_params()?;
390
391 let method_response = run_future_until_timeout(
392 self.service.call(Request::borrowed(method, params.as_deref(), id.clone())),
393 self.request_timeout,
394 )
395 .await?
396 .into_response();
397
398 let rp = ResponseSuccess::try_from(method_response.into_inner())?;
399
400 let result = serde_json::from_str(rp.result.get()).map_err(Error::ParseError)?;
401 if rp.id == id { Ok(result) } else { Err(InvalidRequestId::NotPendingRequest(rp.id.to_string()).into()) }
402 }
403 }
404
405 fn batch_request<'a, R>(
406 &self,
407 batch: BatchRequestBuilder<'a>,
408 ) -> impl Future<Output = Result<jsonrpsee_core::client::BatchResponse<'a, R>, Error>> + Send
409 where
410 R: DeserializeOwned + fmt::Debug + 'a,
411 {
412 async {
413 let _permit = match self.request_guard.as_ref() {
414 Some(permit) => permit.acquire().await.ok(),
415 None => None,
416 };
417 let batch = batch.build()?;
418 let id = self.id_manager.next_request_id();
419 let id_range = generate_batch_id_range(id, batch.len() as u64)?;
420
421 let mut batch_request = Batch::with_capacity(batch.len());
422 for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
423 let id = self.id_manager.as_id_kind().into_id(id);
424 let req = Request {
425 jsonrpc: TwoPointZero,
426 method: method.into(),
427 params: params.map(StdCow::Owned),
428 id,
429 extensions: Extensions::new(),
430 };
431 batch_request.push(req);
432 }
433
434 let rps = run_future_until_timeout(self.service.batch(batch_request), self.request_timeout).await?;
435
436 let mut batch_response = Vec::new();
437 let mut success = 0;
438 let mut failed = 0;
439
440 for _ in 0..rps.len() {
442 batch_response.push(Err(ErrorObject::borrowed(0, "", None)));
443 }
444
445 for rp in rps.into_iter() {
446 let id = rp.id().try_parse_inner_as_number()?;
447
448 let res = match ResponseSuccess::try_from(rp.into_inner()) {
449 Ok(r) => {
450 let v = serde_json::from_str(r.result.get()).map_err(Error::ParseError)?;
451 success += 1;
452 Ok(v)
453 }
454 Err(err) => {
455 failed += 1;
456 Err(err)
457 }
458 };
459
460 let maybe_elem = id
461 .checked_sub(id_range.start)
462 .and_then(|p| p.try_into().ok())
463 .and_then(|p: usize| batch_response.get_mut(p));
464
465 if let Some(elem) = maybe_elem {
466 *elem = res;
467 } else {
468 return Err(InvalidRequestId::NotPendingRequest(id.to_string()).into());
469 }
470 }
471
472 Ok(BatchResponse::new(success, batch_response, failed))
473 }
474 }
475}
476
477impl<S> SubscriptionClientT for HttpClient<S>
478where
479 S: RpcServiceT<
480 MethodResponse = Result<MiddlewareMethodResponse, Error>,
481 BatchResponse = Result<MiddlewareBatchResponse, Error>,
482 NotificationResponse = Result<MiddlewareNotifResponse, Error>,
483 > + Send
484 + Sync,
485{
486 fn subscribe<'a, N, Params>(
489 &self,
490 _subscribe_method: &'a str,
491 _params: Params,
492 _unsubscribe_method: &'a str,
493 ) -> impl Future<Output = Result<Subscription<N>, Error>>
494 where
495 Params: ToRpcParams + Send,
496 N: DeserializeOwned,
497 {
498 async { Err(Error::HttpNotImplemented) }
499 }
500
501 fn subscribe_to_method<N>(&self, _method: &str) -> impl Future<Output = Result<Subscription<N>, Error>>
503 where
504 N: DeserializeOwned,
505 {
506 async { Err(Error::HttpNotImplemented) }
507 }
508}
509
510async fn run_future_until_timeout<F, T>(fut: F, timeout: Duration) -> Result<T, Error>
511where
512 F: std::future::Future<Output = Result<T, Error>>,
513{
514 match tokio::time::timeout(timeout, fut).await {
515 Ok(Ok(r)) => Ok(r),
516 Err(_) => Err(Error::RequestTimeout),
517 Ok(Err(e)) => Err(e),
518 }
519}