1use crate::{config::Config, error::KoraError, token::token::TokenUtil};
2use rust_decimal::{
3 prelude::{FromPrimitive, ToPrimitive},
4 Decimal,
5};
6use serde::{Deserialize, Serialize};
7use solana_client::nonblocking::rpc_client::RpcClient;
8use solana_sdk::pubkey::Pubkey;
9use std::str::FromStr;
10use utoipa::ToSchema;
11
12#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
13#[serde(tag = "type", rename_all = "lowercase")]
14pub enum PriceModel {
15 Margin { margin: f64 },
16 Fixed { amount: u64, token: String, strict: bool },
17 Free,
18}
19
20impl Default for PriceModel {
21 fn default() -> Self {
22 Self::Margin { margin: 0.0 }
23 }
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, Default)]
27pub struct PriceConfig {
28 #[serde(flatten)]
29 pub model: PriceModel,
30}
31
32impl PriceConfig {
33 pub async fn get_required_lamports_with_fixed(
34 &self,
35 rpc_client: &RpcClient,
36 config: &Config,
37 ) -> Result<u64, KoraError> {
38 if let PriceModel::Fixed { amount, token, .. } = &self.model {
39 return TokenUtil::calculate_token_value_in_lamports(
40 *amount,
41 &Pubkey::from_str(token).map_err(|e| {
42 log::error!("Invalid Pubkey for price {e}");
43
44 KoraError::ConfigError(
45 "Invalid token address in fee config: failed to parse as Solana pubkey"
46 .to_string(),
47 )
48 })?,
49 rpc_client,
50 config,
51 )
52 .await;
53 }
54
55 Err(KoraError::ConfigError(
56 "Price model is not 'Fixed': cannot compute fixed fee".to_string(),
57 ))
58 }
59
60 pub async fn get_required_lamports_with_margin(
61 &self,
62 min_transaction_fee: u64,
63 ) -> Result<u64, KoraError> {
64 if let PriceModel::Margin { margin } = &self.model {
65 let margin_decimal = Decimal::from_f64(*margin)
66 .ok_or_else(|| KoraError::ValidationError("Invalid margin".to_string()))?;
67
68 let multiplier = Decimal::from_u64(1u64)
69 .and_then(|result| result.checked_add(margin_decimal))
70 .ok_or_else(|| {
71 log::error!(
72 "Multiplier calculation overflow: min_transaction_fee={}, margin={}",
73 min_transaction_fee,
74 margin,
75 );
76 KoraError::ValidationError("Multiplier calculation overflow".to_string())
77 })?;
78
79 let result = Decimal::from_u64(min_transaction_fee)
80 .and_then(|result| result.checked_mul(multiplier))
81 .ok_or_else(|| {
82 log::error!(
83 "Margin calculation overflow: min_transaction_fee={}, margin={}",
84 min_transaction_fee,
85 margin,
86 );
87 KoraError::ValidationError("Margin calculation overflow".to_string())
88 })?;
89
90 return result.ceil().to_u64().ok_or_else(|| {
91 log::error!(
92 "Margin calculation overflow: min_transaction_fee={}, margin={}, result={}",
93 min_transaction_fee,
94 margin,
95 result
96 );
97 KoraError::ValidationError("Margin calculation overflow".to_string())
98 });
99 }
100
101 Err(KoraError::ConfigError(
102 "Price model is not 'Margin': cannot compute margin fee".to_string(),
103 ))
104 }
105}
106
107#[cfg(test)]
108mod tests {
109
110 use super::*;
111 use crate::tests::{
112 common::create_mock_rpc_client_with_mint,
113 config_mock::{mock_state::get_config, ConfigMockBuilder},
114 };
115
116 #[tokio::test]
117 async fn test_margin_model_get_required_lamports() {
118 let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.1 } };
120
121 let min_transaction_fee = 5000u64; let expected_lamports = (5000.0 * 1.1) as u64; let result =
125 price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
126
127 assert_eq!(result, expected_lamports);
128 }
129
130 #[tokio::test]
131 async fn test_margin_model_get_required_lamports_zero_margin() {
132 let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.0 } };
134
135 let min_transaction_fee = 5000u64;
136
137 let result =
138 price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
139
140 assert_eq!(result, min_transaction_fee);
141 }
142
143 #[tokio::test]
144 async fn test_fixed_model_get_required_lamports_with_oracle() {
145 let _m = ConfigMockBuilder::new().build_and_setup();
146 let config = get_config().unwrap();
147 let rpc_client = create_mock_rpc_client_with_mint(6); let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
150 let price_config = PriceConfig {
151 model: PriceModel::Fixed {
152 amount: 1_000_000, token: usdc_mint.to_string(),
154 strict: false,
155 },
156 };
157
158 let result =
161 price_config.get_required_lamports_with_fixed(&rpc_client, &config).await.unwrap();
162
163 assert_eq!(result, 7500000);
168 }
169
170 #[tokio::test]
171 async fn test_fixed_model_get_required_lamports_with_custom_price() {
172 let _m = ConfigMockBuilder::new().build_and_setup();
173 let config = get_config().unwrap();
174 let rpc_client = create_mock_rpc_client_with_mint(9); let custom_token = "So11111111111111111111111111111111111111112"; let price_config = PriceConfig {
178 model: PriceModel::Fixed {
179 amount: 500000000, token: custom_token.to_string(),
181 strict: false,
182 },
183 };
184
185 let result =
188 price_config.get_required_lamports_with_fixed(&rpc_client, &config).await.unwrap();
189
190 assert_eq!(result, 500000000);
195 }
196
197 #[tokio::test]
198 async fn test_fixed_model_get_required_lamports_small_amount() {
199 let _m = ConfigMockBuilder::new().build_and_setup();
200 let config = get_config().unwrap();
201 let rpc_client = create_mock_rpc_client_with_mint(6); let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
204 let price_config = PriceConfig {
205 model: PriceModel::Fixed {
206 amount: 1000, token: usdc_mint.to_string(),
208 strict: false,
209 },
210 };
211
212 let result =
213 price_config.get_required_lamports_with_fixed(&rpc_client, &config).await.unwrap();
214
215 assert_eq!(result, 7500);
220 }
221
222 #[tokio::test]
223 async fn test_default_price_config() {
224 let default_config = PriceConfig::default();
226
227 match default_config.model {
228 PriceModel::Margin { margin } => assert_eq!(margin, 0.0),
229 _ => panic!("Default should be Margin with 0.0 margin"),
230 }
231 }
232}