kora_lib/fee/
price.rs

1use crate::{config::Config, error::KoraError, token::token::TokenUtil};
2use rust_decimal::{
3    prelude::{FromPrimitive, ToPrimitive},
4    Decimal,
5};
6use serde::{Deserialize, Serialize};
7use solana_client::nonblocking::rpc_client::RpcClient;
8use solana_sdk::pubkey::Pubkey;
9use std::str::FromStr;
10use utoipa::ToSchema;
11
12#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
13#[serde(tag = "type", rename_all = "lowercase")]
14pub enum PriceModel {
15    Margin { margin: f64 },
16    Fixed { amount: u64, token: String, strict: bool },
17    Free,
18}
19
20impl Default for PriceModel {
21    fn default() -> Self {
22        Self::Margin { margin: 0.0 }
23    }
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, Default)]
27pub struct PriceConfig {
28    #[serde(flatten)]
29    pub model: PriceModel,
30}
31
32impl PriceConfig {
33    pub async fn get_required_lamports_with_fixed(
34        &self,
35        rpc_client: &RpcClient,
36        config: &Config,
37    ) -> Result<u64, KoraError> {
38        if let PriceModel::Fixed { amount, token, .. } = &self.model {
39            return TokenUtil::calculate_token_value_in_lamports(
40                *amount,
41                &Pubkey::from_str(token).map_err(|e| {
42                    log::error!("Invalid Pubkey for price {e}");
43
44                    KoraError::ConfigError
45                })?,
46                rpc_client,
47                config,
48            )
49            .await;
50        }
51
52        Err(KoraError::ConfigError)
53    }
54
55    pub async fn get_required_lamports_with_margin(
56        &self,
57        min_transaction_fee: u64,
58    ) -> Result<u64, KoraError> {
59        if let PriceModel::Margin { margin } = &self.model {
60            let margin_decimal = Decimal::from_f64(*margin)
61                .ok_or_else(|| KoraError::ValidationError("Invalid margin".to_string()))?;
62
63            let multiplier = Decimal::from_u64(1u64)
64                .and_then(|result| result.checked_add(margin_decimal))
65                .ok_or_else(|| {
66                    log::error!(
67                        "Multiplier calculation overflow: min_transaction_fee={}, margin={}",
68                        min_transaction_fee,
69                        margin,
70                    );
71                    KoraError::ValidationError("Multiplier calculation overflow".to_string())
72                })?;
73
74            let result = Decimal::from_u64(min_transaction_fee)
75                .and_then(|result| result.checked_mul(multiplier))
76                .ok_or_else(|| {
77                    log::error!(
78                        "Margin calculation overflow: min_transaction_fee={}, margin={}",
79                        min_transaction_fee,
80                        margin,
81                    );
82                    KoraError::ValidationError("Margin calculation overflow".to_string())
83                })?;
84
85            return result.ceil().to_u64().ok_or_else(|| {
86                log::error!(
87                    "Margin calculation overflow: min_transaction_fee={}, margin={}, result={}",
88                    min_transaction_fee,
89                    margin,
90                    result
91                );
92                KoraError::ValidationError("Margin calculation overflow".to_string())
93            });
94        }
95
96        Err(KoraError::ConfigError)
97    }
98}
99
100#[cfg(test)]
101mod tests {
102
103    use super::*;
104    use crate::tests::{
105        common::create_mock_rpc_client_with_mint,
106        config_mock::{mock_state::get_config, ConfigMockBuilder},
107    };
108
109    #[tokio::test]
110    async fn test_margin_model_get_required_lamports() {
111        // Test margin of 0.1 (10%)
112        let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.1 } };
113
114        let min_transaction_fee = 5000u64; // 5000 lamports base fee
115        let expected_lamports = (5000.0 * 1.1) as u64; // 5500 lamports
116
117        let result =
118            price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
119
120        assert_eq!(result, expected_lamports);
121    }
122
123    #[tokio::test]
124    async fn test_margin_model_get_required_lamports_zero_margin() {
125        // Test margin of 0.0 (no margin)
126        let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.0 } };
127
128        let min_transaction_fee = 5000u64;
129
130        let result =
131            price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
132
133        assert_eq!(result, min_transaction_fee);
134    }
135
136    #[tokio::test]
137    async fn test_fixed_model_get_required_lamports_with_oracle() {
138        let _m = ConfigMockBuilder::new().build_and_setup();
139        let config = get_config().unwrap();
140        let rpc_client = create_mock_rpc_client_with_mint(6); // USDC has 6 decimals
141
142        let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
143        let price_config = PriceConfig {
144            model: PriceModel::Fixed {
145                amount: 1_000_000, // 1 USDC (1,000,000 base units with 6 decimals)
146                token: usdc_mint.to_string(),
147                strict: false,
148            },
149        };
150
151        // Use Mock price source which returns 0.0075 SOL per USDC
152
153        let result =
154            price_config.get_required_lamports_with_fixed(&rpc_client, &config).await.unwrap();
155
156        // Expected calculation:
157        // 1,000,000 base units / 10^6 = 1.0 USDC
158        // 1.0 USDC * 0.0075 SOL/USDC = 0.0075 SOL
159        // 0.0075 SOL * 1,000,000,000 lamports/SOL = 7,500,000 lamports
160        assert_eq!(result, 7500000);
161    }
162
163    #[tokio::test]
164    async fn test_fixed_model_get_required_lamports_with_custom_price() {
165        let _m = ConfigMockBuilder::new().build_and_setup();
166        let config = get_config().unwrap();
167        let rpc_client = create_mock_rpc_client_with_mint(9); // 9 decimals token
168
169        let custom_token = "So11111111111111111111111111111111111111112"; // SOL mint
170        let price_config = PriceConfig {
171            model: PriceModel::Fixed {
172                amount: 500000000, // 0.5 tokens (500,000,000 base units with 9 decimals)
173                token: custom_token.to_string(),
174                strict: false,
175            },
176        };
177
178        // Mock oracle returns 1.0 SOL price for SOL mint
179
180        let result =
181            price_config.get_required_lamports_with_fixed(&rpc_client, &config).await.unwrap();
182
183        // Expected calculation:
184        // 500,000,000 base units / 10^9 = 0.5 tokens
185        // 0.5 tokens * 1.0 SOL/token = 0.5 SOL
186        // 0.5 SOL * 1,000,000,000 lamports/SOL = 500,000,000 lamports
187        assert_eq!(result, 500000000);
188    }
189
190    #[tokio::test]
191    async fn test_fixed_model_get_required_lamports_small_amount() {
192        let _m = ConfigMockBuilder::new().build_and_setup();
193        let config = get_config().unwrap();
194        let rpc_client = create_mock_rpc_client_with_mint(6); // USDC has 6 decimals
195
196        let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
197        let price_config = PriceConfig {
198            model: PriceModel::Fixed {
199                amount: 1000, // 0.001 USDC (1,000 base units with 6 decimals)
200                token: usdc_mint.to_string(),
201                strict: false,
202            },
203        };
204
205        let result =
206            price_config.get_required_lamports_with_fixed(&rpc_client, &config).await.unwrap();
207
208        // Expected calculation:
209        // 1,000 base units / 10^6 = 0.001 USDC
210        // 0.001 USDC * 0.0075 SOL/USDC = 0.0000075 SOL
211        // 0.0000075 SOL * 1,000,000,000 lamports/SOL = 7,500 lamports
212        assert_eq!(result, 7500);
213    }
214
215    #[tokio::test]
216    async fn test_default_price_config() {
217        // Test that default creates Margin with 0.0 margin
218        let default_config = PriceConfig::default();
219
220        match default_config.model {
221            PriceModel::Margin { margin } => assert_eq!(margin, 0.0),
222            _ => panic!("Default should be Margin with 0.0 margin"),
223        }
224    }
225}