use std::{
net::IpAddr,
ops::{Deref, DerefMut},
};
use anyhow::Result;
use ipnet::IpNet;
use crate::{
core::message::Message,
types::{
addr::Address,
link::Link,
message::{AddressMessage, Attribute, RouteAttr},
},
};
use super::{handle::SocketHandle, zero_terminated};
pub struct AddrHandle<'a> {
pub socket: &'a mut SocketHandle,
}
impl<'a> Deref for AddrHandle<'a> {
type Target = SocketHandle;
fn deref(&self) -> &Self::Target {
self.socket
}
}
impl DerefMut for AddrHandle<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.socket
}
}
impl<'a> From<&'a mut SocketHandle> for AddrHandle<'a> {
fn from(socket: &'a mut SocketHandle) -> Self {
Self { socket }
}
}
impl AddrHandle<'_> {
pub fn handle<T>(&mut self, link: &T, addr: &Address, proto: u16, flags: i32) -> Result<()>
where
T: Link + ?Sized,
{
let mut req = Message::new(proto, flags);
let base = link.attrs();
let mut index: i32 = base.index;
if index == 0 {
let mut link_handle = self.handle_link();
index = match link_handle.get(base) {
Ok(link) => link.attrs().index,
Err(_) => 0,
}
}
let (family, local_addr_data) = match addr.ip {
IpNet::V4(ip) => (libc::AF_INET, ip.addr().octets().to_vec()),
IpNet::V6(ip) => (libc::AF_INET6, ip.addr().octets().to_vec()),
};
let peer_addr_data = match addr.peer {
Some(IpNet::V4(ip)) if family == libc::AF_INET6 => {
ip.addr().to_ipv6_mapped().octets().to_vec()
}
Some(IpNet::V6(ip)) if family == libc::AF_INET => {
ip.addr().to_ipv4().unwrap().octets().to_vec()
}
Some(IpNet::V4(ip)) => ip.addr().octets().to_vec(),
Some(IpNet::V6(ip)) => ip.addr().octets().to_vec(),
None => local_addr_data.clone(),
};
let msg = AddressMessage {
family: family as u8,
prefix_len: addr.ip.prefix_len(),
flags: addr.flags,
scope: addr.scope,
index,
};
let local_data = RouteAttr::new(libc::IFA_LOCAL, &local_addr_data);
let address_data = RouteAttr::new(libc::IFA_ADDRESS, &peer_addr_data);
req.add(&msg.serialize()?);
req.add(&local_data.serialize()?);
req.add(&address_data.serialize()?);
if family == libc::AF_INET {
let broadcast = match addr.broadcast {
Some(IpAddr::V4(br)) => Some(br.octets().to_vec()),
Some(IpAddr::V6(br)) => Some(br.octets().to_vec()),
None if addr.ip.prefix_len() < 31 => match addr.ip.broadcast() {
IpAddr::V4(br) => Some(br.octets().to_vec()),
IpAddr::V6(br) => Some(br.octets().to_vec()),
},
None => None,
};
if let Some(broadcast) = broadcast {
let broadcast_data = RouteAttr::new(libc::IFA_BROADCAST, &broadcast);
req.add(&broadcast_data.serialize()?);
}
if !addr.label.is_empty() {
let label_data = RouteAttr::new(libc::IFA_LABEL, &zero_terminated(&addr.label));
req.add(&label_data.serialize()?);
}
}
self.request(&mut req, 0)?;
Ok(())
}
pub fn list<T>(&mut self, link: &T, family: i32) -> Result<Vec<Address>>
where
T: Link + ?Sized,
{
let link_index = link.attrs().index;
let mut req = Message::new(libc::RTM_GETADDR, libc::NLM_F_DUMP);
let msg = AddressMessage::new(family);
req.add(&msg.serialize()?);
Ok(self
.request(&mut req, libc::RTM_NEWADDR)?
.iter()
.filter_map(|m| {
let addr = Address::from(m.as_slice());
if addr.index == link_index {
Some(addr)
} else {
None
}
})
.collect())
}
pub fn list_all(&mut self, family: i32) -> Result<Vec<Address>> {
let mut req = Message::new(libc::RTM_GETADDR, libc::NLM_F_DUMP);
let msg = AddressMessage::new(family);
req.add(&msg.serialize()?);
Ok(self
.request(&mut req, libc::RTM_NEWADDR)?
.iter()
.map(|m| Address::from(m.as_slice()))
.collect())
}
}
#[cfg(test)]
mod tests {
use crate::{
test_setup,
types::{addr::AddressBuilder, link::LinkAttrs},
};
#[test]
fn test_addr_handle() {
test_setup!();
let mut handle = super::SocketHandle::new(libc::NETLINK_ROUTE);
let mut link_handle = handle.handle_link();
let attr = LinkAttrs::new("lo");
let link = link_handle.get(&attr).unwrap();
let address = "127.0.0.2/24".parse().unwrap();
let addr = AddressBuilder::default().ip(address).build().unwrap();
let proto = libc::RTM_NEWADDR;
let flags = libc::NLM_F_CREATE | libc::NLM_F_EXCL | libc::NLM_F_ACK;
let mut addr_handle = handle.handle_addr();
addr_handle.handle(&link, &addr, proto, flags).unwrap();
let addrs = addr_handle.list(&link, libc::AF_UNSPEC).unwrap();
assert_eq!(addrs.len(), 1);
assert_eq!(addrs[0].ip, address);
}
#[test]
fn test_addr_list() {
let mut handle = super::SocketHandle::new(libc::NETLINK_ROUTE);
let mut link_handle = handle.handle_link();
let attr = LinkAttrs::new("lo");
let link = link_handle.get(&attr).unwrap();
let mut addr_handle = handle.handle_addr();
let addrs = addr_handle.list(&link, libc::AF_UNSPEC).unwrap();
for addr in &addrs {
println!("{addr:?}");
}
assert!(!addrs.is_empty());
}
#[test]
fn test_addr_list_all() {
let mut handle = super::SocketHandle::new(libc::NETLINK_ROUTE);
let mut addr_handle = handle.handle_addr();
let addrs = addr_handle.list_all(libc::AF_UNSPEC).unwrap();
for addr in &addrs {
println!("{addr:?}");
}
assert!(!addrs.is_empty());
}
}