wp_knowledge/
sqlite_ext.rs1use rusqlite::Result as SqlResult;
8use rusqlite::functions::{Context, FunctionFlags};
9
10pub fn register_builtin(conn: &rusqlite::Connection) -> SqlResult<()> {
14 conn.create_scalar_function(
16 "ip4_int",
17 1,
18 FunctionFlags::SQLITE_DETERMINISTIC,
19 |ctx: &Context| {
20 let s: String = ctx.get(0)?;
21 Ok(parse_ipv4_to_u32(&s).map(|v| v as i64).unwrap_or(0))
22 },
23 )?;
24
25 conn.create_scalar_function(
27 "cidr4_min",
28 1,
29 FunctionFlags::SQLITE_DETERMINISTIC,
30 |ctx: &Context| {
31 let s: String = ctx.get(0)?;
32 Ok(cidr4_min(&s).map(|v| v as i64).unwrap_or(0))
33 },
34 )?;
35
36 conn.create_scalar_function(
38 "cidr4_max",
39 1,
40 FunctionFlags::SQLITE_DETERMINISTIC,
41 |ctx: &Context| {
42 let s: String = ctx.get(0)?;
43 Ok(cidr4_max(&s).map(|v| v as i64).unwrap_or(0))
44 },
45 )?;
46
47 conn.create_scalar_function(
49 "ip4_between",
50 3,
51 FunctionFlags::SQLITE_DETERMINISTIC,
52 |ctx: &Context| {
53 let ip: String = ctx.get(0)?;
54 let s = ctx
56 .get::<i64>(1)
57 .ok()
58 .map(|i| i as u32)
59 .or_else(|| {
60 let s: String = ctx.get(1).ok()?;
61 parse_ipv4_to_u32(&s)
62 })
63 .unwrap_or(u32::MAX);
64 let e = ctx
65 .get::<i64>(2)
66 .ok()
67 .map(|i| i as u32)
68 .or_else(|| {
69 let s: String = ctx.get(2).ok()?;
70 parse_ipv4_to_u32(&s)
71 })
72 .unwrap_or(0);
73 let v = parse_ipv4_to_u32(&ip).unwrap_or(u32::MAX);
74 Ok(((v >= s) && (v <= e)) as i64)
75 },
76 )?;
77
78 conn.create_scalar_function(
80 "cidr4_contains",
81 2,
82 FunctionFlags::SQLITE_DETERMINISTIC,
83 |ctx: &Context| {
84 let ip_s: String = ctx.get(0)?;
85 let cidr_s: String = ctx.get(1)?;
86 let ip = match parse_ipv4_to_u32(&ip_s) {
87 Some(v) => v,
88 None => return Ok(0),
89 };
90 let (net_ip, mask) = match parse_cidr4(&cidr_s) {
91 Some(v) => v,
92 None => return Ok(0),
93 };
94 let net = u32::from(net_ip) & mask;
95 Ok(((ip & mask) == net) as i64)
96 },
97 )?;
98
99 conn.create_scalar_function(
101 "ip4_text",
102 1,
103 FunctionFlags::SQLITE_DETERMINISTIC,
104 |ctx: &Context| {
105 let val = ctx.get::<i64>(0).ok();
107 let v = if let Some(i) = val {
108 i as u32
109 } else {
110 let s: String = ctx.get(0)?;
111 match s.trim().parse::<u64>() {
112 Ok(n) => n as u32,
113 Err(_) => 0,
114 }
115 };
116 Ok(ipv4_from_u32(v))
117 },
118 )?;
119
120 conn.create_scalar_function(
122 "trim_quotes",
123 1,
124 FunctionFlags::SQLITE_DETERMINISTIC,
125 |ctx: &Context| {
126 let s: String = ctx.get(0)?;
127 Ok(trim_quotes(&s))
128 },
129 )?;
130 Ok(())
131}
132
133fn parse_ipv4_to_u32(s: &str) -> Option<u32> {
135 let t = s.trim().trim_matches('"');
137 let ip: std::net::Ipv4Addr = t.parse().ok()?;
138 Some(u32::from(ip))
139}
140
141fn cidr4_min(s: &str) -> Option<u32> {
142 let (ip, mask) = parse_cidr4(s)?;
143 Some(u32::from(ip) & mask)
144}
145
146fn cidr4_max(s: &str) -> Option<u32> {
147 let (ip, mask) = parse_cidr4(s)?;
148 Some((u32::from(ip) & mask) | !mask)
149}
150
151fn parse_cidr4(s: &str) -> Option<(std::net::Ipv4Addr, u32)> {
152 let t = s.trim().trim_matches('"');
153 let mut it = t.split('/');
154 let ip_s = it.next()?;
155 let pfx_s = it.next()?;
156 if it.next().is_some() {
157 return None;
158 }
159 let ip: std::net::Ipv4Addr = ip_s.parse().ok()?;
160 let pfx: u32 = pfx_s.parse().ok()?;
161 if pfx > 32 {
162 return None;
163 }
164 let mask = if pfx == 0 { 0 } else { u32::MAX << (32 - pfx) };
165 Some((ip, mask))
166}
167
168fn ipv4_from_u32(v: u32) -> String {
170 let ip = std::net::Ipv4Addr::from(v);
171 ip.to_string()
172}
173
174fn trim_quotes(s: &str) -> String {
176 let t = s.trim();
177 if t.len() >= 2 {
178 let b = t.as_bytes();
179 let mut hidx = 0usize; if b.len() >= 2 && b[0] == b'\\' && (b[1] == b'"' || b[1] == b'\'') {
182 hidx = 1;
183 }
184
185 if b.len() >= 2 {
186 let tidx = b.len() - 1; let head = b[hidx];
188 let tail = b[tidx];
189 if (head == b'"' && tail == b'"') || (head == b'\'' && tail == b'\'') {
190 let start = hidx + 1;
192 let mut end_excl = tidx; if tidx >= 1 && b[tidx - 1] == b'\\' {
195 end_excl = tidx - 1;
196 }
197 if start <= end_excl {
198 return t[start..end_excl].to_string();
199 }
200 return String::new();
201 }
202 }
203 }
204 t.to_string()
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use rusqlite::Connection;
211
212 #[test]
213 fn test_ip4_scalar_funcs() {
214 let conn = Connection::open_in_memory().unwrap();
215 register_builtin(&conn).unwrap();
216
217 let v: i64 = conn
218 .query_row("SELECT ip4_int('1.2.3.4')", [], |r| r.get(0))
219 .unwrap();
220 assert_eq!(v as u32, 0x01020304);
221
222 let s: String = conn
223 .query_row("SELECT ip4_text(16909060)", [], |r| r.get(0))
224 .unwrap();
225 assert_eq!(s, "1.2.3.4");
226
227 let ok: i64 = conn
228 .query_row(
229 "SELECT ip4_between('10.0.0.5','10.0.0.1','10.0.0.10')",
230 [],
231 |r| r.get(0),
232 )
233 .unwrap();
234 assert_eq!(ok, 1);
235
236 let cmin: i64 = conn
237 .query_row("SELECT cidr4_min('10.0.0.0/8')", [], |r| r.get(0))
238 .unwrap();
239 let cmax: i64 = conn
240 .query_row("SELECT cidr4_max('10.0.0.0/8')", [], |r| r.get(0))
241 .unwrap();
242 let exp_min = parse_ipv4_to_u32("10.0.0.0").unwrap() as i64;
244 let exp_max = parse_ipv4_to_u32("10.255.255.255").unwrap() as i64;
245 assert_eq!(cmin, exp_min);
246 assert_eq!(cmax, exp_max);
247
248 let contains: i64 = conn
249 .query_row(
250 "SELECT cidr4_contains('10.1.2.3','10.0.0.0/8') AS ok",
251 [],
252 |r| r.get(0),
253 )
254 .unwrap();
255 assert_eq!(contains, 1);
256
257 let z: String = conn
259 .query_row("SELECT trim_quotes(' \"work_zone\" ')", [], |r| r.get(0))
260 .unwrap();
261 assert_eq!(z, "work_zone");
262 let z2: String = conn
263 .query_row("SELECT trim_quotes('no_quotes')", [], |r| r.get(0))
264 .unwrap();
265 assert_eq!(z2, "no_quotes");
266
267 let z3: String = conn
270 .query_row("SELECT trim_quotes(?1)", ["\\\"work_zone\\\""], |r| {
271 r.get(0)
272 })
273 .unwrap();
274 assert_eq!(z3, "work_zone");
275 }
276}