Skip to main content

wp_knowledge/
sqlite_ext.rs

1//! SQLite 扩展:在 rusqlite 连接上注册内置 UDF,供 KnowDB 导入/查询使用。
2//!
3//! 说明:
4//! - 每个新建的 Connection 都需要注册一次(权威库写连接、只读线程克隆连接分别注册)。
5//! - 本模块仅包含轻量、与 IP/CIDR 相关的函数;字符串类函数请优先使用 SQLite 内置的 lower/upper/trim。
6
7use rusqlite::Result as SqlResult;
8use rusqlite::functions::{Context, FunctionFlags};
9
10/// 注册内置 UDF 到给定连接。
11/// 注意:需在每个新建的 Connection 上调用一次(writer/reader 各自注册)。
12/// 在给定连接上注册内置函数集合。
13pub fn register_builtin(conn: &rusqlite::Connection) -> SqlResult<()> {
14    // ip4_int(text) -> integer(u32 转 i64)
15    conn.create_scalar_function(
16        "ip4_int",
17        1,
18        FunctionFlags::SQLITE_DETERMINISTIC,
19        |ctx: &Context| {
20            let s: String = ctx.get(0)?;
21            Ok(parse_ipv4_to_u32(&s).map(|v| v as i64).unwrap_or(0))
22        },
23    )?;
24
25    // cidr4_min(text) -> integer(CIDR 起始地址)
26    conn.create_scalar_function(
27        "cidr4_min",
28        1,
29        FunctionFlags::SQLITE_DETERMINISTIC,
30        |ctx: &Context| {
31            let s: String = ctx.get(0)?;
32            Ok(cidr4_min(&s).map(|v| v as i64).unwrap_or(0))
33        },
34    )?;
35
36    // cidr4_max(text) -> integer(CIDR 结束地址,含)
37    conn.create_scalar_function(
38        "cidr4_max",
39        1,
40        FunctionFlags::SQLITE_DETERMINISTIC,
41        |ctx: &Context| {
42            let s: String = ctx.get(0)?;
43            Ok(cidr4_max(&s).map(|v| v as i64).unwrap_or(0))
44        },
45    )?;
46
47    // ip4_between(ip, start, end) -> integer (1/0)
48    conn.create_scalar_function(
49        "ip4_between",
50        3,
51        FunctionFlags::SQLITE_DETERMINISTIC,
52        |ctx: &Context| {
53            let ip: String = ctx.get(0)?;
54            // 参数 1/2 既可能是整数列(*_int),也可能是字符串;两种都尝试
55            let s = ctx
56                .get::<i64>(1)
57                .ok()
58                .map(|i| i as u32)
59                .or_else(|| {
60                    let s: String = ctx.get(1).ok()?;
61                    parse_ipv4_to_u32(&s)
62                })
63                .unwrap_or(u32::MAX);
64            let e = ctx
65                .get::<i64>(2)
66                .ok()
67                .map(|i| i as u32)
68                .or_else(|| {
69                    let s: String = ctx.get(2).ok()?;
70                    parse_ipv4_to_u32(&s)
71                })
72                .unwrap_or(0);
73            let v = parse_ipv4_to_u32(&ip).unwrap_or(u32::MAX);
74            Ok(((v >= s) && (v <= e)) as i64)
75        },
76    )?;
77
78    // cidr4_contains(ip, cidr) -> integer (1/0)
79    conn.create_scalar_function(
80        "cidr4_contains",
81        2,
82        FunctionFlags::SQLITE_DETERMINISTIC,
83        |ctx: &Context| {
84            let ip_s: String = ctx.get(0)?;
85            let cidr_s: String = ctx.get(1)?;
86            let ip = match parse_ipv4_to_u32(&ip_s) {
87                Some(v) => v,
88                None => return Ok(0),
89            };
90            let (net_ip, mask) = match parse_cidr4(&cidr_s) {
91                Some(v) => v,
92                None => return Ok(0),
93            };
94            let net = u32::from(net_ip) & mask;
95            Ok(((ip & mask) == net) as i64)
96        },
97    )?;
98
99    // ip4_text(integer) -> text
100    conn.create_scalar_function(
101        "ip4_text",
102        1,
103        FunctionFlags::SQLITE_DETERMINISTIC,
104        |ctx: &Context| {
105            // 同时支持整型或可解析为整型的字符串
106            let val = ctx.get::<i64>(0).ok();
107            let v = if let Some(i) = val {
108                i as u32
109            } else {
110                let s: String = ctx.get(0)?;
111                match s.trim().parse::<u64>() {
112                    Ok(n) => n as u32,
113                    Err(_) => 0,
114                }
115            };
116            Ok(ipv4_from_u32(v))
117        },
118    )?;
119
120    // trim_quotes(text) -> text:去除两端成对引号(支持 ' 或 "),容忍前后空白
121    conn.create_scalar_function(
122        "trim_quotes",
123        1,
124        FunctionFlags::SQLITE_DETERMINISTIC,
125        |ctx: &Context| {
126            let s: String = ctx.get(0)?;
127            Ok(trim_quotes(&s))
128        },
129    )?;
130    Ok(())
131}
132
133/// 解析点分 IPv4 为 u32;容忍前后空白与引号。
134fn parse_ipv4_to_u32(s: &str) -> Option<u32> {
135    // 允许带空白/引号
136    let t = s.trim().trim_matches('"');
137    let ip: std::net::Ipv4Addr = t.parse().ok()?;
138    Some(u32::from(ip))
139}
140
141fn cidr4_min(s: &str) -> Option<u32> {
142    let (ip, mask) = parse_cidr4(s)?;
143    Some(u32::from(ip) & mask)
144}
145
146fn cidr4_max(s: &str) -> Option<u32> {
147    let (ip, mask) = parse_cidr4(s)?;
148    Some((u32::from(ip) & mask) | !mask)
149}
150
151fn parse_cidr4(s: &str) -> Option<(std::net::Ipv4Addr, u32)> {
152    let t = s.trim().trim_matches('"');
153    let mut it = t.split('/');
154    let ip_s = it.next()?;
155    let pfx_s = it.next()?;
156    if it.next().is_some() {
157        return None;
158    }
159    let ip: std::net::Ipv4Addr = ip_s.parse().ok()?;
160    let pfx: u32 = pfx_s.parse().ok()?;
161    if pfx > 32 {
162        return None;
163    }
164    let mask = if pfx == 0 { 0 } else { u32::MAX << (32 - pfx) };
165    Some((ip, mask))
166}
167
168/// 将整数 IPv4 转为点分字符串。
169fn ipv4_from_u32(v: u32) -> String {
170    let ip = std::net::Ipv4Addr::from(v);
171    ip.to_string()
172}
173
174/// 去除两端成对引号(' 或 "),先 trim 再判断;未成对则返回 trim 后原串
175fn trim_quotes(s: &str) -> String {
176    let t = s.trim();
177    if t.len() >= 2 {
178        let b = t.as_bytes();
179        let mut hidx = 0usize; // 参与配对判断的“头部”索引
180        // 开头允许 `\"` 或 `\'`:跳过反斜杠,仅用于配对判断
181        if b.len() >= 2 && b[0] == b'\\' && (b[1] == b'"' || b[1] == b'\'') {
182            hidx = 1;
183        }
184
185        if b.len() >= 2 {
186            let tidx = b.len() - 1; // 尾部引号所在下标(若确认为引号)
187            let head = b[hidx];
188            let tail = b[tidx];
189            if (head == b'"' && tail == b'"') || (head == b'\'' && tail == b'\'') {
190                // 生成去除头尾引号(以及尾部可能存在的反斜杠)的子串边界
191                let start = hidx + 1;
192                let mut end_excl = tidx; // 先排除尾部引号
193                // 如果尾部是转义形式(... \\" 或 ... \\\'),也排除反斜杠
194                if tidx >= 1 && b[tidx - 1] == b'\\' {
195                    end_excl = tidx - 1;
196                }
197                if start <= end_excl {
198                    return t[start..end_excl].to_string();
199                }
200                return String::new();
201            }
202        }
203    }
204    t.to_string()
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use rusqlite::Connection;
211
212    #[test]
213    fn test_ip4_scalar_funcs() {
214        let conn = Connection::open_in_memory().unwrap();
215        register_builtin(&conn).unwrap();
216
217        let v: i64 = conn
218            .query_row("SELECT ip4_int('1.2.3.4')", [], |r| r.get(0))
219            .unwrap();
220        assert_eq!(v as u32, 0x01020304);
221
222        let s: String = conn
223            .query_row("SELECT ip4_text(16909060)", [], |r| r.get(0))
224            .unwrap();
225        assert_eq!(s, "1.2.3.4");
226
227        let ok: i64 = conn
228            .query_row(
229                "SELECT ip4_between('10.0.0.5','10.0.0.1','10.0.0.10')",
230                [],
231                |r| r.get(0),
232            )
233            .unwrap();
234        assert_eq!(ok, 1);
235
236        let cmin: i64 = conn
237            .query_row("SELECT cidr4_min('10.0.0.0/8')", [], |r| r.get(0))
238            .unwrap();
239        let cmax: i64 = conn
240            .query_row("SELECT cidr4_max('10.0.0.0/8')", [], |r| r.get(0))
241            .unwrap();
242        // 10.0.0.0/8 => 10.0.0.0 .. 10.255.255.255
243        let exp_min = parse_ipv4_to_u32("10.0.0.0").unwrap() as i64;
244        let exp_max = parse_ipv4_to_u32("10.255.255.255").unwrap() as i64;
245        assert_eq!(cmin, exp_min);
246        assert_eq!(cmax, exp_max);
247
248        let contains: i64 = conn
249            .query_row(
250                "SELECT cidr4_contains('10.1.2.3','10.0.0.0/8') AS ok",
251                [],
252                |r| r.get(0),
253            )
254            .unwrap();
255        assert_eq!(contains, 1);
256
257        // trim_quotes
258        let z: String = conn
259            .query_row("SELECT trim_quotes('  \"work_zone\"  ')", [], |r| r.get(0))
260            .unwrap();
261        assert_eq!(z, "work_zone");
262        let z2: String = conn
263            .query_row("SELECT trim_quotes('no_quotes')", [], |r| r.get(0))
264            .unwrap();
265        assert_eq!(z2, "no_quotes");
266
267        // 支持反斜杠转义的成对引号
268        // 注意:在 Rust 源码里书写 SQL 字符串时需要多重转义,这里改为绑定参数方式更直观
269        let z3: String = conn
270            .query_row("SELECT trim_quotes(?1)", ["\\\"work_zone\\\""], |r| {
271                r.get(0)
272            })
273            .unwrap();
274        assert_eq!(z3, "work_zone");
275    }
276}