use crate::{error::KoraError, oracle::PriceSource, token::token::TokenUtil};
use rust_decimal::{
prelude::{FromPrimitive, ToPrimitive},
Decimal,
};
use serde::{Deserialize, Serialize};
use solana_client::nonblocking::rpc_client::RpcClient;
use solana_sdk::pubkey::Pubkey;
use std::str::FromStr;
use utoipa::ToSchema;
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum PriceModel {
Margin { margin: f64 },
Fixed { amount: u64, token: String, strict: bool },
Free,
}
impl Default for PriceModel {
fn default() -> Self {
Self::Margin { margin: 0.0 }
}
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, Default)]
pub struct PriceConfig {
#[serde(flatten)]
pub model: PriceModel,
}
impl PriceConfig {
pub async fn get_required_lamports_with_fixed(
&self,
rpc_client: &RpcClient,
price_source: PriceSource,
) -> Result<u64, KoraError> {
if let PriceModel::Fixed { amount, token, .. } = &self.model {
return TokenUtil::calculate_token_value_in_lamports(
*amount,
&Pubkey::from_str(token).map_err(|e| {
log::error!("Invalid Pubkey for price {e}");
KoraError::ConfigError
})?,
price_source,
rpc_client,
)
.await;
}
Err(KoraError::ConfigError)
}
pub async fn get_required_lamports_with_margin(
&self,
min_transaction_fee: u64,
) -> Result<u64, KoraError> {
if let PriceModel::Margin { margin } = &self.model {
let margin_decimal = Decimal::from_f64(*margin)
.ok_or_else(|| KoraError::ValidationError("Invalid margin".to_string()))?;
let multiplier = Decimal::from_u64(1u64)
.and_then(|result| result.checked_add(margin_decimal))
.ok_or_else(|| {
log::error!(
"Multiplier calculation overflow: min_transaction_fee={}, margin={}",
min_transaction_fee,
margin,
);
KoraError::ValidationError("Multiplier calculation overflow".to_string())
})?;
let result = Decimal::from_u64(min_transaction_fee)
.and_then(|result| result.checked_mul(multiplier))
.ok_or_else(|| {
log::error!(
"Margin calculation overflow: min_transaction_fee={}, margin={}",
min_transaction_fee,
margin,
);
KoraError::ValidationError("Margin calculation overflow".to_string())
})?;
return result.ceil().to_u64().ok_or_else(|| {
log::error!(
"Margin calculation overflow: min_transaction_fee={}, margin={}, result={}",
min_transaction_fee,
margin,
result
);
KoraError::ValidationError("Margin calculation overflow".to_string())
});
}
Err(KoraError::ConfigError)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{common::create_mock_rpc_client_with_mint, config_mock::ConfigMockBuilder};
#[tokio::test]
async fn test_margin_model_get_required_lamports() {
let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.1 } };
let min_transaction_fee = 5000u64; let expected_lamports = (5000.0 * 1.1) as u64;
let result =
price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
assert_eq!(result, expected_lamports);
}
#[tokio::test]
async fn test_margin_model_get_required_lamports_zero_margin() {
let price_config = PriceConfig { model: PriceModel::Margin { margin: 0.0 } };
let min_transaction_fee = 5000u64;
let result =
price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();
assert_eq!(result, min_transaction_fee);
}
#[tokio::test]
async fn test_fixed_model_get_required_lamports_with_oracle() {
let _m = ConfigMockBuilder::new().build_and_setup();
let rpc_client = create_mock_rpc_client_with_mint(6);
let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
let price_config = PriceConfig {
model: PriceModel::Fixed {
amount: 1_000_000, token: usdc_mint.to_string(),
strict: false,
},
};
let price_source = PriceSource::Mock;
let result =
price_config.get_required_lamports_with_fixed(&rpc_client, price_source).await.unwrap();
assert_eq!(result, 100000);
}
#[tokio::test]
async fn test_fixed_model_get_required_lamports_with_custom_price() {
let _m = ConfigMockBuilder::new().build_and_setup();
let rpc_client = create_mock_rpc_client_with_mint(9);
let custom_token = "So11111111111111111111111111111111111111112"; let price_config = PriceConfig {
model: PriceModel::Fixed {
amount: 500000000, token: custom_token.to_string(),
strict: false,
},
};
let price_source = PriceSource::Mock;
let result =
price_config.get_required_lamports_with_fixed(&rpc_client, price_source).await.unwrap();
assert_eq!(result, 500000000);
}
#[tokio::test]
async fn test_fixed_model_get_required_lamports_small_amount() {
let _m = ConfigMockBuilder::new().build_and_setup();
let rpc_client = create_mock_rpc_client_with_mint(6);
let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
let price_config = PriceConfig {
model: PriceModel::Fixed {
amount: 1000, token: usdc_mint.to_string(),
strict: false,
},
};
let price_source = PriceSource::Mock;
let result =
price_config.get_required_lamports_with_fixed(&rpc_client, price_source).await.unwrap();
assert_eq!(result, 100);
}
#[tokio::test]
async fn test_default_price_config() {
let default_config = PriceConfig::default();
match default_config.model {
PriceModel::Margin { margin } => assert_eq!(margin, 0.0),
_ => panic!("Default should be Margin with 0.0 margin"),
}
}
}