1use crate::apis;
2use crate::apis::ark_service_api::ark_service_confirm_registration;
3use crate::apis::ark_service_api::ark_service_delete_intent;
4use crate::apis::ark_service_api::ark_service_finalize_tx;
5use crate::apis::ark_service_api::ark_service_get_info;
6use crate::apis::ark_service_api::ark_service_register_intent;
7use crate::apis::ark_service_api::ark_service_submit_signed_forfeit_txs;
8use crate::apis::ark_service_api::ark_service_submit_tree_nonces;
9use crate::apis::ark_service_api::ark_service_submit_tree_signatures;
10use crate::apis::ark_service_api::ark_service_submit_tx;
11use crate::apis::indexer_service_api::indexer_service_get_virtual_txs;
12use crate::apis::indexer_service_api::indexer_service_get_vtxos;
13use crate::apis::indexer_service_api::indexer_service_subscribe_for_scripts;
14use crate::apis::indexer_service_api::indexer_service_unsubscribe_for_scripts;
15use crate::models;
16use crate::models::ConfirmRegistrationRequest;
17use crate::models::Intent;
18use crate::models::SubmitSignedForfeitTxsRequest;
19use crate::models::SubmitTreeNoncesRequest;
20use crate::models::SubmitTreeSignaturesRequest;
21use crate::models::SubscribeForScriptsRequest;
22use crate::models::UnsubscribeForScriptsRequest;
23use crate::Error;
24use ark_core::server::FinalizeOffchainTxResponse;
25use ark_core::server::GetVtxosRequest;
26use ark_core::server::GetVtxosRequestFilter;
27use ark_core::server::GetVtxosRequestReference;
28use ark_core::server::IndexerPage;
29use ark_core::server::NoncePks;
30use ark_core::server::PartialSigTree;
31use ark_core::server::StreamEvent;
32use ark_core::server::SubmitOffchainTxResponse;
33use ark_core::server::SubscriptionResponse;
34use ark_core::server::VirtualTxOutPoint;
35use ark_core::server::VirtualTxsResponse;
36use ark_core::server::SDK_VERSION;
37use ark_core::server::TARGET_ARKD_VERSION;
38use ark_core::ArkAddress;
39use bitcoin::base64;
40use bitcoin::base64::Engine;
41use bitcoin::secp256k1::PublicKey;
42use bitcoin::Psbt;
43use bitcoin::Txid;
44use futures::stream;
45use futures::Future;
46use futures::Stream;
47use futures::StreamExt;
48use std::error::Error as StdError;
49use std::sync::Arc;
50use std::sync::RwLock;
51
52type InfoRefreshHook = Arc<
53 dyn Fn(ark_core::server::Info) -> Result<(), Box<dyn StdError + Send + Sync + 'static>>
54 + Send
55 + Sync,
56>;
57
58pub struct Client {
59 configuration: RwLock<apis::configuration::Configuration>,
60 digest: RwLock<Option<String>>,
61 info_refresh_hook: Option<InfoRefreshHook>,
62}
63
64pub struct ListVtxosResponse {
65 pub vtxos: Vec<VirtualTxOutPoint>,
66 pub page: Option<IndexerPage>,
67}
68
69fn build_reqwest_client(digest: Option<&str>) -> Result<reqwest::Client, Error> {
70 let mut default_headers = reqwest::header::HeaderMap::new();
71 default_headers.insert(
72 "X-Build-Version",
73 reqwest::header::HeaderValue::from_static(TARGET_ARKD_VERSION),
74 );
75 default_headers.insert(
76 "X-SDK-Version",
77 reqwest::header::HeaderValue::from_static(SDK_VERSION),
78 );
79 if let Some(digest) = digest {
80 default_headers.insert(
81 "X-Digest",
82 reqwest::header::HeaderValue::from_str(digest).map_err(Error::request)?,
83 );
84 }
85
86 reqwest::Client::builder()
87 .default_headers(default_headers)
88 .build()
89 .map_err(Error::request)
90}
91
92impl Client {
93 pub fn new(ark_server_url: String) -> Result<Self, Error> {
94 let configuration = apis::configuration::Configuration {
95 base_path: ark_server_url,
96 client: build_reqwest_client(None)?,
97 ..Default::default()
98 };
99
100 Ok(Self {
101 configuration: RwLock::new(configuration),
102 digest: RwLock::new(None),
103 info_refresh_hook: None,
104 })
105 }
106
107 pub fn set_info_refresh_hook(
108 &mut self,
109 hook: impl Fn(ark_core::server::Info) -> Result<(), Box<dyn StdError + Send + Sync + 'static>>
110 + Send
111 + Sync
112 + 'static,
113 ) {
114 self.info_refresh_hook = Some(Arc::new(hook));
115 }
116
117 fn configuration(&self) -> Result<apis::configuration::Configuration, Error> {
118 self.configuration
119 .read()
120 .map(|configuration| configuration.clone())
121 .map_err(|_| Error::request("REST client configuration lock poisoned"))
122 }
123
124 fn update_digest(&self, digest: &str) -> Result<(), Error> {
125 let normalized = (!digest.is_empty()).then(|| digest.to_owned());
126
127 {
128 let current = self
129 .digest
130 .read()
131 .map_err(|_| Error::request("REST client digest lock poisoned"))?;
132 if *current == normalized {
133 return Ok(());
134 }
135 }
136
137 let mut configuration = self
142 .configuration
143 .write()
144 .map_err(|_| Error::request("REST client configuration lock poisoned"))?;
145 configuration.client = build_reqwest_client(normalized.as_deref())?;
146
147 let mut current = self
148 .digest
149 .write()
150 .map_err(|_| Error::request("REST client digest lock poisoned"))?;
151 *current = normalized;
152 Ok(())
153 }
154
155 async fn guarded<T>(&self, op: impl Future<Output = Result<T, Error>>) -> Result<T, Error> {
156 match op.await {
157 Ok(value) => Ok(value),
158 Err(err) if err.is_digest_mismatch() => {
159 let original = err;
160 self.refresh_after_digest_mismatch().await?;
161 Err(Error::server_info_changed(original))
162 }
163 Err(err) => Err(err),
164 }
165 }
166
167 async fn refresh_on_digest_mismatch(&self, err: Error) -> Error {
168 if !err.is_digest_mismatch() {
169 return err;
170 }
171
172 match self.refresh_after_digest_mismatch().await {
173 Ok(()) => Error::server_info_changed(err),
174 Err(refresh_err) => refresh_err,
175 }
176 }
177
178 async fn refresh_after_digest_mismatch(&self) -> Result<(), Error> {
179 let info = self.fetch_info_unguarded().await?;
180 let digest = info.digest.clone();
181
182 if let Some(hook) = &self.info_refresh_hook {
183 hook(info).map_err(Error::conversion)?;
184 }
185
186 self.update_digest(&digest)
189 }
190
191 async fn fetch_info_unguarded(&self) -> Result<ark_core::server::Info, Error> {
192 let configuration = self.configuration()?;
193 let info = ark_service_get_info(&configuration)
194 .await
195 .map_err(Error::request)?;
196
197 info.try_into().map_err(Error::conversion)
198 }
199
200 pub async fn get_info(&self) -> Result<ark_core::server::Info, Error> {
201 let info = self.fetch_info_unguarded().await?;
202 self.update_digest(&info.digest)?;
203 Ok(info)
204 }
205
206 pub async fn submit_offchain_transaction_request(
207 &self,
208 ark_tx: Psbt,
209 checkpoint_txs: Vec<Psbt>,
210 ) -> Result<SubmitOffchainTxResponse, Error> {
211 let base64 = base64::engine::GeneralPurpose::new(
212 &base64::alphabet::STANDARD,
213 base64::engine::GeneralPurposeConfig::new(),
214 );
215
216 let ark_tx = base64.encode(ark_tx.serialize());
217
218 let checkpoint_txs = checkpoint_txs
219 .into_iter()
220 .map(|tx| Some(base64.encode(tx.serialize())))
221 .collect();
222
223 let configuration = self.configuration()?;
224 let res = self
225 .guarded(async {
226 ark_service_submit_tx(
227 &configuration,
228 models::SubmitTxRequest {
229 signed_ark_tx: Some(ark_tx),
230 checkpoint_txs,
231 },
232 )
233 .await
234 .map_err(Error::request)
235 })
236 .await?;
237
238 let signed_ark_tx = res.final_ark_tx;
239 let signed_ark_tx = signed_ark_tx.ok_or(Error::request("Signed ark tx not received"))?;
240
241 let signed_ark_tx = base64.decode(signed_ark_tx).map_err(Error::conversion)?;
242 let signed_ark_tx = Psbt::deserialize(&signed_ark_tx).map_err(Error::conversion)?;
243
244 let signed_checkpoint_txs = res
245 .signed_checkpoint_txs
246 .ok_or(Error::request("Signed checkpoint tx not received"))?
247 .into_iter()
248 .map(|tx| {
249 let tx = base64.decode(tx).map_err(Error::conversion)?;
250 let tx = Psbt::deserialize(&tx).map_err(Error::conversion)?;
251
252 Ok(tx)
253 })
254 .collect::<Result<Vec<_>, Error>>()?;
255
256 Ok(SubmitOffchainTxResponse {
257 signed_ark_tx,
258 signed_checkpoint_txs,
259 })
260 }
261
262 pub async fn finalize_offchain_transaction(
263 &self,
264 txid: Txid,
265 checkpoint_txs: Vec<Psbt>,
266 ) -> Result<FinalizeOffchainTxResponse, Error> {
267 let base64 = base64::engine::GeneralPurpose::new(
268 &base64::alphabet::STANDARD,
269 base64::engine::GeneralPurposeConfig::new(),
270 );
271
272 let checkpoint_txs = checkpoint_txs
273 .into_iter()
274 .map(|tx| Some(base64.encode(tx.serialize())))
275 .collect();
276
277 let configuration = self.configuration()?;
278 self.guarded(async {
279 ark_service_finalize_tx(
280 &configuration,
281 models::FinalizeTxRequest {
282 ark_txid: Some(txid.to_string()),
283 final_checkpoint_txs: checkpoint_txs,
284 },
285 )
286 .await
287 .map_err(Error::request)
288 })
289 .await?;
290
291 Ok(FinalizeOffchainTxResponse {})
292 }
293
294 pub async fn list_vtxos(&self, request: GetVtxosRequest) -> Result<ListVtxosResponse, Error> {
295 let reference = request.reference();
296
297 if reference.is_empty() {
298 return Ok(ListVtxosResponse {
299 vtxos: Vec::new(),
300 page: None,
301 });
302 }
303
304 let filter = request.filter();
305
306 let (scripts, outpoints) = match reference {
307 GetVtxosRequestReference::Scripts(s) => (
308 Some(s.iter().map(|s| s.to_hex_string()).clone().collect()),
309 None,
310 ),
311 GetVtxosRequestReference::OutPoints(o) => {
312 (None, Some(o.iter().map(|o| o.to_string()).collect()))
313 }
314 };
315 let (spendable_only, spent_only, recoverable_only, pending_only) = match filter {
316 None => (Some(false), Some(false), Some(false), Some(false)),
317 Some(filter) => match filter {
318 GetVtxosRequestFilter::Spendable => {
319 (Some(true), Some(false), Some(false), Some(false))
320 }
321 GetVtxosRequestFilter::Spent => (Some(false), Some(true), Some(false), Some(false)),
322 GetVtxosRequestFilter::Recoverable => {
323 (Some(false), Some(false), Some(true), Some(false))
324 }
325 GetVtxosRequestFilter::PendingOnly => {
326 (Some(false), Some(false), Some(false), Some(true))
327 }
328 },
329 };
330
331 let page_period_size: Option<i32> = request.page().map(|p| p.size);
332 let page_period_index: Option<i32> = request.page().map(|p| p.index);
333
334 let before = request.before().map(|b| b as i64);
335 let after = request.after().map(|b| b as i64);
336
337 let configuration = self.configuration()?;
338 let response = self
339 .guarded(async {
340 indexer_service_get_vtxos(
341 &configuration,
342 scripts,
343 outpoints,
344 spendable_only,
345 spent_only,
346 recoverable_only,
347 pending_only,
348 before,
349 after,
350 page_period_size,
351 page_period_index,
352 )
353 .await
354 .map_err(Error::request)
355 })
356 .await?;
357
358 let vtxos = response.vtxos.ok_or(Error::request("VTXOs not received"))?;
359 let vtxos = vtxos
360 .into_iter()
361 .map(VirtualTxOutPoint::try_from)
362 .collect::<Result<Vec<_>, crate::conversions::ConversionError>>()?;
363
364 let page = response.page.map(|p| IndexerPage {
365 current: p.current.unwrap_or_default(),
366 next: p.next.unwrap_or_default(),
367 total: p.total.unwrap_or_default(),
368 });
369
370 Ok(ListVtxosResponse { vtxos, page })
371 }
372
373 pub async fn register_intent(
374 &self,
375 intent_message: &ark_core::intent::IntentMessage,
376 proof: &Psbt,
377 ) -> Result<String, Error> {
378 let message = intent_message.encode().map_err(Error::conversion)?;
379 let base64 = base64::engine::GeneralPurpose::new(
380 &base64::alphabet::STANDARD,
381 base64::engine::GeneralPurposeConfig::new(),
382 );
383
384 let bytes = proof.serialize();
385
386 let proof = base64.encode(&bytes);
387
388 let configuration = self.configuration()?;
389 let response = self
390 .guarded(async {
391 ark_service_register_intent(
392 &configuration,
393 models::RegisterIntentRequest {
394 intent: Some(Intent {
395 proof: Some(proof),
396 message: Some(message),
397 }),
398 },
399 )
400 .await
401 .map_err(Error::request)
402 })
403 .await?;
404 let intent_id = response
405 .intent_id
406 .ok_or(Error::request("Could not get intent id"))?;
407
408 Ok(intent_id)
409 }
410
411 pub async fn delete_intent(
412 &self,
413 intent_message: &ark_core::intent::IntentMessage,
414 proof: &Psbt,
415 ) -> Result<(), Error> {
416 let message = intent_message.encode().map_err(Error::conversion)?;
417 let base64 = base64::engine::GeneralPurpose::new(
418 &base64::alphabet::STANDARD,
419 base64::engine::GeneralPurposeConfig::new(),
420 );
421
422 let bytes = proof.serialize();
423
424 let proof = base64.encode(&bytes);
425 let configuration = self.configuration()?;
426 self.guarded(async {
427 ark_service_delete_intent(
428 &configuration,
429 models::DeleteIntentRequest {
430 intent: Some(Intent {
431 proof: Some(proof),
432 message: Some(message),
433 }),
434 },
435 )
436 .await
437 .map_err(Error::request)
438 })
439 .await?;
440
441 Ok(())
442 }
443
444 pub async fn get_event_stream(
445 &self,
446 topics: Vec<String>,
447 ) -> Result<impl Stream<Item = Result<StreamEvent, Error>> + Unpin, Error> {
448 let configuration = self.configuration()?;
449
450 let mut url = format!("{}/v1/batch/events", configuration.base_path);
452 if !topics.is_empty() {
453 let query_params: Vec<String> = topics
454 .iter()
455 .map(|topic| format!("topics={}", urlencoding::encode(topic)))
456 .collect();
457 url = format!("{}?{}", url, query_params.join("&"));
458 }
459
460 let request = configuration
462 .client
463 .get(&url)
464 .header("Accept", "text/event-stream")
465 .send()
466 .await
467 .map_err(Error::request)?;
468
469 if !request.status().is_success() {
472 let status = request.status();
473 let body = request.text().await.unwrap_or_default();
474 let err = Error::request(format!(
475 "Event stream request failed with status {status}: {body}"
476 ));
477 return Err(self.refresh_on_digest_mismatch(err).await);
478 }
479
480 let byte_stream = request.bytes_stream();
482
483 let stream = stream::unfold(byte_stream, |mut byte_stream| async move {
485 loop {
486 match byte_stream.next().await {
487 Some(chunk_result) => {
488 let result = match chunk_result {
489 Ok(bytes) => {
490 let event = String::from_utf8(bytes.to_vec());
491 match event {
492 Ok(event) => {
493 let event = event.trim();
494 if event.is_empty() || event.starts_with(':') {
496 continue;
497 }
498 let event = event.strip_prefix("data: ").unwrap_or(event);
500 if let Ok(response) =
501 serde_json::from_str::<models::GetEventStreamResponse>(
502 event,
503 )
504 {
505 match StreamEvent::try_from(response) {
506 Ok(stream_event) => Ok(stream_event),
507 Err(e) => Err(Error::conversion(e)),
508 }
509 } else {
510 Err(Error::conversion("Failed to parse JSON"))
512 }
513 }
514 Err(error) => Err(Error::conversion(error)),
515 }
516 }
517 Err(e) => Err(Error::request(e)),
518 };
519 return Some((result, byte_stream));
520 }
521 None => return None,
522 }
523 }
524 });
525
526 Ok(Box::pin(stream))
527 }
528 pub async fn confirm_registration(&self, intent_id: String) -> Result<(), Error> {
529 let configuration = self.configuration()?;
530 self.guarded(async {
531 ark_service_confirm_registration(
532 &configuration,
533 ConfirmRegistrationRequest {
534 intent_id: Some(intent_id),
535 },
536 )
537 .await
538 .map_err(Error::request)
539 })
540 .await?;
541
542 Ok(())
543 }
544
545 pub async fn submit_tree_nonces(
546 &self,
547 batch_id: &str,
548 cosigner_pubkey: PublicKey,
549 pub_nonce_tree: NoncePks,
550 ) -> Result<(), Error> {
551 let tree_nonces = pub_nonce_tree.encode();
552
553 let configuration = self.configuration()?;
554 self.guarded(async {
555 ark_service_submit_tree_nonces(
556 &configuration,
557 SubmitTreeNoncesRequest {
558 batch_id: Some(batch_id.to_string()),
559 pubkey: Some(cosigner_pubkey.to_string()),
560 tree_nonces: Some(tree_nonces),
561 },
562 )
563 .await
564 .map_err(Error::request)
565 })
566 .await?;
567
568 Ok(())
569 }
570
571 pub async fn submit_tree_signatures(
572 &self,
573 batch_id: &str,
574 cosigner_pk: PublicKey,
575 partial_sig_tree: PartialSigTree,
576 ) -> Result<(), Error> {
577 let tree_signatures = partial_sig_tree.encode();
578
579 let configuration = self.configuration()?;
580 self.guarded(async {
581 ark_service_submit_tree_signatures(
582 &configuration,
583 SubmitTreeSignaturesRequest {
584 batch_id: Some(batch_id.to_string()),
585 pubkey: Some(cosigner_pk.to_string()),
586 tree_signatures: Some(tree_signatures),
587 },
588 )
589 .await
590 .map_err(Error::request)
591 })
592 .await?;
593
594 Ok(())
595 }
596
597 pub async fn submit_signed_forfeit_txs(
598 &self,
599 signed_forfeit_txs: Vec<Psbt>,
600 signed_commitment_tx: Option<Psbt>,
601 ) -> Result<(), Error> {
602 let base64 = base64::engine::GeneralPurpose::new(
603 &base64::alphabet::STANDARD,
604 base64::engine::GeneralPurposeConfig::new(),
605 );
606
607 let signed_commitment_tx = signed_commitment_tx
608 .map(|tx| base64.encode(tx.serialize()))
609 .unwrap_or_default();
610
611 let configuration = self.configuration()?;
612 self.guarded(async {
613 ark_service_submit_signed_forfeit_txs(
614 &configuration,
615 SubmitSignedForfeitTxsRequest {
616 signed_forfeit_txs: signed_forfeit_txs
617 .iter()
618 .map(|psbt| Some(base64.encode(psbt.serialize())))
619 .collect(),
620 signed_commitment_tx: Some(signed_commitment_tx),
621 },
622 )
623 .await
624 .map_err(Error::request)
625 })
626 .await?;
627
628 Ok(())
629 }
630
631 pub async fn subscribe_to_scripts(
641 &self,
642 scripts: Vec<ArkAddress>,
643 subscription_id: Option<String>,
644 ) -> Result<String, Error> {
645 let scripts = scripts
646 .iter()
647 .map(|address| address.to_p2tr_script_pubkey().to_hex_string())
648 .collect::<Vec<_>>();
649
650 let subscription_id = subscription_id.unwrap_or_default();
652
653 let configuration = self.configuration()?;
654 let response = self
655 .guarded(async {
656 indexer_service_subscribe_for_scripts(
657 &configuration,
658 SubscribeForScriptsRequest {
659 scripts: Some(scripts),
660 subscription_id: Some(subscription_id),
661 },
662 )
663 .await
664 .map_err(Error::request)
665 })
666 .await?;
667
668 let subscription_id = response
669 .subscription_id
670 .ok_or(Error::request("No subscription id"))?;
671
672 Ok(subscription_id)
673 }
674
675 pub async fn unsubscribe_from_scripts(
677 &self,
678 scripts: Vec<ArkAddress>,
679 subscription_id: String,
680 ) -> Result<(), Error> {
681 let scripts = scripts
682 .iter()
683 .map(|address| address.to_p2tr_script_pubkey().to_hex_string())
684 .collect::<Vec<_>>();
685
686 let configuration = self.configuration()?;
687 self.guarded(async {
688 indexer_service_unsubscribe_for_scripts(
689 &configuration,
690 UnsubscribeForScriptsRequest {
691 subscription_id: Some(subscription_id),
692 scripts: Some(scripts),
693 },
694 )
695 .await
696 .map_err(Error::request)
697 })
698 .await?;
699
700 Ok(())
701 }
702
703 pub async fn get_subscription(
704 &self,
705 subscription_id: String,
706 ) -> Result<impl Stream<Item = Result<SubscriptionResponse, Error>> + Unpin, Error> {
707 let configuration = self.configuration()?;
708
709 let url = format!(
711 "{}/v1/script/subscription/{subscription_id}",
712 configuration.base_path,
713 );
714
715 let request = configuration
717 .client
718 .get(&url)
719 .header("Accept", "text/event-stream")
720 .send()
721 .await
722 .map_err(Error::request)?;
723
724 if !request.status().is_success() {
727 let status = request.status();
728 let body = request.text().await.unwrap_or_default();
729 let err = Error::request(format!(
730 "Subscription stream request failed with status {status}: {body}"
731 ));
732 return Err(self.refresh_on_digest_mismatch(err).await);
733 }
734
735 let byte_stream = request.bytes_stream();
737
738 let stream = stream::unfold(byte_stream, |mut byte_stream| async move {
740 loop {
741 match byte_stream.next().await {
742 Some(chunk_result) => {
743 let result = match chunk_result {
744 Ok(bytes) => {
745 let event = String::from_utf8(bytes.to_vec());
746 match event {
747 Ok(event) => {
748 let event = event.trim();
749 if event.is_empty() || event.starts_with(':') {
751 continue;
752 }
753 let event = event.strip_prefix("data: ").unwrap_or(event);
755 if let Ok(response) =
756 serde_json::from_str::<models::GetSubscriptionResponse>(
757 event,
758 )
759 {
760 match SubscriptionResponse::try_from(response) {
761 Ok(subscription_response) => {
762 Ok(subscription_response)
763 }
764 Err(e) => Err(Error::conversion(e)),
765 }
766 } else {
767 Err(Error::conversion("Failed to parse JSON"))
769 }
770 }
771 Err(error) => Err(Error::conversion(error)),
772 }
773 }
774 Err(e) => Err(Error::request(e)),
775 };
776 return Some((result, byte_stream));
777 }
778 None => return None,
779 }
780 }
781 });
782
783 Ok(Box::pin(stream))
784 }
785
786 pub async fn get_virtual_txs(
787 &self,
788 txids: Vec<String>,
789 size_and_index: Option<(i32, i32)>,
790 ) -> Result<VirtualTxsResponse, Error> {
791 let (size, index) = size_and_index
792 .map(|(sz, indx)| (Some(sz), Some(indx)))
793 .unwrap_or_default();
794 let configuration = self.configuration()?;
795 let response = self
796 .guarded(async {
797 indexer_service_get_virtual_txs(&configuration, txids, size, index)
798 .await
799 .map_err(Error::request)
800 })
801 .await?;
802
803 let base64 = &base64::engine::GeneralPurpose::new(
804 &base64::alphabet::STANDARD,
805 base64::engine::GeneralPurposeConfig::new(),
806 );
807
808 let txs = response
809 .txs
810 .unwrap_or_default()
811 .into_iter()
812 .map(|tx| {
813 let bytes = base64.decode(&tx).map_err(Error::conversion)?;
814 let psbt = Psbt::deserialize(&bytes).map_err(Error::conversion)?;
815
816 Ok(psbt)
817 })
818 .collect::<Result<Vec<Psbt>, Error>>()?;
819
820 Ok(VirtualTxsResponse {
821 txs,
822 page: response.page.map(|a| IndexerPage {
823 current: a.current.unwrap_or_default(),
824 next: a.next.unwrap_or_default(),
825 total: a.total.unwrap_or_default(),
826 }),
827 })
828 }
829}
830
831#[cfg(test)]
832mod tests {
833 use super::*;
834 use std::sync::atomic::AtomicBool;
835 use std::sync::atomic::Ordering;
836
837 #[tokio::test]
838 async fn guarded_passes_through_non_digest_error() {
839 let mut client = Client::new("http://127.0.0.1:1".to_string()).unwrap();
840 let hook_fired = Arc::new(AtomicBool::new(false));
841 let flag = hook_fired.clone();
842 client.set_info_refresh_hook(move |_info| {
843 flag.store(true, Ordering::SeqCst);
844 Ok(())
845 });
846
847 let err = client
848 .guarded(async { Err::<(), _>(Error::request("connection refused")) })
849 .await
850 .expect_err("should surface the original error");
851
852 assert!(!err.is_server_info_changed());
853 assert!(!err.is_digest_mismatch());
854 assert!(!hook_fired.load(Ordering::SeqCst));
855 }
856
857 #[tokio::test]
858 async fn guarded_detects_digest_mismatch_and_attempts_refresh() {
859 let mut client = Client::new("http://127.0.0.1:1".to_string()).unwrap();
860 let hook_fired = Arc::new(AtomicBool::new(false));
861 let flag = hook_fired.clone();
862 client.set_info_refresh_hook(move |_info| {
863 flag.store(true, Ordering::SeqCst);
864 Ok(())
865 });
866
867 let err = client
868 .guarded(async { Err::<(), _>(Error::request("DIGEST_MISMATCH")) })
869 .await
870 .expect_err("digest mismatch should trigger a refresh that fails on a closed port");
871
872 assert!(!err.is_server_info_changed());
875 assert!(!hook_fired.load(Ordering::SeqCst));
876 }
877}