1use crate::{error::KoraError, oracle::PriceSource, token::token::TokenUtil};
2use rust_decimal::{
3 prelude::{FromPrimitive, ToPrimitive},
4 Decimal,
5};
6use serde::{Deserialize, Serialize};
7use solana_client::nonblocking::rpc_client::RpcClient;
8use solana_sdk::pubkey::Pubkey;
9use std::str::FromStr;
10use utoipa::ToSchema;
11
12#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
13#[serde(tag = "type", rename_all = "lowercase")]
14pub enum PriceModel {
15 Margin { margin: f64 },
16 Fixed { amount: u64, token: String, strict: bool },
17 Free,
18}
19
20impl Default for PriceModel {
21 fn default() -> Self {
22 Self::Margin { margin: 0.0 }
23 }
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, Default)]
27pub struct PriceConfig {
28 #[serde(flatten)]
29 pub model: PriceModel,
30}
31
32impl PriceConfig {
33 pub async fn get_required_lamports_with_fixed(
34 &self,
35 rpc_client: &RpcClient,
36 price_source: PriceSource,
37 ) -> Result<u64, KoraError> {
38 if let PriceModel::Fixed { amount, token, .. } = &self.model {
39 return TokenUtil::calculate_token_value_in_lamports(
40 *amount,
41 &Pubkey::from_str(token).map_err(|e| {
42 log::error!("Invalid Pubkey for price {e}");
43
44 KoraError::ConfigError
45 })?,
46 price_source,
47 rpc_client,
48 )
49 .await;
50 }
51
52 Err(KoraError::ConfigError)
53 }
54
55 pub async fn get_required_lamports_with_margin(
56 &self,
57 min_transaction_fee: u64,
58 ) -> Result<u64, KoraError> {
59 if let PriceModel::Margin { margin } = &self.model {
60 let margin_decimal = Decimal::from_f64(*margin)
61 .ok_or_else(|| KoraError::ValidationError("Invalid margin".to_string()))?;
62
63 let multiplier = Decimal::from_u64(1u64)
64 .and_then(|result| result.checked_add(margin_decimal))
65 .ok_or_else(|| {
66 log::error!(
67 "Multiplier calculation overflow: min_transaction_fee={}, margin={}",
68 min_transaction_fee,
69 margin,
70 );
71 KoraError::ValidationError("Multiplier calculation overflow".to_string())
72 })?;
73
74 let result = Decimal::from_u64(min_transaction_fee)
75 .and_then(|result| result.checked_mul(multiplier))
76 .ok_or_else(|| {
77 log::error!(
78 "Margin calculation overflow: min_transaction_fee={}, margin={}",
79 min_transaction_fee,
80 margin,
81 );
82 KoraError::ValidationError("Margin calculation overflow".to_string())
83 })?;
84
85 return result.ceil().to_u64().ok_or_else(|| {
86 log::error!(
87 "Margin calculation overflow: min_transaction_fee={}, margin={}, result={}",
88 min_transaction_fee,
89 margin,
90 result
91 );
92 KoraError::ValidationError("Margin calculation overflow".to_string())
93 });
94 }
95
96 Err(KoraError::ConfigError)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102
103 use super::*;
104 use crate::tests::{common::create_mock_rpc_client_with_mint, config_mock::ConfigMockBuilder};
105
106 #[tokio::test]
107 async fn test_margin_model_get_required_lamports() {
108 let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.1 } };
110
111 let min_transaction_fee = 5000u64; let expected_lamports = (5000.0 * 1.1) as u64; let result =
115 price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
116
117 assert_eq!(result, expected_lamports);
118 }
119
120 #[tokio::test]
121 async fn test_margin_model_get_required_lamports_zero_margin() {
122 let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.0 } };
124
125 let min_transaction_fee = 5000u64;
126
127 let result =
128 price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
129
130 assert_eq!(result, min_transaction_fee);
131 }
132
133 #[tokio::test]
134 async fn test_fixed_model_get_required_lamports_with_oracle() {
135 let _m = ConfigMockBuilder::new().build_and_setup();
136 let rpc_client = create_mock_rpc_client_with_mint(6); let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
139 let price_config = PriceConfig {
140 model: PriceModel::Fixed {
141 amount: 1_000_000, token: usdc_mint.to_string(),
143 strict: false,
144 },
145 };
146
147 let price_source = PriceSource::Mock;
149
150 let result =
151 price_config.get_required_lamports_with_fixed(&rpc_client, price_source).await.unwrap();
152
153 assert_eq!(result, 100000);
158 }
159
160 #[tokio::test]
161 async fn test_fixed_model_get_required_lamports_with_custom_price() {
162 let _m = ConfigMockBuilder::new().build_and_setup();
163 let rpc_client = create_mock_rpc_client_with_mint(9); let custom_token = "So11111111111111111111111111111111111111112"; let price_config = PriceConfig {
167 model: PriceModel::Fixed {
168 amount: 500000000, token: custom_token.to_string(),
170 strict: false,
171 },
172 };
173
174 let price_source = PriceSource::Mock;
176
177 let result =
178 price_config.get_required_lamports_with_fixed(&rpc_client, price_source).await.unwrap();
179
180 assert_eq!(result, 500000000);
185 }
186
187 #[tokio::test]
188 async fn test_fixed_model_get_required_lamports_small_amount() {
189 let _m = ConfigMockBuilder::new().build_and_setup();
190 let rpc_client = create_mock_rpc_client_with_mint(6); let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
193 let price_config = PriceConfig {
194 model: PriceModel::Fixed {
195 amount: 1000, token: usdc_mint.to_string(),
197 strict: false,
198 },
199 };
200
201 let price_source = PriceSource::Mock;
202
203 let result =
204 price_config.get_required_lamports_with_fixed(&rpc_client, price_source).await.unwrap();
205
206 assert_eq!(result, 100);
211 }
212
213 #[tokio::test]
214 async fn test_default_price_config() {
215 let default_config = PriceConfig::default();
217
218 match default_config.model {
219 PriceModel::Margin { margin } => assert_eq!(margin, 0.0),
220 _ => panic!("Default should be Margin with 0.0 margin"),
221 }
222 }
223}