1pub(crate) mod agent_config;
3pub mod agent_error;
4pub(crate) mod builder;
5#[doc(hidden)]
7#[deprecated(since = "0.38.0", note = "use the AgentBuilder methods")]
8pub mod http_transport;
9pub(crate) mod nonce;
10pub(crate) mod response_authentication;
11pub mod route_provider;
12pub mod status;
13pub mod subnet;
14
15pub use agent_config::AgentConfig;
16pub use agent_error::AgentError;
17use agent_error::{HttpErrorPayload, Operation};
18use async_lock::Semaphore;
19use async_trait::async_trait;
20pub use builder::AgentBuilder;
21use bytes::Bytes;
22use cached::{Cached, TimedCache};
23use http::{header::CONTENT_TYPE, HeaderMap, Method, StatusCode, Uri};
24use ic_ed25519::{PublicKey, SignatureError};
25#[doc(inline)]
26pub use ic_transport_types::{
27 signed, CallResponse, Envelope, EnvelopeContent, RejectCode, RejectResponse, ReplyResponse,
28 RequestStatusResponse,
29};
30pub use nonce::{NonceFactory, NonceGenerator};
31use rangemap::{RangeInclusiveMap, StepFns};
32use reqwest::{Client, Request, Response};
33use route_provider::{
34 dynamic_routing::{
35 dynamic_route_provider::DynamicRouteProviderBuilder, node::Node,
36 snapshot::latency_based_routing::LatencyRoutingSnapshot,
37 },
38 RouteProvider, UrlUntilReady,
39};
40pub use subnet::Subnet;
41use time::OffsetDateTime;
42use tower_service::Service;
43
44#[cfg(test)]
45mod agent_test;
46
47use crate::{
48 agent::response_authentication::{
49 extract_der, lookup_canister_info, lookup_canister_metadata, lookup_canister_ranges,
50 lookup_incomplete_subnet, lookup_request_status, lookup_subnet_and_ranges,
51 lookup_subnet_canister_ranges, lookup_subnet_metrics, lookup_time, lookup_tree,
52 lookup_value,
53 },
54 agent_error::TransportError,
55 export::Principal,
56 identity::Identity,
57 to_request_id, RequestId,
58};
59use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
60use backoff::{exponential::ExponentialBackoff, SystemClock};
61use ic_certification::{Certificate, Delegation, Label};
62use ic_transport_types::{
63 signed::{SignedQuery, SignedRequestStatus, SignedUpdate},
64 QueryResponse, ReadStateResponse, SubnetMetrics, TransportCallResponse,
65};
66use serde::Serialize;
67use status::Status;
68use std::{
69 borrow::Cow,
70 convert::TryFrom,
71 fmt::{self, Debug},
72 future::{Future, IntoFuture},
73 pin::Pin,
74 str::FromStr,
75 sync::{Arc, Mutex, RwLock},
76 task::{Context, Poll},
77 time::Duration,
78};
79
80use crate::agent::response_authentication::lookup_api_boundary_nodes;
81
82const IC_STATE_ROOT_DOMAIN_SEPARATOR: &[u8; 14] = b"\x0Dic-state-root";
83
84const IC_ROOT_KEY: &[u8; 133] = b"\x30\x81\x82\x30\x1d\x06\x0d\x2b\x06\x01\x04\x01\x82\xdc\x7c\x05\x03\x01\x02\x01\x06\x0c\x2b\x06\x01\x04\x01\x82\xdc\x7c\x05\x03\x02\x01\x03\x61\x00\x81\x4c\x0e\x6e\xc7\x1f\xab\x58\x3b\x08\xbd\x81\x37\x3c\x25\x5c\x3c\x37\x1b\x2e\x84\x86\x3c\x98\xa4\xf1\xe0\x8b\x74\x23\x5d\x14\xfb\x5d\x9c\x0c\xd5\x46\xd9\x68\x5f\x91\x3a\x0c\x0b\x2c\xc5\x34\x15\x83\xbf\x4b\x43\x92\xe4\x67\xdb\x96\xd6\x5b\x9b\xb4\xcb\x71\x71\x12\xf8\x47\x2e\x0d\x5a\x4d\x14\x50\x5f\xfd\x74\x84\xb0\x12\x91\x09\x1c\x5f\x87\xb9\x88\x83\x46\x3f\x98\x09\x1a\x0b\xaa\xae";
85
86#[cfg(not(target_family = "wasm"))]
87type AgentFuture<'a, V> = Pin<Box<dyn Future<Output = Result<V, AgentError>> + Send + 'a>>;
88
89#[cfg(target_family = "wasm")]
90type AgentFuture<'a, V> = Pin<Box<dyn Future<Output = Result<V, AgentError>> + 'a>>;
91
92#[cfg_attr(unix, doc = " ```rust")] #[cfg_attr(not(unix), doc = " ```ignore")]
96#[derive(Clone)]
157pub struct Agent {
158 nonce_factory: Arc<dyn NonceGenerator>,
159 identity: Arc<dyn Identity>,
160 ingress_expiry: Duration,
161 root_key: Arc<RwLock<Vec<u8>>>,
162 client: Arc<dyn HttpService>,
163 route_provider: Arc<dyn RouteProvider>,
164 subnet_key_cache: Arc<Mutex<SubnetCache>>,
165 concurrent_requests_semaphore: Arc<Semaphore>,
166 verify_query_signatures: bool,
167 max_response_body_size: Option<usize>,
168 max_polling_time: Duration,
169 #[allow(dead_code)]
170 max_tcp_error_retries: usize,
171}
172
173impl fmt::Debug for Agent {
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
175 f.debug_struct("Agent")
176 .field("ingress_expiry", &self.ingress_expiry)
177 .finish_non_exhaustive()
178 }
179}
180
181impl Agent {
182 pub fn builder() -> builder::AgentBuilder {
185 Default::default()
186 }
187
188 pub fn new(config: agent_config::AgentConfig) -> Result<Agent, AgentError> {
190 let client = config.http_service.unwrap_or_else(|| {
191 Arc::new(Retry429Logic {
192 client: config.client.unwrap_or_else(|| {
193 #[cfg(not(target_family = "wasm"))]
194 {
195 Client::builder()
196 .use_rustls_tls()
197 .timeout(Duration::from_secs(360))
198 .build()
199 .expect("Could not create HTTP client.")
200 }
201 #[cfg(all(target_family = "wasm", feature = "wasm-bindgen"))]
202 {
203 Client::new()
204 }
205 }),
206 })
207 });
208 Ok(Agent {
209 nonce_factory: config.nonce_factory,
210 identity: config.identity,
211 ingress_expiry: config.ingress_expiry,
212 root_key: Arc::new(RwLock::new(IC_ROOT_KEY.to_vec())),
213 client: client.clone(),
214 route_provider: if let Some(route_provider) = config.route_provider {
215 route_provider
216 } else if let Some(url) = config.url {
217 if config.background_dynamic_routing {
218 assert!(
219 url.scheme() == "https" && url.path() == "/" && url.port().is_none() && url.domain().is_some(),
220 "in dynamic routing mode, URL must be in the exact form https://domain with no path, port, IP, or non-HTTPS scheme"
221 );
222 let seeds = vec![Node::new(url.domain().unwrap()).unwrap()];
223 UrlUntilReady::new(url, async move {
224 DynamicRouteProviderBuilder::new(
225 LatencyRoutingSnapshot::new(),
226 seeds,
227 client,
228 )
229 .build()
230 .await
231 }) as Arc<dyn RouteProvider>
232 } else {
233 Arc::new(url)
234 }
235 } else {
236 panic!("either route_provider or url must be specified");
237 },
238 subnet_key_cache: Arc::new(Mutex::new(SubnetCache::new())),
239 verify_query_signatures: config.verify_query_signatures,
240 concurrent_requests_semaphore: Arc::new(Semaphore::new(config.max_concurrent_requests)),
241 max_response_body_size: config.max_response_body_size,
242 max_tcp_error_retries: config.max_tcp_error_retries,
243 max_polling_time: config.max_polling_time,
244 })
245 }
246
247 pub fn set_identity<I>(&mut self, identity: I)
253 where
254 I: 'static + Identity,
255 {
256 self.identity = Arc::new(identity);
257 }
258 pub fn set_arc_identity(&mut self, identity: Arc<dyn Identity>) {
264 self.identity = identity;
265 }
266
267 pub async fn fetch_root_key(&self) -> Result<(), AgentError> {
276 if self.read_root_key()[..] != IC_ROOT_KEY[..] {
277 return Ok(());
279 }
280 let status = self.status().await?;
281 let Some(root_key) = status.root_key else {
282 return Err(AgentError::NoRootKeyInStatus(status));
283 };
284 self.set_root_key(root_key);
285 Ok(())
286 }
287
288 pub fn set_root_key(&self, root_key: Vec<u8>) {
293 *self.root_key.write().unwrap() = root_key;
294 }
295
296 pub fn read_root_key(&self) -> Vec<u8> {
298 self.root_key.read().unwrap().clone()
299 }
300
301 fn get_expiry_date(&self) -> u64 {
302 let expiry_raw = OffsetDateTime::now_utc() + self.ingress_expiry;
303 let mut rounded = expiry_raw.replace_nanosecond(0).unwrap();
304 if self.ingress_expiry.as_secs() > 90 {
305 rounded = rounded.replace_second(0).unwrap();
306 }
307 rounded.unix_timestamp_nanos().try_into().unwrap()
308 }
309
310 pub fn get_principal(&self) -> Result<Principal, String> {
312 self.identity.sender()
313 }
314
315 async fn query_endpoint<A>(
316 &self,
317 effective_canister_id: Principal,
318 serialized_bytes: Vec<u8>,
319 ) -> Result<A, AgentError>
320 where
321 A: serde::de::DeserializeOwned,
322 {
323 let _permit = self.concurrent_requests_semaphore.acquire().await;
324 let bytes = self
325 .execute(
326 Method::POST,
327 &format!("api/v3/canister/{}/query", effective_canister_id.to_text()),
328 Some(serialized_bytes),
329 )
330 .await?
331 .1;
332 serde_cbor::from_slice(&bytes).map_err(AgentError::InvalidCborData)
333 }
334
335 async fn read_state_endpoint<A>(
336 &self,
337 effective_canister_id: Principal,
338 serialized_bytes: Vec<u8>,
339 ) -> Result<A, AgentError>
340 where
341 A: serde::de::DeserializeOwned,
342 {
343 let _permit = self.concurrent_requests_semaphore.acquire().await;
344 let endpoint = format!(
345 "api/v3/canister/{}/read_state",
346 effective_canister_id.to_text()
347 );
348 let bytes = self
349 .execute(Method::POST, &endpoint, Some(serialized_bytes))
350 .await?
351 .1;
352 serde_cbor::from_slice(&bytes).map_err(AgentError::InvalidCborData)
353 }
354
355 async fn read_subnet_state_endpoint<A>(
356 &self,
357 subnet_id: Principal,
358 serialized_bytes: Vec<u8>,
359 ) -> Result<A, AgentError>
360 where
361 A: serde::de::DeserializeOwned,
362 {
363 let _permit = self.concurrent_requests_semaphore.acquire().await;
364 let endpoint = format!("api/v3/subnet/{}/read_state", subnet_id.to_text());
365 let bytes = self
366 .execute(Method::POST, &endpoint, Some(serialized_bytes))
367 .await?
368 .1;
369 serde_cbor::from_slice(&bytes).map_err(AgentError::InvalidCborData)
370 }
371
372 async fn call_endpoint(
373 &self,
374 effective_canister_id: Principal,
375 serialized_bytes: Vec<u8>,
376 ) -> Result<TransportCallResponse, AgentError> {
377 let _permit = self.concurrent_requests_semaphore.acquire().await;
378 let endpoint = format!("api/v4/canister/{}/call", effective_canister_id.to_text());
379 let (status_code, response_body) = self
380 .execute(Method::POST, &endpoint, Some(serialized_bytes))
381 .await?;
382
383 if status_code == StatusCode::ACCEPTED {
384 return Ok(TransportCallResponse::Accepted);
385 }
386
387 serde_cbor::from_slice(&response_body).map_err(AgentError::InvalidCborData)
388 }
389
390 #[allow(clippy::too_many_arguments)]
393 async fn query_raw(
394 &self,
395 canister_id: Principal,
396 effective_canister_id: Principal,
397 method_name: String,
398 arg: Vec<u8>,
399 ingress_expiry_datetime: Option<u64>,
400 use_nonce: bool,
401 explicit_verify_query_signatures: Option<bool>,
402 ) -> Result<Vec<u8>, AgentError> {
403 let operation = Operation::Call {
404 canister: canister_id,
405 method: method_name.clone(),
406 };
407 let content = self.query_content(
408 canister_id,
409 method_name,
410 arg,
411 ingress_expiry_datetime,
412 use_nonce,
413 )?;
414 let serialized_bytes = sign_envelope(&content, self.identity.clone())?;
415 self.query_inner(
416 effective_canister_id,
417 serialized_bytes,
418 content.to_request_id(),
419 explicit_verify_query_signatures,
420 operation,
421 )
422 .await
423 }
424
425 pub async fn query_signed(
429 &self,
430 effective_canister_id: Principal,
431 signed_query: Vec<u8>,
432 ) -> Result<Vec<u8>, AgentError> {
433 let envelope: Envelope =
434 serde_cbor::from_slice(&signed_query).map_err(AgentError::InvalidCborData)?;
435 let EnvelopeContent::Query {
436 canister_id,
437 method_name,
438 ..
439 } = &*envelope.content
440 else {
441 return Err(AgentError::CallDataMismatch {
442 field: "request_type".to_string(),
443 value_arg: "query".to_string(),
444 value_cbor: if matches!(*envelope.content, EnvelopeContent::Call { .. }) {
445 "update"
446 } else {
447 "read_state"
448 }
449 .to_string(),
450 });
451 };
452 let operation = Operation::Call {
453 canister: *canister_id,
454 method: method_name.clone(),
455 };
456 self.query_inner(
457 effective_canister_id,
458 signed_query,
459 envelope.content.to_request_id(),
460 None,
461 operation,
462 )
463 .await
464 }
465
466 async fn query_inner(
470 &self,
471 effective_canister_id: Principal,
472 signed_query: Vec<u8>,
473 request_id: RequestId,
474 explicit_verify_query_signatures: Option<bool>,
475 operation: Operation,
476 ) -> Result<Vec<u8>, AgentError> {
477 let response = if explicit_verify_query_signatures.unwrap_or(self.verify_query_signatures) {
478 let (response, mut subnet) = futures_util::try_join!(
479 self.query_endpoint::<QueryResponse>(effective_canister_id, signed_query),
480 self.get_subnet_by_canister(&effective_canister_id)
481 )?;
482 if response.signatures().is_empty() {
483 return Err(AgentError::MissingSignature);
484 } else if response.signatures().len() > subnet.node_keys.len() {
485 return Err(AgentError::TooManySignatures {
486 had: response.signatures().len(),
487 needed: subnet.node_keys.len(),
488 });
489 }
490 for signature in response.signatures() {
491 if OffsetDateTime::now_utc()
492 - OffsetDateTime::from_unix_timestamp_nanos(signature.timestamp.into()).unwrap()
493 > self.ingress_expiry
494 {
495 return Err(AgentError::CertificateOutdated(self.ingress_expiry));
496 }
497 let signable = response.signable(request_id, signature.timestamp);
498 let node_key = if let Some(node_key) = subnet.node_keys.get(&signature.identity) {
499 node_key
500 } else {
501 subnet = self
502 .fetch_subnet_by_canister(&effective_canister_id)
503 .await?;
504 subnet
505 .node_keys
506 .get(&signature.identity)
507 .ok_or(AgentError::CertificateNotAuthorized())?
508 };
509 if node_key.len() != 44 {
510 return Err(AgentError::DerKeyLengthMismatch {
511 expected: 44,
512 actual: node_key.len(),
513 });
514 }
515 const DER_PREFIX: [u8; 12] = [48, 42, 48, 5, 6, 3, 43, 101, 112, 3, 33, 0];
516 if node_key[..12] != DER_PREFIX {
517 return Err(AgentError::DerPrefixMismatch {
518 expected: DER_PREFIX.to_vec(),
519 actual: node_key[..12].to_vec(),
520 });
521 }
522 let pubkey = PublicKey::deserialize_raw(&node_key[12..])
523 .map_err(|_| AgentError::MalformedPublicKey)?;
524
525 match pubkey.verify_signature(&signable, &signature.signature[..]) {
526 Ok(()) => (),
527 Err(SignatureError::InvalidSignature) => {
528 return Err(AgentError::QuerySignatureVerificationFailed)
529 }
530 Err(SignatureError::InvalidLength) => {
531 return Err(AgentError::MalformedSignature)
532 }
533 _ => unreachable!(),
534 }
535 }
536 response
537 } else {
538 self.query_endpoint::<QueryResponse>(effective_canister_id, signed_query)
539 .await?
540 };
541
542 match response {
543 QueryResponse::Replied { reply, .. } => Ok(reply.arg),
544 QueryResponse::Rejected { reject, .. } => Err(AgentError::UncertifiedReject {
545 reject,
546 operation: Some(operation),
547 }),
548 }
549 }
550
551 fn query_content(
552 &self,
553 canister_id: Principal,
554 method_name: String,
555 arg: Vec<u8>,
556 ingress_expiry_datetime: Option<u64>,
557 use_nonce: bool,
558 ) -> Result<EnvelopeContent, AgentError> {
559 Ok(EnvelopeContent::Query {
560 sender: self.identity.sender().map_err(AgentError::SigningError)?,
561 canister_id,
562 method_name,
563 arg,
564 ingress_expiry: ingress_expiry_datetime.unwrap_or_else(|| self.get_expiry_date()),
565 nonce: use_nonce.then(|| self.nonce_factory.generate()).flatten(),
566 })
567 }
568
569 async fn update_raw(
571 &self,
572 canister_id: Principal,
573 effective_canister_id: Principal,
574 method_name: String,
575 arg: Vec<u8>,
576 ingress_expiry_datetime: Option<u64>,
577 ) -> Result<CallResponse<(Vec<u8>, Certificate)>, AgentError> {
578 let nonce = self.nonce_factory.generate();
579 let content = self.update_content(
580 canister_id,
581 method_name.clone(),
582 arg,
583 ingress_expiry_datetime,
584 nonce,
585 )?;
586 let operation = Some(Operation::Call {
587 canister: canister_id,
588 method: method_name,
589 });
590 let request_id = to_request_id(&content)?;
591 let serialized_bytes = sign_envelope(&content, self.identity.clone())?;
592
593 let response_body = self
594 .call_endpoint(effective_canister_id, serialized_bytes)
595 .await?;
596
597 match response_body {
598 TransportCallResponse::Replied { certificate } => {
599 let certificate =
600 serde_cbor::from_slice(&certificate).map_err(AgentError::InvalidCborData)?;
601
602 self.verify(&certificate, effective_canister_id)?;
603 let status = lookup_request_status(&certificate, &request_id)?;
604
605 match status {
606 RequestStatusResponse::Replied(reply) => {
607 Ok(CallResponse::Response((reply.arg, certificate)))
608 }
609 RequestStatusResponse::Rejected(reject_response) => {
610 Err(AgentError::CertifiedReject {
611 reject: reject_response,
612 operation,
613 })?
614 }
615 _ => Ok(CallResponse::Poll(request_id)),
616 }
617 }
618 TransportCallResponse::Accepted => Ok(CallResponse::Poll(request_id)),
619 TransportCallResponse::NonReplicatedRejection(reject_response) => {
620 Err(AgentError::UncertifiedReject {
621 reject: reject_response,
622 operation,
623 })
624 }
625 }
626 }
627
628 pub async fn update_signed(
632 &self,
633 effective_canister_id: Principal,
634 signed_update: Vec<u8>,
635 ) -> Result<CallResponse<Vec<u8>>, AgentError> {
636 let envelope: Envelope =
637 serde_cbor::from_slice(&signed_update).map_err(AgentError::InvalidCborData)?;
638 let EnvelopeContent::Call {
639 canister_id,
640 method_name,
641 ..
642 } = &*envelope.content
643 else {
644 return Err(AgentError::CallDataMismatch {
645 field: "request_type".to_string(),
646 value_arg: "update".to_string(),
647 value_cbor: if matches!(*envelope.content, EnvelopeContent::Query { .. }) {
648 "query"
649 } else {
650 "read_state"
651 }
652 .to_string(),
653 });
654 };
655 let operation = Some(Operation::Call {
656 canister: *canister_id,
657 method: method_name.clone(),
658 });
659 let request_id = to_request_id(&envelope.content)?;
660
661 let response_body = self
662 .call_endpoint(effective_canister_id, signed_update)
663 .await?;
664
665 match response_body {
666 TransportCallResponse::Replied { certificate } => {
667 let certificate =
668 serde_cbor::from_slice(&certificate).map_err(AgentError::InvalidCborData)?;
669
670 self.verify(&certificate, effective_canister_id)?;
671 let status = lookup_request_status(&certificate, &request_id)?;
672
673 match status {
674 RequestStatusResponse::Replied(reply) => Ok(CallResponse::Response(reply.arg)),
675 RequestStatusResponse::Rejected(reject_response) => {
676 Err(AgentError::CertifiedReject {
677 reject: reject_response,
678 operation,
679 })?
680 }
681 _ => Ok(CallResponse::Poll(request_id)),
682 }
683 }
684 TransportCallResponse::Accepted => Ok(CallResponse::Poll(request_id)),
685 TransportCallResponse::NonReplicatedRejection(reject_response) => {
686 Err(AgentError::UncertifiedReject {
687 reject: reject_response,
688 operation,
689 })
690 }
691 }
692 }
693
694 fn update_content(
695 &self,
696 canister_id: Principal,
697 method_name: String,
698 arg: Vec<u8>,
699 ingress_expiry_datetime: Option<u64>,
700 nonce: Option<Vec<u8>>,
701 ) -> Result<EnvelopeContent, AgentError> {
702 Ok(EnvelopeContent::Call {
703 canister_id,
704 method_name,
705 arg,
706 nonce,
707 sender: self.identity.sender().map_err(AgentError::SigningError)?,
708 ingress_expiry: ingress_expiry_datetime.unwrap_or_else(|| self.get_expiry_date()),
709 })
710 }
711
712 fn get_retry_policy(&self) -> ExponentialBackoff<SystemClock> {
713 ExponentialBackoffBuilder::new()
714 .with_initial_interval(Duration::from_millis(500))
715 .with_max_interval(Duration::from_secs(1))
716 .with_multiplier(1.4)
717 .with_max_elapsed_time(Some(self.max_polling_time))
718 .build()
719 }
720
721 pub async fn wait_signed(
723 &self,
724 request_id: &RequestId,
725 effective_canister_id: Principal,
726 signed_request_status: Vec<u8>,
727 ) -> Result<(Vec<u8>, Certificate), AgentError> {
728 let mut retry_policy = self.get_retry_policy();
729
730 let mut request_accepted = false;
731 loop {
732 let (resp, cert) = self
733 .request_status_signed(
734 request_id,
735 effective_canister_id,
736 signed_request_status.clone(),
737 )
738 .await?;
739 match resp {
740 RequestStatusResponse::Unknown => {}
741
742 RequestStatusResponse::Received | RequestStatusResponse::Processing => {
743 if !request_accepted {
744 retry_policy.reset();
745 request_accepted = true;
746 }
747 }
748
749 RequestStatusResponse::Replied(ReplyResponse { arg, .. }) => {
750 return Ok((arg, cert))
751 }
752
753 RequestStatusResponse::Rejected(response) => {
754 return Err(AgentError::CertifiedReject {
755 reject: response,
756 operation: None,
757 })
758 }
759
760 RequestStatusResponse::Done => {
761 return Err(AgentError::RequestStatusDoneNoReply(String::from(
762 *request_id,
763 )))
764 }
765 };
766
767 match retry_policy.next_backoff() {
768 Some(duration) => crate::util::sleep(duration).await,
769
770 None => return Err(AgentError::TimeoutWaitingForResponse()),
771 }
772 }
773 }
774
775 pub async fn wait(
777 &self,
778 request_id: &RequestId,
779 effective_canister_id: Principal,
780 ) -> Result<(Vec<u8>, Certificate), AgentError> {
781 self.wait_inner(request_id, effective_canister_id, None)
782 .await
783 }
784
785 async fn wait_inner(
786 &self,
787 request_id: &RequestId,
788 effective_canister_id: Principal,
789 operation: Option<Operation>,
790 ) -> Result<(Vec<u8>, Certificate), AgentError> {
791 let mut retry_policy = self.get_retry_policy();
792
793 let mut request_accepted = false;
794 loop {
795 let (resp, cert) = self
796 .request_status_raw(request_id, effective_canister_id)
797 .await?;
798 match resp {
799 RequestStatusResponse::Unknown => {}
800
801 RequestStatusResponse::Received | RequestStatusResponse::Processing => {
802 if !request_accepted {
803 retry_policy.reset();
811 request_accepted = true;
812 }
813 }
814
815 RequestStatusResponse::Replied(ReplyResponse { arg, .. }) => {
816 return Ok((arg, cert))
817 }
818
819 RequestStatusResponse::Rejected(response) => {
820 return Err(AgentError::CertifiedReject {
821 reject: response,
822 operation,
823 })
824 }
825
826 RequestStatusResponse::Done => {
827 return Err(AgentError::RequestStatusDoneNoReply(String::from(
828 *request_id,
829 )))
830 }
831 };
832
833 match retry_policy.next_backoff() {
834 Some(duration) => crate::util::sleep(duration).await,
835
836 None => return Err(AgentError::TimeoutWaitingForResponse()),
837 }
838 }
839 }
840
841 pub async fn read_state_raw(
844 &self,
845 paths: Vec<Vec<Label>>,
846 effective_canister_id: Principal,
847 ) -> Result<Certificate, AgentError> {
848 let content = self.read_state_content(paths)?;
849 let serialized_bytes = sign_envelope(&content, self.identity.clone())?;
850
851 let read_state_response: ReadStateResponse = self
852 .read_state_endpoint(effective_canister_id, serialized_bytes)
853 .await?;
854 let cert: Certificate = serde_cbor::from_slice(&read_state_response.certificate)
855 .map_err(AgentError::InvalidCborData)?;
856 self.verify(&cert, effective_canister_id)?;
857 Ok(cert)
858 }
859
860 pub async fn read_subnet_state_raw(
863 &self,
864 paths: Vec<Vec<Label>>,
865 subnet_id: Principal,
866 ) -> Result<Certificate, AgentError> {
867 let content = self.read_state_content(paths)?;
868 let serialized_bytes = sign_envelope(&content, self.identity.clone())?;
869
870 let read_state_response: ReadStateResponse = self
871 .read_subnet_state_endpoint(subnet_id, serialized_bytes)
872 .await?;
873 let cert: Certificate = serde_cbor::from_slice(&read_state_response.certificate)
874 .map_err(AgentError::InvalidCborData)?;
875 self.verify_for_subnet(&cert, subnet_id)?;
876 Ok(cert)
877 }
878
879 fn read_state_content(&self, paths: Vec<Vec<Label>>) -> Result<EnvelopeContent, AgentError> {
880 Ok(EnvelopeContent::ReadState {
881 sender: self.identity.sender().map_err(AgentError::SigningError)?,
882 paths,
883 ingress_expiry: self.get_expiry_date(),
884 })
885 }
886
887 pub fn verify(
890 &self,
891 cert: &Certificate,
892 effective_canister_id: Principal,
893 ) -> Result<(), AgentError> {
894 self.verify_cert(cert, effective_canister_id)?;
895 self.verify_cert_timestamp(cert)?;
896 Ok(())
897 }
898
899 fn verify_cert(
900 &self,
901 cert: &Certificate,
902 effective_canister_id: Principal,
903 ) -> Result<(), AgentError> {
904 let sig = &cert.signature;
905
906 let root_hash = cert.tree.digest();
907 let mut msg = vec![];
908 msg.extend_from_slice(IC_STATE_ROOT_DOMAIN_SEPARATOR);
909 msg.extend_from_slice(&root_hash);
910
911 let der_key = self.check_delegation(&cert.delegation, effective_canister_id)?;
912 let key = extract_der(der_key)?;
913
914 ic_verify_bls_signature::verify_bls_signature(sig, &msg, &key)
915 .map_err(|_| AgentError::CertificateVerificationFailed())?;
916 Ok(())
917 }
918
919 pub fn verify_for_subnet(
922 &self,
923 cert: &Certificate,
924 subnet_id: Principal,
925 ) -> Result<(), AgentError> {
926 self.verify_cert_for_subnet(cert, subnet_id)?;
927 self.verify_cert_timestamp(cert)?;
928 Ok(())
929 }
930
931 fn verify_cert_for_subnet(
932 &self,
933 cert: &Certificate,
934 subnet_id: Principal,
935 ) -> Result<(), AgentError> {
936 let sig = &cert.signature;
937
938 let root_hash = cert.tree.digest();
939 let mut msg = vec![];
940 msg.extend_from_slice(IC_STATE_ROOT_DOMAIN_SEPARATOR);
941 msg.extend_from_slice(&root_hash);
942
943 let der_key = self.check_delegation_for_subnet(&cert.delegation, subnet_id)?;
944 let key = extract_der(der_key)?;
945
946 ic_verify_bls_signature::verify_bls_signature(sig, &msg, &key)
947 .map_err(|_| AgentError::CertificateVerificationFailed())?;
948 Ok(())
949 }
950
951 fn verify_cert_timestamp(&self, cert: &Certificate) -> Result<(), AgentError> {
952 let time = lookup_time(cert)?;
953 if (OffsetDateTime::now_utc()
954 - OffsetDateTime::from_unix_timestamp_nanos(time.into()).unwrap())
955 .abs()
956 > self.ingress_expiry
957 {
958 Err(AgentError::CertificateOutdated(self.ingress_expiry))
959 } else {
960 Ok(())
961 }
962 }
963
964 fn check_delegation(
965 &self,
966 delegation: &Option<Delegation>,
967 effective_canister_id: Principal,
968 ) -> Result<Vec<u8>, AgentError> {
969 match delegation {
970 None => Ok(self.read_root_key()),
971 Some(delegation) => {
972 let cert: Certificate = serde_cbor::from_slice(&delegation.certificate)
973 .map_err(AgentError::InvalidCborData)?;
974 if cert.delegation.is_some() {
975 return Err(AgentError::CertificateHasTooManyDelegations);
976 }
977 self.verify_cert(&cert, effective_canister_id)?;
978 let canister_range_shards_lookup =
979 ["canister_ranges".as_bytes(), delegation.subnet_id.as_ref()];
980 let canister_range_shards = lookup_tree(&cert.tree, canister_range_shards_lookup)?;
981 let mut shard_paths = canister_range_shards
982 .list_paths() .into_iter()
984 .map(|mut x| {
985 x.pop() .ok_or_else(AgentError::CertificateVerificationFailed)
987 })
988 .collect::<Result<Vec<_>, _>>()?;
989 if shard_paths.is_empty() {
990 return Err(AgentError::CertificateNotAuthorized());
991 }
992 shard_paths.sort_unstable();
993 let shard_division = shard_paths
994 .partition_point(|shard| shard.as_bytes() <= effective_canister_id.as_slice());
995 if shard_division == 0 {
996 return Err(AgentError::CertificateNotAuthorized());
998 }
999 let max_potential_shard = &shard_paths[shard_division - 1];
1000 let canister_range_lookup = [max_potential_shard.as_bytes()];
1001 let canister_range = lookup_value(&canister_range_shards, canister_range_lookup)?;
1002 let ranges: Vec<(Principal, Principal)> =
1003 serde_cbor::from_slice(canister_range).map_err(AgentError::InvalidCborData)?;
1004 if !principal_is_within_ranges(&effective_canister_id, &ranges[..]) {
1005 return Err(AgentError::CertificateNotAuthorized());
1007 }
1008
1009 let public_key_path = [
1010 "subnet".as_bytes(),
1011 delegation.subnet_id.as_ref(),
1012 "public_key".as_bytes(),
1013 ];
1014 lookup_value(&cert.tree, public_key_path).map(<[u8]>::to_vec)
1015 }
1016 }
1017 }
1018
1019 fn check_delegation_for_subnet(
1020 &self,
1021 delegation: &Option<Delegation>,
1022 subnet_id: Principal,
1023 ) -> Result<Vec<u8>, AgentError> {
1024 match delegation {
1025 None => Ok(self.read_root_key()),
1026 Some(delegation) => {
1027 let cert: Certificate = serde_cbor::from_slice(&delegation.certificate)
1028 .map_err(AgentError::InvalidCborData)?;
1029 if cert.delegation.is_some() {
1030 return Err(AgentError::CertificateHasTooManyDelegations);
1031 }
1032 self.verify_cert_for_subnet(&cert, subnet_id)?;
1033 let public_key_path = [
1034 "subnet".as_bytes(),
1035 subnet_id.as_ref(),
1036 "public_key".as_bytes(),
1037 ];
1038 let pk = lookup_value(&cert.tree, public_key_path)
1039 .map_err(|_| AgentError::CertificateNotAuthorized())?
1040 .to_vec();
1041 Ok(pk)
1042 }
1043 }
1044 }
1045
1046 pub async fn read_state_canister_info(
1049 &self,
1050 canister_id: Principal,
1051 path: &str,
1052 ) -> Result<Vec<u8>, AgentError> {
1053 let paths: Vec<Vec<Label>> = vec![vec![
1054 "canister".into(),
1055 Label::from_bytes(canister_id.as_slice()),
1056 path.into(),
1057 ]];
1058
1059 let cert = self.read_state_raw(paths, canister_id).await?;
1060
1061 lookup_canister_info(cert, canister_id, path)
1062 }
1063
1064 pub async fn read_state_canister_controllers(
1066 &self,
1067 canister_id: Principal,
1068 ) -> Result<Vec<Principal>, AgentError> {
1069 let blob = self
1070 .read_state_canister_info(canister_id, "controllers")
1071 .await?;
1072 let controllers: Vec<Principal> =
1073 serde_cbor::from_slice(&blob).map_err(AgentError::InvalidCborData)?;
1074 Ok(controllers)
1075 }
1076
1077 pub async fn read_state_canister_module_hash(
1079 &self,
1080 canister_id: Principal,
1081 ) -> Result<Vec<u8>, AgentError> {
1082 self.read_state_canister_info(canister_id, "module_hash")
1083 .await
1084 }
1085
1086 pub async fn read_state_canister_metadata(
1088 &self,
1089 canister_id: Principal,
1090 path: &str,
1091 ) -> Result<Vec<u8>, AgentError> {
1092 let paths: Vec<Vec<Label>> = vec![vec![
1093 "canister".into(),
1094 Label::from_bytes(canister_id.as_slice()),
1095 "metadata".into(),
1096 path.into(),
1097 ]];
1098
1099 let cert = self.read_state_raw(paths, canister_id).await?;
1100
1101 lookup_canister_metadata(cert, canister_id, path)
1102 }
1103
1104 pub async fn read_state_subnet_metrics(
1106 &self,
1107 subnet_id: Principal,
1108 ) -> Result<SubnetMetrics, AgentError> {
1109 let paths = vec![vec![
1110 "subnet".into(),
1111 Label::from_bytes(subnet_id.as_slice()),
1112 "metrics".into(),
1113 ]];
1114 let cert = self.read_subnet_state_raw(paths, subnet_id).await?;
1115 lookup_subnet_metrics(cert, subnet_id)
1116 }
1117
1118 pub async fn read_state_subnet_canister_ranges(
1120 &self,
1121 subnet_id: Principal,
1122 ) -> Result<Vec<(Principal, Principal)>, AgentError> {
1123 let paths = vec![vec![
1124 "subnet".into(),
1125 Label::from_bytes(subnet_id.as_slice()),
1126 "canister_ranges".into(),
1127 ]];
1128 let cert = self.read_subnet_state_raw(paths, subnet_id).await?;
1129 lookup_subnet_canister_ranges(&cert, subnet_id)
1130 }
1131
1132 pub async fn request_status_raw(
1134 &self,
1135 request_id: &RequestId,
1136 effective_canister_id: Principal,
1137 ) -> Result<(RequestStatusResponse, Certificate), AgentError> {
1138 let paths: Vec<Vec<Label>> =
1139 vec![vec!["request_status".into(), request_id.to_vec().into()]];
1140
1141 let cert = self.read_state_raw(paths, effective_canister_id).await?;
1142
1143 Ok((lookup_request_status(&cert, request_id)?, cert))
1144 }
1145
1146 pub async fn request_status_signed(
1150 &self,
1151 request_id: &RequestId,
1152 effective_canister_id: Principal,
1153 signed_request_status: Vec<u8>,
1154 ) -> Result<(RequestStatusResponse, Certificate), AgentError> {
1155 let _envelope: Envelope =
1156 serde_cbor::from_slice(&signed_request_status).map_err(AgentError::InvalidCborData)?;
1157 let read_state_response: ReadStateResponse = self
1158 .read_state_endpoint(effective_canister_id, signed_request_status)
1159 .await?;
1160
1161 let cert: Certificate = serde_cbor::from_slice(&read_state_response.certificate)
1162 .map_err(AgentError::InvalidCborData)?;
1163 self.verify(&cert, effective_canister_id)?;
1164 Ok((lookup_request_status(&cert, request_id)?, cert))
1165 }
1166
1167 pub fn update<S: Into<String>>(
1170 &self,
1171 canister_id: &Principal,
1172 method_name: S,
1173 ) -> UpdateBuilder<'_> {
1174 UpdateBuilder::new(self, *canister_id, method_name.into())
1175 }
1176
1177 pub async fn status(&self) -> Result<Status, AgentError> {
1179 let endpoint = "api/v2/status";
1180 let bytes = self.execute(Method::GET, endpoint, None).await?.1;
1181
1182 let cbor: serde_cbor::Value =
1183 serde_cbor::from_slice(&bytes).map_err(AgentError::InvalidCborData)?;
1184
1185 Status::try_from(&cbor).map_err(|_| AgentError::InvalidReplicaStatus)
1186 }
1187
1188 pub fn query<S: Into<String>>(
1191 &self,
1192 canister_id: &Principal,
1193 method_name: S,
1194 ) -> QueryBuilder<'_> {
1195 QueryBuilder::new(self, *canister_id, method_name.into())
1196 }
1197
1198 pub fn sign_request_status(
1201 &self,
1202 effective_canister_id: Principal,
1203 request_id: RequestId,
1204 ) -> Result<SignedRequestStatus, AgentError> {
1205 let paths: Vec<Vec<Label>> =
1206 vec![vec!["request_status".into(), request_id.to_vec().into()]];
1207 let read_state_content = self.read_state_content(paths)?;
1208 let signed_request_status = sign_envelope(&read_state_content, self.identity.clone())?;
1209 let ingress_expiry = read_state_content.ingress_expiry();
1210 let sender = *read_state_content.sender();
1211 Ok(SignedRequestStatus {
1212 ingress_expiry,
1213 sender,
1214 effective_canister_id,
1215 request_id,
1216 signed_request_status,
1217 })
1218 }
1219
1220 pub async fn get_subnet_by_canister(
1223 &self,
1224 canister: &Principal,
1225 ) -> Result<Arc<Subnet>, AgentError> {
1226 let subnet = self
1227 .subnet_key_cache
1228 .lock()
1229 .unwrap()
1230 .get_subnet_by_canister(canister);
1231 if let Some(subnet) = subnet {
1232 Ok(subnet)
1233 } else {
1234 self.fetch_subnet_by_canister(canister).await
1235 }
1236 }
1237
1238 pub async fn get_subnet_by_id(&self, subnet_id: &Principal) -> Result<Arc<Subnet>, AgentError> {
1241 let subnet = self
1242 .subnet_key_cache
1243 .lock()
1244 .unwrap()
1245 .get_subnet_by_id(subnet_id);
1246 if let Some(subnet) = subnet {
1247 Ok(subnet)
1248 } else {
1249 self.fetch_subnet_by_id(subnet_id).await
1250 }
1251 }
1252
1253 pub async fn fetch_api_boundary_nodes_by_canister_id(
1255 &self,
1256 canister_id: Principal,
1257 ) -> Result<Vec<ApiBoundaryNode>, AgentError> {
1258 let paths = vec![vec!["api_boundary_nodes".into()]];
1259 let certificate = self.read_state_raw(paths, canister_id).await?;
1260 let api_boundary_nodes = lookup_api_boundary_nodes(certificate)?;
1261 Ok(api_boundary_nodes)
1262 }
1263
1264 pub async fn fetch_api_boundary_nodes_by_subnet_id(
1266 &self,
1267 subnet_id: Principal,
1268 ) -> Result<Vec<ApiBoundaryNode>, AgentError> {
1269 let paths = vec![vec!["api_boundary_nodes".into()]];
1270 let certificate = self.read_subnet_state_raw(paths, subnet_id).await?;
1271 let api_boundary_nodes = lookup_api_boundary_nodes(certificate)?;
1272 Ok(api_boundary_nodes)
1273 }
1274
1275 pub async fn fetch_subnet_by_canister(
1280 &self,
1281 canister: &Principal,
1282 ) -> Result<Arc<Subnet>, AgentError> {
1283 let canister_cert = self
1284 .read_state_raw(vec![vec!["subnet".into()]], *canister)
1285 .await?;
1286 let subnet_id = if let Some(delegation) = canister_cert.delegation.as_ref() {
1287 Principal::from_slice(&delegation.subnet_id)
1288 } else {
1289 Principal::self_authenticating(&self.root_key.read().unwrap()[..])
1291 };
1292 let mut subnet = lookup_incomplete_subnet(&subnet_id, &canister_cert)?;
1293 let canister_ranges = if let Some(delegation) = canister_cert.delegation.as_ref() {
1294 let delegation_cert: Certificate = serde_cbor::from_slice(&delegation.certificate)?;
1296 lookup_canister_ranges(&subnet_id, &delegation_cert)?
1297 } else {
1298 lookup_canister_ranges(&subnet_id, &canister_cert)?
1299 };
1300 subnet.canister_ranges = canister_ranges;
1301 if !subnet.canister_ranges.contains(canister) {
1302 return Err(AgentError::CertificateNotAuthorized());
1303 }
1304 let subnet = Arc::new(subnet);
1305 self.subnet_key_cache
1306 .lock()
1307 .unwrap()
1308 .insert_subnet(subnet_id, subnet.clone());
1309 Ok(subnet)
1310 }
1311
1312 pub async fn fetch_subnet_by_id(
1317 &self,
1318 subnet_id: &Principal,
1319 ) -> Result<Arc<Subnet>, AgentError> {
1320 let subnet_cert = self
1321 .read_subnet_state_raw(
1322 vec![
1323 vec!["canister_ranges".into(), subnet_id.as_slice().into()],
1324 vec!["subnet".into(), subnet_id.as_slice().into()],
1325 ],
1326 *subnet_id,
1327 )
1328 .await?;
1329 let subnet = lookup_subnet_and_ranges(subnet_id, &subnet_cert)?;
1330 let subnet = Arc::new(subnet);
1331 self.subnet_key_cache
1332 .lock()
1333 .unwrap()
1334 .insert_subnet(*subnet_id, subnet.clone());
1335 Ok(subnet)
1336 }
1337
1338 async fn request(
1339 &self,
1340 method: Method,
1341 endpoint: &str,
1342 body: Option<Vec<u8>>,
1343 ) -> Result<(StatusCode, HeaderMap, Vec<u8>), AgentError> {
1344 let body = body.map(Bytes::from);
1345
1346 let create_request_with_generated_url = || -> Result<http::Request<Bytes>, AgentError> {
1347 let url = self.route_provider.route()?.join(endpoint)?;
1348 let uri = Uri::from_str(url.as_str())
1349 .map_err(|e| AgentError::InvalidReplicaUrl(e.to_string()))?;
1350 let body = body.clone().unwrap_or_default();
1351 let request = http::Request::builder()
1352 .method(method.clone())
1353 .uri(uri)
1354 .header(CONTENT_TYPE, "application/cbor")
1355 .body(body)
1356 .map_err(|e| {
1357 AgentError::TransportError(TransportError::Generic(format!(
1358 "unable to create request: {e:#}"
1359 )))
1360 })?;
1361
1362 Ok(request)
1363 };
1364
1365 let response = self
1366 .client
1367 .call(
1368 &create_request_with_generated_url,
1369 self.max_tcp_error_retries,
1370 self.max_response_body_size,
1371 )
1372 .await?;
1373
1374 let (parts, body) = response.into_parts();
1375
1376 Ok((parts.status, parts.headers, body.to_vec()))
1377 }
1378
1379 async fn execute(
1380 &self,
1381 method: Method,
1382 endpoint: &str,
1383 body: Option<Vec<u8>>,
1384 ) -> Result<(StatusCode, Vec<u8>), AgentError> {
1385 let request_result = self.request(method.clone(), endpoint, body.clone()).await?;
1386
1387 let status = request_result.0;
1388 let headers = request_result.1;
1389 let body = request_result.2;
1390
1391 if status.is_client_error() || status.is_server_error() {
1392 Err(AgentError::HttpError(HttpErrorPayload {
1393 status: status.into(),
1394 content_type: headers
1395 .get(CONTENT_TYPE)
1396 .and_then(|value| value.to_str().ok())
1397 .map(str::to_string),
1398 content: body,
1399 }))
1400 } else if !(status == StatusCode::OK || status == StatusCode::ACCEPTED) {
1401 Err(AgentError::InvalidHttpResponse(format!(
1402 "Expected `200`, `202`, 4xx`, or `5xx` HTTP status code. Got: {status}",
1403 )))
1404 } else {
1405 Ok((status, body))
1406 }
1407 }
1408}
1409
1410fn principal_is_within_ranges(principal: &Principal, ranges: &[(Principal, Principal)]) -> bool {
1413 ranges
1414 .iter()
1415 .any(|r| principal >= &r.0 && principal <= &r.1)
1416}
1417
1418fn sign_envelope(
1419 content: &EnvelopeContent,
1420 identity: Arc<dyn Identity>,
1421) -> Result<Vec<u8>, AgentError> {
1422 let signature = identity.sign(content).map_err(AgentError::SigningError)?;
1423
1424 let envelope = Envelope {
1425 content: Cow::Borrowed(content),
1426 sender_pubkey: signature.public_key,
1427 sender_sig: signature.signature,
1428 sender_delegation: signature.delegations,
1429 };
1430
1431 let mut serialized_bytes = Vec::new();
1432 let mut serializer = serde_cbor::Serializer::new(&mut serialized_bytes);
1433 serializer.self_describe()?;
1434 envelope.serialize(&mut serializer)?;
1435
1436 Ok(serialized_bytes)
1437}
1438
1439pub fn signed_query_inspect(
1442 sender: Principal,
1443 canister_id: Principal,
1444 method_name: &str,
1445 arg: &[u8],
1446 ingress_expiry: u64,
1447 signed_query: Vec<u8>,
1448) -> Result<(), AgentError> {
1449 let envelope: Envelope =
1450 serde_cbor::from_slice(&signed_query).map_err(AgentError::InvalidCborData)?;
1451 match envelope.content.as_ref() {
1452 EnvelopeContent::Query {
1453 ingress_expiry: ingress_expiry_cbor,
1454 sender: sender_cbor,
1455 canister_id: canister_id_cbor,
1456 method_name: method_name_cbor,
1457 arg: arg_cbor,
1458 nonce: _nonce,
1459 } => {
1460 if ingress_expiry != *ingress_expiry_cbor {
1461 return Err(AgentError::CallDataMismatch {
1462 field: "ingress_expiry".to_string(),
1463 value_arg: ingress_expiry.to_string(),
1464 value_cbor: ingress_expiry_cbor.to_string(),
1465 });
1466 }
1467 if sender != *sender_cbor {
1468 return Err(AgentError::CallDataMismatch {
1469 field: "sender".to_string(),
1470 value_arg: sender.to_string(),
1471 value_cbor: sender_cbor.to_string(),
1472 });
1473 }
1474 if canister_id != *canister_id_cbor {
1475 return Err(AgentError::CallDataMismatch {
1476 field: "canister_id".to_string(),
1477 value_arg: canister_id.to_string(),
1478 value_cbor: canister_id_cbor.to_string(),
1479 });
1480 }
1481 if method_name != *method_name_cbor {
1482 return Err(AgentError::CallDataMismatch {
1483 field: "method_name".to_string(),
1484 value_arg: method_name.to_string(),
1485 value_cbor: method_name_cbor.clone(),
1486 });
1487 }
1488 if arg != *arg_cbor {
1489 return Err(AgentError::CallDataMismatch {
1490 field: "arg".to_string(),
1491 value_arg: format!("{arg:?}"),
1492 value_cbor: format!("{arg_cbor:?}"),
1493 });
1494 }
1495 }
1496 EnvelopeContent::Call { .. } => {
1497 return Err(AgentError::CallDataMismatch {
1498 field: "request_type".to_string(),
1499 value_arg: "query".to_string(),
1500 value_cbor: "call".to_string(),
1501 })
1502 }
1503 EnvelopeContent::ReadState { .. } => {
1504 return Err(AgentError::CallDataMismatch {
1505 field: "request_type".to_string(),
1506 value_arg: "query".to_string(),
1507 value_cbor: "read_state".to_string(),
1508 })
1509 }
1510 }
1511 Ok(())
1512}
1513
1514pub fn signed_update_inspect(
1517 sender: Principal,
1518 canister_id: Principal,
1519 method_name: &str,
1520 arg: &[u8],
1521 ingress_expiry: u64,
1522 signed_update: Vec<u8>,
1523) -> Result<(), AgentError> {
1524 let envelope: Envelope =
1525 serde_cbor::from_slice(&signed_update).map_err(AgentError::InvalidCborData)?;
1526 match envelope.content.as_ref() {
1527 EnvelopeContent::Call {
1528 nonce: _nonce,
1529 ingress_expiry: ingress_expiry_cbor,
1530 sender: sender_cbor,
1531 canister_id: canister_id_cbor,
1532 method_name: method_name_cbor,
1533 arg: arg_cbor,
1534 } => {
1535 if ingress_expiry != *ingress_expiry_cbor {
1536 return Err(AgentError::CallDataMismatch {
1537 field: "ingress_expiry".to_string(),
1538 value_arg: ingress_expiry.to_string(),
1539 value_cbor: ingress_expiry_cbor.to_string(),
1540 });
1541 }
1542 if sender != *sender_cbor {
1543 return Err(AgentError::CallDataMismatch {
1544 field: "sender".to_string(),
1545 value_arg: sender.to_string(),
1546 value_cbor: sender_cbor.to_string(),
1547 });
1548 }
1549 if canister_id != *canister_id_cbor {
1550 return Err(AgentError::CallDataMismatch {
1551 field: "canister_id".to_string(),
1552 value_arg: canister_id.to_string(),
1553 value_cbor: canister_id_cbor.to_string(),
1554 });
1555 }
1556 if method_name != *method_name_cbor {
1557 return Err(AgentError::CallDataMismatch {
1558 field: "method_name".to_string(),
1559 value_arg: method_name.to_string(),
1560 value_cbor: method_name_cbor.clone(),
1561 });
1562 }
1563 if arg != *arg_cbor {
1564 return Err(AgentError::CallDataMismatch {
1565 field: "arg".to_string(),
1566 value_arg: format!("{arg:?}"),
1567 value_cbor: format!("{arg_cbor:?}"),
1568 });
1569 }
1570 }
1571 EnvelopeContent::ReadState { .. } => {
1572 return Err(AgentError::CallDataMismatch {
1573 field: "request_type".to_string(),
1574 value_arg: "call".to_string(),
1575 value_cbor: "read_state".to_string(),
1576 })
1577 }
1578 EnvelopeContent::Query { .. } => {
1579 return Err(AgentError::CallDataMismatch {
1580 field: "request_type".to_string(),
1581 value_arg: "call".to_string(),
1582 value_cbor: "query".to_string(),
1583 })
1584 }
1585 }
1586 Ok(())
1587}
1588
1589pub fn signed_request_status_inspect(
1592 sender: Principal,
1593 request_id: &RequestId,
1594 ingress_expiry: u64,
1595 signed_request_status: Vec<u8>,
1596) -> Result<(), AgentError> {
1597 let paths: Vec<Vec<Label>> = vec![vec!["request_status".into(), request_id.to_vec().into()]];
1598 let envelope: Envelope =
1599 serde_cbor::from_slice(&signed_request_status).map_err(AgentError::InvalidCborData)?;
1600 match envelope.content.as_ref() {
1601 EnvelopeContent::ReadState {
1602 ingress_expiry: ingress_expiry_cbor,
1603 sender: sender_cbor,
1604 paths: paths_cbor,
1605 } => {
1606 if ingress_expiry != *ingress_expiry_cbor {
1607 return Err(AgentError::CallDataMismatch {
1608 field: "ingress_expiry".to_string(),
1609 value_arg: ingress_expiry.to_string(),
1610 value_cbor: ingress_expiry_cbor.to_string(),
1611 });
1612 }
1613 if sender != *sender_cbor {
1614 return Err(AgentError::CallDataMismatch {
1615 field: "sender".to_string(),
1616 value_arg: sender.to_string(),
1617 value_cbor: sender_cbor.to_string(),
1618 });
1619 }
1620
1621 if paths != *paths_cbor {
1622 return Err(AgentError::CallDataMismatch {
1623 field: "paths".to_string(),
1624 value_arg: format!("{paths:?}"),
1625 value_cbor: format!("{paths_cbor:?}"),
1626 });
1627 }
1628 }
1629 EnvelopeContent::Query { .. } => {
1630 return Err(AgentError::CallDataMismatch {
1631 field: "request_type".to_string(),
1632 value_arg: "read_state".to_string(),
1633 value_cbor: "query".to_string(),
1634 })
1635 }
1636 EnvelopeContent::Call { .. } => {
1637 return Err(AgentError::CallDataMismatch {
1638 field: "request_type".to_string(),
1639 value_arg: "read_state".to_string(),
1640 value_cbor: "call".to_string(),
1641 })
1642 }
1643 }
1644 Ok(())
1645}
1646
1647#[derive(Clone)]
1648struct SubnetCache {
1649 subnets: TimedCache<Principal, Arc<Subnet>>,
1650 canister_index: RangeInclusiveMap<Principal, Principal, PrincipalStep>,
1651}
1652
1653impl SubnetCache {
1654 fn new() -> Self {
1655 Self {
1656 subnets: TimedCache::with_lifespan(Duration::from_secs(300)),
1657 canister_index: RangeInclusiveMap::new_with_step_fns(),
1658 }
1659 }
1660
1661 fn get_subnet_by_canister(&mut self, canister: &Principal) -> Option<Arc<Subnet>> {
1662 self.canister_index
1663 .get(canister)
1664 .and_then(|subnet_id| self.subnets.cache_get(subnet_id).cloned())
1665 .filter(|subnet| subnet.canister_ranges.contains(canister))
1666 }
1667
1668 fn get_subnet_by_id(&mut self, subnet_id: &Principal) -> Option<Arc<Subnet>> {
1669 self.subnets.cache_get(subnet_id).cloned()
1670 }
1671
1672 fn insert_subnet(&mut self, subnet_id: Principal, subnet: Arc<Subnet>) {
1673 self.subnets.cache_set(subnet_id, subnet.clone());
1674 for range in subnet.canister_ranges.iter() {
1675 self.canister_index.insert(range.clone(), subnet_id);
1676 }
1677 }
1678}
1679
1680#[derive(Clone, Copy)]
1681pub(crate) struct PrincipalStep;
1682
1683impl StepFns<Principal> for PrincipalStep {
1684 fn add_one(start: &Principal) -> Principal {
1685 let bytes = start.as_slice();
1686 let mut arr = [0; 29];
1687 arr[..bytes.len()].copy_from_slice(bytes);
1688 for byte in arr[..bytes.len() - 1].iter_mut().rev() {
1689 *byte = byte.wrapping_add(1);
1690 if *byte != 0 {
1691 break;
1692 }
1693 }
1694 Principal::from_slice(&arr[..bytes.len()])
1695 }
1696 fn sub_one(start: &Principal) -> Principal {
1697 let bytes = start.as_slice();
1698 let mut arr = [0; 29];
1699 arr[..bytes.len()].copy_from_slice(bytes);
1700 for byte in arr[..bytes.len() - 1].iter_mut().rev() {
1701 *byte = byte.wrapping_sub(1);
1702 if *byte != 255 {
1703 break;
1704 }
1705 }
1706 Principal::from_slice(&arr[..bytes.len()])
1707 }
1708}
1709
1710#[derive(Debug, Clone)]
1712pub struct ApiBoundaryNode {
1713 pub domain: String,
1715 pub ipv6_address: String,
1717 pub ipv4_address: Option<String>,
1719}
1720
1721#[derive(Debug, Clone)]
1725#[non_exhaustive]
1726pub struct QueryBuilder<'agent> {
1727 agent: &'agent Agent,
1728 pub effective_canister_id: Principal,
1730 pub canister_id: Principal,
1732 pub method_name: String,
1734 pub arg: Vec<u8>,
1736 pub ingress_expiry_datetime: Option<u64>,
1738 pub use_nonce: bool,
1740}
1741
1742impl<'agent> QueryBuilder<'agent> {
1743 pub fn new(agent: &'agent Agent, canister_id: Principal, method_name: String) -> Self {
1745 Self {
1746 agent,
1747 effective_canister_id: canister_id,
1748 canister_id,
1749 method_name,
1750 arg: vec![],
1751 ingress_expiry_datetime: None,
1752 use_nonce: false,
1753 }
1754 }
1755
1756 pub fn with_effective_canister_id(mut self, canister_id: Principal) -> Self {
1758 self.effective_canister_id = canister_id;
1759 self
1760 }
1761
1762 pub fn with_arg<A: Into<Vec<u8>>>(mut self, arg: A) -> Self {
1764 self.arg = arg.into();
1765 self
1766 }
1767
1768 pub fn expire_at(mut self, time: impl Into<OffsetDateTime>) -> Self {
1770 self.ingress_expiry_datetime = Some(time.into().unix_timestamp_nanos() as u64);
1771 self
1772 }
1773
1774 pub fn expire_after(mut self, duration: Duration) -> Self {
1776 self.ingress_expiry_datetime = Some(
1777 OffsetDateTime::now_utc()
1778 .saturating_add(duration.try_into().expect("negative duration"))
1779 .unix_timestamp_nanos() as u64,
1780 );
1781 self
1782 }
1783
1784 pub fn with_nonce_generation(mut self) -> Self {
1787 self.use_nonce = true;
1788 self
1789 }
1790
1791 pub async fn call(self) -> Result<Vec<u8>, AgentError> {
1793 self.agent
1794 .query_raw(
1795 self.canister_id,
1796 self.effective_canister_id,
1797 self.method_name,
1798 self.arg,
1799 self.ingress_expiry_datetime,
1800 self.use_nonce,
1801 None,
1802 )
1803 .await
1804 }
1805
1806 pub async fn call_with_verification(self) -> Result<Vec<u8>, AgentError> {
1811 self.agent
1812 .query_raw(
1813 self.canister_id,
1814 self.effective_canister_id,
1815 self.method_name,
1816 self.arg,
1817 self.ingress_expiry_datetime,
1818 self.use_nonce,
1819 Some(true),
1820 )
1821 .await
1822 }
1823
1824 pub async fn call_without_verification(self) -> Result<Vec<u8>, AgentError> {
1829 self.agent
1830 .query_raw(
1831 self.canister_id,
1832 self.effective_canister_id,
1833 self.method_name,
1834 self.arg,
1835 self.ingress_expiry_datetime,
1836 self.use_nonce,
1837 Some(false),
1838 )
1839 .await
1840 }
1841
1842 pub fn sign(self) -> Result<SignedQuery, AgentError> {
1845 let effective_canister_id = self.effective_canister_id;
1846 let identity = self.agent.identity.clone();
1847 let content = self.into_envelope()?;
1848 let signed_query = sign_envelope(&content, identity)?;
1849 let EnvelopeContent::Query {
1850 ingress_expiry,
1851 sender,
1852 canister_id,
1853 method_name,
1854 arg,
1855 nonce,
1856 } = content
1857 else {
1858 unreachable!()
1859 };
1860 Ok(SignedQuery {
1861 ingress_expiry,
1862 sender,
1863 canister_id,
1864 method_name,
1865 arg,
1866 effective_canister_id,
1867 signed_query,
1868 nonce,
1869 })
1870 }
1871
1872 pub fn into_envelope(self) -> Result<EnvelopeContent, AgentError> {
1874 self.agent.query_content(
1875 self.canister_id,
1876 self.method_name,
1877 self.arg,
1878 self.ingress_expiry_datetime,
1879 self.use_nonce,
1880 )
1881 }
1882}
1883
1884impl<'agent> IntoFuture for QueryBuilder<'agent> {
1885 type IntoFuture = AgentFuture<'agent, Vec<u8>>;
1886 type Output = Result<Vec<u8>, AgentError>;
1887 fn into_future(self) -> Self::IntoFuture {
1888 Box::pin(self.call())
1889 }
1890}
1891
1892pub struct UpdateCall<'agent> {
1894 agent: &'agent Agent,
1895 response_future: AgentFuture<'agent, CallResponse<(Vec<u8>, Certificate)>>,
1896 effective_canister_id: Principal,
1897 canister_id: Principal,
1898 method_name: String,
1899}
1900
1901impl fmt::Debug for UpdateCall<'_> {
1902 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1903 f.debug_struct("UpdateCall")
1904 .field("agent", &self.agent)
1905 .field("effective_canister_id", &self.effective_canister_id)
1906 .finish_non_exhaustive()
1907 }
1908}
1909
1910impl Future for UpdateCall<'_> {
1911 type Output = Result<CallResponse<(Vec<u8>, Certificate)>, AgentError>;
1912 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1913 self.response_future.as_mut().poll(cx)
1914 }
1915}
1916
1917impl<'a> UpdateCall<'a> {
1918 pub async fn and_wait(self) -> Result<(Vec<u8>, Certificate), AgentError> {
1920 let response = self.response_future.await?;
1921
1922 match response {
1923 CallResponse::Response(response) => Ok(response),
1924 CallResponse::Poll(request_id) => {
1925 self.agent
1926 .wait_inner(
1927 &request_id,
1928 self.effective_canister_id,
1929 Some(Operation::Call {
1930 canister: self.canister_id,
1931 method: self.method_name,
1932 }),
1933 )
1934 .await
1935 }
1936 }
1937 }
1938}
1939#[derive(Debug)]
1944pub struct UpdateBuilder<'agent> {
1945 agent: &'agent Agent,
1946 pub effective_canister_id: Principal,
1948 pub canister_id: Principal,
1950 pub method_name: String,
1952 pub arg: Vec<u8>,
1954 pub ingress_expiry_datetime: Option<u64>,
1956}
1957
1958impl<'agent> UpdateBuilder<'agent> {
1959 pub fn new(agent: &'agent Agent, canister_id: Principal, method_name: String) -> Self {
1961 Self {
1962 agent,
1963 effective_canister_id: canister_id,
1964 canister_id,
1965 method_name,
1966 arg: vec![],
1967 ingress_expiry_datetime: None,
1968 }
1969 }
1970
1971 pub fn with_effective_canister_id(mut self, canister_id: Principal) -> Self {
1973 self.effective_canister_id = canister_id;
1974 self
1975 }
1976
1977 pub fn with_arg<A: Into<Vec<u8>>>(mut self, arg: A) -> Self {
1979 self.arg = arg.into();
1980 self
1981 }
1982
1983 pub fn expire_at(mut self, time: impl Into<OffsetDateTime>) -> Self {
1985 self.ingress_expiry_datetime = Some(time.into().unix_timestamp_nanos() as u64);
1986 self
1987 }
1988
1989 pub fn expire_after(mut self, duration: Duration) -> Self {
1991 self.ingress_expiry_datetime = Some(
1992 OffsetDateTime::now_utc()
1993 .saturating_add(duration.try_into().expect("negative duration"))
1994 .unix_timestamp_nanos() as u64,
1995 );
1996 self
1997 }
1998
1999 pub async fn call_and_wait(self) -> Result<Vec<u8>, AgentError> {
2002 self.call().and_wait().await.map(|x| x.0)
2003 }
2004
2005 pub fn call(self) -> UpdateCall<'agent> {
2008 let method_name = self.method_name.clone();
2009 let response_future = async move {
2010 self.agent
2011 .update_raw(
2012 self.canister_id,
2013 self.effective_canister_id,
2014 self.method_name,
2015 self.arg,
2016 self.ingress_expiry_datetime,
2017 )
2018 .await
2019 };
2020 UpdateCall {
2021 agent: self.agent,
2022 response_future: Box::pin(response_future),
2023 effective_canister_id: self.effective_canister_id,
2024 canister_id: self.canister_id,
2025 method_name,
2026 }
2027 }
2028
2029 pub fn sign(self) -> Result<SignedUpdate, AgentError> {
2032 let identity = self.agent.identity.clone();
2033 let effective_canister_id = self.effective_canister_id;
2034 let content = self.into_envelope()?;
2035 let signed_update = sign_envelope(&content, identity)?;
2036 let request_id = to_request_id(&content)?;
2037 let EnvelopeContent::Call {
2038 nonce,
2039 ingress_expiry,
2040 sender,
2041 canister_id,
2042 method_name,
2043 arg,
2044 } = content
2045 else {
2046 unreachable!()
2047 };
2048 Ok(SignedUpdate {
2049 nonce,
2050 ingress_expiry,
2051 sender,
2052 canister_id,
2053 method_name,
2054 arg,
2055 effective_canister_id,
2056 signed_update,
2057 request_id,
2058 })
2059 }
2060
2061 pub fn into_envelope(self) -> Result<EnvelopeContent, AgentError> {
2063 let nonce = self.agent.nonce_factory.generate();
2064 self.agent.update_content(
2065 self.canister_id,
2066 self.method_name,
2067 self.arg,
2068 self.ingress_expiry_datetime,
2069 nonce,
2070 )
2071 }
2072}
2073
2074impl<'agent> IntoFuture for UpdateBuilder<'agent> {
2075 type IntoFuture = AgentFuture<'agent, Vec<u8>>;
2076 type Output = Result<Vec<u8>, AgentError>;
2077 fn into_future(self) -> Self::IntoFuture {
2078 Box::pin(self.call_and_wait())
2079 }
2080}
2081
2082#[cfg_attr(target_family = "wasm", async_trait(?Send))]
2084#[cfg_attr(not(target_family = "wasm"), async_trait)]
2085pub trait HttpService: Send + Sync + Debug {
2086 async fn call<'a>(
2088 &'a self,
2089 req: &'a (dyn Fn() -> Result<http::Request<Bytes>, AgentError> + Send + Sync),
2090 max_retries: usize,
2091 size_limit: Option<usize>,
2092 ) -> Result<http::Response<Bytes>, AgentError>;
2093}
2094
2095fn from_http_request(req: http::Request<Bytes>) -> Result<Request, AgentError> {
2097 let (parts, body) = req.into_parts();
2098 let body = reqwest::Body::from(body);
2099 let request = http::Request::from_parts(parts, body)
2102 .try_into()
2103 .map_err(|e: reqwest::Error| AgentError::InvalidReplicaUrl(e.to_string()))?;
2104
2105 Ok(request)
2106}
2107
2108#[cfg(not(target_family = "wasm"))]
2110async fn to_http_response(
2111 resp: Response,
2112 size_limit: Option<usize>,
2113) -> Result<http::Response<Bytes>, AgentError> {
2114 use http_body_util::{BodyExt, Limited};
2115
2116 let resp: http::Response<reqwest::Body> = resp.into();
2117 let (parts, body) = resp.into_parts();
2118 let body = Limited::new(body, size_limit.unwrap_or(usize::MAX));
2119 let body = body
2120 .collect()
2121 .await
2122 .map_err(|e| {
2123 AgentError::TransportError(TransportError::Generic(format!(
2124 "unable to read response body: {e:#}"
2125 )))
2126 })?
2127 .to_bytes();
2128 let resp = http::Response::from_parts(parts, body);
2129
2130 Ok(resp)
2131}
2132
2133#[cfg(target_family = "wasm")]
2137async fn to_http_response(
2138 resp: Response,
2139 size_limit: Option<usize>,
2140) -> Result<http::Response<Bytes>, AgentError> {
2141 use futures_util::StreamExt;
2142 use http_body::Frame;
2143 use http_body_util::{Limited, StreamBody};
2144
2145 let status = resp.status();
2147 let headers = resp.headers().clone();
2148
2149 let stream = resp.bytes_stream().map(|x| x.map(Frame::data));
2151 let body = StreamBody::new(stream);
2152 let body = Limited::new(body, size_limit.unwrap_or(usize::MAX));
2153 let body = http_body_util::BodyExt::collect(body)
2154 .await
2155 .map_err(|e| {
2156 AgentError::TransportError(TransportError::Generic(format!(
2157 "unable to read response body: {e:#}"
2158 )))
2159 })?
2160 .to_bytes();
2161
2162 let mut resp = http::Response::new(body);
2163 *resp.status_mut() = status;
2164 *resp.headers_mut() = headers;
2165
2166 Ok(resp)
2167}
2168
2169#[cfg(not(target_family = "wasm"))]
2170#[async_trait]
2171impl<T> HttpService for T
2172where
2173 for<'a> &'a T: Service<Request, Response = Response, Error = reqwest::Error>,
2174 for<'a> <&'a Self as Service<Request>>::Future: Send,
2175 T: Send + Sync + Debug + ?Sized,
2176{
2177 #[allow(clippy::needless_arbitrary_self_type)]
2178 async fn call<'a>(
2179 mut self: &'a Self,
2180 req: &'a (dyn Fn() -> Result<http::Request<Bytes>, AgentError> + Send + Sync),
2181 max_retries: usize,
2182 size_limit: Option<usize>,
2183 ) -> Result<http::Response<Bytes>, AgentError> {
2184 let mut retry_count = 0;
2185 loop {
2186 let request = from_http_request(req()?)?;
2187
2188 match Service::call(&mut self, request).await {
2189 Err(err) => {
2190 if err.is_connect() {
2192 if retry_count >= max_retries {
2193 return Err(AgentError::TransportError(TransportError::Reqwest(err)));
2194 }
2195 retry_count += 1;
2196 }
2197 else {
2199 return Err(AgentError::TransportError(TransportError::Reqwest(err)));
2200 }
2201 }
2202
2203 Ok(resp) => {
2204 let resp = to_http_response(resp, size_limit).await?;
2205 return Ok(resp);
2206 }
2207 }
2208 }
2209 }
2210}
2211
2212#[cfg(target_family = "wasm")]
2213#[async_trait(?Send)]
2214impl<T> HttpService for T
2215where
2216 for<'a> &'a T: Service<Request, Response = Response, Error = reqwest::Error>,
2217 T: Send + Sync + Debug + ?Sized,
2218{
2219 #[allow(clippy::needless_arbitrary_self_type)]
2220 async fn call<'a>(
2221 mut self: &'a Self,
2222 req: &'a (dyn Fn() -> Result<http::Request<Bytes>, AgentError> + Send + Sync),
2223 _retries: usize,
2224 _size_limit: Option<usize>,
2225 ) -> Result<http::Response<Bytes>, AgentError> {
2226 let request = from_http_request(req()?)?;
2227 let response = Service::call(&mut self, request)
2228 .await
2229 .map_err(|e| AgentError::TransportError(TransportError::Reqwest(e)))?;
2230
2231 to_http_response(response, _size_limit).await
2232 }
2233}
2234
2235#[derive(Debug)]
2236struct Retry429Logic {
2237 client: Client,
2238}
2239
2240#[cfg_attr(target_family = "wasm", async_trait(?Send))]
2241#[cfg_attr(not(target_family = "wasm"), async_trait)]
2242impl HttpService for Retry429Logic {
2243 async fn call<'a>(
2244 &'a self,
2245 req: &'a (dyn Fn() -> Result<http::Request<Bytes>, AgentError> + Send + Sync),
2246 _max_tcp_retries: usize,
2247 _size_limit: Option<usize>,
2248 ) -> Result<http::Response<Bytes>, AgentError> {
2249 let mut retries = 0;
2250 loop {
2251 #[cfg(not(target_family = "wasm"))]
2252 let resp = self.client.call(req, _max_tcp_retries, _size_limit).await?;
2253 #[cfg(target_family = "wasm")]
2255 let resp = {
2256 let request = from_http_request(req()?)?;
2257 let resp = self
2258 .client
2259 .execute(request)
2260 .await
2261 .map_err(|e| AgentError::TransportError(TransportError::Reqwest(e)))?;
2262
2263 to_http_response(resp, _size_limit).await?
2264 };
2265
2266 if resp.status() == StatusCode::TOO_MANY_REQUESTS {
2267 if retries == 6 {
2268 break Ok(resp);
2269 } else {
2270 retries += 1;
2271 crate::util::sleep(Duration::from_millis(250)).await;
2272 continue;
2273 }
2274 } else {
2275 break Ok(resp);
2276 }
2277 }
2278 }
2279}
2280
2281#[cfg(all(test, not(target_family = "wasm")))]
2282mod offline_tests {
2283 use super::*;
2284 use tokio::net::TcpListener;
2285 #[test]
2288 fn rounded_expiry() {
2289 let agent = Agent::builder()
2290 .with_url("http://not-a-real-url")
2291 .build()
2292 .unwrap();
2293 let mut prev_expiry = None;
2294 let mut num_timestamps = 0;
2295 for _ in 0..6 {
2296 let update = agent
2297 .update(&Principal::management_canister(), "not_a_method")
2298 .sign()
2299 .unwrap();
2300 if prev_expiry < Some(update.ingress_expiry) {
2301 prev_expiry = Some(update.ingress_expiry);
2302 num_timestamps += 1;
2303 }
2304 }
2305 assert!(num_timestamps <= 2, "num_timestamps:{num_timestamps} > 2");
2307 }
2308
2309 #[tokio::test]
2310 async fn client_ratelimit() {
2311 let mock_server = TcpListener::bind("127.0.0.1:0").await.unwrap();
2312 let count = Arc::new(Mutex::new(0));
2313 let port = mock_server.local_addr().unwrap().port();
2314 tokio::spawn({
2315 let count = count.clone();
2316 async move {
2317 loop {
2318 let (mut conn, _) = mock_server.accept().await.unwrap();
2319 *count.lock().unwrap() += 1;
2320 tokio::spawn(
2321 async move { tokio::io::copy(&mut conn, &mut tokio::io::sink()).await },
2323 );
2324 }
2325 }
2326 });
2327 let agent = Agent::builder()
2328 .with_http_client(Client::builder().http1_only().build().unwrap())
2329 .with_url(format!("http://127.0.0.1:{port}"))
2330 .with_max_concurrent_requests(2)
2331 .build()
2332 .unwrap();
2333 for _ in 0..3 {
2334 let agent = agent.clone();
2335 tokio::spawn(async move {
2336 agent
2337 .query(&"ryjl3-tyaaa-aaaaa-aaaba-cai".parse().unwrap(), "greet")
2338 .call()
2339 .await
2340 });
2341 }
2342 crate::util::sleep(Duration::from_millis(250)).await;
2343 assert_eq!(*count.lock().unwrap(), 2);
2344 }
2345}