Skip to main content

kora_lib/fee/
price.rs

1use crate::{config::Config, error::KoraError, token::token::TokenUtil};
2use rust_decimal::{
3    prelude::{FromPrimitive, ToPrimitive},
4    Decimal,
5};
6use serde::{Deserialize, Serialize};
7use solana_client::nonblocking::rpc_client::RpcClient;
8use solana_sdk::pubkey::Pubkey;
9use std::str::FromStr;
10use utoipa::ToSchema;
11
12#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
13#[serde(tag = "type", rename_all = "lowercase")]
14pub enum PriceModel {
15    Margin { margin: f64 },
16    Fixed { amount: u64, token: String, strict: bool },
17    Free,
18}
19
20impl Default for PriceModel {
21    fn default() -> Self {
22        Self::Margin { margin: 0.0 }
23    }
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, Default)]
27pub struct PriceConfig {
28    #[serde(flatten)]
29    pub model: PriceModel,
30}
31
32impl PriceConfig {
33    pub async fn get_required_lamports_with_fixed(
34        &self,
35        rpc_client: &RpcClient,
36        config: &Config,
37    ) -> Result<u64, KoraError> {
38        if let PriceModel::Fixed { amount, token, .. } = &self.model {
39            return TokenUtil::calculate_token_value_in_lamports(
40                *amount,
41                &Pubkey::from_str(token).map_err(|e| {
42                    log::error!("Invalid Pubkey for price {e}");
43
44                    KoraError::ConfigError(
45                        "Invalid token address in fee config: failed to parse as Solana pubkey"
46                            .to_string(),
47                    )
48                })?,
49                rpc_client,
50                config,
51            )
52            .await;
53        }
54
55        Err(KoraError::ConfigError(
56            "Price model is not 'Fixed': cannot compute fixed fee".to_string(),
57        ))
58    }
59
60    pub async fn get_required_lamports_with_margin(
61        &self,
62        min_transaction_fee: u64,
63    ) -> Result<u64, KoraError> {
64        if let PriceModel::Margin { margin } = &self.model {
65            let margin_decimal = Decimal::from_f64(*margin)
66                .ok_or_else(|| KoraError::ValidationError("Invalid margin".to_string()))?;
67
68            let multiplier = Decimal::from_u64(1u64)
69                .and_then(|result| result.checked_add(margin_decimal))
70                .ok_or_else(|| {
71                    log::error!(
72                        "Multiplier calculation overflow: min_transaction_fee={}, margin={}",
73                        min_transaction_fee,
74                        margin,
75                    );
76                    KoraError::ValidationError("Multiplier calculation overflow".to_string())
77                })?;
78
79            let result = Decimal::from_u64(min_transaction_fee)
80                .and_then(|result| result.checked_mul(multiplier))
81                .ok_or_else(|| {
82                    log::error!(
83                        "Margin calculation overflow: min_transaction_fee={}, margin={}",
84                        min_transaction_fee,
85                        margin,
86                    );
87                    KoraError::ValidationError("Margin calculation overflow".to_string())
88                })?;
89
90            return result.ceil().to_u64().ok_or_else(|| {
91                log::error!(
92                    "Margin calculation overflow: min_transaction_fee={}, margin={}, result={}",
93                    min_transaction_fee,
94                    margin,
95                    result
96                );
97                KoraError::ValidationError("Margin calculation overflow".to_string())
98            });
99        }
100
101        Err(KoraError::ConfigError(
102            "Price model is not 'Margin': cannot compute margin fee".to_string(),
103        ))
104    }
105}
106
107#[cfg(test)]
108mod tests {
109
110    use super::*;
111    use crate::tests::{
112        common::create_mock_rpc_client_with_mint,
113        config_mock::{mock_state::get_config, ConfigMockBuilder},
114    };
115
116    #[tokio::test]
117    async fn test_margin_model_get_required_lamports() {
118        // Test margin of 0.1 (10%)
119        let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.1 } };
120
121        let min_transaction_fee = 5000u64; // 5000 lamports base fee
122        let expected_lamports = (5000.0 * 1.1) as u64; // 5500 lamports
123
124        let result =
125            price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
126
127        assert_eq!(result, expected_lamports);
128    }
129
130    #[tokio::test]
131    async fn test_margin_model_get_required_lamports_zero_margin() {
132        // Test margin of 0.0 (no margin)
133        let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.0 } };
134
135        let min_transaction_fee = 5000u64;
136
137        let result =
138            price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
139
140        assert_eq!(result, min_transaction_fee);
141    }
142
143    #[tokio::test]
144    async fn test_fixed_model_get_required_lamports_with_oracle() {
145        let _m = ConfigMockBuilder::new().build_and_setup();
146        let config = get_config().unwrap();
147        let rpc_client = create_mock_rpc_client_with_mint(6); // USDC has 6 decimals
148
149        let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
150        let price_config = PriceConfig {
151            model: PriceModel::Fixed {
152                amount: 1_000_000, // 1 USDC (1,000,000 base units with 6 decimals)
153                token: usdc_mint.to_string(),
154                strict: false,
155            },
156        };
157
158        // Use Mock price source which returns 0.0075 SOL per USDC
159
160        let result =
161            price_config.get_required_lamports_with_fixed(&rpc_client, &config).await.unwrap();
162
163        // Expected calculation:
164        // 1,000,000 base units / 10^6 = 1.0 USDC
165        // 1.0 USDC * 0.0075 SOL/USDC = 0.0075 SOL
166        // 0.0075 SOL * 1,000,000,000 lamports/SOL = 7,500,000 lamports
167        assert_eq!(result, 7500000);
168    }
169
170    #[tokio::test]
171    async fn test_fixed_model_get_required_lamports_with_custom_price() {
172        let _m = ConfigMockBuilder::new().build_and_setup();
173        let config = get_config().unwrap();
174        let rpc_client = create_mock_rpc_client_with_mint(9); // 9 decimals token
175
176        let custom_token = "So11111111111111111111111111111111111111112"; // SOL mint
177        let price_config = PriceConfig {
178            model: PriceModel::Fixed {
179                amount: 500000000, // 0.5 tokens (500,000,000 base units with 9 decimals)
180                token: custom_token.to_string(),
181                strict: false,
182            },
183        };
184
185        // Mock oracle returns 1.0 SOL price for SOL mint
186
187        let result =
188            price_config.get_required_lamports_with_fixed(&rpc_client, &config).await.unwrap();
189
190        // Expected calculation:
191        // 500,000,000 base units / 10^9 = 0.5 tokens
192        // 0.5 tokens * 1.0 SOL/token = 0.5 SOL
193        // 0.5 SOL * 1,000,000,000 lamports/SOL = 500,000,000 lamports
194        assert_eq!(result, 500000000);
195    }
196
197    #[tokio::test]
198    async fn test_fixed_model_get_required_lamports_small_amount() {
199        let _m = ConfigMockBuilder::new().build_and_setup();
200        let config = get_config().unwrap();
201        let rpc_client = create_mock_rpc_client_with_mint(6); // USDC has 6 decimals
202
203        let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
204        let price_config = PriceConfig {
205            model: PriceModel::Fixed {
206                amount: 1000, // 0.001 USDC (1,000 base units with 6 decimals)
207                token: usdc_mint.to_string(),
208                strict: false,
209            },
210        };
211
212        let result =
213            price_config.get_required_lamports_with_fixed(&rpc_client, &config).await.unwrap();
214
215        // Expected calculation:
216        // 1,000 base units / 10^6 = 0.001 USDC
217        // 0.001 USDC * 0.0075 SOL/USDC = 0.0000075 SOL
218        // 0.0000075 SOL * 1,000,000,000 lamports/SOL = 7,500 lamports
219        assert_eq!(result, 7500);
220    }
221
222    #[tokio::test]
223    async fn test_default_price_config() {
224        // Test that default creates Margin with 0.0 margin
225        let default_config = PriceConfig::default();
226
227        match default_config.model {
228            PriceModel::Margin { margin } => assert_eq!(margin, 0.0),
229            _ => panic!("Default should be Margin with 0.0 margin"),
230        }
231    }
232}