libsubconverter/settings/
toml_deserializer.rs

1use serde::de::{MapAccess, SeqAccess, Visitor};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fmt;
5
6use crate::models::{
7    cron::CronTaskConfig, BalanceStrategy, ProxyGroupConfig, ProxyGroupType, RegexMatchConfig,
8    RulesetConfig,
9};
10use crate::settings::settings::toml_settings::TemplateSettings;
11
12pub trait ImportableInToml: serde::de::DeserializeOwned + Clone {
13    fn is_import_node(&self) -> bool;
14    fn get_import_path(&self) -> Option<String>;
15    fn try_from_toml_value(value: &toml::Value) -> Result<Self, Box<dyn std::error::Error>> {
16        Ok(value.clone().try_into()?)
17    }
18}
19
20/// Stream rule configuration
21#[derive(Debug, Clone, Serialize, Deserialize, Default)]
22#[serde(default)]
23pub struct RegexMatchRuleInToml {
24    #[serde(rename = "match")]
25    pub match_str: Option<String>,
26
27    #[serde(alias = "emoji")]
28    pub replace: Option<String>,
29    pub script: Option<String>,
30    pub import: Option<String>,
31}
32
33impl Into<RegexMatchConfig> for RegexMatchRuleInToml {
34    fn into(self) -> RegexMatchConfig {
35        RegexMatchConfig {
36            _match: self.match_str.unwrap_or_default(),
37            replace: self.replace.unwrap_or_default(),
38        }
39    }
40}
41
42impl ImportableInToml for RegexMatchRuleInToml {
43    fn is_import_node(&self) -> bool {
44        self.import.is_some()
45    }
46
47    fn get_import_path(&self) -> Option<String> {
48        self.import.clone()
49    }
50}
51
52/// Ruleset configuration
53#[derive(Debug, Clone, Serialize, Deserialize, Default)]
54#[serde(default)]
55pub struct RulesetConfigInToml {
56    pub group: String,
57    pub ruleset: Option<String>,
58    #[serde(rename = "type")]
59    pub ruleset_type: Option<String>,
60    pub interval: Option<u32>,
61    pub import: Option<String>,
62}
63
64impl ImportableInToml for RulesetConfigInToml {
65    fn is_import_node(&self) -> bool {
66        self.import.is_some()
67    }
68
69    fn get_import_path(&self) -> Option<String> {
70        self.import.clone()
71    }
72}
73
74impl Into<RulesetConfig> for RulesetConfigInToml {
75    fn into(self) -> RulesetConfig {
76        RulesetConfig {
77            url: self.ruleset.unwrap_or_default(),
78            group: self.group,
79            interval: self.interval.unwrap_or(300),
80        }
81    }
82}
83
84fn default_test_url() -> Option<String> {
85    Some("http://www.gstatic.com/generate_204".to_string())
86}
87
88fn default_interval() -> Option<u32> {
89    Some(300)
90}
91
92/// Proxy group configuration
93#[derive(Debug, Clone, Serialize, Deserialize, Default)]
94#[serde(default)]
95pub struct ProxyGroupConfigInToml {
96    pub name: String,
97    #[serde(rename = "type")]
98    pub group_type: String,
99    pub strategy: Option<String>,
100    pub rule: Vec<String>,
101    #[serde(default = "default_test_url")]
102    pub url: Option<String>,
103    #[serde(default = "default_interval")]
104    pub interval: Option<u32>,
105    pub lazy: Option<bool>,
106    pub tolerance: Option<u32>,
107    pub timeout: Option<u32>,
108    pub disable_udp: Option<bool>,
109    pub import: Option<String>,
110}
111
112impl ImportableInToml for ProxyGroupConfigInToml {
113    fn is_import_node(&self) -> bool {
114        self.import.is_some()
115    }
116
117    fn get_import_path(&self) -> Option<String> {
118        self.import.clone()
119    }
120}
121
122impl Into<ProxyGroupConfig> for ProxyGroupConfigInToml {
123    fn into(self) -> ProxyGroupConfig {
124        let group_type = match self.group_type.as_str() {
125            "select" => ProxyGroupType::Select,
126            "url-test" => ProxyGroupType::URLTest,
127            "load-balance" => ProxyGroupType::LoadBalance,
128            "fallback" => ProxyGroupType::Fallback,
129            "relay" => ProxyGroupType::Relay,
130            "ssid" => ProxyGroupType::SSID,
131            "smart" => ProxyGroupType::Smart,
132            _ => ProxyGroupType::Select, // 默认为 Select
133        };
134
135        // 处理 strategy 字段
136        let strategy = match self.strategy.as_deref() {
137            Some("consistent-hashing") => BalanceStrategy::ConsistentHashing,
138            Some("round-robin") => BalanceStrategy::RoundRobin,
139            _ => BalanceStrategy::ConsistentHashing,
140        };
141
142        // 创建基本的 ProxyGroupConfig
143        let mut config = ProxyGroupConfig {
144            name: self.name,
145            group_type,
146            proxies: self.rule,
147            url: self.url.unwrap_or_default(),
148            interval: self.interval.unwrap_or(300),
149            tolerance: self.tolerance.unwrap_or(0),
150            timeout: self.timeout.unwrap_or(5),
151            lazy: self.lazy.unwrap_or(false),
152            disable_udp: self.disable_udp.unwrap_or(false),
153            strategy,
154            // 添加缺失的字段
155            persistent: false,
156            evaluate_before_use: false,
157            using_provider: Vec::new(),
158        };
159
160        // 根据不同的代理组类型设置特定属性
161        match config.group_type {
162            ProxyGroupType::URLTest | ProxyGroupType::Smart => {
163                // 这些类型需要 URL 和 interval
164                if config.url.is_empty() {
165                    config.url = "http://www.gstatic.com/generate_204".to_string();
166                }
167            }
168            ProxyGroupType::LoadBalance => {
169                // 负载均衡需要 URL、interval 和 strategy
170                if config.url.is_empty() {
171                    config.url = "http://www.gstatic.com/generate_204".to_string();
172                }
173            }
174            ProxyGroupType::Fallback => {
175                // 故障转移需要 URL 和 interval
176                if config.url.is_empty() {
177                    config.url = "http://www.gstatic.com/generate_204".to_string();
178                }
179            }
180            _ => {}
181        }
182
183        config
184    }
185}
186
187/// Task configuration
188#[derive(Debug, Clone, Serialize, Deserialize, Default)]
189#[serde(default)]
190pub struct TaskConfigInToml {
191    pub name: String,
192    pub cronexp: String,
193    pub path: String,
194    pub timeout: u32,
195    pub import: Option<String>,
196}
197
198impl ImportableInToml for TaskConfigInToml {
199    fn is_import_node(&self) -> bool {
200        self.import.is_some()
201    }
202
203    fn get_import_path(&self) -> Option<String> {
204        self.import.clone()
205    }
206}
207
208impl Into<CronTaskConfig> for TaskConfigInToml {
209    fn into(self) -> CronTaskConfig {
210        CronTaskConfig {
211            name: self.name,
212            cron_exp: self.cronexp,
213            path: self.path,
214            timeout: self.timeout,
215        }
216    }
217}
218
219pub fn deserialize_template_as_template_settings<'de, D>(
220    deserializer: D,
221) -> Result<TemplateSettings, D::Error>
222where
223    D: serde::Deserializer<'de>,
224{
225    struct TemplateSettingsVisitor;
226
227    impl<'de> Visitor<'de> for TemplateSettingsVisitor {
228        type Value = TemplateSettings;
229
230        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
231            formatter.write_str("a TemplateSettings struct")
232        }
233
234        fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
235        where
236            V: MapAccess<'de>,
237        {
238            let mut template_settings = TemplateSettings::default();
239            while let Some(key) = map.next_key::<String>()? {
240                let value = map.next_value::<String>()?;
241                if key == "template_path" {
242                    template_settings.template_path = value.clone();
243                } else {
244                    template_settings.globals.insert(key, value);
245                }
246            }
247            Ok(template_settings)
248        }
249    }
250
251    deserializer.deserialize_any(TemplateSettingsVisitor)
252}
253
254/// Template argument structure for deserialization
255#[derive(Debug, Clone, Deserialize, Default)]
256struct TemplateArgument {
257    pub key: String,
258    pub value: String,
259}
260
261pub fn deserialize_template_args_as_hash_map<'de, D>(
262    deserializer: D,
263) -> Result<Option<HashMap<String, String>>, D::Error>
264where
265    D: serde::Deserializer<'de>,
266{
267    struct TemplateArgsVisitor;
268
269    impl<'de> Visitor<'de> for TemplateArgsVisitor {
270        type Value = Option<HashMap<String, String>>;
271
272        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
273            formatter.write_str("a sequence of template arguments or a map of key-value pairs")
274        }
275
276        fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
277        where
278            S: SeqAccess<'de>,
279        {
280            let mut template_args = HashMap::new();
281
282            while let Some(item) = seq.next_element::<TemplateArgument>()? {
283                template_args.insert(item.key, item.value);
284            }
285
286            if template_args.is_empty() {
287                Ok(None)
288            } else {
289                Ok(Some(template_args))
290            }
291        }
292
293        fn visit_none<E>(self) -> Result<Self::Value, E>
294        where
295            E: serde::de::Error,
296        {
297            Ok(None)
298        }
299
300        fn visit_unit<E>(self) -> Result<Self::Value, E>
301        where
302            E: serde::de::Error,
303        {
304            Ok(None)
305        }
306
307        fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
308        where
309            M: MapAccess<'de>,
310        {
311            let mut template_args = HashMap::new();
312
313            while let Some((key, value)) = map.next_entry::<String, String>()? {
314                template_args.insert(key, value);
315            }
316
317            if template_args.is_empty() {
318                Ok(None)
319            } else {
320                Ok(Some(template_args))
321            }
322        }
323    }
324
325    deserializer.deserialize_any(TemplateArgsVisitor)
326}