chia_sdk_driver/offers/
royalty.rs

1use bigdecimal::{BigDecimal, RoundingMode, ToPrimitive};
2use chia_protocol::Bytes32;
3use chia_puzzle_types::offer::{NotarizedPayment, Payment};
4use chia_puzzles::SETTLEMENT_PAYMENT_HASH;
5use chia_sdk_types::conditions::TradePrice;
6
7use crate::{
8    AssetInfo, CatAssetInfo, CatInfo, DriverError, OfferAmounts, RequestedPayments, SpendContext,
9};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct RoyaltyInfo {
13    pub launcher_id: Bytes32,
14    pub puzzle_hash: Bytes32,
15    pub basis_points: u16,
16}
17
18impl RoyaltyInfo {
19    pub fn new(launcher_id: Bytes32, puzzle_hash: Bytes32, basis_points: u16) -> Self {
20        Self {
21            launcher_id,
22            puzzle_hash,
23            basis_points,
24        }
25    }
26
27    pub fn payment(
28        &self,
29        ctx: &mut SpendContext,
30        amount: u64,
31    ) -> Result<NotarizedPayment, DriverError> {
32        let hint = ctx.hint(self.puzzle_hash)?;
33        Ok(NotarizedPayment::new(
34            self.launcher_id,
35            vec![Payment::new(self.puzzle_hash, amount, hint)],
36        ))
37    }
38}
39
40pub fn calculate_trade_price_amounts(
41    amounts: &OfferAmounts,
42    royalty_nft_count: usize,
43) -> OfferAmounts {
44    if royalty_nft_count == 0 {
45        return OfferAmounts::new();
46    }
47
48    OfferAmounts {
49        xch: calculate_nft_trace_price(amounts.xch, royalty_nft_count),
50        cats: amounts
51            .cats
52            .iter()
53            .map(|(&asset_id, &amount)| {
54                let amount = calculate_nft_trace_price(amount, royalty_nft_count);
55                (asset_id, amount)
56            })
57            .collect(),
58    }
59}
60
61pub fn calculate_trade_prices(
62    trade_price_amounts: &OfferAmounts,
63    asset_info: &AssetInfo,
64) -> Vec<TradePrice> {
65    let mut trade_prices = Vec::new();
66
67    if trade_price_amounts.xch > 0 {
68        trade_prices.push(TradePrice::new(
69            trade_price_amounts.xch,
70            SETTLEMENT_PAYMENT_HASH.into(),
71        ));
72    }
73
74    for (&asset_id, &amount) in &trade_price_amounts.cats {
75        if amount == 0 {
76            continue;
77        }
78
79        let default = CatAssetInfo::default();
80        let info = asset_info.cat(asset_id).unwrap_or(&default);
81        let puzzle_hash = CatInfo::new(
82            asset_id,
83            info.hidden_puzzle_hash,
84            SETTLEMENT_PAYMENT_HASH.into(),
85        )
86        .puzzle_hash()
87        .into();
88
89        trade_prices.push(TradePrice::new(amount, puzzle_hash));
90    }
91
92    trade_prices
93}
94
95pub fn calculate_royalty_payments(
96    ctx: &mut SpendContext,
97    trade_prices: &OfferAmounts,
98    royalties: &[RoyaltyInfo],
99) -> Result<RequestedPayments, DriverError> {
100    let mut payments = RequestedPayments::new();
101
102    for royalty in royalties {
103        let amount = calculate_nft_royalty(trade_prices.xch, royalty.basis_points);
104
105        if amount > 0 {
106            payments.xch.push(royalty.payment(ctx, amount)?);
107        }
108
109        for (&asset_id, &amount) in &trade_prices.cats {
110            let amount = calculate_nft_royalty(amount, royalty.basis_points);
111
112            if amount > 0 {
113                payments
114                    .cats
115                    .entry(asset_id)
116                    .or_default()
117                    .push(royalty.payment(ctx, amount)?);
118            }
119        }
120    }
121
122    Ok(payments)
123}
124
125pub fn calculate_royalty_amounts(
126    trade_prices: &OfferAmounts,
127    royalties: &[RoyaltyInfo],
128) -> OfferAmounts {
129    let mut amounts = OfferAmounts::new();
130
131    for royalty in royalties {
132        amounts.xch = calculate_nft_royalty(trade_prices.xch, royalty.basis_points);
133
134        for (&asset_id, &amount) in &trade_prices.cats {
135            amounts.cats.insert(
136                asset_id,
137                calculate_nft_royalty(amount, royalty.basis_points),
138            );
139        }
140    }
141
142    amounts
143}
144
145pub fn calculate_nft_trace_price(amount: u64, royalty_nft_count: usize) -> u64 {
146    let amount = BigDecimal::from(amount);
147    let royalty_nft_count = BigDecimal::from(royalty_nft_count as u64);
148    floor(amount / royalty_nft_count)
149        .to_u64()
150        .expect("out of bounds")
151}
152
153pub fn calculate_nft_royalty(trade_price: u64, royalty_percentage: u16) -> u64 {
154    let trade_price = BigDecimal::from(trade_price);
155    let royalty_percentage = BigDecimal::from(royalty_percentage);
156    let percent = royalty_percentage / BigDecimal::from(10_000);
157    floor(trade_price * percent)
158        .to_u64()
159        .expect("out of bounds")
160}
161
162#[allow(clippy::needless_pass_by_value)]
163fn floor(amount: BigDecimal) -> BigDecimal {
164    amount.with_scale_round(0, RoundingMode::Floor)
165}