use rusqlite::Result as SqlResult;
use rusqlite::functions::{Context, FunctionFlags};
pub fn register_builtin(conn: &rusqlite::Connection) -> SqlResult<()> {
conn.create_scalar_function(
"ip4_int",
1,
FunctionFlags::SQLITE_DETERMINISTIC,
|ctx: &Context| {
let s: String = ctx.get(0)?;
Ok(parse_ipv4_to_u32(&s).map(|v| v as i64).unwrap_or(0))
},
)?;
conn.create_scalar_function(
"cidr4_min",
1,
FunctionFlags::SQLITE_DETERMINISTIC,
|ctx: &Context| {
let s: String = ctx.get(0)?;
Ok(cidr4_min(&s).map(|v| v as i64).unwrap_or(0))
},
)?;
conn.create_scalar_function(
"cidr4_max",
1,
FunctionFlags::SQLITE_DETERMINISTIC,
|ctx: &Context| {
let s: String = ctx.get(0)?;
Ok(cidr4_max(&s).map(|v| v as i64).unwrap_or(0))
},
)?;
conn.create_scalar_function(
"ip4_between",
3,
FunctionFlags::SQLITE_DETERMINISTIC,
|ctx: &Context| {
let ip: String = ctx.get(0)?;
let s = ctx
.get::<i64>(1)
.ok()
.map(|i| i as u32)
.or_else(|| {
let s: String = ctx.get(1).ok()?;
parse_ipv4_to_u32(&s)
})
.unwrap_or(u32::MAX);
let e = ctx
.get::<i64>(2)
.ok()
.map(|i| i as u32)
.or_else(|| {
let s: String = ctx.get(2).ok()?;
parse_ipv4_to_u32(&s)
})
.unwrap_or(0);
let v = parse_ipv4_to_u32(&ip).unwrap_or(u32::MAX);
Ok(((v >= s) && (v <= e)) as i64)
},
)?;
conn.create_scalar_function(
"cidr4_contains",
2,
FunctionFlags::SQLITE_DETERMINISTIC,
|ctx: &Context| {
let ip_s: String = ctx.get(0)?;
let cidr_s: String = ctx.get(1)?;
let ip = match parse_ipv4_to_u32(&ip_s) {
Some(v) => v,
None => return Ok(0),
};
let (net_ip, mask) = match parse_cidr4(&cidr_s) {
Some(v) => v,
None => return Ok(0),
};
let net = u32::from(net_ip) & mask;
Ok(((ip & mask) == net) as i64)
},
)?;
conn.create_scalar_function(
"ip4_text",
1,
FunctionFlags::SQLITE_DETERMINISTIC,
|ctx: &Context| {
let val = ctx.get::<i64>(0).ok();
let v = if let Some(i) = val {
i as u32
} else {
let s: String = ctx.get(0)?;
match s.trim().parse::<u64>() {
Ok(n) => n as u32,
Err(_) => 0,
}
};
Ok(ipv4_from_u32(v))
},
)?;
conn.create_scalar_function(
"trim_quotes",
1,
FunctionFlags::SQLITE_DETERMINISTIC,
|ctx: &Context| {
let s: String = ctx.get(0)?;
Ok(trim_quotes(&s))
},
)?;
Ok(())
}
fn parse_ipv4_to_u32(s: &str) -> Option<u32> {
let t = s.trim().trim_matches('"');
let ip: std::net::Ipv4Addr = t.parse().ok()?;
Some(u32::from(ip))
}
fn cidr4_min(s: &str) -> Option<u32> {
let (ip, mask) = parse_cidr4(s)?;
Some(u32::from(ip) & mask)
}
fn cidr4_max(s: &str) -> Option<u32> {
let (ip, mask) = parse_cidr4(s)?;
Some((u32::from(ip) & mask) | !mask)
}
fn parse_cidr4(s: &str) -> Option<(std::net::Ipv4Addr, u32)> {
let t = s.trim().trim_matches('"');
let mut it = t.split('/');
let ip_s = it.next()?;
let pfx_s = it.next()?;
if it.next().is_some() {
return None;
}
let ip: std::net::Ipv4Addr = ip_s.parse().ok()?;
let pfx: u32 = pfx_s.parse().ok()?;
if pfx > 32 {
return None;
}
let mask = if pfx == 0 { 0 } else { u32::MAX << (32 - pfx) };
Some((ip, mask))
}
fn ipv4_from_u32(v: u32) -> String {
let ip = std::net::Ipv4Addr::from(v);
ip.to_string()
}
fn trim_quotes(s: &str) -> String {
let t = s.trim();
if t.len() >= 2 {
let b = t.as_bytes();
let mut hidx = 0usize; if b.len() >= 2 && b[0] == b'\\' && (b[1] == b'"' || b[1] == b'\'') {
hidx = 1;
}
if b.len() >= 2 {
let tidx = b.len() - 1; let head = b[hidx];
let tail = b[tidx];
if (head == b'"' && tail == b'"') || (head == b'\'' && tail == b'\'') {
let start = hidx + 1;
let mut end_excl = tidx; if tidx >= 1 && b[tidx - 1] == b'\\' {
end_excl = tidx - 1;
}
if start <= end_excl {
return t[start..end_excl].to_string();
}
return String::new();
}
}
}
t.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use rusqlite::Connection;
#[test]
fn test_ip4_scalar_funcs() {
let conn = Connection::open_in_memory().unwrap();
register_builtin(&conn).unwrap();
let v: i64 = conn
.query_row("SELECT ip4_int('1.2.3.4')", [], |r| r.get(0))
.unwrap();
assert_eq!(v as u32, 0x01020304);
let s: String = conn
.query_row("SELECT ip4_text(16909060)", [], |r| r.get(0))
.unwrap();
assert_eq!(s, "1.2.3.4");
let ok: i64 = conn
.query_row(
"SELECT ip4_between('10.0.0.5','10.0.0.1','10.0.0.10')",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(ok, 1);
let cmin: i64 = conn
.query_row("SELECT cidr4_min('10.0.0.0/8')", [], |r| r.get(0))
.unwrap();
let cmax: i64 = conn
.query_row("SELECT cidr4_max('10.0.0.0/8')", [], |r| r.get(0))
.unwrap();
let exp_min = parse_ipv4_to_u32("10.0.0.0").unwrap() as i64;
let exp_max = parse_ipv4_to_u32("10.255.255.255").unwrap() as i64;
assert_eq!(cmin, exp_min);
assert_eq!(cmax, exp_max);
let contains: i64 = conn
.query_row(
"SELECT cidr4_contains('10.1.2.3','10.0.0.0/8') AS ok",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(contains, 1);
let z: String = conn
.query_row("SELECT trim_quotes(' \"work_zone\" ')", [], |r| r.get(0))
.unwrap();
assert_eq!(z, "work_zone");
let z2: String = conn
.query_row("SELECT trim_quotes('no_quotes')", [], |r| r.get(0))
.unwrap();
assert_eq!(z2, "no_quotes");
let z3: String = conn
.query_row("SELECT trim_quotes(?1)", ["\\\"work_zone\\\""], |r| {
r.get(0)
})
.unwrap();
assert_eq!(z3, "work_zone");
}
}