geosite-rs 0.1.3

A simple crate that parses geosite.dat file format
Documentation
use std::collections::HashMap;

use prost::Message;

/// parse dat data (which has protobuf format)
pub fn read(buf: &[u8]) -> Result<SiteGroupList, prost::DecodeError> {
    SiteGroupList::decode(buf)
}

/// save to the dat format (which has protobuf format)
pub fn save(sg: SiteGroupList) -> Vec<u8> {
    sg.encode_to_vec()
}

/// covert to a hashmap that is compatible with the one in crate 'clash_rules'
///
/// key is "DOMAIN-KEYWORD","DOMAIN-SUFFIX","DOMAIN","DOMAIN-REGEX".
pub fn to_hashmap(site_group_list: &SiteGroupList) -> HashMap<String, Vec<Vec<String>>> {
    let mut map: HashMap<String, Vec<Vec<String>>> = HashMap::new();

    for group in &site_group_list.site_group {
        for domain in &group.domain {
            let key = match domain.r#type {
                0 => "DOMAIN-KEYWORD", // Plain
                1 => "DOMAIN-SUFFIX",  // Domain
                2 => "DOMAIN",         // Full
                3 => "DOMAIN-REGEX",   // Regex
                _ => continue,         // 跳过未知类型
            };

            let v = vec![domain.value.clone(), group.tag.clone()];
            map.entry(key.to_string()).or_default().push(v);
        }
    }

    map
}
//include!(concat!(env!("OUT_DIR"), "/_.rs"));
//
// This file is @generated by prost-build.
/// Domain for routing decision.
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Domain {
    /// Domain matching type.
    #[prost(enumeration = "domain::Type", tag = "1")]
    pub r#type: i32,
    /// Domain value.
    #[prost(string, tag = "2")]
    pub value: ::prost::alloc::string::String,
    /// Attributes of this domain. May be used for filtering.
    #[prost(message, repeated, tag = "3")]
    pub attribute: ::prost::alloc::vec::Vec<domain::Attribute>,
}
/// Nested message and enum types in `Domain`.
pub mod domain {
    #[derive(Clone, PartialEq, ::prost::Message)]
    pub struct Attribute {
        #[prost(string, tag = "1")]
        pub key: ::prost::alloc::string::String,
        #[prost(oneof = "attribute::TypedValue", tags = "2, 3")]
        pub typed_value: ::core::option::Option<attribute::TypedValue>,
    }
    /// Nested message and enum types in `Attribute`.
    pub mod attribute {
        #[derive(Clone, Copy, PartialEq, ::prost::Oneof)]
        pub enum TypedValue {
            #[prost(bool, tag = "2")]
            BoolValue(bool),
            #[prost(int64, tag = "3")]
            IntValue(i64),
        }
    }
    /// Type of domain value.
    #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
    #[repr(i32)]
    pub enum Type {
        /// The value is used as is.
        Plain = 0,
        /// The value is used as a regular expression.
        Regex = 1,
        /// The value is a root domain.
        Domain = 2,
        /// The value is a domain.
        Full = 3,
    }
    impl Type {
        /// String value of the enum field names used in the ProtoBuf definition.
        ///
        /// The values are not transformed in any way and thus are considered stable
        /// (if the ProtoBuf definition does not change) and safe for programmatic use.
        pub fn as_str_name(&self) -> &'static str {
            match self {
                Self::Plain => "Plain",
                Self::Regex => "Regex",
                Self::Domain => "Domain",
                Self::Full => "Full",
            }
        }
        /// Creates an enum from field names used in the ProtoBuf definition.
        pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
            match value {
                "Plain" => Some(Self::Plain),
                "Regex" => Some(Self::Regex),
                "Domain" => Some(Self::Domain),
                "Full" => Some(Self::Full),
                _ => None,
            }
        }
    }
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct SiteGroup {
    #[prost(string, tag = "1")]
    pub tag: ::prost::alloc::string::String,
    #[prost(message, repeated, tag = "2")]
    pub domain: ::prost::alloc::vec::Vec<Domain>,
}

