use std::time::Duration;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
#[cfg(unix)]
use tokio::net::{UnixListener, UnixStream};
use super::cross_mob_remote::{RemoteEndpoint, RemoteMobError};
const MAX_CONTROL_PAYLOAD: u32 = 64 * 1024;
pub const DEFAULT_CONTROL_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "op", rename_all = "snake_case")]
pub enum ControlRequest {
Wire {
remote_member: String,
local_peer_spec_address: String,
local_comms_name: String,
local_peer_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
local_pubkey_b64: Option<String>,
},
Unwire {
remote_member: String,
local_peer_spec_address: String,
local_comms_name: String,
local_peer_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
local_pubkey_b64: Option<String>,
},
Inject {
remote_member: String,
content: serde_json::Value,
},
LookupMember { remote_member: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "result", rename_all = "snake_case")]
pub enum ControlResponse {
Ok,
Injected { session_id: String },
Member { peer_id: String, comms_name: String },
Err { code: String, message: String },
}
enum ControlStream {
Tcp(TcpStream),
#[cfg(unix)]
Uds(UnixStream),
}
impl ControlStream {
async fn write_frame(&mut self, payload: &[u8]) -> Result<(), std::io::Error> {
let len = u32::try_from(payload.len()).map_err(|_| {
std::io::Error::new(std::io::ErrorKind::InvalidInput, "payload too large")
})?;
let header = len.to_be_bytes();
match self {
Self::Tcp(s) => {
s.write_all(&header).await?;
s.write_all(payload).await?;
s.flush().await
}
#[cfg(unix)]
Self::Uds(s) => {
s.write_all(&header).await?;
s.write_all(payload).await?;
s.flush().await
}
}
}
async fn read_frame(&mut self) -> Result<Vec<u8>, std::io::Error> {
let mut header = [0u8; 4];
match self {
Self::Tcp(s) => s.read_exact(&mut header).await?,
#[cfg(unix)]
Self::Uds(s) => s.read_exact(&mut header).await?,
};
let len = u32::from_be_bytes(header);
if len > MAX_CONTROL_PAYLOAD {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("frame too large: {len} bytes"),
));
}
let mut buf = vec![0u8; len as usize];
match self {
Self::Tcp(s) => s.read_exact(&mut buf).await?,
#[cfg(unix)]
Self::Uds(s) => s.read_exact(&mut buf).await?,
};
Ok(buf)
}
}
pub struct RemoteControlClient;
impl RemoteControlClient {
pub async fn send(
endpoint: &RemoteEndpoint,
request: &ControlRequest,
timeout: Duration,
) -> Result<ControlResponse, RemoteMobError> {
tokio::time::timeout(timeout, Self::send_inner(endpoint, request))
.await
.map_err(|_| RemoteMobError::ControlChannelUnavailable {
mob_id: String::new(),
endpoint: endpoint.comms_address(),
operation: "timeout",
})?
}
async fn send_inner(
endpoint: &RemoteEndpoint,
request: &ControlRequest,
) -> Result<ControlResponse, RemoteMobError> {
let mut stream = match endpoint {
RemoteEndpoint::Tcp(addr) => ControlStream::Tcp(
TcpStream::connect(addr)
.await
.map_err(|err| io_error("connect", endpoint, err))?,
),
#[cfg(unix)]
RemoteEndpoint::Uds(path) => ControlStream::Uds(
UnixStream::connect(std::path::Path::new(path))
.await
.map_err(|err| io_error("connect", endpoint, err))?,
),
#[cfg(not(unix))]
RemoteEndpoint::Uds(_) => {
return Err(RemoteMobError::UnsupportedTransport {
mob_id: String::new(),
transport: endpoint.comms_address(),
});
}
};
let payload =
serde_json::to_vec(request).map_err(|err| encode_error(endpoint, err.to_string()))?;
stream
.write_frame(&payload)
.await
.map_err(|err| io_error("write", endpoint, err))?;
let response_payload = stream
.read_frame()
.await
.map_err(|err| io_error("read", endpoint, err))?;
serde_json::from_slice::<ControlResponse>(&response_payload)
.map_err(|err| decode_error(endpoint, err.to_string()))
}
}
fn io_error(stage: &'static str, endpoint: &RemoteEndpoint, err: std::io::Error) -> RemoteMobError {
RemoteMobError::ControlChannelUnavailable {
mob_id: String::new(),
endpoint: endpoint.comms_address(),
operation: match stage {
"connect" => "connect",
"write" => "write",
"read" => "read",
_ => "io",
},
}
.with_context(err.to_string())
}
fn encode_error(endpoint: &RemoteEndpoint, message: String) -> RemoteMobError {
RemoteMobError::Encode {
endpoint: endpoint.comms_address(),
message,
}
}
fn decode_error(endpoint: &RemoteEndpoint, message: String) -> RemoteMobError {
RemoteMobError::Decode {
endpoint: endpoint.comms_address(),
message,
}
}
pub trait ControlHandler: Send + Sync + 'static {
fn handle(
&self,
request: ControlRequest,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ControlResponse> + Send + '_>>;
}
pub struct MobHandleControlHandler {
handle: meerkat_mob::MobHandle,
}
impl MobHandleControlHandler {
pub fn new(handle: meerkat_mob::MobHandle) -> Self {
Self { handle }
}
}
impl ControlHandler for MobHandleControlHandler {
fn handle(
&self,
request: ControlRequest,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ControlResponse> + Send + '_>> {
let handle = self.handle.clone();
Box::pin(async move {
match request {
ControlRequest::Wire {
remote_member,
local_peer_spec_address,
local_comms_name,
local_peer_id,
local_pubkey_b64,
} => {
handle_wire(
&handle,
&remote_member,
&local_peer_spec_address,
&local_comms_name,
&local_peer_id,
local_pubkey_b64.as_deref(),
true,
)
.await
}
ControlRequest::Unwire {
remote_member,
local_peer_spec_address,
local_comms_name,
local_peer_id,
local_pubkey_b64,
} => {
handle_wire(
&handle,
&remote_member,
&local_peer_spec_address,
&local_comms_name,
&local_peer_id,
local_pubkey_b64.as_deref(),
false,
)
.await
}
ControlRequest::Inject {
remote_member,
content,
} => handle_inject(&handle, &remote_member, content).await,
ControlRequest::LookupMember { remote_member } => {
handle_lookup_member(&handle, &remote_member).await
}
}
})
}
}
async fn handle_wire(
handle: &meerkat_mob::MobHandle,
remote_member: &str,
local_peer_spec_address: &str,
local_comms_name: &str,
local_peer_id: &str,
local_pubkey_b64: Option<&str>,
wire: bool,
) -> ControlResponse {
let pubkey = match local_pubkey_b64 {
Some(s) if !s.is_empty() => match crate::auth::peer_keys::decode_pubkey_b64(s) {
Ok(bytes) => Some(bytes),
Err(err) => {
return ControlResponse::Err {
code: "decode".to_string(),
message: format!("local_pubkey_b64: {err}"),
};
}
},
_ => None,
};
let spec_result = match pubkey {
Some(bytes) => meerkat_core::comms::TrustedPeerDescriptor::unsigned_with_pubkey(
local_comms_name,
local_peer_id,
bytes,
local_peer_spec_address,
),
None => meerkat_core::comms::TrustedPeerDescriptor::test_only_unsigned(
local_comms_name,
local_peer_id,
local_peer_spec_address,
),
};
let spec = match spec_result {
Ok(spec) => spec,
Err(err) => {
return ControlResponse::Err {
code: "peer_spec".to_string(),
message: err,
};
}
};
let mid = meerkat_mob::ids::MeerkatId::from(remote_member);
let result = if wire {
handle
.wire(mid, meerkat_mob::PeerTarget::External(spec))
.await
} else {
handle
.unwire(mid, meerkat_mob::PeerTarget::External(spec))
.await
};
match result {
Ok(()) => ControlResponse::Ok,
Err(err) => ControlResponse::Err {
code: "mob_error".to_string(),
message: err.to_string(),
},
}
}
async fn handle_inject(
handle: &meerkat_mob::MobHandle,
remote_member: &str,
content: serde_json::Value,
) -> ControlResponse {
let content_input: meerkat_core::ContentInput = match serde_json::from_value(content) {
Ok(c) => c,
Err(err) => {
return ControlResponse::Err {
code: "decode".to_string(),
message: format!("content: {err}"),
};
}
};
let mid = meerkat_mob::ids::MeerkatId::from(remote_member);
let member = match handle.member(&mid).await {
Ok(m) => m,
Err(err) => {
return ControlResponse::Err {
code: "unknown_member".to_string(),
message: err.to_string(),
};
}
};
if let Err(err) = member
.send(content_input, meerkat_core::types::HandlingMode::Queue)
.await
{
return ControlResponse::Err {
code: "mob_error".to_string(),
message: err.to_string(),
};
}
match handle.resolve_bridge_session_id(&mid).await {
Some(sid) => ControlResponse::Injected {
session_id: sid.to_string(),
},
None => ControlResponse::Err {
code: "no_session".to_string(),
message: format!("member '{remote_member}' has no bound bridge session"),
},
}
}
async fn handle_lookup_member(
handle: &meerkat_mob::MobHandle,
remote_member: &str,
) -> ControlResponse {
let mid = meerkat_mob::ids::MeerkatId::from(remote_member);
let mob_id = handle.mob_id().to_string();
let entry = match handle.get_member(&mid).await {
Some(e) => e,
None => {
return ControlResponse::Err {
code: "unknown_member".to_string(),
message: format!("member '{remote_member}' not in mob '{mob_id}'"),
};
}
};
let peer_id = match entry.peer_id() {
Some(p) => p.to_string(),
None => {
return ControlResponse::Err {
code: "no_comms".to_string(),
message: format!("member '{remote_member}' has no comms runtime"),
};
}
};
let comms_name = format!("{}/{}/{}", mob_id, entry.role, remote_member);
ControlResponse::Member {
peer_id,
comms_name,
}
}
pub async fn serve_tcp_control(listener: TcpListener, handler: std::sync::Arc<dyn ControlHandler>) {
loop {
let (stream, _peer_addr) = match listener.accept().await {
Ok(pair) => pair,
Err(err) => {
tracing::warn!(error = %err, "control listener accept failed; exiting");
return;
}
};
let handler = handler.clone();
tokio::spawn(serve_one_tcp(stream, handler));
}
}
#[cfg(unix)]
pub async fn serve_uds_control(
listener: UnixListener,
handler: std::sync::Arc<dyn ControlHandler>,
) {
loop {
let (stream, _peer_addr) = match listener.accept().await {
Ok(pair) => pair,
Err(err) => {
tracing::warn!(error = %err, "uds control listener accept failed; exiting");
return;
}
};
let handler = handler.clone();
tokio::spawn(serve_one_uds(stream, handler));
}
}
async fn serve_one_tcp(stream: TcpStream, handler: std::sync::Arc<dyn ControlHandler>) {
let mut s = ControlStream::Tcp(stream);
serve_one(&mut s, handler).await;
}
#[cfg(unix)]
async fn serve_one_uds(stream: UnixStream, handler: std::sync::Arc<dyn ControlHandler>) {
let mut s = ControlStream::Uds(stream);
serve_one(&mut s, handler).await;
}
async fn serve_one(stream: &mut ControlStream, handler: std::sync::Arc<dyn ControlHandler>) {
let payload = match stream.read_frame().await {
Ok(buf) => buf,
Err(err) => {
tracing::debug!(error = %err, "control listener: read failed");
return;
}
};
let request = match serde_json::from_slice::<ControlRequest>(&payload) {
Ok(req) => req,
Err(err) => {
let response = ControlResponse::Err {
code: "decode".to_string(),
message: err.to_string(),
};
let response_payload = serde_json::to_vec(&response).unwrap_or_default();
let _ = stream.write_frame(&response_payload).await;
return;
}
};
let response = handler.handle(request).await;
let response_payload = serde_json::to_vec(&response).unwrap_or_default();
let _ = stream.write_frame(&response_payload).await;
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use std::sync::Arc;
struct EchoHandler;
impl ControlHandler for EchoHandler {
fn handle(
&self,
request: ControlRequest,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ControlResponse> + Send + '_>>
{
Box::pin(async move {
match request {
ControlRequest::Wire { .. } | ControlRequest::Unwire { .. } => {
ControlResponse::Ok
}
ControlRequest::Inject { remote_member, .. } => ControlResponse::Injected {
session_id: format!("session-for-{remote_member}"),
},
ControlRequest::LookupMember { remote_member } => ControlResponse::Member {
peer_id: format!("peer-id-for-{remote_member}"),
comms_name: format!("mob/role/{remote_member}"),
},
}
})
}
}
#[tokio::test]
async fn tcp_round_trip_inject() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("addr");
let handler: Arc<dyn ControlHandler> = Arc::new(EchoHandler);
let server = tokio::spawn(serve_tcp_control(listener, handler));
let endpoint = RemoteEndpoint::Tcp(addr.to_string());
let request = ControlRequest::Inject {
remote_member: "alice".to_string(),
content: serde_json::json!({"text": "hello"}),
};
let response = RemoteControlClient::send(&endpoint, &request, DEFAULT_CONTROL_TIMEOUT)
.await
.expect("control rpc");
assert_eq!(
response,
ControlResponse::Injected {
session_id: "session-for-alice".to_string(),
},
);
server.abort();
}
#[tokio::test]
async fn tcp_round_trip_wire() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("addr");
let handler: Arc<dyn ControlHandler> = Arc::new(EchoHandler);
let server = tokio::spawn(serve_tcp_control(listener, handler));
let endpoint = RemoteEndpoint::Tcp(addr.to_string());
let request = ControlRequest::Wire {
remote_member: "bob".to_string(),
local_peer_spec_address: "tcp://127.0.0.1:9001".to_string(),
local_comms_name: "demo/role/alice".to_string(),
local_peer_id: "00000000-0000-4000-8000-000000000001".to_string(),
local_pubkey_b64: None,
};
let response = RemoteControlClient::send(&endpoint, &request, DEFAULT_CONTROL_TIMEOUT)
.await
.expect("control rpc");
assert_eq!(response, ControlResponse::Ok);
server.abort();
}
#[tokio::test]
async fn malformed_request_returns_decode_error() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("addr");
let handler: Arc<dyn ControlHandler> = Arc::new(EchoHandler);
let _server = tokio::spawn(serve_tcp_control(listener, handler));
let mut stream = TcpStream::connect(addr).await.expect("connect");
stream
.write_all(&u32::to_be_bytes(5))
.await
.expect("write header");
stream.write_all(b"hello").await.expect("write payload");
stream.flush().await.expect("flush");
let mut header = [0u8; 4];
stream.read_exact(&mut header).await.expect("read header");
let len = u32::from_be_bytes(header) as usize;
let mut buf = vec![0u8; len];
stream.read_exact(&mut buf).await.expect("read payload");
let response: ControlResponse = serde_json::from_slice(&buf).expect("decode response");
match response {
ControlResponse::Err { code, .. } => assert_eq!(code, "decode"),
other => panic!("expected decode error, got {other:?}"),
}
}
}