use crate::io::Connection;
use crate::node::Creation;
#[cfg(doc)]
use crate::node::NodeName;
use crate::{HIGHEST_DISTRIBUTION_PROTOCOL_VERSION, LOWEST_DISTRIBUTION_PROTOCOL_VERSION};
use futures::io::{AsyncRead, AsyncWrite};
use std::str::FromStr;
pub const DEFAULT_EPMD_PORT: u16 = 4369;
const TAG_DUMP_REQ: u8 = 100;
const TAG_KILL_REQ: u8 = 107;
const TAG_NAMES_REQ: u8 = 110;
const TAG_ALIVE2_X_RESP: u8 = 118;
const TAG_PORT2_RESP: u8 = 119;
const TAG_ALIVE2_REQ: u8 = 120;
const TAG_ALIVE2_RESP: u8 = 121;
const TAG_PORT_PLEASE2_REQ: u8 = 122;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct NodeEntry {
pub name: String,
pub port: u16,
pub node_type: NodeType,
pub protocol: TransportProtocol,
pub highest_version: u16,
pub lowest_version: u16,
pub extra: Vec<u8>,
}
impl NodeEntry {
pub fn new(name: &str, port: u16) -> Self {
Self {
name: name.to_owned(),
port,
node_type: NodeType::Normal,
protocol: TransportProtocol::TcpIpV4,
highest_version: HIGHEST_DISTRIBUTION_PROTOCOL_VERSION,
lowest_version: LOWEST_DISTRIBUTION_PROTOCOL_VERSION,
extra: Vec::new(),
}
}
pub fn new_hidden(name: &str, port: u16) -> Self {
Self {
name: name.to_owned(),
port,
node_type: NodeType::Hidden,
protocol: TransportProtocol::TcpIpV4,
highest_version: HIGHEST_DISTRIBUTION_PROTOCOL_VERSION,
lowest_version: LOWEST_DISTRIBUTION_PROTOCOL_VERSION,
extra: Vec::new(),
}
}
fn bytes_len(&self) -> usize {
2 + self.name.len() + 2 + 1 + 1 + 2 + 2 + 2 + self.extra.len() }
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
#[allow(missing_docs)]
pub enum EpmdError {
#[error("received an unknown tag {tag} as the response of {request}")]
UnknownResponseTag { request: &'static str, tag: u8 },
#[error("request byte size must be less than 0xFFFF, but got {size} bytes")]
TooLongRequest { size: usize },
#[error("EPMD responded an error code {code} against a PORT_PLEASE2_REQ request")]
GetNodeEntryError { code: u8 },
#[error("EPMD responded an error code {code} against an ALIVE2_REQ request")]
RegisterNodeError { code: u8 },
#[error("found a malformed NAMES_RESP line: expected_format=\"name {{NAME}} at port {{PORT}}\", actual_line={line:?}")]
MalformedNamesResponse { line: String },
#[error(transparent)]
Io(#[from] std::io::Error),
}
#[derive(Debug)]
pub struct EpmdClient<T> {
connection: Connection<T>,
}
impl<T> EpmdClient<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(connection: T) -> Self {
Self {
connection: Connection::new(connection),
}
}
pub async fn register(mut self, node: NodeEntry) -> Result<(T, Creation), EpmdError> {
let size = 1 + node.bytes_len();
let size = u16::try_from(size).map_err(|_| EpmdError::TooLongRequest { size })?;
self.connection.write_u16(size).await?;
self.connection.write_u8(TAG_ALIVE2_REQ).await?;
self.connection.write_u16(node.port).await?;
self.connection.write_u8(node.node_type.into()).await?;
self.connection.write_u8(node.protocol.into()).await?;
self.connection.write_u16(node.highest_version).await?;
self.connection.write_u16(node.lowest_version).await?;
self.connection.write_u16(node.name.len() as u16).await?;
self.connection.write_all(node.name.as_bytes()).await?;
self.connection.write_u16(node.extra.len() as u16).await?;
self.connection.write_all(&node.extra).await?;
self.connection.flush().await?;
match self.connection.read_u8().await? {
TAG_ALIVE2_RESP => {
match self.connection.read_u8().await? {
0 => {}
code => return Err(EpmdError::RegisterNodeError { code }),
}
let creation = Creation::new(u32::from(self.connection.read_u16().await?));
Ok((self.connection.into_inner(), creation))
}
TAG_ALIVE2_X_RESP => {
match self.connection.read_u8().await? {
0 => {}
code => return Err(EpmdError::RegisterNodeError { code }),
}
let creation = Creation::new(self.connection.read_u32().await?);
Ok((self.connection.into_inner(), creation))
}
tag => Err(EpmdError::UnknownResponseTag {
request: "ALIVE2_REQ",
tag,
}),
}
}
pub async fn get_names(mut self) -> Result<Vec<(String, u16)>, EpmdError> {
self.connection.write_u16(1).await?; self.connection.write_u8(TAG_NAMES_REQ).await?;
self.connection.flush().await?;
let _epmd_port = self.connection.read_u32().await?;
let node_info_text = self.connection.read_string().await?;
node_info_text
.split('\n')
.filter(|s| !s.is_empty())
.map(|line| NodeNameAndPort::from_str(line).map(|x| (x.name, x.port)))
.collect()
}
pub async fn get_node(mut self, node_name: &str) -> Result<Option<NodeEntry>, EpmdError> {
let size = 1 + node_name.len();
let size = u16::try_from(size).map_err(|_| EpmdError::TooLongRequest { size })?;
self.connection.write_u16(size).await?;
self.connection.write_u8(TAG_PORT_PLEASE2_REQ).await?;
self.connection.write_all(node_name.as_bytes()).await?;
self.connection.flush().await?;
let tag = self.connection.read_u8().await?;
if tag != TAG_PORT2_RESP {
return Err(EpmdError::UnknownResponseTag {
request: "NAMES_REQ",
tag,
});
}
match self.connection.read_u8().await? {
0 => {}
1 => {
return Ok(None);
}
code => {
return Err(EpmdError::GetNodeEntryError { code });
}
}
Ok(Some(NodeEntry {
port: self.connection.read_u16().await?,
node_type: self.connection.read_u8().await?.into(),
protocol: self.connection.read_u8().await?.into(),
highest_version: self.connection.read_u16().await?,
lowest_version: self.connection.read_u16().await?,
name: self.connection.read_u16_string().await?,
extra: self.connection.read_u16_bytes().await?,
}))
}
pub async fn kill(mut self) -> Result<String, EpmdError> {
self.connection.write_u16(1).await?;
self.connection.write_u8(TAG_KILL_REQ).await?;
self.connection.flush().await?;
let result = self.connection.read_string().await?;
Ok(result)
}
pub async fn dump(mut self) -> Result<String, EpmdError> {
self.connection.write_u16(1).await?;
self.connection.write_u8(TAG_DUMP_REQ).await?;
self.connection.flush().await?;
let _epmd_port = self.connection.read_u32().await?;
let info = self.connection.read_string().await?;
Ok(info)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
struct NodeNameAndPort {
name: String,
port: u16,
}
impl FromStr for NodeNameAndPort {
type Err = EpmdError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let error = || EpmdError::MalformedNamesResponse { line: s.to_owned() };
if !s.starts_with("name ") {
return Err(error());
}
let s = &s["name ".len()..];
let pos = s.find(" at port ").ok_or_else(error)?;
let name = s[..pos].to_string();
let port = s[pos + " at port ".len()..].parse().map_err(|_| error())?;
Ok(Self { name, port })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum TransportProtocol {
TcpIpV4,
Other(u8),
}
impl From<u8> for TransportProtocol {
fn from(v: u8) -> Self {
match v {
0 => Self::TcpIpV4,
_ => Self::Other(v),
}
}
}
impl From<TransportProtocol> for u8 {
fn from(v: TransportProtocol) -> Self {
match v {
TransportProtocol::TcpIpV4 => 0,
TransportProtocol::Other(v) => v,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum NodeType {
Hidden,
Normal,
Other(u8),
}
impl From<u8> for NodeType {
fn from(v: u8) -> Self {
match v {
72 => Self::Hidden,
77 => Self::Normal,
_ => Self::Other(v),
}
}
}
impl From<NodeType> for u8 {
fn from(v: NodeType) -> Self {
match v {
NodeType::Hidden => 72,
NodeType::Normal => 77,
NodeType::Other(v) => v,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn epmd_client_works() {
let node_name = "epmd_client_works";
smol::block_on(async {
let erl_node = crate::tests::TestErlangNode::new(node_name)
.await
.expect("failed to run a test erlang node");
let node = crate::tests::epmd_client()
.await
.get_node(node_name)
.await
.expect("failed to get node");
let node = node.expect("no such node");
assert_eq!(node.name, node_name);
let client = crate::tests::epmd_client().await;
let new_node_name = "erl_dist_test_new_node";
let new_node = NodeEntry::new_hidden(new_node_name, 3000);
let (stream, _creation) = client
.register(new_node)
.await
.expect("failed to register a new node");
let node = crate::tests::epmd_client()
.await
.get_node(new_node_name)
.await
.expect("failed to get node");
let node = node.expect("no such node");
assert_eq!(node.name, new_node_name);
std::mem::drop(stream);
std::thread::sleep(std::time::Duration::from_millis(100));
let node = crate::tests::epmd_client()
.await
.get_node(new_node_name)
.await
.expect("failed to get node");
assert!(node.is_none());
std::mem::drop(erl_node);
});
}
}