1use casper_types::{
3 account::AccountHash,
4 global_state::TrieMerkleProof,
5 system::{
6 handle_payment::{ACCUMULATION_PURSE_KEY, PAYMENT_PURSE_KEY, REFUND_PURSE_KEY},
7 mint::BalanceHoldAddrTag,
8 HANDLE_PAYMENT,
9 },
10 AccessRights, BlockTime, Digest, EntityAddr, HoldBalanceHandling, InitiatorAddr, Key,
11 ProtocolVersion, PublicKey, StoredValue, TimeDiff, URef, URefAddr, U512,
12};
13use itertools::Itertools;
14use num_rational::Ratio;
15use num_traits::CheckedMul;
16use std::{
17 collections::{btree_map::Entry, BTreeMap},
18 fmt::{Display, Formatter},
19};
20use tracing::error;
21
22use crate::{
23 global_state::state::StateReader,
24 tracking_copy::{TrackingCopyEntityExt, TrackingCopyError, TrackingCopyExt},
25 TrackingCopy,
26};
27
28#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
30pub enum BalanceHandling {
31 #[default]
33 Total,
34 Available,
36}
37
38#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
40pub enum ProofHandling {
41 #[default]
43 NoProofs,
44 Proofs,
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum BalanceIdentifier {
51 Refund,
53 Payment,
55 Accumulate,
57 Purse(URef),
59 Public(PublicKey),
61 Account(AccountHash),
63 Entity(EntityAddr),
65 Internal(URefAddr),
67 PenalizedAccount(AccountHash),
69 PenalizedPayment,
71}
72
73impl BalanceIdentifier {
74 pub fn as_purse_addr(&self) -> Option<URefAddr> {
76 match self {
77 BalanceIdentifier::Internal(addr) => Some(*addr),
78 BalanceIdentifier::Purse(uref) => Some(uref.addr()),
79 BalanceIdentifier::Public(_)
80 | BalanceIdentifier::Account(_)
81 | BalanceIdentifier::PenalizedAccount(_)
82 | BalanceIdentifier::PenalizedPayment
83 | BalanceIdentifier::Entity(_)
84 | BalanceIdentifier::Refund
85 | BalanceIdentifier::Payment
86 | BalanceIdentifier::Accumulate => None,
87 }
88 }
89
90 pub fn purse_uref<S>(
92 &self,
93 tc: &mut TrackingCopy<S>,
94 protocol_version: ProtocolVersion,
95 ) -> Result<URef, TrackingCopyError>
96 where
97 S: StateReader<Key, StoredValue, Error = crate::global_state::error::Error>,
98 {
99 let purse_uref = match self {
100 BalanceIdentifier::Internal(addr) => URef::new(*addr, AccessRights::READ),
101 BalanceIdentifier::Purse(purse_uref) => *purse_uref,
102 BalanceIdentifier::Public(public_key) => {
103 let account_hash = public_key.to_account_hash();
104 match tc.runtime_footprint_by_account_hash(protocol_version, account_hash) {
105 Ok((_, entity)) => entity
106 .main_purse()
107 .ok_or(TrackingCopyError::Authorization)?,
108 Err(tce) => return Err(tce),
109 }
110 }
111 BalanceIdentifier::Account(account_hash)
112 | BalanceIdentifier::PenalizedAccount(account_hash) => {
113 match tc.runtime_footprint_by_account_hash(protocol_version, *account_hash) {
114 Ok((_, entity)) => entity
115 .main_purse()
116 .ok_or(TrackingCopyError::Authorization)?,
117 Err(tce) => return Err(tce),
118 }
119 }
120 BalanceIdentifier::Entity(entity_addr) => {
121 match tc.runtime_footprint_by_entity_addr(*entity_addr) {
122 Ok(entity) => entity
123 .main_purse()
124 .ok_or(TrackingCopyError::Authorization)?,
125 Err(tce) => return Err(tce),
126 }
127 }
128 BalanceIdentifier::Refund => {
129 self.get_system_purse(tc, HANDLE_PAYMENT, REFUND_PURSE_KEY)?
130 }
131 BalanceIdentifier::Payment | BalanceIdentifier::PenalizedPayment => {
132 self.get_system_purse(tc, HANDLE_PAYMENT, PAYMENT_PURSE_KEY)?
133 }
134 BalanceIdentifier::Accumulate => {
135 self.get_system_purse(tc, HANDLE_PAYMENT, ACCUMULATION_PURSE_KEY)?
136 }
137 };
138 Ok(purse_uref)
139 }
140
141 fn get_system_purse<S>(
142 &self,
143 tc: &mut TrackingCopy<S>,
144 system_contract_name: &str,
145 named_key_name: &str,
146 ) -> Result<URef, TrackingCopyError>
147 where
148 S: StateReader<Key, StoredValue, Error = crate::global_state::error::Error>,
149 {
150 let system_contract_registry = tc.get_system_entity_registry()?;
151
152 let entity_hash = system_contract_registry
153 .get(system_contract_name)
154 .ok_or_else(|| {
155 error!("Missing system handle payment contract hash");
156 TrackingCopyError::MissingSystemContractHash(system_contract_name.to_string())
157 })?;
158
159 let named_keys = tc
160 .runtime_footprint_by_entity_addr(EntityAddr::System(*entity_hash))?
161 .take_named_keys();
162
163 let named_key =
164 named_keys
165 .get(named_key_name)
166 .ok_or(TrackingCopyError::NamedKeyNotFound(
167 named_key_name.to_string(),
168 ))?;
169 let uref = named_key
170 .as_uref()
171 .ok_or(TrackingCopyError::UnexpectedKeyVariant(*named_key))?;
172 Ok(*uref)
173 }
174
175 pub fn is_penalty(&self) -> bool {
177 matches!(
178 self,
179 BalanceIdentifier::PenalizedAccount(_) | BalanceIdentifier::PenalizedPayment
180 )
181 }
182}
183
184impl Default for BalanceIdentifier {
185 fn default() -> Self {
186 BalanceIdentifier::Purse(URef::default())
187 }
188}
189
190impl From<InitiatorAddr> for BalanceIdentifier {
191 fn from(value: InitiatorAddr) -> Self {
192 match value {
193 InitiatorAddr::PublicKey(public_key) => BalanceIdentifier::Public(public_key),
194 InitiatorAddr::AccountHash(account_hash) => BalanceIdentifier::Account(account_hash),
195 }
196 }
197}
198
199#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
201pub struct ProcessingHoldBalanceHandling {}
202
203impl ProcessingHoldBalanceHandling {
204 pub fn new() -> Self {
206 ProcessingHoldBalanceHandling::default()
207 }
208
209 pub fn handling(&self) -> HoldBalanceHandling {
211 HoldBalanceHandling::Accrued
212 }
213
214 pub fn is_amortized(&self) -> bool {
216 false
217 }
218
219 pub fn interval(&self) -> TimeDiff {
221 TimeDiff::default()
222 }
223}
224
225impl From<(HoldBalanceHandling, u64)> for ProcessingHoldBalanceHandling {
226 fn from(_value: (HoldBalanceHandling, u64)) -> Self {
227 ProcessingHoldBalanceHandling::default()
228 }
229}
230
231#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
233pub struct GasHoldBalanceHandling {
234 handling: HoldBalanceHandling,
235 interval: TimeDiff,
236}
237
238impl GasHoldBalanceHandling {
239 pub fn new(handling: HoldBalanceHandling, interval: TimeDiff) -> Self {
241 GasHoldBalanceHandling { handling, interval }
242 }
243
244 pub fn handling(&self) -> HoldBalanceHandling {
246 self.handling
247 }
248
249 pub fn interval(&self) -> TimeDiff {
251 self.interval
252 }
253
254 pub fn is_amortized(&self) -> bool {
256 matches!(self.handling, HoldBalanceHandling::Amortized)
257 }
258}
259
260impl From<(HoldBalanceHandling, TimeDiff)> for GasHoldBalanceHandling {
261 fn from(value: (HoldBalanceHandling, TimeDiff)) -> Self {
262 GasHoldBalanceHandling {
263 handling: value.0,
264 interval: value.1,
265 }
266 }
267}
268
269impl From<(HoldBalanceHandling, u64)> for GasHoldBalanceHandling {
270 fn from(value: (HoldBalanceHandling, u64)) -> Self {
271 GasHoldBalanceHandling {
272 handling: value.0,
273 interval: TimeDiff::from_millis(value.1),
274 }
275 }
276}
277
278#[derive(Debug, Clone, PartialEq, Eq)]
280pub struct BalanceRequest {
281 state_hash: Digest,
282 protocol_version: ProtocolVersion,
283 identifier: BalanceIdentifier,
284 balance_handling: BalanceHandling,
285 proof_handling: ProofHandling,
286}
287
288impl BalanceRequest {
289 pub fn new(
291 state_hash: Digest,
292 protocol_version: ProtocolVersion,
293 identifier: BalanceIdentifier,
294 balance_handling: BalanceHandling,
295 proof_handling: ProofHandling,
296 ) -> Self {
297 BalanceRequest {
298 state_hash,
299 protocol_version,
300 identifier,
301 balance_handling,
302 proof_handling,
303 }
304 }
305
306 pub fn from_purse(
308 state_hash: Digest,
309 protocol_version: ProtocolVersion,
310 purse_uref: URef,
311 balance_handling: BalanceHandling,
312 proof_handling: ProofHandling,
313 ) -> Self {
314 BalanceRequest {
315 state_hash,
316 protocol_version,
317 identifier: BalanceIdentifier::Purse(purse_uref),
318 balance_handling,
319 proof_handling,
320 }
321 }
322
323 pub fn from_public_key(
325 state_hash: Digest,
326 protocol_version: ProtocolVersion,
327 public_key: PublicKey,
328 balance_handling: BalanceHandling,
329 proof_handling: ProofHandling,
330 ) -> Self {
331 BalanceRequest {
332 state_hash,
333 protocol_version,
334 identifier: BalanceIdentifier::Public(public_key),
335 balance_handling,
336 proof_handling,
337 }
338 }
339
340 pub fn from_account_hash(
342 state_hash: Digest,
343 protocol_version: ProtocolVersion,
344 account_hash: AccountHash,
345 balance_handling: BalanceHandling,
346 proof_handling: ProofHandling,
347 ) -> Self {
348 BalanceRequest {
349 state_hash,
350 protocol_version,
351 identifier: BalanceIdentifier::Account(account_hash),
352 balance_handling,
353 proof_handling,
354 }
355 }
356
357 pub fn from_entity_addr(
359 state_hash: Digest,
360 protocol_version: ProtocolVersion,
361 entity_addr: EntityAddr,
362 balance_handling: BalanceHandling,
363 proof_handling: ProofHandling,
364 ) -> Self {
365 BalanceRequest {
366 state_hash,
367 protocol_version,
368 identifier: BalanceIdentifier::Entity(entity_addr),
369 balance_handling,
370 proof_handling,
371 }
372 }
373
374 pub fn from_internal(
376 state_hash: Digest,
377 protocol_version: ProtocolVersion,
378 balance_addr: URefAddr,
379 balance_handling: BalanceHandling,
380 proof_handling: ProofHandling,
381 ) -> Self {
382 BalanceRequest {
383 state_hash,
384 protocol_version,
385 identifier: BalanceIdentifier::Internal(balance_addr),
386 balance_handling,
387 proof_handling,
388 }
389 }
390
391 pub fn state_hash(&self) -> Digest {
393 self.state_hash
394 }
395
396 pub fn protocol_version(&self) -> ProtocolVersion {
398 self.protocol_version
399 }
400
401 pub fn identifier(&self) -> &BalanceIdentifier {
403 &self.identifier
404 }
405
406 pub fn balance_handling(&self) -> BalanceHandling {
408 self.balance_handling
409 }
410
411 pub fn proof_handling(&self) -> ProofHandling {
413 self.proof_handling
414 }
415}
416
417pub trait AvailableBalanceChecker {
419 fn available_balance(
421 &self,
422 block_time: BlockTime,
423 total_balance: U512,
424 gas_hold_balance_handling: GasHoldBalanceHandling,
425 processing_hold_balance_handling: ProcessingHoldBalanceHandling,
426 ) -> Result<U512, BalanceFailure> {
427 if self.is_empty() {
428 return Ok(total_balance);
429 }
430
431 let gas_held = match gas_hold_balance_handling.handling() {
432 HoldBalanceHandling::Accrued => self.accrued(BalanceHoldAddrTag::Gas),
433 HoldBalanceHandling::Amortized => {
434 let interval = gas_hold_balance_handling.interval();
435 self.amortization(BalanceHoldAddrTag::Gas, block_time, interval)?
436 }
437 };
438
439 let processing_held = match processing_hold_balance_handling.handling() {
440 HoldBalanceHandling::Accrued => self.accrued(BalanceHoldAddrTag::Processing),
441 HoldBalanceHandling::Amortized => {
442 let interval = processing_hold_balance_handling.interval();
443 self.amortization(BalanceHoldAddrTag::Processing, block_time, interval)?
444 }
445 };
446
447 let held = gas_held.saturating_add(processing_held);
448
449 if held > total_balance {
450 return Ok(U512::zero());
451 }
452
453 debug_assert!(
454 total_balance >= held,
455 "it should not be possible to hold more than the total available"
456 );
457 match total_balance.checked_sub(held) {
458 Some(available_balance) => Ok(available_balance),
459 None => {
460 error!(%held, %total_balance, "held amount exceeds total balance, which should never occur.");
461 Err(BalanceFailure::HeldExceedsTotal)
462 }
463 }
464 }
465
466 fn amortization(
468 &self,
469 hold_kind: BalanceHoldAddrTag,
470 block_time: BlockTime,
471 interval: TimeDiff,
472 ) -> Result<U512, BalanceFailure> {
473 let mut held = U512::zero();
474 let block_time = block_time.value();
475 let interval = interval.millis();
476
477 for (hold_created_time, holds) in self.holds(hold_kind) {
478 let hold_created_time = hold_created_time.value();
479 if hold_created_time > block_time {
480 continue;
481 }
482 let expiry = hold_created_time.saturating_add(interval);
483 if block_time > expiry {
484 continue;
485 }
486 let held_ratio = Ratio::new_raw(
488 holds.values().copied().collect_vec().into_iter().sum(),
489 U512::one(),
490 );
491 let remaining_time = U512::from(expiry.saturating_sub(block_time));
493 let ratio = Ratio::new_raw(remaining_time, U512::from(interval));
495 match held_ratio.checked_mul(&ratio) {
506 Some(amortized) => held += amortized.to_integer(),
507 None => return Err(BalanceFailure::AmortizationFailure),
508 }
509 }
510 Ok(held)
511 }
512
513 fn accrued(&self, hold_kind: BalanceHoldAddrTag) -> U512;
515
516 fn holds(&self, hold_kind: BalanceHoldAddrTag) -> BTreeMap<BlockTime, BalanceHolds>;
518
519 fn is_empty(&self) -> bool;
521}
522
523pub type BalanceHolds = BTreeMap<BalanceHoldAddrTag, U512>;
525
526impl AvailableBalanceChecker for BTreeMap<BlockTime, BalanceHolds> {
527 fn accrued(&self, hold_kind: BalanceHoldAddrTag) -> U512 {
528 self.values()
529 .filter_map(|holds| holds.get(&hold_kind).copied())
530 .collect_vec()
531 .into_iter()
532 .sum()
533 }
534
535 fn holds(&self, hold_kind: BalanceHoldAddrTag) -> BTreeMap<BlockTime, BalanceHolds> {
536 let mut ret = BTreeMap::new();
537 for (k, v) in self {
538 if let Some(hold) = v.get(&hold_kind) {
539 let mut inner = BTreeMap::new();
540 inner.insert(hold_kind, *hold);
541 ret.insert(*k, inner);
542 }
543 }
544 ret
545 }
546
547 fn is_empty(&self) -> bool {
548 self.is_empty()
549 }
550}
551
552pub type BalanceHoldsWithProof =
554 BTreeMap<BalanceHoldAddrTag, (U512, TrieMerkleProof<Key, StoredValue>)>;
555
556impl AvailableBalanceChecker for BTreeMap<BlockTime, BalanceHoldsWithProof> {
557 fn accrued(&self, hold_kind: BalanceHoldAddrTag) -> U512 {
558 self.values()
559 .filter_map(|holds| holds.get(&hold_kind))
560 .map(|(amount, _)| *amount)
561 .collect_vec()
562 .into_iter()
563 .sum()
564 }
565
566 fn holds(&self, hold_kind: BalanceHoldAddrTag) -> BTreeMap<BlockTime, BalanceHolds> {
567 let mut ret: BTreeMap<BlockTime, BalanceHolds> = BTreeMap::new();
568 for (block_time, holds_with_proof) in self {
569 let mut holds: BTreeMap<BalanceHoldAddrTag, U512> = BTreeMap::new();
570 for (addr, (held, _)) in holds_with_proof {
571 if addr == &hold_kind {
572 match holds.entry(*addr) {
573 Entry::Vacant(v) => v.insert(*held),
574 Entry::Occupied(mut o) => &mut o.insert(*held),
575 };
576 }
577 }
578 if !holds.is_empty() {
579 match ret.entry(*block_time) {
580 Entry::Vacant(v) => v.insert(holds),
581 Entry::Occupied(mut o) => &mut o.insert(holds),
582 };
583 }
584 }
585 ret
586 }
587
588 fn is_empty(&self) -> bool {
589 self.is_empty()
590 }
591}
592
593#[derive(Debug, Clone, PartialEq, Eq)]
595pub enum ProofsResult {
596 NotRequested {
598 balance_holds: BTreeMap<BlockTime, BalanceHolds>,
600 },
601 Proofs {
603 total_balance_proof: Box<TrieMerkleProof<Key, StoredValue>>,
605 balance_holds: BTreeMap<BlockTime, BalanceHoldsWithProof>,
607 },
608}
609
610impl ProofsResult {
611 pub fn total_balance_proof(&self) -> Option<&TrieMerkleProof<Key, StoredValue>> {
613 match self {
614 ProofsResult::NotRequested { .. } => None,
615 ProofsResult::Proofs {
616 total_balance_proof,
617 ..
618 } => Some(total_balance_proof),
619 }
620 }
621
622 pub fn balance_holds_with_proof(&self) -> Option<&BTreeMap<BlockTime, BalanceHoldsWithProof>> {
624 match self {
625 ProofsResult::NotRequested { .. } => None,
626 ProofsResult::Proofs { balance_holds, .. } => Some(balance_holds),
627 }
628 }
629
630 pub fn balance_holds(&self) -> Option<&BTreeMap<BlockTime, BalanceHolds>> {
632 match self {
633 ProofsResult::NotRequested { balance_holds } => Some(balance_holds),
634 ProofsResult::Proofs { .. } => None,
635 }
636 }
637
638 pub fn total_held_amount(&self) -> U512 {
640 match self {
641 ProofsResult::NotRequested { balance_holds } => balance_holds
642 .values()
643 .flat_map(|holds| holds.values().copied())
644 .collect_vec()
645 .into_iter()
646 .sum(),
647 ProofsResult::Proofs { balance_holds, .. } => balance_holds
648 .values()
649 .flat_map(|holds| holds.values().map(|(v, _)| *v))
650 .collect_vec()
651 .into_iter()
652 .sum(),
653 }
654 }
655
656 #[allow(clippy::result_unit_err)]
658 pub fn available_balance(
659 &self,
660 block_time: BlockTime,
661 total_balance: U512,
662 gas_hold_balance_handling: GasHoldBalanceHandling,
663 processing_hold_balance_handling: ProcessingHoldBalanceHandling,
664 ) -> Result<U512, BalanceFailure> {
665 match self {
666 ProofsResult::NotRequested { balance_holds } => balance_holds.available_balance(
667 block_time,
668 total_balance,
669 gas_hold_balance_handling,
670 processing_hold_balance_handling,
671 ),
672 ProofsResult::Proofs { balance_holds, .. } => balance_holds.available_balance(
673 block_time,
674 total_balance,
675 gas_hold_balance_handling,
676 processing_hold_balance_handling,
677 ),
678 }
679 }
680}
681
682#[derive(Debug, Clone)]
684pub enum BalanceFailure {
685 AmortizationFailure,
687 HeldExceedsTotal,
689}
690
691impl Display for BalanceFailure {
692 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
693 match self {
694 BalanceFailure::AmortizationFailure => {
695 write!(
696 f,
697 "AmortizationFailure: failed to calculate amortization (checked multiplication)."
698 )
699 }
700 BalanceFailure::HeldExceedsTotal => {
701 write!(
702 f,
703 "HeldExceedsTotal: held amount exceeds total balance, which should never occur."
704 )
705 }
706 }
707 }
708}
709
710#[derive(Debug, Clone)]
712pub enum BalanceResult {
713 RootNotFound,
715 Success {
717 purse_addr: URefAddr,
719 total_balance: U512,
721 available_balance: U512,
723 proofs_result: ProofsResult,
725 },
726 Failure(TrackingCopyError),
728}
729
730impl BalanceResult {
731 pub fn purse_addr(&self) -> Option<URefAddr> {
733 match self {
734 BalanceResult::Success { purse_addr, .. } => Some(*purse_addr),
735 _ => None,
736 }
737 }
738
739 pub fn total_balance(&self) -> Option<&U512> {
741 match self {
742 BalanceResult::Success { total_balance, .. } => Some(total_balance),
743 _ => None,
744 }
745 }
746
747 pub fn available_balance(&self) -> Option<&U512> {
749 match self {
750 BalanceResult::Success {
751 available_balance, ..
752 } => Some(available_balance),
753 _ => None,
754 }
755 }
756
757 pub fn proofs_result(self) -> Option<ProofsResult> {
759 match self {
760 BalanceResult::Success { proofs_result, .. } => Some(proofs_result),
761 _ => None,
762 }
763 }
764
765 pub fn is_sufficient(&self, cost: U512) -> bool {
767 match self {
768 BalanceResult::RootNotFound | BalanceResult::Failure(_) => false,
769 BalanceResult::Success {
770 available_balance, ..
771 } => available_balance >= &cost,
772 }
773 }
774
775 pub fn is_success(&self) -> bool {
777 match self {
778 BalanceResult::RootNotFound | BalanceResult::Failure(_) => false,
779 BalanceResult::Success { .. } => true,
780 }
781 }
782
783 pub fn error(&self) -> Option<&TrackingCopyError> {
785 match self {
786 BalanceResult::RootNotFound | BalanceResult::Success { .. } => None,
787 BalanceResult::Failure(err) => Some(err),
788 }
789 }
790}
791
792impl From<TrackingCopyError> for BalanceResult {
793 fn from(tce: TrackingCopyError) -> Self {
794 BalanceResult::Failure(tce)
795 }
796}