1pub use aho_corasick::AhoCorasick;
2#[cfg(feature = "maxminddb")]
3pub use maxminddb;
4pub use prefix_trie::PrefixMap;
5pub use radix_trie::{Trie, TrieCommon};
6#[cfg(feature = "rusqlite")]
7pub use rusqlite;
8
9#[cfg(feature = "serde_yaml_ng")]
10pub use serde_yaml_ng;
11
12use ipnet::{Ipv4Net, Ipv6Net};
13use std::collections::HashMap;
14use std::fmt::Display;
15use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
16use std::num::ParseIntError;
17use std::path::Path;
18
19use serde::{Deserialize, Serialize};
20pub const DOMAIN: &str = "DOMAIN";
21pub const DOMAIN_SUFFIX: &str = "DOMAIN-SUFFIX";
22pub const DOMAIN_KEYWORD: &str = "DOMAIN-KEYWORD";
23pub const DOMAIN_REGEX: &str = "DOMAIN-REGEX";
24pub const IP_CIDR: &str = "IP-CIDR";
25pub const IP_CIDR6: &str = "IP-CIDR6";
26pub const PROCESS_NAME: &str = "PROCESS-NAME";
27pub const DST_PORT: &str = "DST-PORT";
28pub const GEOIP: &str = "GEOIP";
29pub const NETWORK: &str = "NETWORK";
30pub const AND: &str = "AND";
31pub const OR: &str = "OR";
32pub const NOT: &str = "NOT";
33pub const MATCH: &str = "MATCH";
34
35#[derive(Serialize, Deserialize, Debug)]
36pub struct RuleSet {
37 pub payload: Vec<String>,
38}
39#[derive(Debug)]
40pub enum LoadYamlFileError {
41 FileErr(std::io::Error),
42 #[cfg(feature = "serde_yaml_ng")]
43 YamlErr(serde_yaml_ng::Error),
44}
45impl Display for LoadYamlFileError {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 LoadYamlFileError::FileErr(error) => write!(f, "{}", error),
49 #[cfg(feature = "serde_yaml_ng")]
50 LoadYamlFileError::YamlErr(error) => write!(f, "{}", error),
51 }
52 }
53}
54
55impl From<std::io::Error> for LoadYamlFileError {
56 fn from(err: std::io::Error) -> Self {
57 LoadYamlFileError::FileErr(err)
58 }
59}
60
61#[cfg(feature = "serde_yaml_ng")]
62impl From<serde_yaml_ng::Error> for LoadYamlFileError {
63 fn from(err: serde_yaml_ng::Error) -> Self {
64 LoadYamlFileError::YamlErr(err)
65 }
66}
67
68#[cfg(feature = "serde_yaml_ng")]
69pub fn load_rule_set_from_file<P: AsRef<Path>>(path: P) -> Result<RuleSet, LoadYamlFileError> {
70 let content = std::fs::read_to_string(path)?;
71 let ruleset = load_rule_set_from_str(&content)?;
72 Ok(ruleset)
73}
74#[cfg(feature = "serde_yaml_ng")]
75pub fn load_rule_set_from_str(s: &str) -> Result<RuleSet, serde_yaml_ng::Error> {
76 let ruleset = serde_yaml_ng::from_str(s)?;
77 Ok(ruleset)
78}
79
80pub fn parse_rule_set_as_domain_suffix_trie(
82 mut trie: Trie<String, usize>,
83 payload: &[String],
84 target_id: usize,
85) {
86 for v in payload.iter() {
87 let mut r: String = v.chars().rev().collect();
88 r = r.trim_end_matches('+').to_string();
90 trie.insert(r, target_id);
91 }
92}
93
94pub fn parse_rule_set_as_ip_cidr_trie(
96 mut trie: PrefixMap<Ipv4Net, usize>,
97 mut trie6: PrefixMap<Ipv6Net, usize>,
98 payload: &[String],
99 target_id: usize,
100) {
101 for v in payload.iter() {
102 let r: Result<Ipv4Net, <Ipv4Net as std::str::FromStr>::Err> = v.parse();
103 match r {
104 Ok(r) => trie.insert(r, target_id),
105 Err(_) => {
106 let r: Ipv6Net = v.parse().unwrap();
107 trie6.insert(r, target_id)
108 }
109 };
110 }
111}
112
113pub fn parse_rule_set_as_classic(
114 payload: &[String],
115 target: String,
116) -> HashMap<String, Vec<Vec<String>>> {
117 let mut hashmap: HashMap<String, Vec<Vec<String>>> = HashMap::new();
118 for rule in payload {
119 let mut parts = rule.split(',');
120 if let Some(key) = parts.next() {
121 let mut values: Vec<String> = parts.map(|part| part.to_string()).collect();
122 values.insert(1, target.clone());
123 hashmap.entry(key.to_string()).or_default().push(values);
124 }
125 }
126 hashmap
127}
128
129#[derive(Serialize, Deserialize, Debug)]
130pub struct RuleConfig {
131 pub rules: Vec<String>,
132}
133#[cfg(feature = "serde_yaml_ng")]
134pub fn load_rules_from_file<P: AsRef<Path>>(path: P) -> Result<RuleConfig, LoadYamlFileError> {
135 let s = std::fs::read_to_string(path)?;
136 let rs = load_rules_from_str(&s)?;
137 Ok(rs)
138}
139#[cfg(feature = "serde_yaml_ng")]
140pub fn load_rules_from_str(s: &str) -> Result<RuleConfig, serde_yaml_ng::Error> {
141 let rs = serde_yaml_ng::from_str(s)?;
142 Ok(rs)
143}
144
145pub fn parse_rules(rc: &RuleConfig) -> HashMap<String, Vec<Vec<String>>> {
147 let mut hashmap: HashMap<String, Vec<Vec<String>>> = HashMap::new();
149 for rule in &rc.rules {
150 let mut parts = rule.split(',');
151 if let Some(key) = parts.next() {
152 let values: Vec<String> = parts.map(|part| part.to_string()).collect();
153 hashmap.entry(key.to_string()).or_default().push(values);
154 }
155 }
156 hashmap
157}
158
159pub fn merge_method_rules_map(
160 map1: HashMap<String, Vec<Vec<String>>>,
161 map2: HashMap<String, Vec<Vec<String>>>,
162) -> HashMap<String, Vec<Vec<String>>> {
163 let mut merged_map = map1;
164
165 for (key, value) in map2 {
166 merged_map.entry(key).or_default().extend(value);
167 }
168
169 merged_map
170}
171
172pub fn get_domain_rules(
173 method_rules_map: &HashMap<String, Vec<Vec<String>>>,
174) -> Option<&Vec<Vec<String>>> {
175 method_rules_map.get(DOMAIN)
176}
177pub fn get_suffix_rules(
178 method_rules_map: &HashMap<String, Vec<Vec<String>>>,
179) -> Option<&Vec<Vec<String>>> {
180 method_rules_map.get(DOMAIN_SUFFIX)
181}
182pub fn get_keyword_rules(
183 method_rules_map: &HashMap<String, Vec<Vec<String>>>,
184) -> Option<&Vec<Vec<String>>> {
185 method_rules_map.get(DOMAIN_KEYWORD)
186}
187pub fn get_ip_cidr_rules(
188 method_rules_map: &HashMap<String, Vec<Vec<String>>>,
189) -> Option<&Vec<Vec<String>>> {
190 method_rules_map.get(IP_CIDR)
191}
192pub fn get_ip6_cidr_rules(
193 method_rules_map: &HashMap<String, Vec<Vec<String>>>,
194) -> Option<&Vec<Vec<String>>> {
195 method_rules_map.get(IP_CIDR6)
196}
197
198pub fn get_item_target_map(rules: &[Vec<String>]) -> HashMap<String, String> {
200 let mut map: HashMap<String, String> = HashMap::new();
201 for x in rules {
202 let item = x.first().unwrap();
203 let target = x.get(1).unwrap();
204 map.insert(item.clone(), target.clone());
205 }
206 map
207}
208
209pub fn get_target_item_map(rules: &[Vec<String>]) -> HashMap<String, Vec<String>> {
211 let mut map: HashMap<String, Vec<String>> = HashMap::new();
212 for part in rules {
213 let item = part.first().unwrap();
214 let target = part.get(1).unwrap();
215 map.entry(target.clone()).or_default().push(item.clone());
216 }
217 map
218}
219
220pub fn gen_keywords_ac(
221 target_keywords_map: &HashMap<String, Vec<String>>,
222) -> HashMap<String, AhoCorasick> {
223 target_keywords_map
224 .iter()
225 .map(|(k, v)| (k.clone(), AhoCorasick::new(v).unwrap()))
226 .collect()
227}
228
229pub fn get_keywords_targets(rules: &[Vec<String>]) -> Vec<String> {
230 rules.iter().filter_map(|v| v.get(1).cloned()).collect()
231}
232pub fn gen_keywords_ac2(rules: &[Vec<String>]) -> AhoCorasick {
233 let result: Vec<String> = rules.iter().filter_map(|v| v.first().cloned()).collect();
234
235 AhoCorasick::new(&result).unwrap()
236}
237pub fn gen_ip_trie<T: AsRef<str>>(target_ip_map: &HashMap<T, Vec<T>>) -> PrefixMap<Ipv4Net, usize> {
238 let mut trie = PrefixMap::<Ipv4Net, usize>::new();
239 for (i, (_key, value)) in target_ip_map.iter().enumerate() {
240 for v in value {
241 let r: Ipv4Net = v.as_ref().parse().unwrap();
242 trie.insert(r, i);
243 }
244 }
245 trie
246}
247#[derive(PartialEq, Eq, Debug)]
248struct Ipv4NetWrapper(pub Ipv4Net);
249impl radix_trie::TrieKey for Ipv4NetWrapper {
250 fn encode_bytes(&self) -> Vec<u8> {
251 fn u32_to_bit_u8_vec(n: u32, len: u8) -> Vec<u8> {
252 (32 - len..32u8).rev().map(|i| (n >> i) as u8).collect()
253 }
254 let ipnet = &self.0;
255 u32_to_bit_u8_vec(
256 u32::from_be_bytes(ipnet.network().octets()),
257 ipnet.prefix_len(),
258 )
259 }
260}
261
262pub struct IpTrie2(Trie<Ipv4NetWrapper, usize>);
265pub fn gen_ip_trie2<T: AsRef<str>>(target_ip_map: &HashMap<T, Vec<T>>) -> IpTrie2 {
266 let mut trie = Trie::<Ipv4NetWrapper, usize>::new();
267 for (i, (_key, value)) in target_ip_map.iter().enumerate() {
268 for v in value {
269 let r: Ipv4Net = v.as_ref().parse().unwrap();
270 trie.insert(Ipv4NetWrapper(r), i);
271 }
272 }
273 IpTrie2(trie)
274}
275pub fn gen_ip6_trie<T: AsRef<str>>(
277 target_ip_map: &HashMap<T, Vec<T>>,
278) -> PrefixMap<Ipv6Net, usize> {
279 let mut trie = PrefixMap::<Ipv6Net, usize>::new();
280 for (i, (_key, value)) in target_ip_map.iter().enumerate() {
281 for v in value {
282 let r: Ipv6Net = v.as_ref().parse().unwrap();
283 trie.insert(r, i);
284 }
285 }
286 trie
287}
288pub fn gen_prefix_trie<T: AsRef<str>>(target_item_map: &HashMap<T, Vec<T>>) -> Trie<String, usize> {
290 let mut trie = Trie::new();
291
292 for (i, (_key, value)) in target_item_map.iter().enumerate() {
293 for v in value {
294 trie.insert(v.as_ref().to_string(), i);
295 }
296 }
297 trie
298}
299
300pub fn gen_suffix_trie<T: AsRef<str>>(
303 target_suffix_map: &HashMap<T, Vec<T>>,
304) -> Trie<String, usize> {
305 let mut trie = Trie::new();
306
307 for (i, (_key, value)) in target_suffix_map.iter().enumerate() {
308 for v in value {
309 let r: String = v.as_ref().chars().rev().collect();
310 trie.insert(r, i);
311 }
312 }
313 trie
314}
315
316pub fn check_suffix_dummy<'a, T>(
317 target_suffix_map: &'a HashMap<T, Vec<T>>,
318 domain: &str,
319) -> Option<&'a T>
320where
321 T: AsRef<str> + Eq + std::hash::Hash,
322{
323 for (target, items) in target_suffix_map {
324 for v in items {
325 if domain.ends_with(v.as_ref()) {
326 return Some(target);
327 }
328 }
329 }
330 None
331}
332
333pub fn check_suffix_trie(trie: &Trie<String, usize>, domain: &str) -> Option<usize> {
335 let sr: String = domain.chars().rev().collect();
336 if let Some(subtree) = trie.get_ancestor(&sr) {
337 subtree.value().cloned()
338 } else {
339 None
340 }
341}
342pub fn check_prefix_trie(trie: &Trie<&str, usize>, domain: &str) -> Option<usize> {
344 if let Some(subtree) = trie.get_ancestor(domain) {
345 subtree.value().cloned()
346 } else {
347 None
348 }
349}
350pub fn check_keyword_ac<'a, T: AsRef<str>>(
351 target_keyword_ac_map: &'a HashMap<T, AhoCorasick>,
352 domain: &str,
353) -> Option<&'a str> {
354 for (target, ac) in target_keyword_ac_map {
355 if ac.is_match(domain) {
356 return Some(target.as_ref());
357 }
358 }
359 None
360}
361
362pub fn check_keyword_ac2<'a>(
364 keyword_ac: &AhoCorasick,
365 domain: &str,
366 targets: &'a [String],
367) -> Option<&'a String> {
368 if let Some(mat) = keyword_ac.find_iter(domain).next() {
369 let keyword_index = mat.pattern();
370 return Some(&targets[keyword_index]);
371 }
372 None
373}
374
375pub fn check_keyword_dummy<'a, T>(
376 target_keyword_map: &'a HashMap<T, Vec<T>>,
377 domain: &str,
378) -> Option<&'a T>
379where
380 T: AsRef<str> + Eq + std::hash::Hash,
381{
382 for (target, items) in target_keyword_map {
383 for v in items {
384 if domain.contains(v.as_ref()) {
385 return Some(target);
386 }
387 }
388 }
389 None
390}
391
392pub fn check_ip_trie2(trie: &IpTrie2, ip: Ipv4Addr) -> Option<usize> {
393 let ipn = Ipv4NetWrapper(Ipv4Net::new(ip, 32).unwrap());
394 if let Some(subtree) = trie.0.get_ancestor(&ipn) {
395 subtree.value().cloned()
396 } else {
397 None
398 }
399}
400pub fn check_ip_trie(trie: &PrefixMap<Ipv4Net, usize>, ip: Ipv4Addr) -> Option<usize> {
401 trie.get_lpm(&Ipv4Net::new(ip, 32).unwrap()).map(|r| *r.1)
402}
403pub fn check_ip6_trie(trie: &PrefixMap<Ipv6Net, usize>, ip6: Ipv6Addr) -> Option<usize> {
404 trie.get_lpm(&Ipv6Net::new(ip6, 32).unwrap()).map(|r| *r.1)
405}
406
407#[cfg(test)]
408pub fn get_test_ips() -> Vec<Ipv4Addr> {
409 vec![
410 Ipv4Addr::new(1, 2, 3, 4),
411 Ipv4Addr::new(2, 2, 3, 4),
412 Ipv4Addr::new(3, 2, 3, 4),
413 Ipv4Addr::new(15, 207, 213, 128),
414 ]
415}
416#[cfg(test)]
417pub fn get_test_domains() -> Vec<&'static str> {
418 vec![
419 "www.google.com",
420 "jdj.reddit.com",
421 "hdjd.baidu.com",
422 "hshsh.djdjdj.djdj",
423 ]
424}
425#[cfg(feature = "serde_yaml_ng")]
427#[test]
428fn test() {
429 let rule_map = parse_rules(&load_rules_from_file("test.yaml").unwrap());
430
431 let dr = get_domain_rules(&rule_map).unwrap();
432 println!("{:?}", dr.len());
433 let suffix_rules = get_suffix_rules(&rule_map).unwrap();
434 println!("{:?}", suffix_rules.len());
435 let suffix_map = get_target_item_map(suffix_rules);
436
437 let suffix_targets: Vec<&String> = suffix_map.keys().collect();
438
439 println!("{:?}", suffix_targets);
440 let trie = gen_suffix_trie(&suffix_map);
441
442 let keyword_rules = get_keyword_rules(&rule_map).unwrap();
443 println!("{:?}", keyword_rules.len());
444 let kmap = get_target_item_map(keyword_rules);
445 let ac = gen_keywords_ac(&kmap);
446 let ac2 = gen_keywords_ac2(keyword_rules);
447 let ac2_targets = get_keywords_targets(keyword_rules);
448
449 let ds = get_test_domains();
450 for d in &ds {
451 let r = check_suffix_trie(&trie, d);
452 println!("{:?}", r.map(|i| suffix_targets.get(i).unwrap()));
453 let r = check_keyword_ac(&ac, d);
454 println!("{:?}", r);
455 let r = check_keyword_ac2(&ac2, d, &ac2_targets);
456 println!("{:?}", r);
457 }
458
459 let ip_rules = get_ip_cidr_rules(&rule_map).unwrap();
460 println!("{:?}", ip_rules.len());
461 let ip_map = get_target_item_map(ip_rules);
462 let ip_targets: Vec<_> = ip_map.keys().collect();
463 let it = gen_ip_trie(&ip_map);
464 let it2 = gen_ip_trie2(&ip_map);
465
466 let ips = get_test_ips();
467 for ip in &ips {
468 let r = check_ip_trie(&it, *ip);
469 println!("{:?}", r.map(|i| ip_targets.get(i).unwrap()));
470 let r = check_ip_trie2(&it2, *ip);
471 println!("{:?}", r.map(|i| ip_targets.get(i).unwrap()));
472 }
473
474 let ip_rules = get_ip6_cidr_rules(&rule_map).unwrap();
475 println!("{:?}", ip_rules.len());
476
477 let cm = ClashRuleMatcher::from_clash_config_file("test.yaml").unwrap();
478 for d in ds {
479 let r = cm.check_domain(d);
480 println!("{:?}", r);
481 }
482 for ip in ips {
483 let r = cm.check_ip(std::net::IpAddr::V4(ip));
484 println!("{:?}", r);
485 }
489}
490#[cfg(feature = "maxminddb")]
491pub fn get_ip_iso_by_reader(ip: IpAddr, reader: &maxminddb::Reader<Vec<u8>>) -> &str {
493 let r = reader.lookup(ip);
494 let c: maxminddb::geoip2::Country = match r {
495 Ok(c) => c,
496 Err(_e) => {
497 return "";
499 }
500 };
501 if let Some(c) = c.country {
502 c.iso_code.unwrap_or_default()
503 } else {
504 ""
505 }
506}
507
508#[derive(Debug)]
509pub struct DomainKeywordMatcher {
510 pub ac: AhoCorasick,
511 pub targets: Vec<String>,
512}
513impl DomainKeywordMatcher {
514 pub fn check(&self, domain: &str) -> Option<&String> {
515 check_keyword_ac2(&self.ac, domain, &self.targets)
516 }
517}
518#[derive(Debug)]
519pub struct DomainSuffixMatcher {
520 pub trie: Trie<String, usize>,
521 pub targets: Vec<String>,
522}
523impl DomainSuffixMatcher {
524 pub fn check(&self, domain: &str) -> Option<&String> {
525 check_suffix_trie(&self.trie, domain).map(|i| self.targets.get(i).unwrap())
526 }
527}
528#[derive(Debug)]
529pub struct IpMatcher {
530 pub trie: PrefixMap<Ipv4Net, usize>,
531 pub targets: Vec<String>,
532}
533impl IpMatcher {
534 pub fn check(&self, ip: Ipv4Addr) -> Option<&String> {
535 check_ip_trie(&self.trie, ip).map(|i| self.targets.get(i).unwrap())
536 }
537}
538#[derive(Debug)]
539pub struct Ip6Matcher {
540 pub trie: PrefixMap<Ipv6Net, usize>,
541 pub targets: Vec<String>,
542}
543impl Ip6Matcher {
544 pub fn check(&self, ip: Ipv6Addr) -> Option<&String> {
545 check_ip6_trie(&self.trie, ip).map(|i| self.targets.get(i).unwrap())
546 }
547}
548
549#[derive(Debug, Default)]
552pub struct ClashRuleMatcher {
553 pub domain_target_map: Option<HashMap<String, String>>,
554 pub domain_keyword_matcher: Option<DomainKeywordMatcher>,
555 pub domain_suffix_matcher: Option<DomainSuffixMatcher>,
556 pub domain_regex_set: Option<HashMap<String, regex::RegexSet>>,
557 pub ip4_matcher: Option<IpMatcher>,
558 pub ip6_matcher: Option<Ip6Matcher>,
559
560 #[cfg(feature = "maxminddb")]
562 pub mmdb_reader: Option<std::sync::Arc<maxminddb::Reader<Vec<u8>>>>,
563
564 #[cfg(feature = "maxminddb")]
566 pub country_target_map: Option<HashMap<String, String>>,
567
568 pub rules: Vec<(Rule, String)>,
573}
574
575impl ClashRuleMatcher {
576 pub fn from_hashmap(
577 mut method_rules_map: HashMap<String, Vec<Vec<String>>>,
578 ) -> Result<Self, ParseRuleError> {
579 let mut s = Self::default();
580
581 if let Some(v) = get_domain_rules(&method_rules_map) {
582 s.domain_target_map = Some(get_item_target_map(v));
583 method_rules_map.remove(DOMAIN);
584 }
585 #[cfg(feature = "maxminddb")]
586 if let Some(v) = method_rules_map.get(GEOIP) {
587 s.country_target_map = Some(get_item_target_map(v));
588 method_rules_map.remove(GEOIP);
589 }
590 if let Some(v) = get_keyword_rules(&method_rules_map) {
591 let map = get_target_item_map(v);
592 let targets = map.keys().cloned().collect();
593 let ac = gen_keywords_ac2(v);
594 s.domain_keyword_matcher = Some(DomainKeywordMatcher { ac, targets });
595 method_rules_map.remove(DOMAIN_KEYWORD);
596 }
597 if let Some(v) = get_suffix_rules(&method_rules_map) {
598 let map = get_target_item_map(v);
599 let targets = map.keys().cloned().collect();
600 let trie = gen_suffix_trie(&map);
601 s.domain_suffix_matcher = Some(DomainSuffixMatcher { trie, targets });
602 method_rules_map.remove(DOMAIN_SUFFIX);
603 }
604 if let Some(v) = get_ip_cidr_rules(&method_rules_map) {
605 let map = get_target_item_map(v);
606 let targets = map.keys().cloned().collect();
607 let trie = gen_ip_trie(&map);
608 s.ip4_matcher = Some(IpMatcher { trie, targets });
609 method_rules_map.remove(IP_CIDR);
610 }
611 if let Some(v) = get_ip6_cidr_rules(&method_rules_map) {
612 let map = get_target_item_map(v);
613 let targets = map.keys().cloned().collect();
614 let trie = gen_ip6_trie(&map);
615 s.ip6_matcher = Some(Ip6Matcher { trie, targets });
616 method_rules_map.remove(IP_CIDR6);
617 }
618 if let Some(v) = method_rules_map.get(DOMAIN_REGEX) {
619 let map = get_target_item_map(v);
620 s.domain_regex_set = Some(
621 map.into_iter()
622 .map(|(t, r)| (t, regex::RegexSet::new(r).unwrap()))
623 .collect(),
624 );
625 method_rules_map.remove(DOMAIN_REGEX);
626 }
627
628 for (rt, v) in method_rules_map {
629 for ss in v {
630 let mut ss = ss.clone();
632 let target = if ss.len() > 1 {
633 ss.remove(1)
634 } else {
635 ss.pop().unwrap()
636 };
637 ss.insert(0, rt.clone());
638 let rule = ss.join(",");
639 let r = parse_rule(&rule)?;
640 s.rules.push((r, target));
641 }
642 }
643
644 Ok(s)
645 }
646 #[cfg(feature = "serde_yaml_ng")]
647 pub fn from_clash_config_str(cs: &str) -> Result<Self, Box<dyn std::error::Error>> {
648 let method_rules_map = parse_rules(&load_rules_from_str(cs)?);
649
650 Ok(Self::from_hashmap(method_rules_map)?)
651 }
652 #[cfg(feature = "serde_yaml_ng")]
653 pub fn from_clash_config_file<P: AsRef<Path>>(
654 path: P,
655 ) -> Result<Self, Box<dyn std::error::Error>> {
656 let s = std::fs::read_to_string(path)?;
657 let s = Self::from_clash_config_str(&s)?;
658 Ok(s)
659 }
660
661 pub fn check_ip4(&self, ip: Ipv4Addr) -> Option<&String> {
662 if let Some(m) = &self.ip4_matcher {
663 m.check(ip)
664 } else {
665 None
666 }
667 }
668 pub fn check_ip6(&self, ip: Ipv6Addr) -> Option<&String> {
669 if let Some(m) = &self.ip6_matcher {
670 m.check(ip)
671 } else {
672 None
673 }
674 }
675 pub fn check_ip(&self, ip: std::net::IpAddr) -> Option<&String> {
676 match ip {
677 std::net::IpAddr::V4(ipv4_addr) => self.check_ip4(ipv4_addr),
678 std::net::IpAddr::V6(ipv6_addr) => self.check_ip6(ipv6_addr),
679 }
680 }
681 #[cfg(feature = "maxminddb")]
682 pub fn check_ip_country_iso(&self, ip: std::net::IpAddr) -> &str {
683 if let Some(m) = &self.mmdb_reader {
684 get_ip_iso_by_reader(ip, m)
685 } else {
686 ""
687 }
688 }
689 #[cfg(feature = "maxminddb")]
690 pub fn check_ip_country(&self, ip: std::net::IpAddr) -> Option<&String> {
691 if let Some(m) = &self.country_target_map {
692 let c = self.check_ip_country_iso(ip);
693 m.get(c)
694 } else {
695 None
696 }
697 }
698
699 pub fn check_domain(&self, domain: &str) -> Option<&String> {
700 if let Some(m) = &self.domain_target_map {
701 let r = m.get(domain);
702 if r.is_some() {
703 return r;
704 }
705 }
706 if let Some(m) = &self.domain_suffix_matcher {
707 let r = m.check(domain);
708 if r.is_some() {
709 return r;
710 }
711 }
712 if let Some(m) = &self.domain_keyword_matcher {
713 let r = m.check(domain);
714 if r.is_some() {
715 return r;
716 }
717 }
718 if let Some(m) = &self.domain_regex_set {
719 for (t, r) in m {
720 if r.is_match(domain) {
721 return Some(t);
722 }
723 }
724 }
725 None
726 }
727}
728
729#[test]
731fn test_logic() {
732 let rule_str = "OR,((DOMAIN-KEYWORD,bili),(DOMAIN-REGEX,(?i)pcdn|mcdn))";
733 let rule0 = parse_rule(rule_str).unwrap();
734 println!("{:#?}", rule0);
735 let rule_str = "AND,((DOMAIN-KEYWORD,bili),(DOMAIN-REGEX,(?i)pcdn|mcdn))";
736 let rule = parse_rule(rule_str).unwrap();
737 println!("{:#?}", rule);
738 assert!(rule0.matches(&RuleInput {
739 domain: Some("pcdbili.com".to_string()),
740 ..Default::default()
741 }));
742 assert!(!rule.matches(&RuleInput {
743 domain: Some("pcdbili.com".to_string()),
744 ..Default::default()
745 }));
746 assert!(rule.matches(&RuleInput {
747 domain: Some("pcdn.bili.com".to_string()),
748 ..Default::default()
749 }));
750 let rule_str = "AND,((OR,((DOMAIN-KEYWORD,bili),(DOMAIN,0))),(DOMAIN-REGEX,(?i)pcdn|mcdn))";
751 let rule = parse_rule(rule_str);
752 println!("{:#?}", rule);
753}
754
755#[derive(Debug)]
756pub enum Rule {
757 And(Vec<Box<Rule>>),
758 Or(Vec<Box<Rule>>),
759 Not(Box<Rule>),
760 Domain(String),
761 DomainSuffix(String),
762 DomainKeyword(String),
763 DomainRegex(regex::Regex),
764 IpCidr(Ipv4Net),
765 IpCidr6(Ipv6Net),
766 GeoIp(String),
767 Network(String),
768 DstPort(u16),
769 ProcessName(String),
770 Match,
771 Other(String, String),
772}
773
774#[derive(Default, Debug)]
775pub struct RuleInput {
776 pub domain: Option<String>,
777 pub process_name: Option<String>,
778 pub network: Option<String>,
779 pub ip: Option<IpAddr>,
780 pub dst_port: Option<u16>,
781
782 #[cfg(feature = "maxminddb")]
784 pub mmdb_reader: Option<std::sync::Arc<maxminddb::Reader<Vec<u8>>>>,
785}
786
787impl Rule {
788 pub fn matches(&self, input: &RuleInput) -> bool {
789 match self {
790 Rule::Match => true,
791 Rule::And(rules) => {
792 for r in rules {
793 if !r.matches(input) {
794 return false;
795 }
796 }
797 true
798 }
799 Rule::Or(rules) => {
800 for r in rules {
801 if r.matches(input) {
802 return true;
803 }
804 }
805 false
806 }
807 Rule::Not(rule) => !rule.matches(input),
808
809 Rule::ProcessName(p) => input
810 .process_name
811 .as_ref()
812 .is_some_and(|real_p| real_p.eq(p)),
813 Rule::Network(n) => input.network.as_ref().is_some_and(|d| d.eq(n)),
814 Rule::Domain(domain) => input.domain.as_ref().is_some_and(|d| d.eq(domain)),
815 Rule::DomainRegex(r) => input.domain.as_ref().is_some_and(|d| r.is_match(d)),
816 Rule::DomainSuffix(suffix) => {
817 input.domain.as_ref().is_some_and(|d| d.ends_with(suffix))
818 }
819 Rule::DomainKeyword(k) => input.domain.as_ref().is_some_and(|d| d.contains(k)),
820 Rule::IpCidr6(k) => input.ip.as_ref().is_some_and(|ip| {
821 if let IpAddr::V6(i) = ip {
822 k.contains(i)
823 } else {
824 false
825 }
826 }),
827 Rule::IpCidr(k) => input.ip.as_ref().is_some_and(|ip| {
828 if let IpAddr::V4(i) = ip {
829 k.contains(i)
830 } else {
831 false
832 }
833 }),
834 #[cfg(feature = "maxminddb")]
835 Rule::GeoIp(region) => input.ip.is_some_and(|ip| {
836 if let Some(m) = &input.mmdb_reader {
837 let iso = get_ip_iso_by_reader(ip, m);
838 iso.eq(region)
839 } else {
840 false
841 }
842 }),
843 Rule::DstPort(d) => input.dst_port.is_some_and(|rd| rd == *d),
844 _ => false,
845 }
846 }
847}
848
849use thiserror::Error;
850
851#[derive(Error, Debug)]
852pub enum ParseRuleError {
853 #[error("not wrapped with ()")]
854 E1,
855 #[error("regex error")]
856 Regex(#[from] regex::Error),
857 #[error("parse ipcidr err")]
858 ParseIpnet(#[from] ipnet::AddrParseError),
859 #[error("parse dst port err")]
860 ParseNum(#[from] ParseIntError),
861}
862
863pub fn parse_rule(input: &str) -> Result<Rule, ParseRuleError> {
866 if input == MATCH {
867 return Ok(Rule::Match);
868 }
869 let (rt, r) = input.split_once(",").unwrap();
870
871 if rt.eq(AND) || rt.eq(OR) || rt.eq(NOT) {
872 if !(r.starts_with('(') && r.ends_with(')')) {
873 return Err(ParseRuleError::E1);
874 }
875 let r = &r[1..r.len() - 1];
876 let mut subrules: Vec<_> = extract_sub_rules_from(r)
877 .iter()
878 .map(|s| Box::new(parse_rule(s).unwrap()))
879 .collect();
880 let r = match rt {
881 AND => Rule::And(subrules),
882 OR => Rule::Or(subrules),
883 NOT => {
884 let b = subrules.pop().unwrap();
885 Rule::Not(b)
886 }
887 _ => unreachable!("ur"),
888 };
889 Ok(r)
890 } else {
891 let mut r = r.to_string();
892 let r = match rt {
893 DOMAIN => Rule::Domain(r),
894 DOMAIN_SUFFIX => Rule::DomainSuffix(r),
895 DOMAIN_KEYWORD => Rule::DomainKeyword(r),
896 DOMAIN_REGEX => Rule::DomainRegex(regex::Regex::new(&r)?),
897 IP_CIDR6 => {
898 if let Some(pos) = r.find(',') {
899 r.truncate(pos);
900 }
901 let r: Ipv6Net = r.parse()?;
902 Rule::IpCidr6(r)
903 }
904 IP_CIDR => {
905 if let Some(pos) = r.find(',') {
906 r.truncate(pos);
907 }
908 let r: Ipv4Net = r.parse()?;
909 Rule::IpCidr(r)
910 }
911 GEOIP => Rule::GeoIp(r),
912 DST_PORT => Rule::DstPort(r.parse()?),
913 NETWORK => Rule::Network(r),
914 PROCESS_NAME => Rule::ProcessName(r),
915 _ => Rule::Other(rt.to_string(), r),
916 };
917 Ok(r)
918 }
919}
920
921fn extract_sub_rules_from(input: &str) -> Vec<String> {
927 if input.starts_with('(') {
928 let mut v = vec![];
929 let mut lbi = 0;
930 loop {
931 let rbi = find_matching_bracket(input, lbi).unwrap();
932 let s = &input[lbi + 1..rbi];
933 v.push(s.trim().to_string());
934 if rbi + 1 == input.len() {
935 break;
936 }
937
938 let commap = input[rbi + 1..].find(',').unwrap();
939 let bp = input[rbi + 1..].find('(').unwrap();
940 if commap >= bp {
941 panic!("commap >= bp")
942 }
943
944 lbi = bp + rbi + 1;
945 }
946 v
947 } else {
948 let r = &input[1..input.len() - 1];
949 vec![r.trim().to_string()]
950 }
951}
952
953fn find_matching_bracket(text: &str, left_bracket_index: usize) -> Option<usize> {
954 let mut stack = 0;
955 for (i, c) in text[left_bracket_index..].char_indices() {
956 match c {
957 '(' => stack += 1,
958 ')' => {
959 stack -= 1;
960 if stack == 0 {
961 return Some(left_bracket_index + i);
962 }
963 }
964 _ => {}
965 }
966 }
967 None
968}
969
970#[cfg(feature = "rusqlite")]
971use rusqlite::{params, Connection};
972
973pub const RULE_TYPES: &[&str] = &[
975 DOMAIN,
976 DOMAIN_KEYWORD,
977 DOMAIN_SUFFIX,
978 IP_CIDR,
979 IP_CIDR6,
980 PROCESS_NAME,
981 GEOIP,
982];
983
984pub fn to_sql_table_name(input: &str) -> String {
985 input.replace("-", "_").to_lowercase()
986}
987pub fn to_clash_rule_name(input: &str) -> String {
988 input.replace("_", "-").to_uppercase()
989}
990
991#[cfg(feature = "rusqlite")]
993pub fn init_db(conn: &Connection) -> rusqlite::Result<()> {
994 for &table in RULE_TYPES {
995 let create_table_sql = format!(
996 "CREATE TABLE IF NOT EXISTS {} (
997 id INTEGER PRIMARY KEY AUTOINCREMENT,
998 content TEXT NOT NULL,
999 target TEXT NOT NULL
1000 )",
1001 to_sql_table_name(table)
1002 );
1003 conn.execute(&create_table_sql, [])?;
1004 }
1005
1006 let create_view_sql = "
1008 CREATE VIEW IF NOT EXISTS rules_view AS
1009 SELECT 'DOMAIN' AS rule_name, content, target FROM domain
1010 UNION ALL
1011 SELECT 'DOMAIN-SUFFIX', content, target FROM domain_suffix
1012 UNION ALL
1013 SELECT 'DOMAIN-KEYWORD', content, target FROM domain_keyword
1014 UNION ALL
1015 SELECT 'IP-CIDR', content, target FROM ip_cidr
1016 UNION ALL
1017 SELECT 'IP-CIDR6', content, target FROM ip_cidr6
1018 UNION ALL
1019 SELECT 'PROCESS-NAME', content, target FROM process_name
1020 UNION ALL
1021 SELECT 'GEOIP', content, target FROM geoip;
1022 ";
1023 conn.execute(create_view_sql, [])?;
1024 Ok(())
1025}
1026
1027pub fn query_rules_view(
1031 conn: &Connection,
1032 sql: &str,
1033) -> rusqlite::Result<HashMap<String, Vec<Vec<String>>>> {
1034 let mut rules_map: HashMap<String, Vec<Vec<String>>> = HashMap::new();
1035
1036 let mut stmt = conn.prepare(sql)?;
1038 let rows = stmt.query_map([], |row| {
1039 let rule_name: String = row.get(0)?;
1040 let content: String = row.get(1)?;
1041 let target_label: String = row.get(2)?;
1042 Ok((rule_name, vec![content, target_label]))
1043 })?;
1044
1045 for row in rows {
1046 let (rule_name, entry) = row?;
1047 rules_map.entry(rule_name).or_default().push(entry);
1048 }
1049
1050 Ok(rules_map)
1051}
1052
1053#[cfg(feature = "rusqlite")]
1055pub fn save_to_sqlite(
1056 conn: &mut Connection,
1057 rules: &HashMap<String, Vec<Vec<String>>>,
1058) -> rusqlite::Result<()> {
1059 let tx = conn.transaction()?;
1060
1061 for (rule_name, entries) in rules {
1062 if !RULE_TYPES.contains(&rule_name.as_str()) {
1063 continue;
1064 }
1065 let table_name = to_sql_table_name(rule_name);
1067
1068 for entry in entries {
1069 if entry.len() < 2 {
1070 continue; }
1072 let content = &entry[0];
1073 let target = &entry[1];
1074
1075 let insert_sql = format!(
1076 "INSERT INTO {} (content, target) VALUES (?1, ?2)",
1077 table_name
1078 );
1079
1080 tx.execute(&insert_sql, params![content, target])?;
1081 }
1082 }
1083
1084 tx.commit()?;
1085 Ok(())
1086}
1087
1088#[cfg(feature = "rusqlite")]
1090pub fn load_from_sqlite(conn: &Connection) -> rusqlite::Result<HashMap<String, Vec<Vec<String>>>> {
1091 let mut rules_map: HashMap<String, Vec<Vec<String>>> = HashMap::new();
1092
1093 for &table in RULE_TYPES {
1094 let rule_name = to_sql_table_name(table);
1095
1096 let mut stmt = conn.prepare(&format!("SELECT content, target FROM {}", rule_name))?;
1097 let rows = stmt.query_map([], |row| {
1098 let content: String = row.get(0)?;
1099 let target: String = row.get(1)?;
1100 Ok(vec![content, target])
1101 })?;
1102
1103 for row in rows {
1104 rules_map
1105 .entry(to_clash_rule_name(&rule_name))
1106 .or_default()
1107 .push(row?);
1108 }
1109 }
1110
1111 Ok(rules_map)
1112}
1113#[cfg(feature = "rusqlite")]
1115pub fn add_rule(
1116 conn: &Connection,
1117 rule_name: &str,
1118 content: &str,
1119 target: &str,
1120) -> rusqlite::Result<()> {
1121 let table_name = to_sql_table_name(rule_name);
1122 let insert_sql = format!(
1123 "INSERT INTO {} (content, target) VALUES (?1, ?2)",
1124 table_name
1125 );
1126 conn.execute(&insert_sql, params![content, target])?;
1127 Ok(())
1128}
1129
1130#[cfg(feature = "rusqlite")]
1132pub fn delete_rule(conn: &Connection, rule_name: &str, content: &str) -> rusqlite::Result<()> {
1133 let table_name = to_sql_table_name(rule_name);
1134 let delete_sql = format!("DELETE FROM {} WHERE content = ?1", table_name);
1135 conn.execute(&delete_sql, params![content])?;
1136 Ok(())
1137}
1138
1139#[cfg(feature = "rusqlite")]
1141pub fn update_rule(
1142 conn: &Connection,
1143 rule_name: &str,
1144 content: &str,
1145 new_target: &str,
1146) -> rusqlite::Result<()> {
1147 let table_name = to_sql_table_name(rule_name);
1148 let update_sql = format!("UPDATE {} SET target = ?1 WHERE content = ?2", table_name);
1149 conn.execute(&update_sql, params![new_target, content])?;
1150 Ok(())
1151}
1152
1153#[cfg(feature = "rusqlite")]
1155pub fn query_rule(conn: &Connection, rule_name: &str) -> rusqlite::Result<Vec<Vec<String>>> {
1156 let table_name = to_sql_table_name(rule_name);
1157 let mut stmt = conn.prepare(&format!("SELECT content, target FROM {}", table_name))?;
1158 let rows = stmt.query_map([], |row| {
1159 let content: String = row.get(0)?;
1160 let target: String = row.get(1)?;
1161 Ok(vec![content, target])
1162 })?;
1163
1164 let mut result = Vec::new();
1165 for row in rows {
1166 result.push(row?);
1167 }
1168 Ok(result)
1169}
1170
1171#[cfg(feature = "rusqlite")]
1172#[test]
1173fn test_sql() -> rusqlite::Result<()> {
1175 println!("init");
1176 let mut conn = Connection::open("rules.db")?;
1177 init_db(&conn)?;
1178
1179 #[cfg(not(feature = "serde_yaml_ng"))]
1181 let mut rules: HashMap<String, Vec<Vec<String>>> = HashMap::new();
1182 #[cfg(feature = "serde_yaml_ng")]
1183 let mut rules = parse_rules(&load_rules_from_file("test.yaml").unwrap());
1184 rules.insert(
1185 "DOMAIN".to_string(),
1186 vec![
1187 vec!["example.com".to_string(), "proxy".to_string()],
1188 vec!["test.com".to_string(), "direct".to_string()],
1189 ],
1190 );
1191 rules.insert(
1192 "IP-CIDR".to_string(),
1193 vec![
1194 vec!["192.168.1.0/24".to_string(), "proxy".to_string()],
1195 vec!["10.0.0.0/8".to_string(), "direct".to_string()],
1196 ],
1197 );
1198
1199 println!("save");
1200 save_to_sqlite(&mut conn, &rules)?;
1202
1203 println!("load");
1204 load_from_sqlite(&conn)?;
1206
1207 add_rule(&conn, "DOMAIN", "example.com", "proxy")?;
1209 add_rule(&conn, "DOMAIN", "test.com", "direct")?;
1210 add_rule(&conn, "IP-CIDR", "192.168.1.0/24", "proxy")?;
1211
1212 update_rule(&conn, "DOMAIN", "test.com", "proxy")?;
1214
1215 let _domain_rules = query_rule(&conn, "DOMAIN")?;
1217
1218 delete_rule(&conn, "DOMAIN", "example.com")?;
1220
1221 let sql = "SELECT rule_name, content, target FROM rules_view";
1222 let r = query_rules_view(&conn, sql)?;
1223 println!("all {}", r.len());
1224
1225 Ok(())
1226}
1227
1228#[cfg(feature = "rusqlite")]
1230fn get_table_names(conn: &Connection) -> rusqlite::Result<Vec<String>> {
1231 let mut stmt = conn.prepare(
1232 "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
1233 )?;
1234 let tables = stmt
1235 .query_map([], |row| row.get(0))?
1236 .collect::<Result<Vec<String>, _>>()?;
1237 Ok(tables)
1238}
1239
1240#[cfg(feature = "rusqlite")]
1242pub fn merge_databases(db1_path: &str, db2_path: &str) -> rusqlite::Result<()> {
1243 let conn = Connection::open(db1_path)?;
1244
1245 conn.execute(
1247 &format!("ATTACH DATABASE '{}' AS attached_db", db2_path),
1248 [],
1249 )?;
1250
1251 let tables = get_table_names(&conn)?;
1253
1254 for table in tables {
1255 let sql = format!("INSERT INTO {table} SELECT * FROM attached_db.{table}");
1256 conn.execute(&sql, [])?;
1257 }
1258
1259 conn.execute("DETACH DATABASE attached_db", [])?;
1261
1262 Ok(())
1263}
1264
1265#[cfg(feature = "rusqlite")]
1266#[test]
1267fn merge_sql() -> rusqlite::Result<()> {
1269 let _ = std::fs::remove_file("1.db");
1270 let _ = std::fs::remove_file("2.db");
1271 {
1272 let conn = Connection::open("1.db")?;
1273 init_db(&conn)?;
1274 add_rule(&conn, "DOMAIN", "test.com", "direct")?;
1275 let conn = Connection::open("2.db")?;
1276 init_db(&conn)?;
1277 add_rule(&conn, "IP-CIDR", "192.168.1.0/24", "proxy")?;
1278 }
1279 let db1 = "1.db";
1280 let db2 = "2.db";
1281
1282 merge_databases(db1, db2)?;
1283
1284 println!("Databases merged successfully!");
1285 Ok(())
1286}