chia_sdk_driver/offers/
royalty.rs1use bigdecimal::{BigDecimal, RoundingMode, ToPrimitive};
2use chia_protocol::Bytes32;
3use chia_puzzle_types::offer::{NotarizedPayment, Payment};
4use chia_puzzles::SETTLEMENT_PAYMENT_HASH;
5use chia_sdk_types::conditions::TradePrice;
6
7use crate::{
8 AssetInfo, CatAssetInfo, CatInfo, DriverError, OfferAmounts, RequestedPayments, SpendContext,
9};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct RoyaltyInfo {
13 pub launcher_id: Bytes32,
14 pub puzzle_hash: Bytes32,
15 pub basis_points: u16,
16}
17
18impl RoyaltyInfo {
19 pub fn new(launcher_id: Bytes32, puzzle_hash: Bytes32, basis_points: u16) -> Self {
20 Self {
21 launcher_id,
22 puzzle_hash,
23 basis_points,
24 }
25 }
26
27 pub fn payment(
28 &self,
29 ctx: &mut SpendContext,
30 amount: u64,
31 ) -> Result<NotarizedPayment, DriverError> {
32 let hint = ctx.hint(self.puzzle_hash)?;
33 Ok(NotarizedPayment::new(
34 self.launcher_id,
35 vec![Payment::new(self.puzzle_hash, amount, hint)],
36 ))
37 }
38}
39
40pub fn calculate_trade_price_amounts(
41 amounts: &OfferAmounts,
42 royalty_nft_count: usize,
43) -> OfferAmounts {
44 if royalty_nft_count == 0 {
45 return OfferAmounts::new();
46 }
47
48 OfferAmounts {
49 xch: calculate_nft_trace_price(amounts.xch, royalty_nft_count),
50 cats: amounts
51 .cats
52 .iter()
53 .map(|(&asset_id, &amount)| {
54 let amount = calculate_nft_trace_price(amount, royalty_nft_count);
55 (asset_id, amount)
56 })
57 .collect(),
58 }
59}
60
61pub fn calculate_trade_prices(
62 trade_price_amounts: &OfferAmounts,
63 asset_info: &AssetInfo,
64) -> Vec<TradePrice> {
65 let mut trade_prices = Vec::new();
66
67 if trade_price_amounts.xch > 0 {
68 trade_prices.push(TradePrice::new(
69 trade_price_amounts.xch,
70 SETTLEMENT_PAYMENT_HASH.into(),
71 ));
72 }
73
74 for (&asset_id, &amount) in &trade_price_amounts.cats {
75 if amount == 0 {
76 continue;
77 }
78
79 let default = CatAssetInfo::default();
80 let info = asset_info.cat(asset_id).unwrap_or(&default);
81 let puzzle_hash = CatInfo::new(
82 asset_id,
83 info.hidden_puzzle_hash,
84 SETTLEMENT_PAYMENT_HASH.into(),
85 )
86 .puzzle_hash()
87 .into();
88
89 trade_prices.push(TradePrice::new(amount, puzzle_hash));
90 }
91
92 trade_prices
93}
94
95pub fn calculate_royalty_payments(
96 ctx: &mut SpendContext,
97 trade_prices: &OfferAmounts,
98 royalties: &[RoyaltyInfo],
99) -> Result<RequestedPayments, DriverError> {
100 let mut payments = RequestedPayments::new();
101
102 for royalty in royalties {
103 let amount = calculate_nft_royalty(trade_prices.xch, royalty.basis_points);
104
105 if amount > 0 {
106 payments.xch.push(royalty.payment(ctx, amount)?);
107 }
108
109 for (&asset_id, &amount) in &trade_prices.cats {
110 let amount = calculate_nft_royalty(amount, royalty.basis_points);
111
112 if amount > 0 {
113 payments
114 .cats
115 .entry(asset_id)
116 .or_default()
117 .push(royalty.payment(ctx, amount)?);
118 }
119 }
120 }
121
122 Ok(payments)
123}
124
125pub fn calculate_royalty_amounts(
126 trade_prices: &OfferAmounts,
127 royalties: &[RoyaltyInfo],
128) -> OfferAmounts {
129 let mut amounts = OfferAmounts::new();
130
131 for royalty in royalties {
132 amounts.xch = calculate_nft_royalty(trade_prices.xch, royalty.basis_points);
133
134 for (&asset_id, &amount) in &trade_prices.cats {
135 amounts.cats.insert(
136 asset_id,
137 calculate_nft_royalty(amount, royalty.basis_points),
138 );
139 }
140 }
141
142 amounts
143}
144
145pub fn calculate_nft_trace_price(amount: u64, royalty_nft_count: usize) -> u64 {
146 let amount = BigDecimal::from(amount);
147 let royalty_nft_count = BigDecimal::from(royalty_nft_count as u64);
148 floor(amount / royalty_nft_count)
149 .to_u64()
150 .expect("out of bounds")
151}
152
153pub fn calculate_nft_royalty(trade_price: u64, royalty_percentage: u16) -> u64 {
154 let trade_price = BigDecimal::from(trade_price);
155 let royalty_percentage = BigDecimal::from(royalty_percentage);
156 let percent = royalty_percentage / BigDecimal::from(10_000);
157 floor(trade_price * percent)
158 .to_u64()
159 .expect("out of bounds")
160}
161
162#[allow(clippy::needless_pass_by_value)]
163fn floor(amount: BigDecimal) -> BigDecimal {
164 amount.with_scale_round(0, RoundingMode::Floor)
165}