Skip to main content

kora_lib/fee/
price.rs

1use crate::{error::KoraError, oracle::PriceSource, token::token::TokenUtil};
2use rust_decimal::{
3    prelude::{FromPrimitive, ToPrimitive},
4    Decimal,
5};
6use serde::{Deserialize, Serialize};
7use solana_client::nonblocking::rpc_client::RpcClient;
8use solana_sdk::pubkey::Pubkey;
9use std::str::FromStr;
10use utoipa::ToSchema;
11
12#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
13#[serde(tag = "type", rename_all = "lowercase")]
14pub enum PriceModel {
15    Margin { margin: f64 },
16    Fixed { amount: u64, token: String, strict: bool },
17    Free,
18}
19
20impl Default for PriceModel {
21    fn default() -> Self {
22        Self::Margin { margin: 0.0 }
23    }
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, Default)]
27pub struct PriceConfig {
28    #[serde(flatten)]
29    pub model: PriceModel,
30}
31
32impl PriceConfig {
33    pub async fn get_required_lamports_with_fixed(
34        &self,
35        rpc_client: &RpcClient,
36        price_source: PriceSource,
37    ) -> Result<u64, KoraError> {
38        if let PriceModel::Fixed { amount, token, .. } = &self.model {
39            return TokenUtil::calculate_token_value_in_lamports(
40                *amount,
41                &Pubkey::from_str(token).map_err(|e| {
42                    log::error!("Invalid Pubkey for price {e}");
43
44                    KoraError::ConfigError
45                })?,
46                price_source,
47                rpc_client,
48            )
49            .await;
50        }
51
52        Err(KoraError::ConfigError)
53    }
54
55    pub async fn get_required_lamports_with_margin(
56        &self,
57        min_transaction_fee: u64,
58    ) -> Result<u64, KoraError> {
59        if let PriceModel::Margin { margin } = &self.model {
60            let margin_decimal = Decimal::from_f64(*margin)
61                .ok_or_else(|| KoraError::ValidationError("Invalid margin".to_string()))?;
62
63            let multiplier = Decimal::from_u64(1u64)
64                .and_then(|result| result.checked_add(margin_decimal))
65                .ok_or_else(|| {
66                    log::error!(
67                        "Multiplier calculation overflow: min_transaction_fee={}, margin={}",
68                        min_transaction_fee,
69                        margin,
70                    );
71                    KoraError::ValidationError("Multiplier calculation overflow".to_string())
72                })?;
73
74            let result = Decimal::from_u64(min_transaction_fee)
75                .and_then(|result| result.checked_mul(multiplier))
76                .ok_or_else(|| {
77                    log::error!(
78                        "Margin calculation overflow: min_transaction_fee={}, margin={}",
79                        min_transaction_fee,
80                        margin,
81                    );
82                    KoraError::ValidationError("Margin calculation overflow".to_string())
83                })?;
84
85            return result.ceil().to_u64().ok_or_else(|| {
86                log::error!(
87                    "Margin calculation overflow: min_transaction_fee={}, margin={}, result={}",
88                    min_transaction_fee,
89                    margin,
90                    result
91                );
92                KoraError::ValidationError("Margin calculation overflow".to_string())
93            });
94        }
95
96        Err(KoraError::ConfigError)
97    }
98}
99
100#[cfg(test)]
101mod tests {
102
103    use super::*;
104    use crate::tests::{common::create_mock_rpc_client_with_mint, config_mock::ConfigMockBuilder};
105
106    #[tokio::test]
107    async fn test_margin_model_get_required_lamports() {
108        // Test margin of 0.1 (10%)
109        let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.1 } };
110
111        let min_transaction_fee = 5000u64; // 5000 lamports base fee
112        let expected_lamports = (5000.0 * 1.1) as u64; // 5500 lamports
113
114        let result =
115            price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
116
117        assert_eq!(result, expected_lamports);
118    }
119
120    #[tokio::test]
121    async fn test_margin_model_get_required_lamports_zero_margin() {
122        // Test margin of 0.0 (no margin)
123        let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.0 } };
124
125        let min_transaction_fee = 5000u64;
126
127        let result =
128            price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
129
130        assert_eq!(result, min_transaction_fee);
131    }
132
133    #[tokio::test]
134    async fn test_fixed_model_get_required_lamports_with_oracle() {
135        let _m = ConfigMockBuilder::new().build_and_setup();
136        let rpc_client = create_mock_rpc_client_with_mint(6); // USDC has 6 decimals
137
138        let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
139        let price_config = PriceConfig {
140            model: PriceModel::Fixed {
141                amount: 1_000_000, // 1 USDC (1,000,000 base units with 6 decimals)
142                token: usdc_mint.to_string(),
143                strict: false,
144            },
145        };
146
147        // Use Mock price source which returns 0.0001 SOL per USDC
148        let price_source = PriceSource::Mock;
149
150        let result =
151            price_config.get_required_lamports_with_fixed(&rpc_client, price_source).await.unwrap();
152
153        // Expected calculation:
154        // 1,000,000 base units / 10^6 = 1.0 USDC
155        // 1.0 USDC * 0.0001 SOL/USDC = 0.0001 SOL
156        // 0.0001 SOL * 1,000,000,000 lamports/SOL = 100,000 lamports
157        assert_eq!(result, 100000);
158    }
159
160    #[tokio::test]
161    async fn test_fixed_model_get_required_lamports_with_custom_price() {
162        let _m = ConfigMockBuilder::new().build_and_setup();
163        let rpc_client = create_mock_rpc_client_with_mint(9); // 9 decimals token
164
165        let custom_token = "So11111111111111111111111111111111111111112"; // SOL mint
166        let price_config = PriceConfig {
167            model: PriceModel::Fixed {
168                amount: 500000000, // 0.5 tokens (500,000,000 base units with 9 decimals)
169                token: custom_token.to_string(),
170                strict: false,
171            },
172        };
173
174        // Mock oracle returns 1.0 SOL price for SOL mint
175        let price_source = PriceSource::Mock;
176
177        let result =
178            price_config.get_required_lamports_with_fixed(&rpc_client, price_source).await.unwrap();
179
180        // Expected calculation:
181        // 500,000,000 base units / 10^9 = 0.5 tokens
182        // 0.5 tokens * 1.0 SOL/token = 0.5 SOL
183        // 0.5 SOL * 1,000,000,000 lamports/SOL = 500,000,000 lamports
184        assert_eq!(result, 500000000);
185    }
186
187    #[tokio::test]
188    async fn test_fixed_model_get_required_lamports_small_amount() {
189        let _m = ConfigMockBuilder::new().build_and_setup();
190        let rpc_client = create_mock_rpc_client_with_mint(6); // USDC has 6 decimals
191
192        let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
193        let price_config = PriceConfig {
194            model: PriceModel::Fixed {
195                amount: 1000, // 0.001 USDC (1,000 base units with 6 decimals)
196                token: usdc_mint.to_string(),
197                strict: false,
198            },
199        };
200
201        let price_source = PriceSource::Mock;
202
203        let result =
204            price_config.get_required_lamports_with_fixed(&rpc_client, price_source).await.unwrap();
205
206        // Expected calculation:
207        // 1,000 base units / 10^6 = 0.001 USDC
208        // 0.001 USDC * 0.0001 SOL/USDC = 0.0000001 SOL
209        // 0.0000001 SOL * 1,000,000,000 lamports/SOL = 100 lamports (rounded down)
210        assert_eq!(result, 100);
211    }
212
213    #[tokio::test]
214    async fn test_default_price_config() {
215        // Test that default creates Margin with 0.0 margin
216        let default_config = PriceConfig::default();
217
218        match default_config.model {
219            PriceModel::Margin { margin } => assert_eq!(margin, 0.0),
220            _ => panic!("Default should be Margin with 0.0 margin"),
221        }
222    }
223}