use std::collections::HashMap;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use facet::Facet;
use crate::runtime::{Mutex, Receiver, channel, sleep, spawn, spawn_with_abort};
use crate::{
ChannelError, ChannelRegistry, ConnectionHandle, Context, DriverMessage, MessageTransport,
ResponseData, RoamError, Role, ServiceDispatcher, TransportError,
};
use roam_wire::{ConnectionId, Hello, Message};
#[derive(Debug, Clone)]
pub struct Negotiated {
pub max_payload_size: u32,
pub initial_credit: u32,
}
#[derive(Debug)]
pub enum ConnectionError {
Io(std::io::Error),
ProtocolViolation {
rule_id: &'static str,
context: String,
},
Dispatch(String),
Closed,
}
impl std::fmt::Display for ConnectionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConnectionError::Io(e) => write!(f, "IO error: {e}"),
ConnectionError::ProtocolViolation { rule_id, context } => {
write!(f, "protocol violation: {rule_id}: {context}")
}
ConnectionError::Dispatch(msg) => write!(f, "dispatch error: {msg}"),
ConnectionError::Closed => write!(f, "connection closed"),
}
}
}
impl std::error::Error for ConnectionError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
ConnectionError::Io(e) => Some(e),
_ => None,
}
}
}
impl From<std::io::Error> for ConnectionError {
fn from(e: std::io::Error) -> Self {
ConnectionError::Io(e)
}
}
#[derive(Debug, Clone)]
pub struct HandshakeConfig {
pub max_payload_size: u32,
pub initial_channel_credit: u32,
}
impl Default for HandshakeConfig {
fn default() -> Self {
Self {
max_payload_size: 1024 * 1024, initial_channel_credit: 64 * 1024, }
}
}
impl HandshakeConfig {
pub fn to_hello(&self) -> Hello {
Hello::V2 {
max_payload_size: self.max_payload_size,
initial_channel_credit: self.initial_channel_credit,
}
}
}
pub trait MessageConnector: Send + Sync + 'static {
type Transport: MessageTransport;
fn connect(&self) -> impl Future<Output = io::Result<Self::Transport>> + Send;
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f64,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(5),
backoff_multiplier: 2.0,
}
}
}
impl RetryPolicy {
pub fn backoff_for_attempt(&self, attempt: u32) -> Duration {
let multiplier = self
.backoff_multiplier
.powi(attempt.saturating_sub(1) as i32);
let backoff = self.initial_backoff.mul_f64(multiplier);
backoff.min(self.max_backoff)
}
}
#[derive(Debug)]
pub enum ConnectError {
RetriesExhausted {
original: io::Error,
attempts: u32,
},
ConnectFailed(io::Error),
Rpc(TransportError),
Rejected(String),
}
impl std::fmt::Display for ConnectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConnectError::RetriesExhausted { original, attempts } => {
write!(
f,
"reconnection failed after {attempts} attempts: {original}"
)
}
ConnectError::ConnectFailed(e) => write!(f, "connection failed: {e}"),
ConnectError::Rpc(e) => write!(f, "RPC error: {e}"),
ConnectError::Rejected(reason) => write!(f, "connection rejected: {reason}"),
}
}
}
impl std::error::Error for ConnectError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
ConnectError::RetriesExhausted { original, .. } => Some(original),
ConnectError::ConnectFailed(e) => Some(e),
ConnectError::Rpc(e) => Some(e),
ConnectError::Rejected(_) => None,
}
}
}
impl From<TransportError> for ConnectError {
fn from(e: TransportError) -> Self {
ConnectError::Rpc(e)
}
}
pub async fn accept_framed<T, D>(
transport: T,
config: HandshakeConfig,
dispatcher: D,
) -> Result<(ConnectionHandle, IncomingConnections, Driver<T, D>), ConnectionError>
where
T: MessageTransport,
D: ServiceDispatcher,
{
establish(transport, config.to_hello(), dispatcher, Role::Acceptor).await
}
pub fn connect_framed<C, D>(
connector: C,
config: HandshakeConfig,
dispatcher: D,
) -> FramedClient<C, D>
where
C: MessageConnector,
D: ServiceDispatcher + Clone,
{
FramedClient {
connector: Arc::new(connector),
config,
dispatcher,
retry_policy: RetryPolicy::default(),
state: Arc::new(Mutex::new(None)),
}
}
pub fn connect_framed_with_policy<C, D>(
connector: C,
config: HandshakeConfig,
dispatcher: D,
retry_policy: RetryPolicy,
) -> FramedClient<C, D>
where
C: MessageConnector,
D: ServiceDispatcher + Clone,
{
FramedClient {
connector: Arc::new(connector),
config,
dispatcher,
retry_policy,
state: Arc::new(Mutex::new(None)),
}
}
struct FramedClientState {
handle: ConnectionHandle,
}
pub struct FramedClient<C, D> {
connector: Arc<C>,
config: HandshakeConfig,
dispatcher: D,
retry_policy: RetryPolicy,
state: Arc<Mutex<Option<FramedClientState>>>,
}
impl<C, D> Clone for FramedClient<C, D>
where
D: Clone,
{
fn clone(&self) -> Self {
Self {
connector: self.connector.clone(),
config: self.config.clone(),
dispatcher: self.dispatcher.clone(),
retry_policy: self.retry_policy.clone(),
state: self.state.clone(),
}
}
}
impl<C, D> FramedClient<C, D>
where
C: MessageConnector,
D: ServiceDispatcher + Clone + 'static,
{
pub async fn handle(&self) -> Result<ConnectionHandle, ConnectError> {
self.ensure_connected().await
}
async fn ensure_connected(&self) -> Result<ConnectionHandle, ConnectError> {
let mut state = self.state.lock().await;
if let Some(ref conn) = *state {
return Ok(conn.handle.clone());
}
let conn = self.connect_internal().await?;
let handle = conn.handle.clone();
*state = Some(conn);
Ok(handle)
}
async fn connect_internal(&self) -> Result<FramedClientState, ConnectError> {
let transport = self
.connector
.connect()
.await
.map_err(ConnectError::ConnectFailed)?;
let (handle, _incoming, driver) = establish(
transport,
self.config.to_hello(),
self.dispatcher.clone(),
Role::Initiator,
)
.await
.map_err(|e| ConnectError::ConnectFailed(connection_error_to_io(e)))?;
spawn(async move {
let _ = driver.run().await;
});
Ok(FramedClientState { handle })
}
pub async fn call_raw(
&self,
method_id: u64,
payload: Vec<u8>,
) -> Result<Vec<u8>, ConnectError> {
let mut last_error: Option<io::Error> = None;
let mut attempt = 0u32;
loop {
let handle = match self.ensure_connected().await {
Ok(h) => h,
Err(ConnectError::ConnectFailed(e)) => {
attempt += 1;
if attempt >= self.retry_policy.max_attempts {
return Err(ConnectError::RetriesExhausted {
original: last_error.unwrap_or(e),
attempts: attempt,
});
}
last_error = Some(e);
let backoff = self.retry_policy.backoff_for_attempt(attempt);
sleep(backoff).await;
continue;
}
Err(e) => return Err(e),
};
match handle.call_raw(method_id, payload.clone()).await {
Ok(response) => return Ok(response),
Err(TransportError::Encode(e)) => {
return Err(ConnectError::Rpc(TransportError::Encode(e)));
}
Err(TransportError::ConnectionClosed) | Err(TransportError::DriverGone) => {
{
let mut state = self.state.lock().await;
*state = None;
}
attempt += 1;
if attempt >= self.retry_policy.max_attempts {
let error = last_error.unwrap_or_else(|| {
io::Error::new(io::ErrorKind::ConnectionReset, "connection closed")
});
return Err(ConnectError::RetriesExhausted {
original: error,
attempts: attempt,
});
}
last_error = Some(io::Error::new(
io::ErrorKind::ConnectionReset,
"connection closed",
));
let backoff = self.retry_policy.backoff_for_attempt(attempt);
sleep(backoff).await;
}
}
}
}
}
impl<C, D> crate::Caller for FramedClient<C, D>
where
C: MessageConnector,
D: ServiceDispatcher + Clone + 'static,
{
async fn call_with_metadata<T: Facet<'static> + Send>(
&self,
method_id: u64,
args: &mut T,
metadata: roam_wire::Metadata,
) -> Result<ResponseData, TransportError> {
let mut attempt = 0u32;
loop {
let handle = match self.ensure_connected().await {
Ok(h) => h,
Err(ConnectError::ConnectFailed(_)) => {
attempt += 1;
if attempt >= self.retry_policy.max_attempts {
return Err(TransportError::ConnectionClosed);
}
let backoff = self.retry_policy.backoff_for_attempt(attempt);
sleep(backoff).await;
continue;
}
Err(ConnectError::RetriesExhausted { .. }) => {
return Err(TransportError::ConnectionClosed);
}
Err(ConnectError::Rpc(e)) => return Err(e),
Err(ConnectError::Rejected(_)) => {
return Err(TransportError::ConnectionClosed);
}
};
match handle
.call_with_metadata(method_id, args, metadata.clone())
.await
{
Ok(response) => return Ok(response),
Err(TransportError::Encode(e)) => {
return Err(TransportError::Encode(e));
}
Err(TransportError::ConnectionClosed) | Err(TransportError::DriverGone) => {
{
let mut state = self.state.lock().await;
*state = None;
}
attempt += 1;
if attempt >= self.retry_policy.max_attempts {
return Err(TransportError::ConnectionClosed);
}
let backoff = self.retry_policy.backoff_for_attempt(attempt);
sleep(backoff).await;
}
}
}
}
fn bind_response_streams<R: Facet<'static>>(&self, response: &mut R, channels: &[u64]) {
let _ = (response, channels);
}
}
fn connection_error_to_io(e: ConnectionError) -> io::Error {
match e {
ConnectionError::Io(io_err) => io_err,
ConnectionError::ProtocolViolation { rule_id, context } => io::Error::new(
io::ErrorKind::InvalidData,
format!("protocol violation: {rule_id}: {context}"),
),
ConnectionError::Dispatch(msg) => io::Error::other(format!("dispatch error: {msg}")),
ConnectionError::Closed => {
io::Error::new(io::ErrorKind::ConnectionReset, "connection closed")
}
}
}
struct ConnectionState {
#[allow(dead_code)]
conn_id: ConnectionId,
handle: ConnectionHandle,
server_channel_registry: ChannelRegistry,
dispatcher: Option<Box<dyn ServiceDispatcher>>,
pending_responses:
HashMap<u64, crate::runtime::OneshotSender<Result<ResponseData, TransportError>>>,
in_flight_server_requests: HashMap<u64, crate::runtime::AbortHandle>,
}
impl ConnectionState {
fn new(
conn_id: ConnectionId,
driver_tx: crate::runtime::Sender<DriverMessage>,
role: Role,
initial_credit: u32,
diagnostic_state: Option<Arc<crate::diagnostic::DiagnosticState>>,
dispatcher: Option<Box<dyn ServiceDispatcher>>,
) -> Self {
let handle = ConnectionHandle::new_with_diagnostics(
conn_id,
driver_tx.clone(),
role,
initial_credit,
diagnostic_state,
);
let server_channel_registry =
ChannelRegistry::new_with_credit_and_role(conn_id, initial_credit, driver_tx, role);
Self {
conn_id,
handle,
server_channel_registry,
dispatcher,
pending_responses: HashMap::new(),
in_flight_server_requests: HashMap::new(),
}
}
fn fail_pending_responses(&mut self) {
for (_, tx) in self.pending_responses.drain() {
let _ = tx.send(Err(TransportError::ConnectionClosed));
}
}
fn abort_in_flight_requests(&mut self) {
for (_, abort_handle) in self.in_flight_server_requests.drain() {
abort_handle.abort();
}
}
}
pub struct IncomingConnection {
request_id: u64,
pub metadata: roam_wire::Metadata,
response_tx: crate::runtime::OneshotSender<IncomingConnectionResponse>,
}
impl IncomingConnection {
pub async fn accept(
self,
metadata: roam_wire::Metadata,
dispatcher: Option<Box<dyn ServiceDispatcher>>,
) -> Result<ConnectionHandle, TransportError> {
let (handle_tx, handle_rx) = crate::runtime::oneshot();
let _ = self.response_tx.send(IncomingConnectionResponse::Accept {
request_id: self.request_id,
metadata,
dispatcher,
handle_tx,
});
let result: Result<ConnectionHandle, _> =
handle_rx.await.map_err(|_| TransportError::DriverGone)?;
result
}
pub fn reject(self, reason: String, metadata: roam_wire::Metadata) {
let _ = self.response_tx.send(IncomingConnectionResponse::Reject {
request_id: self.request_id,
reason,
metadata,
});
}
}
enum IncomingConnectionResponse {
Accept {
request_id: u64,
metadata: roam_wire::Metadata,
dispatcher: Option<Box<dyn ServiceDispatcher>>,
handle_tx: crate::runtime::OneshotSender<Result<ConnectionHandle, TransportError>>,
},
Reject {
request_id: u64,
reason: String,
metadata: roam_wire::Metadata,
},
}
struct PendingConnect {
response_tx: crate::runtime::OneshotSender<Result<ConnectionHandle, ConnectError>>,
dispatcher: Option<Box<dyn ServiceDispatcher>>,
}
pub struct Driver<T, D> {
io: T,
dispatcher: D,
#[allow(dead_code)]
role: Role,
negotiated: Negotiated,
driver_rx: Receiver<DriverMessage>,
driver_tx: crate::runtime::Sender<DriverMessage>,
connections: HashMap<ConnectionId, ConnectionState>,
next_conn_id: u64,
pending_connects: HashMap<u64, PendingConnect>,
incoming_connections_tx: Option<crate::runtime::Sender<IncomingConnection>>,
incoming_response_rx: Option<Receiver<IncomingConnectionResponse>>,
incoming_response_tx: crate::runtime::Sender<IncomingConnectionResponse>,
diagnostic_state: Option<Arc<crate::diagnostic::DiagnosticState>>,
}
impl<T, D> Driver<T, D>
where
T: MessageTransport,
D: ServiceDispatcher,
{
pub fn root_handle(&self) -> ConnectionHandle {
self.connections
.get(&ConnectionId::ROOT)
.expect("root connection always exists")
.handle
.clone()
}
pub async fn run(mut self) -> Result<(), ConnectionError> {
use futures_util::FutureExt;
loop {
futures_util::select! {
msg = self.driver_rx.recv().fuse() => {
if let Some(msg) = msg {
self.handle_driver_message(msg).await?;
}
}
response = async {
if let Some(rx) = &mut self.incoming_response_rx {
rx.recv().await
} else {
std::future::pending().await
}
}.fuse() => {
if let Some(response) = response {
self.handle_incoming_response(response).await?;
}
}
result = self.io.recv().fuse() => {
match self.handle_recv(result).await {
Ok(true) => continue,
Ok(false) => return Ok(()),
Err(e) => return Err(e),
}
}
}
}
}
async fn handle_incoming_response(
&mut self,
response: IncomingConnectionResponse,
) -> Result<(), ConnectionError> {
match response {
IncomingConnectionResponse::Accept {
request_id,
metadata,
dispatcher,
handle_tx,
} => {
let conn_id = ConnectionId::new(self.next_conn_id);
self.next_conn_id += 1;
let conn_state = ConnectionState::new(
conn_id,
self.driver_tx.clone(),
self.role,
self.negotiated.initial_credit,
self.diagnostic_state.clone(),
dispatcher,
);
let handle = conn_state.handle.clone();
self.connections.insert(conn_id, conn_state);
let msg = Message::Accept {
request_id,
conn_id,
metadata,
};
self.io.send(&msg).await?;
let _ = handle_tx.send(Ok(handle));
}
IncomingConnectionResponse::Reject {
request_id,
reason,
metadata,
} => {
let msg = Message::Reject {
request_id,
reason,
metadata,
};
self.io.send(&msg).await?;
}
}
Ok(())
}
async fn handle_driver_message(&mut self, msg: DriverMessage) -> Result<(), ConnectionError> {
match msg {
DriverMessage::Call {
conn_id,
request_id,
method_id,
metadata,
channels,
payload,
response_tx,
} => {
if let Some(conn) = self.connections.get_mut(&conn_id) {
conn.pending_responses.insert(request_id, response_tx);
} else {
let _ = response_tx.send(Err(TransportError::ConnectionClosed));
return Ok(());
}
let req = Message::Request {
conn_id,
request_id,
method_id,
metadata,
channels,
payload,
};
self.io.send(&req).await?;
}
DriverMessage::Data {
conn_id,
channel_id,
payload,
} => {
let wire_msg = Message::Data {
conn_id,
channel_id,
payload,
};
self.io.send(&wire_msg).await?;
}
DriverMessage::Close {
conn_id,
channel_id,
} => {
let wire_msg = Message::Close {
conn_id,
channel_id,
};
self.io.send(&wire_msg).await?;
}
DriverMessage::Response {
conn_id,
request_id,
channels,
payload,
} => {
let should_send = if let Some(conn) = self.connections.get_mut(&conn_id) {
conn.in_flight_server_requests.remove(&request_id).is_some()
} else {
false
};
if !should_send {
return Ok(());
}
let wire_msg = Message::Response {
conn_id,
request_id,
metadata: vec![],
channels,
payload,
};
self.io.send(&wire_msg).await?;
}
DriverMessage::Connect {
request_id,
metadata,
response_tx,
dispatcher,
} => {
self.pending_connects.insert(
request_id,
PendingConnect {
response_tx,
dispatcher,
},
);
let wire_msg = Message::Connect {
request_id,
metadata,
};
self.io.send(&wire_msg).await?;
}
}
Ok(())
}
async fn handle_recv(
&mut self,
result: std::io::Result<Option<Message>>,
) -> Result<bool, ConnectionError> {
let msg = match result {
Ok(Some(m)) => m,
Ok(None) => return Ok(false),
Err(e) => {
let raw = self.io.last_decoded();
if raw.len() >= 2 && raw[0] == 0x00 && raw[1] != 0x00 {
return Err(self.goodbye("message.hello.unknown-version").await);
}
if !raw.is_empty() && raw[0] >= 12 {
return Err(self.goodbye("message.unknown-variant").await);
}
if e.kind() == std::io::ErrorKind::InvalidData {
return Err(self.goodbye("message.decode-error").await);
}
return Err(ConnectionError::Io(e));
}
};
match self.handle_message(msg).await {
Ok(()) => Ok(true),
Err(ConnectionError::Closed) => Ok(false),
Err(e) => Err(e),
}
}
async fn handle_message(&mut self, msg: Message) -> Result<(), ConnectionError> {
match msg {
Message::Hello(_) => {
}
Message::Connect {
request_id,
metadata,
} => {
if let Some(tx) = &self.incoming_connections_tx {
let (response_tx, response_rx) = crate::runtime::oneshot();
let incoming = IncomingConnection {
request_id,
metadata,
response_tx,
};
if tx.try_send(incoming).is_ok() {
let incoming_response_tx = self.incoming_response_tx.clone();
spawn(async move {
if let Ok(response) = response_rx.await {
let _ = incoming_response_tx.send(response).await;
}
});
} else {
let msg = Message::Reject {
request_id,
reason: "not listening".into(),
metadata: vec![],
};
self.io.send(&msg).await?;
}
} else {
let msg = Message::Reject {
request_id,
reason: "not listening".into(),
metadata: vec![],
};
self.io.send(&msg).await?;
}
}
Message::Accept {
request_id,
conn_id,
metadata: _,
} => {
if let Some(pending) = self.pending_connects.remove(&request_id) {
let conn_state = ConnectionState::new(
conn_id,
self.driver_tx.clone(),
self.role,
self.negotiated.initial_credit,
self.diagnostic_state.clone(),
pending.dispatcher,
);
let handle = conn_state.handle.clone();
self.connections.insert(conn_id, conn_state);
let _ = pending.response_tx.send(Ok(handle));
}
}
Message::Reject {
request_id,
reason,
metadata: _,
} => {
if let Some(pending) = self.pending_connects.remove(&request_id) {
let _ = pending
.response_tx
.send(Err(ConnectError::Rejected(reason)));
}
}
Message::Goodbye { conn_id, reason: _ } => {
if conn_id.is_root() {
for (_, mut conn) in self.connections.drain() {
conn.fail_pending_responses();
conn.abort_in_flight_requests();
}
return Err(ConnectionError::Closed);
} else {
if let Some(mut conn) = self.connections.remove(&conn_id) {
conn.fail_pending_responses();
conn.abort_in_flight_requests();
}
}
}
Message::Request {
conn_id,
request_id,
method_id,
metadata,
channels,
payload,
} => {
self.handle_incoming_request(
conn_id, request_id, method_id, metadata, channels, payload,
)
.await?;
}
Message::Response {
conn_id,
request_id,
channels,
payload,
..
} => {
if let Some(conn) = self.connections.get_mut(&conn_id)
&& let Some(tx) = conn.pending_responses.remove(&request_id)
{
let _ = tx.send(Ok(ResponseData { payload, channels }));
}
}
Message::Cancel {
conn_id,
request_id,
} => {
self.handle_cancel(conn_id, request_id).await?;
}
Message::Data {
conn_id,
channel_id,
payload,
} => {
self.handle_data(conn_id, channel_id, payload).await?;
}
Message::Close {
conn_id,
channel_id,
} => {
self.handle_close(conn_id, channel_id).await?;
}
Message::Reset {
conn_id,
channel_id,
} => {
self.handle_reset(conn_id, channel_id)?;
}
Message::Credit {
conn_id,
channel_id,
bytes,
} => {
self.handle_credit(conn_id, channel_id, bytes)?;
}
}
Ok(())
}
async fn handle_incoming_request(
&mut self,
conn_id: ConnectionId,
request_id: u64,
method_id: u64,
metadata: Vec<(String, roam_wire::MetadataValue)>,
channels: Vec<u64>,
payload: Vec<u8>,
) -> Result<(), ConnectionError> {
let conn = match self.connections.get_mut(&conn_id) {
Some(c) => c,
None => {
return Err(self.goodbye("message.conn-id").await);
}
};
if conn.in_flight_server_requests.contains_key(&request_id) {
return Err(self.goodbye("call.request-id.duplicate-detection").await);
}
if let Err(rule_id) = roam_wire::validate_metadata(&metadata) {
return Err(self.goodbye(rule_id).await);
}
if payload.len() as u32 > self.negotiated.max_payload_size {
return Err(self.goodbye("flow.call.payload-limit").await);
}
let cx = Context::new(
conn_id,
roam_wire::RequestId::new(request_id),
roam_wire::MethodId::new(method_id),
metadata,
channels,
);
let dispatcher: &dyn ServiceDispatcher = if let Some(ref conn_dispatcher) = conn.dispatcher
{
conn_dispatcher.as_ref()
} else {
&self.dispatcher
};
debug!(
conn_id = conn_id.raw(),
request_id, method_id, "dispatching incoming request"
);
let handler_fut = dispatcher.dispatch(&cx, payload, &mut conn.server_channel_registry);
let abort_handle = spawn_with_abort(async move {
handler_fut.await;
});
conn.in_flight_server_requests
.insert(request_id, abort_handle);
Ok(())
}
async fn handle_cancel(
&mut self,
conn_id: ConnectionId,
request_id: u64,
) -> Result<(), ConnectionError> {
let conn = match self.connections.get_mut(&conn_id) {
Some(c) => c,
None => {
return Ok(());
}
};
if let Some(abort_handle) = conn.in_flight_server_requests.remove(&request_id) {
abort_handle.abort();
let wire_msg = Message::Response {
conn_id,
request_id,
metadata: vec![],
channels: vec![],
payload: vec![1, 3],
};
self.io.send(&wire_msg).await?;
}
Ok(())
}
async fn handle_data(
&mut self,
conn_id: ConnectionId,
channel_id: u64,
payload: Vec<u8>,
) -> Result<(), ConnectionError> {
if channel_id == 0 {
return Err(self.goodbye("channeling.id.zero-reserved").await);
}
if payload.len() as u32 > self.negotiated.max_payload_size {
return Err(self.goodbye("flow.call.payload-limit").await);
}
let conn = match self.connections.get_mut(&conn_id) {
Some(c) => c,
None => return Err(self.goodbye("message.conn-id").await),
};
let result = if conn.server_channel_registry.contains_incoming(channel_id) {
conn.server_channel_registry
.route_data(channel_id, payload)
.await
} else if conn.handle.contains_channel(channel_id) {
conn.handle.route_data(channel_id, payload).await
} else {
Err(ChannelError::Unknown)
};
match result {
Ok(()) => Ok(()),
Err(ChannelError::Unknown) => Err(self.goodbye("channeling.unknown").await),
Err(ChannelError::DataAfterClose) => {
Err(self.goodbye("channeling.data-after-close").await)
}
Err(ChannelError::CreditOverrun) => {
Err(self.goodbye("flow.channel.credit-overrun").await)
}
}
}
async fn handle_close(
&mut self,
conn_id: ConnectionId,
channel_id: u64,
) -> Result<(), ConnectionError> {
if channel_id == 0 {
return Err(self.goodbye("channeling.id.zero-reserved").await);
}
let conn = match self.connections.get_mut(&conn_id) {
Some(c) => c,
None => return Err(self.goodbye("message.conn-id").await),
};
if conn.server_channel_registry.contains(channel_id) {
conn.server_channel_registry.close(channel_id);
} else if conn.handle.contains_channel(channel_id) {
conn.handle.close_channel(channel_id);
} else {
return Err(self.goodbye("channeling.unknown").await);
}
Ok(())
}
fn handle_reset(
&mut self,
conn_id: ConnectionId,
channel_id: u64,
) -> Result<(), ConnectionError> {
if let Some(conn) = self.connections.get_mut(&conn_id) {
if conn.server_channel_registry.contains(channel_id) {
conn.server_channel_registry.reset(channel_id);
} else if conn.handle.contains_channel(channel_id) {
conn.handle.reset_channel(channel_id);
}
}
Ok(())
}
fn handle_credit(
&mut self,
conn_id: ConnectionId,
channel_id: u64,
bytes: u32,
) -> Result<(), ConnectionError> {
if let Some(conn) = self.connections.get_mut(&conn_id) {
if conn.server_channel_registry.contains(channel_id) {
conn.server_channel_registry
.receive_credit(channel_id, bytes);
} else if conn.handle.contains_channel(channel_id) {
conn.handle.receive_credit(channel_id, bytes);
}
}
Ok(())
}
async fn goodbye(&mut self, rule_id: &'static str) -> ConnectionError {
for (_, conn) in self.connections.iter_mut() {
conn.fail_pending_responses();
conn.abort_in_flight_requests();
}
let _ = self
.io
.send(&Message::Goodbye {
conn_id: ConnectionId::ROOT,
reason: rule_id.into(),
})
.await;
ConnectionError::ProtocolViolation {
rule_id,
context: String::new(),
}
}
}
pub async fn initiate_framed<T, D>(
transport: T,
config: HandshakeConfig,
dispatcher: D,
) -> Result<(ConnectionHandle, IncomingConnections, Driver<T, D>), ConnectionError>
where
T: MessageTransport,
D: ServiceDispatcher,
{
establish(transport, config.to_hello(), dispatcher, Role::Initiator).await
}
pub type IncomingConnections = Receiver<IncomingConnection>;
async fn establish<T, D>(
mut io: T,
our_hello: Hello,
dispatcher: D,
role: Role,
) -> Result<(ConnectionHandle, IncomingConnections, Driver<T, D>), ConnectionError>
where
T: MessageTransport,
D: ServiceDispatcher,
{
io.send(&Message::Hello(our_hello.clone())).await?;
let peer_hello = match io.recv_timeout(Duration::from_secs(5)).await {
Ok(Some(Message::Hello(Hello::V2 {
max_payload_size,
initial_channel_credit,
}))) => Hello::V2 {
max_payload_size,
initial_channel_credit,
},
Ok(Some(Message::Hello(Hello::V1 { .. }))) => {
let _ = io
.send(&Message::Goodbye {
conn_id: ConnectionId::ROOT,
reason: "message.hello.unknown-version".into(),
})
.await;
return Err(ConnectionError::ProtocolViolation {
rule_id: "message.hello.unknown-version",
context: "received Hello::V1, but V1 is no longer supported".into(),
});
}
Ok(Some(_)) => {
let _ = io
.send(&Message::Goodbye {
conn_id: ConnectionId::ROOT,
reason: "message.hello.ordering".into(),
})
.await;
return Err(ConnectionError::ProtocolViolation {
rule_id: "message.hello.ordering",
context: "received non-Hello before Hello exchange".into(),
});
}
Ok(None) => return Err(ConnectionError::Closed),
Err(e) => {
let raw = io.last_decoded();
let is_unknown_hello = raw.len() >= 2 && raw[0] == 0x00 && raw[1] > 0x01;
let version = if is_unknown_hello { raw[1] } else { 0 };
if is_unknown_hello {
let _ = io
.send(&Message::Goodbye {
conn_id: ConnectionId::ROOT,
reason: "message.hello.unknown-version".into(),
})
.await;
return Err(ConnectionError::ProtocolViolation {
rule_id: "message.hello.unknown-version",
context: format!("unknown Hello version: {version}"),
});
}
return Err(ConnectionError::Io(e));
}
};
let (our_max, our_credit) = match &our_hello {
Hello::V2 {
max_payload_size,
initial_channel_credit,
} => (*max_payload_size, *initial_channel_credit),
Hello::V1 { .. } => unreachable!("we always send V2"),
};
let (peer_max, peer_credit) = match &peer_hello {
Hello::V2 {
max_payload_size,
initial_channel_credit,
} => (*max_payload_size, *initial_channel_credit),
Hello::V1 { .. } => unreachable!("V1 is rejected above"),
};
let negotiated = Negotiated {
max_payload_size: our_max.min(peer_max),
initial_credit: our_credit.min(peer_credit),
};
let (driver_tx, driver_rx) = channel(256);
let root_conn = ConnectionState::new(
ConnectionId::ROOT,
driver_tx.clone(),
role,
negotiated.initial_credit,
None,
None,
);
let handle = root_conn.handle.clone();
let mut connections = HashMap::new();
connections.insert(ConnectionId::ROOT, root_conn);
let (incoming_connections_tx, incoming_connections_rx) = channel(64);
let (incoming_response_tx, incoming_response_rx) = channel(64);
let driver = Driver {
io,
dispatcher,
role,
negotiated: negotiated.clone(),
driver_rx,
driver_tx,
connections,
next_conn_id: 1, pending_connects: HashMap::new(),
incoming_connections_tx: Some(incoming_connections_tx), incoming_response_rx: Some(incoming_response_rx),
incoming_response_tx,
diagnostic_state: None,
};
Ok((handle, incoming_connections_rx, driver))
}
pub struct NoDispatcher;
impl ServiceDispatcher for NoDispatcher {
fn method_ids(&self) -> Vec<u64> {
vec![]
}
fn dispatch(
&self,
cx: &Context,
_payload: Vec<u8>,
registry: &mut ChannelRegistry,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> {
let conn_id = cx.conn_id;
let request_id = cx.request_id.raw();
let driver_tx = registry.driver_tx();
Box::pin(async move {
let response: Result<(), RoamError<()>> = Err(RoamError::UnknownMethod);
let payload = facet_postcard::to_vec(&response).unwrap_or_default();
let _ = driver_tx
.send(DriverMessage::Response {
conn_id,
request_id,
channels: Vec::new(),
payload,
})
.await;
})
}
}
impl Clone for NoDispatcher {
fn clone(&self) -> Self {
NoDispatcher
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backoff_calculation() {
let policy = RetryPolicy::default();
assert_eq!(policy.backoff_for_attempt(1), Duration::from_millis(100));
assert_eq!(policy.backoff_for_attempt(2), Duration::from_millis(200));
assert_eq!(policy.backoff_for_attempt(3), Duration::from_millis(400));
assert_eq!(policy.backoff_for_attempt(10), Duration::from_secs(5));
}
}