/// the final dat file has this type
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct SiteGroupList {
    #[prost(message, repeated, tag = "1")]
    pub site_group: ::prost::alloc::vec::Vec<SiteGroup>,
}

#[cfg(feature = "rusqlite")]
pub use rusqlite;

#[cfg(feature = "rusqlite")]
use rusqlite::{params, Connection, Result};

/// sqlite 格式中目前支持的clash 规则名
pub const RULE_TABLE_NAMES: &[&str] = &[
    "domain",
    "domain_keyword",
    "domain_suffix",
    "domain_regex",
    "ip_cidr",
    "ip_cidr6",
    "process_name",
    "geoip",
];
/// 初始化数据库
///
/// create eg: let mut conn = Connection::open("rules.db")?;
#[cfg(feature = "rusqlite")]
pub fn init_db(conn: &Connection) -> Result<()> {
    for &table in RULE_TABLE_NAMES {
        let create_table_sql = format!(
            "CREATE TABLE IF NOT EXISTS {} (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                content TEXT NOT NULL,
                target TEXT NOT NULL
            )",
            table,
        );
        conn.execute(&create_table_sql, [])?;
    }
    Ok(())
}

/// 将 `SiteGroupList` 保存到 SQLite
#[cfg(feature = "rusqlite")]
pub fn save_to_sqlite(conn: &mut Connection, site_group_list: &SiteGroupList) -> Result<()> {
    let tx = conn.transaction()?;

    for group in &site_group_list.site_group {
        for domain in &group.domain {
            let table = match domain.r#type {
                0 => "domain_keyword", // Plain
                1 => "domain_suffix",  // Domain
                2 => "domain",         // Full
                3 => "domain_regex",   // Regex
                _ => continue,         // 跳过未知类型
            };

            tx.execute(
                &format!(
                    "INSERT OR IGNORE INTO {} (content, target) VALUES (?1, ?2)",
                    table
                ),
                params![domain.value, group.tag],
            )?;
        }
    }

    tx.commit()
}

/// 从 SQLite 加载 `SiteGroupList`
#[cfg(feature = "rusqlite")]
pub fn load_from_sqlite(conn: &Connection) -> Result<SiteGroupList> {
    let rule_tables = &[
        ("domain_keyword", domain::Type::Plain),
        ("domain_suffix", domain::Type::Domain),
        ("domain", domain::Type::Full),
        ("domain_regex", domain::Type::Regex),
    ];

    let mut groups_map = std::collections::HashMap::<String, Vec<Domain>>::new();

    for &(table, domain_type) in rule_tables {
        let mut stmt = conn.prepare(&format!("SELECT content, target FROM {}", table))?;
        let rows = stmt.query_map([], |row| {
            let name: String = row.get(0)?;
            // 以 target 作为 group name
            let group_name: String = row.get(1)?;
            Ok((name, group_name))
        })?;

        for row in rows {
            let (name, group_name) = row?;
            groups_map.entry(group_name).or_default().push(Domain {
                value: name,
                r#type: domain_type as i32,
                attribute: vec![],
            });
        }
    }

    let groups = groups_map
        .into_iter()
        .map(|(group_name, domains)| SiteGroup {
            tag: group_name,
            domain: domains,
        })
        .collect();

    Ok(SiteGroupList { site_group: groups })
}

/// 示例测试
/// cargo test -F rusqlite -- --nocapture
#[cfg(feature = "rusqlite")]
#[test]
fn test_sql() -> Result<(), Box<dyn std::error::Error>> {
    let mut conn = Connection::open("rules.db")?;
    init_db(&conn)?;

    use std::fs;
    let buf = fs::read("geosite.dat")?;
    let site_group_list = read(&buf)?;

    save_to_sqlite(&mut conn, &site_group_list)?;

    let loaded_site_group_list = load_from_sqlite(&conn)?;
    println!(
        "Loaded SiteGroupList: {:?}",
        loaded_site_group_list.site_group.len()
    );

    Ok(())
}