kora_lib/
config.rs

1use serde::{Deserialize, Serialize};
2use std::{fs, path::Path};
3use toml;
4use utoipa::ToSchema;
5
6use solana_client::nonblocking::rpc_client::RpcClient;
7
8use crate::{error::KoraError, oracle::PriceSource, token::check_valid_tokens};
9
10#[derive(Debug, Deserialize)]
11pub struct Config {
12    pub validation: ValidationConfig,
13    pub kora: KoraConfig,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
17pub struct ValidationConfig {
18    pub max_allowed_lamports: u64,
19    pub max_signatures: u64,
20    pub allowed_programs: Vec<String>,
21    pub allowed_tokens: Vec<String>,
22    pub allowed_spl_paid_tokens: Vec<String>,
23    pub disallowed_accounts: Vec<String>,
24    pub price_source: PriceSource,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct KoraConfig {
29    pub rate_limit: u64,
30    // pub redis_url: String,
31}
32
33pub fn load_config<P: AsRef<Path>>(path: P) -> Result<Config, KoraError> {
34    let contents = fs::read_to_string(path).map_err(|e| {
35        KoraError::InternalServerError(format!("Failed to read config file: {}", e))
36    })?;
37
38    toml::from_str(&contents)
39        .map_err(|e| KoraError::InternalServerError(format!("Failed to parse config file: {}", e)))
40}
41
42impl Config {
43    pub async fn validate(&self, _rpc_client: &RpcClient) -> Result<(), KoraError> {
44        if self.validation.allowed_tokens.is_empty() {
45            return Err(KoraError::InternalServerError("No tokens enabled".to_string()));
46        }
47
48        check_valid_tokens(&self.validation.allowed_tokens)?;
49        Ok(())
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use crate::oracle::PriceSource;
56
57    use super::*;
58    use std::fs;
59    use tempfile::NamedTempFile;
60
61    #[test]
62    fn test_load_valid_config() {
63        let config_content = r#"
64            [validation]
65            max_allowed_lamports = 1000000000
66            max_signatures = 10
67            allowed_programs = ["program1", "program2"]
68            allowed_tokens = ["token1", "token2"]
69            allowed_spl_paid_tokens = ["token3"]
70            disallowed_accounts = ["account1"]
71            price_source = "Jupiter"
72            [kora]
73            rate_limit = 100
74        "#;
75
76        let temp_file = NamedTempFile::new().unwrap();
77        fs::write(&temp_file, config_content).unwrap();
78
79        let config = load_config(temp_file.path()).unwrap();
80
81        assert_eq!(config.validation.max_allowed_lamports, 1000000000);
82        assert_eq!(config.validation.max_signatures, 10);
83        assert_eq!(config.validation.allowed_programs, vec!["program1", "program2"]);
84        assert_eq!(config.validation.allowed_tokens, vec!["token1", "token2"]);
85        assert_eq!(config.validation.allowed_spl_paid_tokens, vec!["token3"]);
86        assert_eq!(config.validation.disallowed_accounts, vec!["account1"]);
87        assert_eq!(config.validation.price_source, PriceSource::Jupiter);
88        assert_eq!(config.kora.rate_limit, 100);
89    }
90
91    #[test]
92    fn test_load_invalid_config() {
93        let invalid_content = "invalid toml content";
94        let temp_file = NamedTempFile::new().unwrap();
95        fs::write(&temp_file, invalid_content).unwrap();
96
97        let result = load_config(temp_file.path());
98        assert!(result.is_err());
99    }
100
101    #[test]
102    fn test_load_nonexistent_file() {
103        let result = load_config("nonexistent_file.toml");
104        assert!(result.is_err());
105    }
106
107    #[tokio::test]
108    async fn test_validate_config() {
109        let mut config = Config {
110            validation: ValidationConfig {
111                max_allowed_lamports: 1000000000,
112                max_signatures: 10,
113                allowed_programs: vec!["program1".to_string()],
114                allowed_tokens: vec!["token1".to_string()],
115                allowed_spl_paid_tokens: vec!["token3".to_string()],
116                disallowed_accounts: vec!["account1".to_string()],
117                price_source: PriceSource::Jupiter,
118            },
119            kora: KoraConfig { rate_limit: 100 },
120        };
121
122        // Test empty tokens list
123        config.validation.allowed_tokens.clear();
124        let rpc_client = RpcClient::new("http://localhost:8899".to_string());
125        let result = config.validate(&rpc_client).await;
126        assert!(result.is_err());
127        assert!(matches!(result.unwrap_err(), KoraError::InternalServerError(_)));
128    }
129}