clash_rules/
lib.rs

1pub use aho_corasick::AhoCorasick;
2#[cfg(feature = "maxminddb")]
3pub use maxminddb;
4pub use prefix_trie::PrefixMap;
5pub use radix_trie::{Trie, TrieCommon};
6#[cfg(feature = "rusqlite")]
7pub use rusqlite;
8
9#[cfg(feature = "serde_yaml_ng")]
10pub use serde_yaml_ng;
11
12use ipnet::{Ipv4Net, Ipv6Net};
13use std::collections::HashMap;
14use std::fmt::Display;
15use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
16use std::num::ParseIntError;
17use std::path::Path;
18
19use serde::{Deserialize, Serialize};
20pub const DOMAIN: &str = "DOMAIN";
21pub const DOMAIN_SUFFIX: &str = "DOMAIN-SUFFIX";
22pub const DOMAIN_KEYWORD: &str = "DOMAIN-KEYWORD";
23pub const DOMAIN_REGEX: &str = "DOMAIN-REGEX";
24pub const IP_CIDR: &str = "IP-CIDR";
25pub const IP_CIDR6: &str = "IP-CIDR6";
26pub const PROCESS_NAME: &str = "PROCESS-NAME";
27pub const DST_PORT: &str = "DST-PORT";
28pub const GEOIP: &str = "GEOIP";
29pub const NETWORK: &str = "NETWORK";
30pub const AND: &str = "AND";
31pub const OR: &str = "OR";
32pub const NOT: &str = "NOT";
33pub const MATCH: &str = "MATCH";
34
35#[derive(Serialize, Deserialize, Debug)]
36pub struct RuleSet {
37    pub payload: Vec<String>,
38}
39#[derive(Debug)]
40pub enum LoadYamlFileError {
41    FileErr(std::io::Error),
42    #[cfg(feature = "serde_yaml_ng")]
43    YamlErr(serde_yaml_ng::Error),
44}
45impl Display for LoadYamlFileError {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            LoadYamlFileError::FileErr(error) => write!(f, "{}", error),
49            #[cfg(feature = "serde_yaml_ng")]
50            LoadYamlFileError::YamlErr(error) => write!(f, "{}", error),
51        }
52    }
53}
54
55impl From<std::io::Error> for LoadYamlFileError {
56    fn from(err: std::io::Error) -> Self {
57        LoadYamlFileError::FileErr(err)
58    }
59}
60
61#[cfg(feature = "serde_yaml_ng")]
62impl From<serde_yaml_ng::Error> for LoadYamlFileError {
63    fn from(err: serde_yaml_ng::Error) -> Self {
64        LoadYamlFileError::YamlErr(err)
65    }
66}
67
68#[cfg(feature = "serde_yaml_ng")]
69pub fn load_rule_set_from_file<P: AsRef<Path>>(path: P) -> Result<RuleSet, LoadYamlFileError> {
70    let content = std::fs::read_to_string(path)?;
71    let ruleset = load_rule_set_from_str(&content)?;
72    Ok(ruleset)
73}
74#[cfg(feature = "serde_yaml_ng")]
75pub fn load_rule_set_from_str(s: &str) -> Result<RuleSet, serde_yaml_ng::Error> {
76    let ruleset = serde_yaml_ng::from_str(s)?;
77    Ok(ruleset)
78}
79
80/// init like let mut trie = Trie::new();
81pub fn parse_rule_set_as_domain_suffix_trie(
82    mut trie: Trie<String, usize>,
83    payload: &[String],
84    target_id: usize,
85) {
86    for v in payload.iter() {
87        let mut r: String = v.chars().rev().collect();
88        // RULESET 中 表示 suffix 的 字符串 有个 加号末尾(逆序后)
89        r = r.trim_end_matches('+').to_string();
90        trie.insert(r, target_id);
91    }
92}
93
94/// init like let mut trie = PrefixMap::<Ipv4Net, usize>::new();
95pub fn parse_rule_set_as_ip_cidr_trie(
96    mut trie: PrefixMap<Ipv4Net, usize>,
97    mut trie6: PrefixMap<Ipv6Net, usize>,
98    payload: &[String],
99    target_id: usize,
100) {
101    for v in payload.iter() {
102        let r: Result<Ipv4Net, <Ipv4Net as std::str::FromStr>::Err> = v.parse();
103        match r {
104            Ok(r) => trie.insert(r, target_id),
105            Err(_) => {
106                let r: Ipv6Net = v.parse().unwrap();
107                trie6.insert(r, target_id)
108            }
109        };
110    }
111}
112
113pub fn parse_rule_set_as_classic(
114    payload: &[String],
115    target: String,
116) -> HashMap<String, Vec<Vec<String>>> {
117    let mut hashmap: HashMap<String, Vec<Vec<String>>> = HashMap::new();
118    for rule in payload {
119        let mut parts = rule.split(',');
120        if let Some(key) = parts.next() {
121            let mut values: Vec<String> = parts.map(|part| part.to_string()).collect();
122            values.insert(1, target.clone());
123            hashmap.entry(key.to_string()).or_default().push(values);
124        }
125    }
126    hashmap
127}
128
129#[derive(Serialize, Deserialize, Debug)]
130pub struct RuleConfig {
131    pub rules: Vec<String>,
132}
133#[cfg(feature = "serde_yaml_ng")]
134pub fn load_rules_from_file<P: AsRef<Path>>(path: P) -> Result<RuleConfig, LoadYamlFileError> {
135    let s = std::fs::read_to_string(path)?;
136    let rs = load_rules_from_str(&s)?;
137    Ok(rs)
138}
139#[cfg(feature = "serde_yaml_ng")]
140pub fn load_rules_from_str(s: &str) -> Result<RuleConfig, serde_yaml_ng::Error> {
141    let rs = serde_yaml_ng::from_str(s)?;
142    Ok(rs)
143}
144
145/// parse clash rules into METHOD-rules hashmap, the ',' splitted items is pushed in the inner Vec
146pub fn parse_rules(rc: &RuleConfig) -> HashMap<String, Vec<Vec<String>>> {
147    // rule, items
148    let mut hashmap: HashMap<String, Vec<Vec<String>>> = HashMap::new();
149    for rule in &rc.rules {
150        let mut parts = rule.split(',');
151        if let Some(key) = parts.next() {
152            let values: Vec<String> = parts.map(|part| part.to_string()).collect();
153            hashmap.entry(key.to_string()).or_default().push(values);
154        }
155    }
156    hashmap
157}
158
159pub fn merge_method_rules_map(
160    map1: HashMap<String, Vec<Vec<String>>>,
161    map2: HashMap<String, Vec<Vec<String>>>,
162) -> HashMap<String, Vec<Vec<String>>> {
163    let mut merged_map = map1;
164
165    for (key, value) in map2 {
166        merged_map.entry(key).or_default().extend(value);
167    }
168
169    merged_map
170}
171
172pub fn get_domain_rules(
173    method_rules_map: &HashMap<String, Vec<Vec<String>>>,
174) -> Option<&Vec<Vec<String>>> {
175    method_rules_map.get(DOMAIN)
176}
177pub fn get_suffix_rules(
178    method_rules_map: &HashMap<String, Vec<Vec<String>>>,
179) -> Option<&Vec<Vec<String>>> {
180    method_rules_map.get(DOMAIN_SUFFIX)
181}
182pub fn get_keyword_rules(
183    method_rules_map: &HashMap<String, Vec<Vec<String>>>,
184) -> Option<&Vec<Vec<String>>> {
185    method_rules_map.get(DOMAIN_KEYWORD)
186}
187pub fn get_ip_cidr_rules(
188    method_rules_map: &HashMap<String, Vec<Vec<String>>>,
189) -> Option<&Vec<Vec<String>>> {
190    method_rules_map.get(IP_CIDR)
191}
192pub fn get_ip6_cidr_rules(
193    method_rules_map: &HashMap<String, Vec<Vec<String>>>,
194) -> Option<&Vec<Vec<String>>> {
195    method_rules_map.get(IP_CIDR6)
196}
197
198/// for DOMAIN, PROCESS-NAME etc. that matches directly
199pub fn get_item_target_map(rules: &[Vec<String>]) -> HashMap<String, String> {
200    let mut map: HashMap<String, String> = HashMap::new();
201    for x in rules {
202        let item = x.first().unwrap();
203        let target = x.get(1).unwrap();
204        map.insert(item.clone(), target.clone());
205    }
206    map
207}
208
209/// for SUFFIX, KEYWORD,CIDR etc. that require iter.
210pub fn get_target_item_map(rules: &[Vec<String>]) -> HashMap<String, Vec<String>> {
211    let mut map: HashMap<String, Vec<String>> = HashMap::new();
212    for part in rules {
213        let item = part.first().unwrap();
214        let target = part.get(1).unwrap();
215        map.entry(target.clone()).or_default().push(item.clone());
216    }
217    map
218}
219
220pub fn gen_keywords_ac(
221    target_keywords_map: &HashMap<String, Vec<String>>,
222) -> HashMap<String, AhoCorasick> {
223    target_keywords_map
224        .iter()
225        .map(|(k, v)| (k.clone(), AhoCorasick::new(v).unwrap()))
226        .collect()
227}
228
229pub fn get_keywords_targets(rules: &[Vec<String>]) -> Vec<String> {
230    rules.iter().filter_map(|v| v.get(1).cloned()).collect()
231}
232pub fn gen_keywords_ac2(rules: &[Vec<String>]) -> AhoCorasick {
233    let result: Vec<String> = rules.iter().filter_map(|v| v.first().cloned()).collect();
234
235    AhoCorasick::new(&result).unwrap()
236}
237pub fn gen_ip_trie<T: AsRef<str>>(target_ip_map: &HashMap<T, Vec<T>>) -> PrefixMap<Ipv4Net, usize> {
238    let mut trie = PrefixMap::<Ipv4Net, usize>::new();
239    for (i, (_key, value)) in target_ip_map.iter().enumerate() {
240        for v in value {
241            let r: Ipv4Net = v.as_ref().parse().unwrap();
242            trie.insert(r, i);
243        }
244    }
245    trie
246}
247#[derive(PartialEq, Eq, Debug)]
248struct Ipv4NetWrapper(pub Ipv4Net);
249impl radix_trie::TrieKey for Ipv4NetWrapper {
250    fn encode_bytes(&self) -> Vec<u8> {
251        fn u32_to_bit_u8_vec(n: u32, len: u8) -> Vec<u8> {
252            (32 - len..32u8).rev().map(|i| (n >> i) as u8).collect()
253        }
254        let ipnet = &self.0;
255        u32_to_bit_u8_vec(
256            u32::from_be_bytes(ipnet.network().octets()),
257            ipnet.prefix_len(),
258        )
259    }
260}
261
262/// Trie struct for Ipv4Net using radix_trie::Trie, which is a bit slower than
263/// prefix_trie::PrifixMap
264pub struct IpTrie2(Trie<Ipv4NetWrapper, usize>);
265pub fn gen_ip_trie2<T: AsRef<str>>(target_ip_map: &HashMap<T, Vec<T>>) -> IpTrie2 {
266    let mut trie = Trie::<Ipv4NetWrapper, usize>::new();
267    for (i, (_key, value)) in target_ip_map.iter().enumerate() {
268        for v in value {
269            let r: Ipv4Net = v.as_ref().parse().unwrap();
270            trie.insert(Ipv4NetWrapper(r), i);
271        }
272    }
273    IpTrie2(trie)
274}
275/// the function store ips in the trie with their target index of the map
276pub fn gen_ip6_trie<T: AsRef<str>>(
277    target_ip_map: &HashMap<T, Vec<T>>,
278) -> PrefixMap<Ipv6Net, usize> {
279    let mut trie = PrefixMap::<Ipv6Net, usize>::new();
280    for (i, (_key, value)) in target_ip_map.iter().enumerate() {
281        for v in value {
282            let r: Ipv6Net = v.as_ref().parse().unwrap();
283            trie.insert(r, i);
284        }
285    }
286    trie
287}
288/// the function store domains in the trie with their target index of the map
289pub fn gen_prefix_trie<T: AsRef<str>>(target_item_map: &HashMap<T, Vec<T>>) -> Trie<String, usize> {
290    let mut trie = Trie::new();
291
292    for (i, (_key, value)) in target_item_map.iter().enumerate() {
293        for v in value {
294            trie.insert(v.as_ref().to_string(), i);
295        }
296    }
297    trie
298}
299
300/// the function store domain chars in the result trie in reversed order, and
301/// with their target index of the map
302pub fn gen_suffix_trie<T: AsRef<str>>(
303    target_suffix_map: &HashMap<T, Vec<T>>,
304) -> Trie<String, usize> {
305    let mut trie = Trie::new();
306
307    for (i, (_key, value)) in target_suffix_map.iter().enumerate() {
308        for v in value {
309            let r: String = v.as_ref().chars().rev().collect();
310            trie.insert(r, i);
311        }
312    }
313    trie
314}
315
316pub fn check_suffix_dummy<'a, T>(
317    target_suffix_map: &'a HashMap<T, Vec<T>>,
318    domain: &str,
319) -> Option<&'a T>
320where
321    T: AsRef<str> + Eq + std::hash::Hash,
322{
323    for (target, items) in target_suffix_map {
324        for v in items {
325            if domain.ends_with(v.as_ref()) {
326                return Some(target);
327            }
328        }
329    }
330    None
331}
332
333/// the function matches suffix by reversing the domain
334pub fn check_suffix_trie(trie: &Trie<String, usize>, domain: &str) -> Option<usize> {
335    let sr: String = domain.chars().rev().collect();
336    if let Some(subtree) = trie.get_ancestor(&sr) {
337        subtree.value().cloned()
338    } else {
339        None
340    }
341}
342/// unlike check_suffix_trie, this function matches prefix
343pub fn check_prefix_trie(trie: &Trie<&str, usize>, domain: &str) -> Option<usize> {
344    if let Some(subtree) = trie.get_ancestor(domain) {
345        subtree.value().cloned()
346    } else {
347        None
348    }
349}
350pub fn check_keyword_ac<'a, T: AsRef<str>>(
351    target_keyword_ac_map: &'a HashMap<T, AhoCorasick>,
352    domain: &str,
353) -> Option<&'a str> {
354    for (target, ac) in target_keyword_ac_map {
355        if ac.is_match(domain) {
356            return Some(target.as_ref());
357        }
358    }
359    None
360}
361
362/// faster than ac, but requries an extra targets lookup vec by get_keywords_targets
363pub fn check_keyword_ac2<'a>(
364    keyword_ac: &AhoCorasick,
365    domain: &str,
366    targets: &'a [String],
367) -> Option<&'a String> {
368    if let Some(mat) = keyword_ac.find_iter(domain).next() {
369        let keyword_index = mat.pattern();
370        return Some(&targets[keyword_index]);
371    }
372    None
373}
374
375pub fn check_keyword_dummy<'a, T>(
376    target_keyword_map: &'a HashMap<T, Vec<T>>,
377    domain: &str,
378) -> Option<&'a T>
379where
380    T: AsRef<str> + Eq + std::hash::Hash,
381{
382    for (target, items) in target_keyword_map {
383        for v in items {
384            if domain.contains(v.as_ref()) {
385                return Some(target);
386            }
387        }
388    }
389    None
390}
391
392pub fn check_ip_trie2(trie: &IpTrie2, ip: Ipv4Addr) -> Option<usize> {
393    let ipn = Ipv4NetWrapper(Ipv4Net::new(ip, 32).unwrap());
394    if let Some(subtree) = trie.0.get_ancestor(&ipn) {
395        subtree.value().cloned()
396    } else {
397        None
398    }
399}
400pub fn check_ip_trie(trie: &PrefixMap<Ipv4Net, usize>, ip: Ipv4Addr) -> Option<usize> {
401    trie.get_lpm(&Ipv4Net::new(ip, 32).unwrap()).map(|r| *r.1)
402}
403pub fn check_ip6_trie(trie: &PrefixMap<Ipv6Net, usize>, ip6: Ipv6Addr) -> Option<usize> {
404    trie.get_lpm(&Ipv6Net::new(ip6, 32).unwrap()).map(|r| *r.1)
405}
406
407#[cfg(test)]
408pub fn get_test_ips() -> Vec<Ipv4Addr> {
409    vec![
410        Ipv4Addr::new(1, 2, 3, 4),
411        Ipv4Addr::new(2, 2, 3, 4),
412        Ipv4Addr::new(3, 2, 3, 4),
413        Ipv4Addr::new(15, 207, 213, 128),
414    ]
415}
416#[cfg(test)]
417pub fn get_test_domains() -> Vec<&'static str> {
418    vec![
419        "www.google.com",
420        "jdj.reddit.com",
421        "hdjd.baidu.com",
422        "hshsh.djdjdj.djdj",
423    ]
424}
425/// cargo test -- --nocapture
426#[cfg(feature = "serde_yaml_ng")]
427#[test]
428fn test() {
429    let rule_map = parse_rules(&load_rules_from_file("test.yaml").unwrap());
430
431    let dr = get_domain_rules(&rule_map).unwrap();
432    println!("{:?}", dr.len());
433    let suffix_rules = get_suffix_rules(&rule_map).unwrap();
434    println!("{:?}", suffix_rules.len());
435    let suffix_map = get_target_item_map(suffix_rules);
436
437    let suffix_targets: Vec<&String> = suffix_map.keys().collect();
438
439    println!("{:?}", suffix_targets);
440    let trie = gen_suffix_trie(&suffix_map);
441
442    let keyword_rules = get_keyword_rules(&rule_map).unwrap();
443    println!("{:?}", keyword_rules.len());
444    let kmap = get_target_item_map(keyword_rules);
445    let ac = gen_keywords_ac(&kmap);
446    let ac2 = gen_keywords_ac2(keyword_rules);
447    let ac2_targets = get_keywords_targets(keyword_rules);
448
449    let ds = get_test_domains();
450    for d in &ds {
451        let r = check_suffix_trie(&trie, d);
452        println!("{:?}", r.map(|i| suffix_targets.get(i).unwrap()));
453        let r = check_keyword_ac(&ac, d);
454        println!("{:?}", r);
455        let r = check_keyword_ac2(&ac2, d, &ac2_targets);
456        println!("{:?}", r);
457    }
458
459    let ip_rules = get_ip_cidr_rules(&rule_map).unwrap();
460    println!("{:?}", ip_rules.len());
461    let ip_map = get_target_item_map(ip_rules);
462    let ip_targets: Vec<_> = ip_map.keys().collect();
463    let it = gen_ip_trie(&ip_map);
464    let it2 = gen_ip_trie2(&ip_map);
465
466    let ips = get_test_ips();
467    for ip in &ips {
468        let r = check_ip_trie(&it, *ip);
469        println!("{:?}", r.map(|i| ip_targets.get(i).unwrap()));
470        let r = check_ip_trie2(&it2, *ip);
471        println!("{:?}", r.map(|i| ip_targets.get(i).unwrap()));
472    }
473
474    let ip_rules = get_ip6_cidr_rules(&rule_map).unwrap();
475    println!("{:?}", ip_rules.len());
476
477    let cm = ClashRuleMatcher::from_clash_config_file("test.yaml").unwrap();
478    for d in ds {
479        let r = cm.check_domain(d);
480        println!("{:?}", r);
481    }
482    for ip in ips {
483        let r = cm.check_ip(std::net::IpAddr::V4(ip));
484        println!("{:?}", r);
485        // #[cfg(feature = "maxminddb")]
486        // let r = cm.check_ip_country(std::net::IpAddr::V4(ip));
487        // println!("{:?}", r);
488    }
489}
490#[cfg(feature = "maxminddb")]
491/// <https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes>
492pub fn get_ip_iso_by_reader(ip: IpAddr, reader: &maxminddb::Reader<Vec<u8>>) -> &str {
493    let r = reader.lookup(ip);
494    let c: maxminddb::geoip2::Country = match r {
495        Ok(c) => c,
496        Err(_e) => {
497            // warn!("look up maxminddb::Reader failed, {e}");
498            return "";
499        }
500    };
501    if let Some(c) = c.country {
502        c.iso_code.unwrap_or_default()
503    } else {
504        ""
505    }
506}
507
508#[derive(Debug)]
509pub struct DomainKeywordMatcher {
510    pub ac: AhoCorasick,
511    pub targets: Vec<String>,
512}
513impl DomainKeywordMatcher {
514    pub fn check(&self, domain: &str) -> Option<&String> {
515        check_keyword_ac2(&self.ac, domain, &self.targets)
516    }
517}
518#[derive(Debug)]
519pub struct DomainSuffixMatcher {
520    pub trie: Trie<String, usize>,
521    pub targets: Vec<String>,
522}
523impl DomainSuffixMatcher {
524    pub fn check(&self, domain: &str) -> Option<&String> {
525        check_suffix_trie(&self.trie, domain).map(|i| self.targets.get(i).unwrap())
526    }
527}
528#[derive(Debug)]
529pub struct IpMatcher {
530    pub trie: PrefixMap<Ipv4Net, usize>,
531    pub targets: Vec<String>,
532}
533impl IpMatcher {
534    pub fn check(&self, ip: Ipv4Addr) -> Option<&String> {
535        check_ip_trie(&self.trie, ip).map(|i| self.targets.get(i).unwrap())
536    }
537}
538#[derive(Debug)]
539pub struct Ip6Matcher {
540    pub trie: PrefixMap<Ipv6Net, usize>,
541    pub targets: Vec<String>,
542}
543impl Ip6Matcher {
544    pub fn check(&self, ip: Ipv6Addr) -> Option<&String> {
545        check_ip6_trie(&self.trie, ip).map(|i| self.targets.get(i).unwrap())
546    }
547}
548
549/// convenient struct for checking all rules.
550/// init mmdb_reader using maxminddb::Reader::from_source
551#[derive(Debug, Default)]
552pub struct ClashRuleMatcher {
553    pub domain_target_map: Option<HashMap<String, String>>,
554    pub domain_keyword_matcher: Option<DomainKeywordMatcher>,
555    pub domain_suffix_matcher: Option<DomainSuffixMatcher>,
556    pub domain_regex_set: Option<HashMap<String, regex::RegexSet>>,
557    pub ip4_matcher: Option<IpMatcher>,
558    pub ip6_matcher: Option<Ip6Matcher>,
559
560    /// for GEOIP
561    #[cfg(feature = "maxminddb")]
562    pub mmdb_reader: Option<std::sync::Arc<maxminddb::Reader<Vec<u8>>>>,
563
564    /// for GEOIP
565    #[cfg(feature = "maxminddb")]
566    pub country_target_map: Option<HashMap<String, String>>,
567
568    /// stores un-optimized left rules, which are AND,OR,NOT,PROCESS-NAME,
569    /// DST-PORT, NETWORK,MATCH
570    ///
571    /// (rule, target)
572    pub rules: Vec<(Rule, String)>,
573}
574
575impl ClashRuleMatcher {
576    pub fn from_hashmap(
577        mut method_rules_map: HashMap<String, Vec<Vec<String>>>,
578    ) -> Result<Self, ParseRuleError> {
579        let mut s = Self::default();
580
581        if let Some(v) = get_domain_rules(&method_rules_map) {
582            s.domain_target_map = Some(get_item_target_map(v));
583            method_rules_map.remove(DOMAIN);
584        }
585        #[cfg(feature = "maxminddb")]
586        if let Some(v) = method_rules_map.get(GEOIP) {
587            s.country_target_map = Some(get_item_target_map(v));
588            method_rules_map.remove(GEOIP);
589        }
590        if let Some(v) = get_keyword_rules(&method_rules_map) {
591            let map = get_target_item_map(v);
592            let targets = map.keys().cloned().collect();
593            let ac = gen_keywords_ac2(v);
594            s.domain_keyword_matcher = Some(DomainKeywordMatcher { ac, targets });
595            method_rules_map.remove(DOMAIN_KEYWORD);
596        }
597        if let Some(v) = get_suffix_rules(&method_rules_map) {
598            let map = get_target_item_map(v);
599            let targets = map.keys().cloned().collect();
600            let trie = gen_suffix_trie(&map);
601            s.domain_suffix_matcher = Some(DomainSuffixMatcher { trie, targets });
602            method_rules_map.remove(DOMAIN_SUFFIX);
603        }
604        if let Some(v) = get_ip_cidr_rules(&method_rules_map) {
605            let map = get_target_item_map(v);
606            let targets = map.keys().cloned().collect();
607            let trie = gen_ip_trie(&map);
608            s.ip4_matcher = Some(IpMatcher { trie, targets });
609            method_rules_map.remove(IP_CIDR);
610        }
611        if let Some(v) = get_ip6_cidr_rules(&method_rules_map) {
612            let map = get_target_item_map(v);
613            let targets = map.keys().cloned().collect();
614            let trie = gen_ip6_trie(&map);
615            s.ip6_matcher = Some(Ip6Matcher { trie, targets });
616            method_rules_map.remove(IP_CIDR6);
617        }
618        if let Some(v) = method_rules_map.get(DOMAIN_REGEX) {
619            let map = get_target_item_map(v);
620            s.domain_regex_set = Some(
621                map.into_iter()
622                    .map(|(t, r)| (t, regex::RegexSet::new(r).unwrap()))
623                    .collect(),
624            );
625            method_rules_map.remove(DOMAIN_REGEX);
626        }
627
628        for (rt, v) in method_rules_map {
629            for ss in v {
630                // println!("parsing {ss:?}");
631                let mut ss = ss.clone();
632                let target = if ss.len() > 1 {
633                    ss.remove(1)
634                } else {
635                    ss.pop().unwrap()
636                };
637                ss.insert(0, rt.clone());
638                let rule = ss.join(",");
639                let r = parse_rule(&rule)?;
640                s.rules.push((r, target));
641            }
642        }
643
644        Ok(s)
645    }
646    #[cfg(feature = "serde_yaml_ng")]
647    pub fn from_clash_config_str(cs: &str) -> Result<Self, Box<dyn std::error::Error>> {
648        let method_rules_map = parse_rules(&load_rules_from_str(cs)?);
649
650        Ok(Self::from_hashmap(method_rules_map)?)
651    }
652    #[cfg(feature = "serde_yaml_ng")]
653    pub fn from_clash_config_file<P: AsRef<Path>>(
654        path: P,
655    ) -> Result<Self, Box<dyn std::error::Error>> {
656        let s = std::fs::read_to_string(path)?;
657        let s = Self::from_clash_config_str(&s)?;
658        Ok(s)
659    }
660
661    pub fn check_ip4(&self, ip: Ipv4Addr) -> Option<&String> {
662        if let Some(m) = &self.ip4_matcher {
663            m.check(ip)
664        } else {
665            None
666        }
667    }
668    pub fn check_ip6(&self, ip: Ipv6Addr) -> Option<&String> {
669        if let Some(m) = &self.ip6_matcher {
670            m.check(ip)
671        } else {
672            None
673        }
674    }
675    pub fn check_ip(&self, ip: std::net::IpAddr) -> Option<&String> {
676        match ip {
677            std::net::IpAddr::V4(ipv4_addr) => self.check_ip4(ipv4_addr),
678            std::net::IpAddr::V6(ipv6_addr) => self.check_ip6(ipv6_addr),
679        }
680    }
681    #[cfg(feature = "maxminddb")]
682    pub fn check_ip_country_iso(&self, ip: std::net::IpAddr) -> &str {
683        if let Some(m) = &self.mmdb_reader {
684            get_ip_iso_by_reader(ip, m)
685        } else {
686            ""
687        }
688    }
689    #[cfg(feature = "maxminddb")]
690    pub fn check_ip_country(&self, ip: std::net::IpAddr) -> Option<&String> {
691        if let Some(m) = &self.country_target_map {
692            let c = self.check_ip_country_iso(ip);
693            m.get(c)
694        } else {
695            None
696        }
697    }
698
699    pub fn check_domain(&self, domain: &str) -> Option<&String> {
700        if let Some(m) = &self.domain_target_map {
701            let r = m.get(domain);
702            if r.is_some() {
703                return r;
704            }
705        }
706        if let Some(m) = &self.domain_suffix_matcher {
707            let r = m.check(domain);
708            if r.is_some() {
709                return r;
710            }
711        }
712        if let Some(m) = &self.domain_keyword_matcher {
713            let r = m.check(domain);
714            if r.is_some() {
715                return r;
716            }
717        }
718        if let Some(m) = &self.domain_regex_set {
719            for (t, r) in m {
720                if r.is_match(domain) {
721                    return Some(t);
722                }
723            }
724        }
725        None
726    }
727}
728
729/// cargo test test_logic -- --nocapture
730#[test]
731fn test_logic() {
732    let rule_str = "OR,((DOMAIN-KEYWORD,bili),(DOMAIN-REGEX,(?i)pcdn|mcdn))";
733    let rule0 = parse_rule(rule_str).unwrap();
734    println!("{:#?}", rule0);
735    let rule_str = "AND,((DOMAIN-KEYWORD,bili),(DOMAIN-REGEX,(?i)pcdn|mcdn))";
736    let rule = parse_rule(rule_str).unwrap();
737    println!("{:#?}", rule);
738    assert!(rule0.matches(&RuleInput {
739        domain: Some("pcdbili.com".to_string()),
740        ..Default::default()
741    }));
742    assert!(!rule.matches(&RuleInput {
743        domain: Some("pcdbili.com".to_string()),
744        ..Default::default()
745    }));
746    assert!(rule.matches(&RuleInput {
747        domain: Some("pcdn.bili.com".to_string()),
748        ..Default::default()
749    }));
750    let rule_str = "AND,((OR,((DOMAIN-KEYWORD,bili),(DOMAIN,0))),(DOMAIN-REGEX,(?i)pcdn|mcdn))";
751    let rule = parse_rule(rule_str);
752    println!("{:#?}", rule);
753}
754
755#[derive(Debug)]
756pub enum Rule {
757    And(Vec<Box<Rule>>),
758    Or(Vec<Box<Rule>>),
759    Not(Box<Rule>),
760    Domain(String),
761    DomainSuffix(String),
762    DomainKeyword(String),
763    DomainRegex(regex::Regex),
764    IpCidr(Ipv4Net),
765    IpCidr6(Ipv6Net),
766    GeoIp(String),
767    Network(String),
768    DstPort(u16),
769    ProcessName(String),
770    Match,
771    Other(String, String),
772}
773
774#[derive(Default, Debug)]
775pub struct RuleInput {
776    pub domain: Option<String>,
777    pub process_name: Option<String>,
778    pub network: Option<String>,
779    pub ip: Option<IpAddr>,
780    pub dst_port: Option<u16>,
781
782    /// for geoip
783    #[cfg(feature = "maxminddb")]
784    pub mmdb_reader: Option<std::sync::Arc<maxminddb::Reader<Vec<u8>>>>,
785}
786
787impl Rule {
788    pub fn matches(&self, input: &RuleInput) -> bool {
789        match self {
790            Rule::Match => true,
791            Rule::And(rules) => {
792                for r in rules {
793                    if !r.matches(input) {
794                        return false;
795                    }
796                }
797                true
798            }
799            Rule::Or(rules) => {
800                for r in rules {
801                    if r.matches(input) {
802                        return true;
803                    }
804                }
805                false
806            }
807            Rule::Not(rule) => !rule.matches(input),
808
809            Rule::ProcessName(p) => input
810                .process_name
811                .as_ref()
812                .is_some_and(|real_p| real_p.eq(p)),
813            Rule::Network(n) => input.network.as_ref().is_some_and(|d| d.eq(n)),
814            Rule::Domain(domain) => input.domain.as_ref().is_some_and(|d| d.eq(domain)),
815            Rule::DomainRegex(r) => input.domain.as_ref().is_some_and(|d| r.is_match(d)),
816            Rule::DomainSuffix(suffix) => {
817                input.domain.as_ref().is_some_and(|d| d.ends_with(suffix))
818            }
819            Rule::DomainKeyword(k) => input.domain.as_ref().is_some_and(|d| d.contains(k)),
820            Rule::IpCidr6(k) => input.ip.as_ref().is_some_and(|ip| {
821                if let IpAddr::V6(i) = ip {
822                    k.contains(i)
823                } else {
824                    false
825                }
826            }),
827            Rule::IpCidr(k) => input.ip.as_ref().is_some_and(|ip| {
828                if let IpAddr::V4(i) = ip {
829                    k.contains(i)
830                } else {
831                    false
832                }
833            }),
834            #[cfg(feature = "maxminddb")]
835            Rule::GeoIp(region) => input.ip.is_some_and(|ip| {
836                if let Some(m) = &input.mmdb_reader {
837                    let iso = get_ip_iso_by_reader(ip, m);
838                    iso.eq(region)
839                } else {
840                    false
841                }
842            }),
843            Rule::DstPort(d) => input.dst_port.is_some_and(|rd| rd == *d),
844            _ => false,
845        }
846    }
847}
848
849use thiserror::Error;
850
851#[derive(Error, Debug)]
852pub enum ParseRuleError {
853    #[error("not wrapped with ()")]
854    E1,
855    #[error("regex error")]
856    Regex(#[from] regex::Error),
857    #[error("parse ipcidr err")]
858    ParseIpnet(#[from] ipnet::AddrParseError),
859    #[error("parse dst port err")]
860    ParseNum(#[from] ParseIntError),
861}
862
863///eg: DOMAIN-KEYWORD,bili
864///eg: AND,((DOMAIN-KEYWORD,bili),(DOMAIN-REGEX,(?i)pcdn|mcdn))
865pub fn parse_rule(input: &str) -> Result<Rule, ParseRuleError> {
866    if input == MATCH {
867        return Ok(Rule::Match);
868    }
869    let (rt, r) = input.split_once(",").unwrap();
870
871    if rt.eq(AND) || rt.eq(OR) || rt.eq(NOT) {
872        if !(r.starts_with('(') && r.ends_with(')')) {
873            return Err(ParseRuleError::E1);
874        }
875        let r = &r[1..r.len() - 1];
876        let mut subrules: Vec<_> = extract_sub_rules_from(r)
877            .iter()
878            .map(|s| Box::new(parse_rule(s).unwrap()))
879            .collect();
880        let r = match rt {
881            AND => Rule::And(subrules),
882            OR => Rule::Or(subrules),
883            NOT => {
884                let b = subrules.pop().unwrap();
885                Rule::Not(b)
886            }
887            _ => unreachable!("ur"),
888        };
889        Ok(r)
890    } else {
891        let mut r = r.to_string();
892        let r = match rt {
893            DOMAIN => Rule::Domain(r),
894            DOMAIN_SUFFIX => Rule::DomainSuffix(r),
895            DOMAIN_KEYWORD => Rule::DomainKeyword(r),
896            DOMAIN_REGEX => Rule::DomainRegex(regex::Regex::new(&r)?),
897            IP_CIDR6 => {
898                if let Some(pos) = r.find(',') {
899                    r.truncate(pos);
900                }
901                let r: Ipv6Net = r.parse()?;
902                Rule::IpCidr6(r)
903            }
904            IP_CIDR => {
905                if let Some(pos) = r.find(',') {
906                    r.truncate(pos);
907                }
908                let r: Ipv4Net = r.parse()?;
909                Rule::IpCidr(r)
910            }
911            GEOIP => Rule::GeoIp(r),
912            DST_PORT => Rule::DstPort(r.parse()?),
913            NETWORK => Rule::Network(r),
914            PROCESS_NAME => Rule::ProcessName(r),
915            _ => Rule::Other(rt.to_string(), r),
916        };
917        Ok(r)
918    }
919}
920
921/// panics if malformed.
922///
923/// input eg: (DOMAIN-KEYWORD,bili),(DOMAIN-REGEX,(?i)pcdn|mcdn)
924///
925/// input eg: (AND,((DOMAIN,1),(DOMAIN,2))),(DOMAIN-REGEX,(?i)pcdn|mcdn)
926fn extract_sub_rules_from(input: &str) -> Vec<String> {
927    if input.starts_with('(') {
928        let mut v = vec![];
929        let mut lbi = 0;
930        loop {
931            let rbi = find_matching_bracket(input, lbi).unwrap();
932            let s = &input[lbi + 1..rbi];
933            v.push(s.trim().to_string());
934            if rbi + 1 == input.len() {
935                break;
936            }
937
938            let commap = input[rbi + 1..].find(',').unwrap();
939            let bp = input[rbi + 1..].find('(').unwrap();
940            if commap >= bp {
941                panic!("commap >= bp")
942            }
943
944            lbi = bp + rbi + 1;
945        }
946        v
947    } else {
948        let r = &input[1..input.len() - 1];
949        vec![r.trim().to_string()]
950    }
951}
952
953fn find_matching_bracket(text: &str, left_bracket_index: usize) -> Option<usize> {
954    let mut stack = 0;
955    for (i, c) in text[left_bracket_index..].char_indices() {
956        match c {
957            '(' => stack += 1,
958            ')' => {
959                stack -= 1;
960                if stack == 0 {
961                    return Some(left_bracket_index + i);
962                }
963            }
964            _ => {}
965        }
966    }
967    None
968}
969
970#[cfg(feature = "rusqlite")]
971use rusqlite::{params, Connection};
972
973/// sqlite 格式中目前支持的clash 规则名
974pub const RULE_TYPES: &[&str] = &[
975    DOMAIN,
976    DOMAIN_KEYWORD,
977    DOMAIN_SUFFIX,
978    IP_CIDR,
979    IP_CIDR6,
980    PROCESS_NAME,
981    GEOIP,
982];
983
984pub fn to_sql_table_name(input: &str) -> String {
985    input.replace("-", "_").to_lowercase()
986}
987pub fn to_clash_rule_name(input: &str) -> String {
988    input.replace("_", "-").to_uppercase()
989}
990
991/// 初始化 SQLite 数据库,为每种规则类型创建一个独立的表
992#[cfg(feature = "rusqlite")]
993pub fn init_db(conn: &Connection) -> rusqlite::Result<()> {
994    for &table in RULE_TYPES {
995        let create_table_sql = format!(
996            "CREATE TABLE IF NOT EXISTS {} (
997                id INTEGER PRIMARY KEY AUTOINCREMENT,
998                content TEXT NOT NULL,
999                target TEXT NOT NULL
1000            )",
1001            to_sql_table_name(table)
1002        );
1003        conn.execute(&create_table_sql, [])?;
1004    }
1005
1006    // 创建 rules_view 视图
1007    let create_view_sql = "
1008        CREATE VIEW IF NOT EXISTS rules_view AS
1009        SELECT 'DOMAIN' AS rule_name, content, target FROM domain
1010        UNION ALL
1011        SELECT 'DOMAIN-SUFFIX', content, target FROM domain_suffix
1012        UNION ALL
1013        SELECT 'DOMAIN-KEYWORD', content, target FROM domain_keyword
1014        UNION ALL
1015        SELECT 'IP-CIDR', content, target FROM ip_cidr
1016        UNION ALL
1017        SELECT 'IP-CIDR6', content, target FROM ip_cidr6
1018        UNION ALL
1019        SELECT 'PROCESS-NAME', content, target FROM process_name
1020        UNION ALL
1021        SELECT 'GEOIP', content, target FROM geoip;
1022    ";
1023    conn.execute(create_view_sql, [])?;
1024    Ok(())
1025}
1026
1027/// query from all rule tables
1028///
1029/// eg: let sql = "SELECT rule_name, content, target FROM rules_view";
1030pub fn query_rules_view(
1031    conn: &Connection,
1032    sql: &str,
1033) -> rusqlite::Result<HashMap<String, Vec<Vec<String>>>> {
1034    let mut rules_map: HashMap<String, Vec<Vec<String>>> = HashMap::new();
1035
1036    // let mut stmt = conn.prepare("SELECT rule_name, content, target_label FROM rules_view")?;
1037    let mut stmt = conn.prepare(sql)?;
1038    let rows = stmt.query_map([], |row| {
1039        let rule_name: String = row.get(0)?;
1040        let content: String = row.get(1)?;
1041        let target_label: String = row.get(2)?;
1042        Ok((rule_name, vec![content, target_label]))
1043    })?;
1044
1045    for row in rows {
1046        let (rule_name, entry) = row?;
1047        rules_map.entry(rule_name).or_default().push(entry);
1048    }
1049
1050    Ok(rules_map)
1051}
1052
1053/// 将 HashMap<String, Vec<Vec<String>>> 存入 SQLite,使用多个表
1054#[cfg(feature = "rusqlite")]
1055pub fn save_to_sqlite(
1056    conn: &mut Connection,
1057    rules: &HashMap<String, Vec<Vec<String>>>,
1058) -> rusqlite::Result<()> {
1059    let tx = conn.transaction()?;
1060
1061    for (rule_name, entries) in rules {
1062        if !RULE_TYPES.contains(&rule_name.as_str()) {
1063            continue;
1064        }
1065        // 确保规则名对应一个表
1066        let table_name = to_sql_table_name(rule_name);
1067
1068        for entry in entries {
1069            if entry.len() < 2 {
1070                continue; // 确保 entry 格式正确:[内容, 目标标签]
1071            }
1072            let content = &entry[0];
1073            let target = &entry[1];
1074
1075            let insert_sql = format!(
1076                "INSERT INTO {} (content, target) VALUES (?1, ?2)",
1077                table_name
1078            );
1079
1080            tx.execute(&insert_sql, params![content, target])?;
1081        }
1082    }
1083
1084    tx.commit()?;
1085    Ok(())
1086}
1087
1088/// 从 SQLite 读取数据,并转换为 HashMap<String, Vec<Vec<String>>> 格式
1089#[cfg(feature = "rusqlite")]
1090pub fn load_from_sqlite(conn: &Connection) -> rusqlite::Result<HashMap<String, Vec<Vec<String>>>> {
1091    let mut rules_map: HashMap<String, Vec<Vec<String>>> = HashMap::new();
1092
1093    for &table in RULE_TYPES {
1094        let rule_name = to_sql_table_name(table);
1095
1096        let mut stmt = conn.prepare(&format!("SELECT content, target FROM {}", rule_name))?;
1097        let rows = stmt.query_map([], |row| {
1098            let content: String = row.get(0)?;
1099            let target: String = row.get(1)?;
1100            Ok(vec![content, target])
1101        })?;
1102
1103        for row in rows {
1104            rules_map
1105                .entry(to_clash_rule_name(&rule_name))
1106                .or_default()
1107                .push(row?);
1108        }
1109    }
1110
1111    Ok(rules_map)
1112}
1113/// 新增规则
1114#[cfg(feature = "rusqlite")]
1115pub fn add_rule(
1116    conn: &Connection,
1117    rule_name: &str,
1118    content: &str,
1119    target: &str,
1120) -> rusqlite::Result<()> {
1121    let table_name = to_sql_table_name(rule_name);
1122    let insert_sql = format!(
1123        "INSERT INTO {} (content, target) VALUES (?1, ?2)",
1124        table_name
1125    );
1126    conn.execute(&insert_sql, params![content, target])?;
1127    Ok(())
1128}
1129
1130/// 删除规则(根据内容删除)
1131#[cfg(feature = "rusqlite")]
1132pub fn delete_rule(conn: &Connection, rule_name: &str, content: &str) -> rusqlite::Result<()> {
1133    let table_name = to_sql_table_name(rule_name);
1134    let delete_sql = format!("DELETE FROM {} WHERE content = ?1", table_name);
1135    conn.execute(&delete_sql, params![content])?;
1136    Ok(())
1137}
1138
1139/// 更新规则(修改目标标签)
1140#[cfg(feature = "rusqlite")]
1141pub fn update_rule(
1142    conn: &Connection,
1143    rule_name: &str,
1144    content: &str,
1145    new_target: &str,
1146) -> rusqlite::Result<()> {
1147    let table_name = to_sql_table_name(rule_name);
1148    let update_sql = format!("UPDATE {} SET target = ?1 WHERE content = ?2", table_name);
1149    conn.execute(&update_sql, params![new_target, content])?;
1150    Ok(())
1151}
1152
1153/// 查询特定规则类型的所有数据
1154#[cfg(feature = "rusqlite")]
1155pub fn query_rule(conn: &Connection, rule_name: &str) -> rusqlite::Result<Vec<Vec<String>>> {
1156    let table_name = to_sql_table_name(rule_name);
1157    let mut stmt = conn.prepare(&format!("SELECT content, target FROM {}", table_name))?;
1158    let rows = stmt.query_map([], |row| {
1159        let content: String = row.get(0)?;
1160        let target: String = row.get(1)?;
1161        Ok(vec![content, target])
1162    })?;
1163
1164    let mut result = Vec::new();
1165    for row in rows {
1166        result.push(row?);
1167    }
1168    Ok(result)
1169}
1170
1171#[cfg(feature = "rusqlite")]
1172#[test]
1173/// cargo test -- --nocapture
1174fn test_sql() -> rusqlite::Result<()> {
1175    println!("init");
1176    let mut conn = Connection::open("rules.db")?;
1177    init_db(&conn)?;
1178
1179    // 示例数据
1180    #[cfg(not(feature = "serde_yaml_ng"))]
1181    let mut rules: HashMap<String, Vec<Vec<String>>> = HashMap::new();
1182    #[cfg(feature = "serde_yaml_ng")]
1183    let mut rules = parse_rules(&load_rules_from_file("test.yaml").unwrap());
1184    rules.insert(
1185        "DOMAIN".to_string(),
1186        vec![
1187            vec!["example.com".to_string(), "proxy".to_string()],
1188            vec!["test.com".to_string(), "direct".to_string()],
1189        ],
1190    );
1191    rules.insert(
1192        "IP-CIDR".to_string(),
1193        vec![
1194            vec!["192.168.1.0/24".to_string(), "proxy".to_string()],
1195            vec!["10.0.0.0/8".to_string(), "direct".to_string()],
1196        ],
1197    );
1198
1199    println!("save");
1200    // 存入数据库
1201    save_to_sqlite(&mut conn, &rules)?;
1202
1203    println!("load");
1204    // 读取数据库并恢复成 HashMap
1205    load_from_sqlite(&conn)?;
1206
1207    // 插入规则
1208    add_rule(&conn, "DOMAIN", "example.com", "proxy")?;
1209    add_rule(&conn, "DOMAIN", "test.com", "direct")?;
1210    add_rule(&conn, "IP-CIDR", "192.168.1.0/24", "proxy")?;
1211
1212    // 更新规则
1213    update_rule(&conn, "DOMAIN", "test.com", "proxy")?;
1214
1215    // 查询特定规则
1216    let _domain_rules = query_rule(&conn, "DOMAIN")?;
1217
1218    // 删除规则
1219    delete_rule(&conn, "DOMAIN", "example.com")?;
1220
1221    let sql = "SELECT rule_name, content, target FROM rules_view";
1222    let r = query_rules_view(&conn, sql)?;
1223    println!("all {}", r.len());
1224
1225    Ok(())
1226}
1227
1228/// 获取 `db` 中所有表的名称
1229#[cfg(feature = "rusqlite")]
1230fn get_table_names(conn: &Connection) -> rusqlite::Result<Vec<String>> {
1231    let mut stmt = conn.prepare(
1232        "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
1233    )?;
1234    let tables = stmt
1235        .query_map([], |row| row.get(0))?
1236        .collect::<Result<Vec<String>, _>>()?;
1237    Ok(tables)
1238}
1239
1240/// 将 `db2.sqlite` 的数据合并到 `db1.sqlite`
1241#[cfg(feature = "rusqlite")]
1242pub fn merge_databases(db1_path: &str, db2_path: &str) -> rusqlite::Result<()> {
1243    let conn = Connection::open(db1_path)?;
1244
1245    // 连接第二个数据库
1246    conn.execute(
1247        &format!("ATTACH DATABASE '{}' AS attached_db", db2_path),
1248        [],
1249    )?;
1250
1251    // 获取 db2.sqlite 的所有表
1252    let tables = get_table_names(&conn)?;
1253
1254    for table in tables {
1255        let sql = format!("INSERT INTO {table} SELECT * FROM attached_db.{table}");
1256        conn.execute(&sql, [])?;
1257    }
1258
1259    // 断开连接
1260    conn.execute("DETACH DATABASE attached_db", [])?;
1261
1262    Ok(())
1263}
1264
1265#[cfg(feature = "rusqlite")]
1266#[test]
1267/// cargo test merge_sql -- --nocapture
1268fn merge_sql() -> rusqlite::Result<()> {
1269    let _ = std::fs::remove_file("1.db");
1270    let _ = std::fs::remove_file("2.db");
1271    {
1272        let conn = Connection::open("1.db")?;
1273        init_db(&conn)?;
1274        add_rule(&conn, "DOMAIN", "test.com", "direct")?;
1275        let conn = Connection::open("2.db")?;
1276        init_db(&conn)?;
1277        add_rule(&conn, "IP-CIDR", "192.168.1.0/24", "proxy")?;
1278    }
1279    let db1 = "1.db";
1280    let db2 = "2.db";
1281
1282    merge_databases(db1, db2)?;
1283
1284    println!("Databases merged successfully!");
1285    Ok(())
1286}