use std::collections::HashMap;
use prost::Message;
pub fn read(buf: &[u8]) -> Result<SiteGroupList, prost::DecodeError> {
SiteGroupList::decode(buf)
}
pub fn save(sg: SiteGroupList) -> Vec<u8> {
sg.encode_to_vec()
}
pub fn to_hashmap(site_group_list: &SiteGroupList) -> HashMap<String, Vec<Vec<String>>> {
let mut map: HashMap<String, Vec<Vec<String>>> = HashMap::new();
for group in &site_group_list.site_group {
for domain in &group.domain {
let key = match domain.r#type {
0 => "DOMAIN-KEYWORD", 1 => "DOMAIN-SUFFIX", 2 => "DOMAIN", 3 => "DOMAIN-REGEX", _ => continue, };
let v = vec![domain.value.clone(), group.tag.clone()];
map.entry(key.to_string()).or_default().push(v);
}
}
map
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Domain {
#[prost(enumeration = "domain::Type", tag = "1")]
pub r#type: i32,
#[prost(string, tag = "2")]
pub value: ::prost::alloc::string::String,
#[prost(message, repeated, tag = "3")]
pub attribute: ::prost::alloc::vec::Vec<domain::Attribute>,
}
pub mod domain {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Attribute {
#[prost(string, tag = "1")]
pub key: ::prost::alloc::string::String,
#[prost(oneof = "attribute::TypedValue", tags = "2, 3")]
pub typed_value: ::core::option::Option<attribute::TypedValue>,
}
pub mod attribute {
#[derive(Clone, Copy, PartialEq, ::prost::Oneof)]
pub enum TypedValue {
#[prost(bool, tag = "2")]
BoolValue(bool),
#[prost(int64, tag = "3")]
IntValue(i64),
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
#[repr(i32)]
pub enum Type {
Plain = 0,
Regex = 1,
Domain = 2,
Full = 3,
}
impl Type {
pub fn as_str_name(&self) -> &'static str {
match self {
Self::Plain => "Plain",
Self::Regex => "Regex",
Self::Domain => "Domain",
Self::Full => "Full",
}
}
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
match value {
"Plain" => Some(Self::Plain),
"Regex" => Some(Self::Regex),
"Domain" => Some(Self::Domain),
"Full" => Some(Self::Full),
_ => None,
}
}
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct SiteGroup {
#[prost(string, tag = "1")]
pub tag: ::prost::alloc::string::String,
#[prost(message, repeated, tag = "2")]
pub domain: ::prost::alloc::vec::Vec<Domain>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct SiteGroupList {
#[prost(message, repeated, tag = "1")]
pub site_group: ::prost::alloc::vec::Vec<SiteGroup>,
}
#[cfg(feature = "rusqlite")]
pub use rusqlite;
#[cfg(feature = "rusqlite")]
use rusqlite::{params, Connection, Result};
pub const RULE_TABLE_NAMES: &[&str] = &[
"domain",
"domain_keyword",
"domain_suffix",
"domain_regex",
"ip_cidr",
"ip_cidr6",
"process_name",
"geoip",
];
#[cfg(feature = "rusqlite")]
pub fn init_db(conn: &Connection) -> Result<()> {
for &table in RULE_TABLE_NAMES {
let create_table_sql = format!(
"CREATE TABLE IF NOT EXISTS {} (
id INTEGER PRIMARY KEY AUTOINCREMENT,
content TEXT NOT NULL,
target TEXT NOT NULL
)",
table,
);
conn.execute(&create_table_sql, [])?;
}
Ok(())
}
#[cfg(feature = "rusqlite")]
pub fn save_to_sqlite(conn: &mut Connection, site_group_list: &SiteGroupList) -> Result<()> {
let tx = conn.transaction()?;
for group in &site_group_list.site_group {
for domain in &group.domain {
let table = match domain.r#type {
0 => "domain_keyword", 1 => "domain_suffix", 2 => "domain", 3 => "domain_regex", _ => continue, };
tx.execute(
&format!(
"INSERT OR IGNORE INTO {} (content, target) VALUES (?1, ?2)",
table
),
params![domain.value, group.tag],
)?;
}
}
tx.commit()
}
#[cfg(feature = "rusqlite")]
pub fn load_from_sqlite(conn: &Connection) -> Result<SiteGroupList> {
let rule_tables = &[
("domain_keyword", domain::Type::Plain),
("domain_suffix", domain::Type::Domain),
("domain", domain::Type::Full),
("domain_regex", domain::Type::Regex),
];
let mut groups_map = std::collections::HashMap::<String, Vec<Domain>>::new();
for &(table, domain_type) in rule_tables {
let mut stmt = conn.prepare(&format!("SELECT content, target FROM {}", table))?;
let rows = stmt.query_map([], |row| {
let name: String = row.get(0)?;
let group_name: String = row.get(1)?;
Ok((name, group_name))
})?;
for row in rows {
let (name, group_name) = row?;
groups_map.entry(group_name).or_default().push(Domain {
value: name,
r#type: domain_type as i32,
attribute: vec![],
});
}
}
let groups = groups_map
.into_iter()
.map(|(group_name, domains)| SiteGroup {
tag: group_name,
domain: domains,
})
.collect();
Ok(SiteGroupList { site_group: groups })
}
#[cfg(feature = "rusqlite")]
#[test]
fn test_sql() -> Result<(), Box<dyn std::error::Error>> {
let mut conn = Connection::open("rules.db")?;
init_db(&conn)?;
use std::fs;
let buf = fs::read("geosite.dat")?;
let site_group_list = read(&buf)?;
save_to_sqlite(&mut conn, &site_group_list)?;
let loaded_site_group_list = load_from_sqlite(&conn)?;
println!(
"Loaded SiteGroupList: {:?}",
loaded_site_group_list.site_group.len()
);
Ok(())
}