use alloc::format;
use alloc::string::ToString;
use alloc::vec::Vec;
use spg_sql::ast::BinOp;
use spg_storage::Value;
use super::EvalError;
pub(super) fn inet_host(args: &[Value]) -> Result<Value, EvalError> {
let s = match args {
[Value::Text(s)] => s.clone(),
[Value::Null] => return Ok(Value::Null),
_ => {
return Err(EvalError::TypeMismatch {
detail: alloc::format!("host() takes one TEXT arg, got {} args", args.len()),
});
}
};
let host = s.split('/').next().unwrap_or("").to_string();
Ok(Value::Text(host))
}
pub(super) fn inet_network(args: &[Value]) -> Result<Value, EvalError> {
let s = match args {
[Value::Text(s)] => s.clone(),
[Value::Null] => return Ok(Value::Null),
_ => {
return Err(EvalError::TypeMismatch {
detail: alloc::format!("network() takes one TEXT arg, got {} args", args.len()),
});
}
};
let mut split = s.splitn(2, '/');
let host = split.next().unwrap_or("").to_string();
let mask: u32 = split.next().and_then(|m| m.parse().ok()).unwrap_or(32);
if !host.contains('.') {
return Ok(Value::Text(s));
}
let octets: Vec<&str> = host.split('.').collect();
if octets.len() != 4 {
return Ok(Value::Text(s));
}
let keep_bytes = ((mask + 7) / 8) as usize;
let mut out = alloc::string::String::new();
for (i, oct) in octets.iter().enumerate() {
if i > 0 {
out.push('.');
}
if i < keep_bytes {
out.push_str(oct);
} else {
out.push('0');
}
}
out.push('/');
out.push_str(&mask.to_string());
Ok(Value::Text(out))
}
pub(super) fn inet_masklen(args: &[Value]) -> Result<Value, EvalError> {
let s = match args {
[Value::Text(s)] => s.clone(),
[Value::Null] => return Ok(Value::Null),
_ => {
return Err(EvalError::TypeMismatch {
detail: alloc::format!("masklen() takes one TEXT arg, got {} args", args.len()),
});
}
};
let mask: i32 = s
.split_once('/')
.and_then(|(_, m)| m.parse().ok())
.unwrap_or(32);
Ok(Value::Int(mask))
}
struct InetNet {
bytes: [u8; 16],
family_bytes: u8,
prefix_bits: u8,
}
fn parse_inet_text(s: &str) -> Option<InetNet> {
let mut split = s.splitn(2, '/');
let host = split.next()?;
let mask_str = split.next();
if host.contains(':') {
let bytes = parse_ipv6(host)?;
let prefix_bits = match mask_str {
Some(m) => m.parse::<u8>().ok().filter(|&n| n <= 128)?,
None => 128,
};
let mut out = [0u8; 16];
out.copy_from_slice(&bytes);
Some(InetNet {
bytes: out,
family_bytes: 16,
prefix_bits,
})
} else {
let bytes = parse_ipv4(host)?;
let prefix_bits = match mask_str {
Some(m) => m.parse::<u8>().ok().filter(|&n| n <= 32)?,
None => 32,
};
let mut out = [0u8; 16];
out[..4].copy_from_slice(&bytes);
Some(InetNet {
bytes: out,
family_bytes: 4,
prefix_bits,
})
}
}
fn parse_ipv4(s: &str) -> Option<[u8; 4]> {
let parts: Vec<&str> = s.split('.').collect();
if parts.len() != 4 {
return None;
}
let mut out = [0u8; 4];
for (i, p) in parts.iter().enumerate() {
out[i] = p.parse::<u8>().ok()?;
}
Some(out)
}
fn parse_ipv6(s: &str) -> Option<[u8; 16]> {
let (head, tail) = match s.find("::") {
Some(idx) => (&s[..idx], Some(&s[idx + 2..])),
None => (s, None),
};
let head_groups: Vec<&str> = if head.is_empty() {
Vec::new()
} else {
head.split(':').collect()
};
let tail_groups: Vec<&str> = match tail {
Some(t) if !t.is_empty() => t.split(':').collect(),
_ => Vec::new(),
};
let head_len = head_groups.len();
let tail_len = tail_groups.len();
if tail.is_none() {
if head_len != 8 {
return None;
}
} else if head_len + tail_len > 7 {
return None;
}
let mut words = [0u16; 8];
for (i, g) in head_groups.iter().enumerate() {
words[i] = u16::from_str_radix(g, 16).ok()?;
}
let tail_start = 8 - tail_len;
for (i, g) in tail_groups.iter().enumerate() {
words[tail_start + i] = u16::from_str_radix(g, 16).ok()?;
}
let mut out = [0u8; 16];
for (i, w) in words.iter().enumerate() {
out[i * 2] = (w >> 8) as u8;
out[i * 2 + 1] = (w & 0xff) as u8;
}
Some(out)
}
fn network_prefix_eq(a: &InetNet, b: &InetNet, prefix_bits: u8) -> bool {
let full_bytes = (prefix_bits / 8) as usize;
if a.bytes[..full_bytes] != b.bytes[..full_bytes] {
return false;
}
let extra = prefix_bits % 8;
if extra == 0 {
return true;
}
let mask: u8 = 0xff << (8 - extra);
(a.bytes[full_bytes] & mask) == (b.bytes[full_bytes] & mask)
}
fn inet_contained_eq(a: &InetNet, b: &InetNet) -> bool {
if a.family_bytes != b.family_bytes {
return false;
}
if a.prefix_bits < b.prefix_bits {
return false;
}
network_prefix_eq(a, b, b.prefix_bits)
}
fn inet_networks_equal(a: &InetNet, b: &InetNet) -> bool {
if a.family_bytes != b.family_bytes {
return false;
}
if a.prefix_bits != b.prefix_bits {
return false;
}
network_prefix_eq(a, b, a.prefix_bits)
}
pub(super) fn inet_op_bool_result(op: BinOp, l: &Value, r: &Value) -> Result<Value, EvalError> {
if matches!(l, Value::Null) || matches!(r, Value::Null) {
return Ok(Value::Null);
}
let (lt, rt) = match (l, r) {
(Value::Text(a), Value::Text(b)) => (a, b),
_ => {
return Err(EvalError::TypeMismatch {
detail: format!(
"inet operator requires TEXT/INET operands, got {:?} and {:?}",
l.data_type(),
r.data_type()
),
});
}
};
let a = parse_inet_text(lt).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("invalid inet text: {:?}", lt),
})?;
let b = parse_inet_text(rt).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("invalid inet text: {:?}", rt),
})?;
let result = match op {
BinOp::InetContainedByEq => inet_contained_eq(&a, &b),
BinOp::InetContainedBy => inet_contained_eq(&a, &b) && !inet_networks_equal(&a, &b),
BinOp::InetContainsEq => inet_contained_eq(&b, &a),
BinOp::InetContains => inet_contained_eq(&b, &a) && !inet_networks_equal(&a, &b),
BinOp::InetOverlap => inet_contained_eq(&a, &b) || inet_contained_eq(&b, &a),
_ => unreachable!("inet_op_bool_result called with non-inet op"),
};
Ok(Value::Bool(result))
}