use std::net::IpAddr;
use super::{
MPTCP_PM_GENL_NAME, MPTCP_PM_GENL_VERSION,
types::{MptcpEndpoint, MptcpEndpointBuilder, MptcpFlags, MptcpLimits},
};
use crate::netlink::{
attr::{AttrIter, NLA_F_NESTED, get},
builder::MessageBuilder,
connection::Connection,
error::{Error, Result},
genl::{CtrlAttr, CtrlCmd, GENL_HDRLEN, GENL_ID_CTRL, GenlMsgHdr},
message::{MessageIter, NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST, NlMsgError},
protocol::{AsyncProtocolInit, Mptcp, ProtocolState},
socket::NetlinkSocket,
};
impl AsyncProtocolInit for Mptcp {
async fn resolve_async(socket: &NetlinkSocket) -> Result<Self> {
let family_id = resolve_mptcp_family(socket).await?;
Ok(Self { family_id })
}
}
use crate::netlink::types::mptcp::{mptcp_pm_addr_attr, mptcp_pm_attr, mptcp_pm_cmd};
impl Connection<Mptcp> {
#[tracing::instrument(level = "debug", skip_all, fields(method = "new_async"))]
pub async fn new_async() -> Result<Self> {
let socket = NetlinkSocket::new(Mptcp::PROTOCOL)?;
let family_id = resolve_mptcp_family(&socket).await?;
let state = Mptcp { family_id };
Ok(Self::from_parts(socket, state))
}
pub fn family_id(&self) -> u16 {
self.state().family_id
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_endpoints"))]
pub async fn get_endpoints(&self) -> Result<Vec<MptcpEndpoint>> {
let responses = self
.dump_mptcp_command(mptcp_pm_cmd::GET_ADDR, |_builder| {})
.await?;
let mut endpoints = Vec::new();
for response in &responses {
if let Some(ep) = parse_endpoint_response(response)? {
endpoints.push(ep);
}
}
Ok(endpoints)
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "add_endpoint"))]
pub async fn add_endpoint(&self, endpoint: MptcpEndpointBuilder) -> Result<()> {
self.mptcp_command(mptcp_pm_cmd::ADD_ADDR, |builder| {
let addr_token = builder.nest_start(mptcp_pm_attr::ADDR | NLA_F_NESTED);
append_endpoint_attrs(builder, &endpoint);
builder.nest_end(addr_token);
})
.await?;
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_endpoint"))]
pub async fn del_endpoint(&self, id: u8) -> Result<()> {
self.mptcp_command(mptcp_pm_cmd::DEL_ADDR, |builder| {
let addr_token = builder.nest_start(mptcp_pm_attr::ADDR | NLA_F_NESTED);
builder.append_attr_u8(mptcp_pm_addr_attr::ID, id);
builder.nest_end(addr_token);
})
.await?;
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "flush_endpoints"))]
pub async fn flush_endpoints(&self) -> Result<()> {
self.mptcp_command(mptcp_pm_cmd::FLUSH_ADDRS, |_builder| {})
.await?;
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_limits"))]
pub async fn get_limits(&self) -> Result<MptcpLimits> {
let response = self
.mptcp_query(mptcp_pm_cmd::GET_LIMITS, |_builder| {})
.await?;
if let Some(limits) = parse_limits_response(&response)? {
return Ok(limits);
}
Ok(MptcpLimits::default())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "set_limits"))]
pub async fn set_limits(&self, limits: MptcpLimits) -> Result<()> {
self.mptcp_command(mptcp_pm_cmd::SET_LIMITS, |builder| {
if let Some(subflows) = limits.subflows {
builder.append_attr_u32(mptcp_pm_attr::SUBFLOWS, subflows);
}
if let Some(add_addr) = limits.add_addr_accepted {
builder.append_attr_u32(mptcp_pm_attr::RCV_ADD_ADDRS, add_addr);
}
})
.await?;
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "set_endpoint_flags"))]
pub async fn set_endpoint_flags(&self, id: u8, flags: MptcpFlags) -> Result<()> {
self.mptcp_command(mptcp_pm_cmd::SET_FLAGS, |builder| {
let addr_token = builder.nest_start(mptcp_pm_attr::ADDR | NLA_F_NESTED);
builder.append_attr_u8(mptcp_pm_addr_attr::ID, id);
builder.append_attr_u32(mptcp_pm_addr_attr::FLAGS, flags.to_raw());
builder.nest_end(addr_token);
})
.await?;
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "create_subflow"))]
pub async fn create_subflow(&self, subflow: super::types::MptcpSubflowBuilder) -> Result<()> {
use crate::netlink::types::mptcp::mptcp_attr;
self.mptcp_command(mptcp_pm_cmd::SUBFLOW_CREATE, |builder| {
builder.append_attr_u32(mptcp_attr::TOKEN, subflow.token);
if let Some(id) = subflow.local_id {
builder.append_attr_u8(mptcp_attr::LOC_ID, id);
}
if let Some(id) = subflow.remote_id {
builder.append_attr_u8(mptcp_attr::REM_ID, id);
}
if let Some(ref addr) = subflow.local_addr {
append_source_addr(builder, addr);
}
if let Some(ref addr) = subflow.remote_addr {
append_dest_addr(builder, addr);
}
if let Some(ifindex) = subflow.ifindex {
builder.append_attr_u32(mptcp_attr::IF_IDX, ifindex);
}
if subflow.backup {
builder.append_attr_u8(mptcp_attr::BACKUP, 1);
}
})
.await?;
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "destroy_subflow"))]
pub async fn destroy_subflow(&self, subflow: super::types::MptcpSubflowBuilder) -> Result<()> {
use crate::netlink::types::mptcp::mptcp_attr;
self.mptcp_command(mptcp_pm_cmd::SUBFLOW_DESTROY, |builder| {
builder.append_attr_u32(mptcp_attr::TOKEN, subflow.token);
if let Some(id) = subflow.local_id {
builder.append_attr_u8(mptcp_attr::LOC_ID, id);
}
if let Some(id) = subflow.remote_id {
builder.append_attr_u8(mptcp_attr::REM_ID, id);
}
if let Some(ref addr) = subflow.local_addr {
append_source_addr(builder, addr);
}
if let Some(ref addr) = subflow.remote_addr {
append_dest_addr(builder, addr);
}
})
.await?;
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "announce_addr"))]
pub async fn announce_addr(&self, announce: super::types::MptcpAnnounceBuilder) -> Result<()> {
use crate::netlink::types::mptcp::mptcp_attr;
self.mptcp_command(mptcp_pm_cmd::ANNOUNCE, |builder| {
builder.append_attr_u32(mptcp_attr::TOKEN, announce.token);
if let Some(id) = announce.addr_id {
builder.append_attr_u8(mptcp_attr::LOC_ID, id);
}
if let Some(ref addr) = announce.address {
append_source_addr(builder, addr);
}
})
.await?;
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_addr"))]
pub async fn del_addr(&self, token: u32, addr_id: u8) -> Result<()> {
use crate::netlink::types::mptcp::mptcp_attr;
self.mptcp_command(mptcp_pm_cmd::REMOVE, |builder| {
builder.append_attr_u32(mptcp_attr::TOKEN, token);
builder.append_attr_u8(mptcp_attr::LOC_ID, addr_id);
})
.await?;
Ok(())
}
async fn mptcp_command(
&self,
cmd: u8,
build_attrs: impl FnOnce(&mut MessageBuilder),
) -> Result<Vec<u8>> {
let family_id = self.state().family_id;
let mut builder = MessageBuilder::new(family_id, NLM_F_REQUEST | NLM_F_ACK);
let genl_hdr = GenlMsgHdr::new(cmd, MPTCP_PM_GENL_VERSION);
builder.append(&genl_hdr);
build_attrs(&mut builder);
let seq = self.socket().next_seq();
builder.set_seq(seq);
builder.set_pid(self.socket().pid());
let msg = builder.finish();
self.socket().send(&msg).await?;
let response: Vec<u8> = self.socket().recv_msg().await?;
self.process_genl_response(&response, seq)?;
Ok(response)
}
async fn mptcp_query(
&self,
cmd: u8,
build_attrs: impl FnOnce(&mut MessageBuilder),
) -> Result<Vec<u8>> {
let family_id = self.state().family_id;
let mut builder = MessageBuilder::new(family_id, NLM_F_REQUEST);
let genl_hdr = GenlMsgHdr::new(cmd, MPTCP_PM_GENL_VERSION);
builder.append(&genl_hdr);
build_attrs(&mut builder);
let seq = self.socket().next_seq();
builder.set_seq(seq);
builder.set_pid(self.socket().pid());
let msg = builder.finish();
self.socket().send(&msg).await?;
let response: Vec<u8> = self.socket().recv_msg().await?;
for result in MessageIter::new(&response) {
let (header, payload) = result?;
if header.nlmsg_seq != seq {
continue;
}
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if !err.is_ack() {
return Err(Error::from_errno(err.error));
}
continue;
}
if !header.is_done() {
return Ok(payload.to_vec());
}
}
Ok(Vec::new())
}
async fn dump_mptcp_command(
&self,
cmd: u8,
build_attrs: impl FnOnce(&mut MessageBuilder),
) -> Result<Vec<Vec<u8>>> {
let family_id = self.state().family_id;
let mut builder = MessageBuilder::new(family_id, NLM_F_REQUEST | NLM_F_DUMP);
let genl_hdr = GenlMsgHdr::new(cmd, MPTCP_PM_GENL_VERSION);
builder.append(&genl_hdr);
build_attrs(&mut builder);
let seq = self.socket().next_seq();
builder.set_seq(seq);
builder.set_pid(self.socket().pid());
let msg = builder.finish();
self.socket().send(&msg).await?;
let mut responses = Vec::new();
loop {
let data: Vec<u8> = self.socket().recv_msg().await?;
let mut done = false;
for result in MessageIter::new(&data) {
let (header, payload) = result?;
if header.nlmsg_seq != seq {
continue;
}
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if !err.is_ack() {
return Err(Error::from_errno(err.error));
}
continue;
}
if header.is_done() {
done = true;
break;
}
responses.push(payload.to_vec());
}
if done {
break;
}
}
Ok(responses)
}
fn process_genl_response(&self, data: &[u8], seq: u32) -> Result<()> {
for result in MessageIter::new(data) {
let (header, payload) = result?;
if header.nlmsg_seq != seq {
continue;
}
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if !err.is_ack() {
return Err(Error::from_errno(err.error));
}
}
}
Ok(())
}
}
async fn resolve_mptcp_family(socket: &NetlinkSocket) -> Result<u16> {
let mut builder = MessageBuilder::new(GENL_ID_CTRL, NLM_F_REQUEST);
let genl_hdr = GenlMsgHdr::new(CtrlCmd::GetFamily as u8, 1);
builder.append(&genl_hdr);
builder.append_attr_str(CtrlAttr::FamilyName as u16, MPTCP_PM_GENL_NAME);
let seq = socket.next_seq();
builder.set_seq(seq);
builder.set_pid(socket.pid());
let msg = builder.finish();
socket.send(&msg).await?;
let response: Vec<u8> = socket.recv_msg().await?;
for result in MessageIter::new(&response) {
let (header, payload) = result?;
if header.nlmsg_seq != seq {
continue;
}
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if !err.is_ack() {
if err.error == -libc::ENOENT {
return Err(Error::FamilyNotFound {
name: MPTCP_PM_GENL_NAME.to_string(),
});
}
return Err(Error::from_errno(err.error));
}
continue;
}
if header.is_done() {
continue;
}
if payload.len() < GENL_HDRLEN {
return Err(Error::InvalidMessage("GENL header too short".into()));
}
let attrs_data = &payload[GENL_HDRLEN..];
for (attr_type, attr_payload) in AttrIter::new(attrs_data) {
if attr_type == CtrlAttr::FamilyId as u16 {
return get::u16_ne(attr_payload);
}
}
}
Err(Error::FamilyNotFound {
name: MPTCP_PM_GENL_NAME.to_string(),
})
}
fn append_endpoint_attrs(builder: &mut MessageBuilder, endpoint: &MptcpEndpointBuilder) {
let family = match endpoint.address {
IpAddr::V4(_) => libc::AF_INET as u16,
IpAddr::V6(_) => libc::AF_INET6 as u16,
};
builder.append_attr(mptcp_pm_addr_attr::FAMILY, &family.to_ne_bytes());
match endpoint.address {
IpAddr::V4(addr) => {
builder.append_attr(mptcp_pm_addr_attr::ADDR4, &addr.octets());
}
IpAddr::V6(addr) => {
builder.append_attr(mptcp_pm_addr_attr::ADDR6, &addr.octets());
}
}
if let Some(id) = endpoint.id {
builder.append_attr_u8(mptcp_pm_addr_attr::ID, id);
}
if let Some(port) = endpoint.port {
builder.append_attr(mptcp_pm_addr_attr::PORT, &port.to_be_bytes());
}
if let Some(ifindex) = endpoint.ifindex {
builder.append_attr_u32(mptcp_pm_addr_attr::IF_IDX, ifindex);
}
let flags = endpoint.flags.to_raw();
if flags != 0 {
builder.append_attr_u32(mptcp_pm_addr_attr::FLAGS, flags);
}
}
fn parse_endpoint_response(payload: &[u8]) -> Result<Option<MptcpEndpoint>> {
if payload.len() < GENL_HDRLEN {
return Ok(None);
}
let data = &payload[GENL_HDRLEN..];
for (attr_type, attr_payload) in AttrIter::new(data) {
if attr_type == mptcp_pm_attr::ADDR {
return Ok(Some(parse_endpoint_attrs(attr_payload)?));
}
}
Ok(None)
}
fn parse_endpoint_attrs(data: &[u8]) -> Result<MptcpEndpoint> {
let mut endpoint = MptcpEndpoint::default();
for (attr_type, payload) in AttrIter::new(data) {
match attr_type {
t if t == mptcp_pm_addr_attr::ID && !payload.is_empty() => {
endpoint.id = payload[0];
}
t if t == mptcp_pm_addr_attr::ADDR4 && payload.len() >= 4 => {
let octets: [u8; 4] = payload[..4].try_into().unwrap();
endpoint.address = IpAddr::V4(octets.into());
}
t if t == mptcp_pm_addr_attr::ADDR6 && payload.len() >= 16 => {
let octets: [u8; 16] = payload[..16].try_into().unwrap();
endpoint.address = IpAddr::V6(octets.into());
}
t if t == mptcp_pm_addr_attr::PORT && payload.len() >= 2 => {
let port = u16::from_be_bytes(payload[..2].try_into().unwrap());
if port != 0 {
endpoint.port = Some(port);
}
}
t if t == mptcp_pm_addr_attr::FLAGS && payload.len() >= 4 => {
let flags = u32::from_ne_bytes(payload[..4].try_into().unwrap());
endpoint.flags = MptcpFlags::from_raw(flags);
}
t if t == mptcp_pm_addr_attr::IF_IDX && payload.len() >= 4 => {
let ifindex = u32::from_ne_bytes(payload[..4].try_into().unwrap());
if ifindex != 0 {
endpoint.ifindex = Some(ifindex);
}
}
_ => {}
}
}
Ok(endpoint)
}
fn parse_limits_response(payload: &[u8]) -> Result<Option<MptcpLimits>> {
if payload.len() < GENL_HDRLEN {
return Ok(None);
}
let data = &payload[GENL_HDRLEN..];
let mut limits = MptcpLimits::default();
let mut found = false;
for (attr_type, attr_payload) in AttrIter::new(data) {
match attr_type {
t if t == mptcp_pm_attr::SUBFLOWS && attr_payload.len() >= 4 => {
limits.subflows = Some(u32::from_ne_bytes(attr_payload[..4].try_into().unwrap()));
found = true;
}
t if t == mptcp_pm_attr::RCV_ADD_ADDRS && attr_payload.len() >= 4 => {
limits.add_addr_accepted =
Some(u32::from_ne_bytes(attr_payload[..4].try_into().unwrap()));
found = true;
}
_ => {}
}
}
if found { Ok(Some(limits)) } else { Ok(None) }
}
fn append_source_addr(builder: &mut MessageBuilder, addr: &super::types::MptcpAddress) {
use crate::netlink::types::mptcp::mptcp_attr;
let family = match addr.addr {
IpAddr::V4(_) => libc::AF_INET as u16,
IpAddr::V6(_) => libc::AF_INET6 as u16,
};
builder.append_attr(mptcp_attr::FAMILY, &family.to_ne_bytes());
match addr.addr {
IpAddr::V4(a) => {
builder.append_attr(mptcp_attr::SADDR4, &a.octets());
}
IpAddr::V6(a) => {
builder.append_attr(mptcp_attr::SADDR6, &a.octets());
}
}
if let Some(port) = addr.port {
builder.append_attr(mptcp_attr::SPORT, &port.to_be_bytes());
}
}
fn append_dest_addr(builder: &mut MessageBuilder, addr: &super::types::MptcpAddress) {
use crate::netlink::types::mptcp::mptcp_attr;
match addr.addr {
IpAddr::V4(a) => {
builder.append_attr(mptcp_attr::DADDR4, &a.octets());
}
IpAddr::V6(a) => {
builder.append_attr(mptcp_attr::DADDR6, &a.octets());
}
}
if let Some(port) = addr.port {
builder.append_attr(mptcp_attr::DPORT, &port.to_be_bytes());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_empty_payload() {
let result = parse_endpoint_response(&[]).unwrap();
assert!(result.is_none());
}
#[test]
fn test_parse_limits_empty() {
let result = parse_limits_response(&[]).unwrap();
assert!(result.is_none());
}
}