1use serde::{Deserialize, Serialize};
2use std::{fs, path::Path};
3use toml;
4use utoipa::ToSchema;
5
6use solana_client::nonblocking::rpc_client::RpcClient;
7
8use crate::{error::KoraError, oracle::PriceSource, token::check_valid_tokens};
9
10#[derive(Debug, Deserialize)]
11pub struct Config {
12 pub validation: ValidationConfig,
13 pub kora: KoraConfig,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
17pub struct ValidationConfig {
18 pub max_allowed_lamports: u64,
19 pub max_signatures: u64,
20 pub allowed_programs: Vec<String>,
21 pub allowed_tokens: Vec<String>,
22 pub allowed_spl_paid_tokens: Vec<String>,
23 pub disallowed_accounts: Vec<String>,
24 pub price_source: PriceSource,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct KoraConfig {
29 pub rate_limit: u64,
30 }
32
33pub fn load_config<P: AsRef<Path>>(path: P) -> Result<Config, KoraError> {
34 let contents = fs::read_to_string(path).map_err(|e| {
35 KoraError::InternalServerError(format!("Failed to read config file: {}", e))
36 })?;
37
38 toml::from_str(&contents)
39 .map_err(|e| KoraError::InternalServerError(format!("Failed to parse config file: {}", e)))
40}
41
42impl Config {
43 pub async fn validate(&self, _rpc_client: &RpcClient) -> Result<(), KoraError> {
44 if self.validation.allowed_tokens.is_empty() {
45 return Err(KoraError::InternalServerError("No tokens enabled".to_string()));
46 }
47
48 check_valid_tokens(&self.validation.allowed_tokens)?;
49 Ok(())
50 }
51}
52
53#[cfg(test)]
54mod tests {
55 use crate::oracle::PriceSource;
56
57 use super::*;
58 use std::fs;
59 use tempfile::NamedTempFile;
60
61 #[test]
62 fn test_load_valid_config() {
63 let config_content = r#"
64 [validation]
65 max_allowed_lamports = 1000000000
66 max_signatures = 10
67 allowed_programs = ["program1", "program2"]
68 allowed_tokens = ["token1", "token2"]
69 allowed_spl_paid_tokens = ["token3"]
70 disallowed_accounts = ["account1"]
71 price_source = "Jupiter"
72 [kora]
73 rate_limit = 100
74 "#;
75
76 let temp_file = NamedTempFile::new().unwrap();
77 fs::write(&temp_file, config_content).unwrap();
78
79 let config = load_config(temp_file.path()).unwrap();
80
81 assert_eq!(config.validation.max_allowed_lamports, 1000000000);
82 assert_eq!(config.validation.max_signatures, 10);
83 assert_eq!(config.validation.allowed_programs, vec!["program1", "program2"]);
84 assert_eq!(config.validation.allowed_tokens, vec!["token1", "token2"]);
85 assert_eq!(config.validation.allowed_spl_paid_tokens, vec!["token3"]);
86 assert_eq!(config.validation.disallowed_accounts, vec!["account1"]);
87 assert_eq!(config.validation.price_source, PriceSource::Jupiter);
88 assert_eq!(config.kora.rate_limit, 100);
89 }
90
91 #[test]
92 fn test_load_invalid_config() {
93 let invalid_content = "invalid toml content";
94 let temp_file = NamedTempFile::new().unwrap();
95 fs::write(&temp_file, invalid_content).unwrap();
96
97 let result = load_config(temp_file.path());
98 assert!(result.is_err());
99 }
100
101 #[test]
102 fn test_load_nonexistent_file() {
103 let result = load_config("nonexistent_file.toml");
104 assert!(result.is_err());
105 }
106
107 #[tokio::test]
108 async fn test_validate_config() {
109 let mut config = Config {
110 validation: ValidationConfig {
111 max_allowed_lamports: 1000000000,
112 max_signatures: 10,
113 allowed_programs: vec!["program1".to_string()],
114 allowed_tokens: vec!["token1".to_string()],
115 allowed_spl_paid_tokens: vec!["token3".to_string()],
116 disallowed_accounts: vec!["account1".to_string()],
117 price_source: PriceSource::Jupiter,
118 },
119 kora: KoraConfig { rate_limit: 100 },
120 };
121
122 config.validation.allowed_tokens.clear();
124 let rpc_client = RpcClient::new("http://localhost:8899".to_string());
125 let result = config.validate(&rpc_client).await;
126 assert!(result.is_err());
127 assert!(matches!(result.unwrap_err(), KoraError::InternalServerError(_)));
128 }
129}