1use crate::{config::Config, error::KoraError, token::token::TokenUtil};
2use rust_decimal::{
3 prelude::{FromPrimitive, ToPrimitive},
4 Decimal,
5};
6use serde::{Deserialize, Serialize};
7use solana_client::nonblocking::rpc_client::RpcClient;
8use solana_sdk::pubkey::Pubkey;
9use std::str::FromStr;
10use utoipa::ToSchema;
11
12#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
13#[serde(tag = "type", rename_all = "lowercase")]
14pub enum PriceModel {
15 Margin { margin: f64 },
16 Fixed { amount: u64, token: String, strict: bool },
17 Free,
18}
19
20impl Default for PriceModel {
21 fn default() -> Self {
22 Self::Margin { margin: 0.0 }
23 }
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, Default)]
27pub struct PriceConfig {
28 #[serde(flatten)]
29 pub model: PriceModel,
30}
31
32impl PriceConfig {
33 pub async fn get_required_lamports_with_fixed(
34 &self,
35 rpc_client: &RpcClient,
36 config: &Config,
37 ) -> Result<u64, KoraError> {
38 if let PriceModel::Fixed { amount, token, .. } = &self.model {
39 return TokenUtil::calculate_token_value_in_lamports(
40 *amount,
41 &Pubkey::from_str(token).map_err(|e| {
42 log::error!("Invalid Pubkey for price {e}");
43
44 KoraError::ConfigError
45 })?,
46 rpc_client,
47 config,
48 )
49 .await;
50 }
51
52 Err(KoraError::ConfigError)
53 }
54
55 pub async fn get_required_lamports_with_margin(
56 &self,
57 min_transaction_fee: u64,
58 ) -> Result<u64, KoraError> {
59 if let PriceModel::Margin { margin } = &self.model {
60 let margin_decimal = Decimal::from_f64(*margin)
61 .ok_or_else(|| KoraError::ValidationError("Invalid margin".to_string()))?;
62
63 let multiplier = Decimal::from_u64(1u64)
64 .and_then(|result| result.checked_add(margin_decimal))
65 .ok_or_else(|| {
66 log::error!(
67 "Multiplier calculation overflow: min_transaction_fee={}, margin={}",
68 min_transaction_fee,
69 margin,
70 );
71 KoraError::ValidationError("Multiplier calculation overflow".to_string())
72 })?;
73
74 let result = Decimal::from_u64(min_transaction_fee)
75 .and_then(|result| result.checked_mul(multiplier))
76 .ok_or_else(|| {
77 log::error!(
78 "Margin calculation overflow: min_transaction_fee={}, margin={}",
79 min_transaction_fee,
80 margin,
81 );
82 KoraError::ValidationError("Margin calculation overflow".to_string())
83 })?;
84
85 return result.ceil().to_u64().ok_or_else(|| {
86 log::error!(
87 "Margin calculation overflow: min_transaction_fee={}, margin={}, result={}",
88 min_transaction_fee,
89 margin,
90 result
91 );
92 KoraError::ValidationError("Margin calculation overflow".to_string())
93 });
94 }
95
96 Err(KoraError::ConfigError)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102
103 use super::*;
104 use crate::tests::{
105 common::create_mock_rpc_client_with_mint,
106 config_mock::{mock_state::get_config, ConfigMockBuilder},
107 };
108
109 #[tokio::test]
110 async fn test_margin_model_get_required_lamports() {
111 let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.1 } };
113
114 let min_transaction_fee = 5000u64; let expected_lamports = (5000.0 * 1.1) as u64; let result =
118 price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
119
120 assert_eq!(result, expected_lamports);
121 }
122
123 #[tokio::test]
124 async fn test_margin_model_get_required_lamports_zero_margin() {
125 let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.0 } };
127
128 let min_transaction_fee = 5000u64;
129
130 let result =
131 price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
132
133 assert_eq!(result, min_transaction_fee);
134 }
135
136 #[tokio::test]
137 async fn test_fixed_model_get_required_lamports_with_oracle() {
138 let _m = ConfigMockBuilder::new().build_and_setup();
139 let config = get_config().unwrap();
140 let rpc_client = create_mock_rpc_client_with_mint(6); let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
143 let price_config = PriceConfig {
144 model: PriceModel::Fixed {
145 amount: 1_000_000, token: usdc_mint.to_string(),
147 strict: false,
148 },
149 };
150
151 let result =
154 price_config.get_required_lamports_with_fixed(&rpc_client, &config).await.unwrap();
155
156 assert_eq!(result, 7500000);
161 }
162
163 #[tokio::test]
164 async fn test_fixed_model_get_required_lamports_with_custom_price() {
165 let _m = ConfigMockBuilder::new().build_and_setup();
166 let config = get_config().unwrap();
167 let rpc_client = create_mock_rpc_client_with_mint(9); let custom_token = "So11111111111111111111111111111111111111112"; let price_config = PriceConfig {
171 model: PriceModel::Fixed {
172 amount: 500000000, token: custom_token.to_string(),
174 strict: false,
175 },
176 };
177
178 let result =
181 price_config.get_required_lamports_with_fixed(&rpc_client, &config).await.unwrap();
182
183 assert_eq!(result, 500000000);
188 }
189
190 #[tokio::test]
191 async fn test_fixed_model_get_required_lamports_small_amount() {
192 let _m = ConfigMockBuilder::new().build_and_setup();
193 let config = get_config().unwrap();
194 let rpc_client = create_mock_rpc_client_with_mint(6); let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
197 let price_config = PriceConfig {
198 model: PriceModel::Fixed {
199 amount: 1000, token: usdc_mint.to_string(),
201 strict: false,
202 },
203 };
204
205 let result =
206 price_config.get_required_lamports_with_fixed(&rpc_client, &config).await.unwrap();
207
208 assert_eq!(result, 7500);
213 }
214
215 #[tokio::test]
216 async fn test_default_price_config() {
217 let default_config = PriceConfig::default();
219
220 match default_config.model {
221 PriceModel::Margin { margin } => assert_eq!(margin, 0.0),
222 _ => panic!("Default should be Margin with 0.0 margin"),
223 }
224 }
225}