use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use parking_lot::Mutex;
use rapace_core::{
ControlPayload, ErrorCode, Frame, FrameFlags, MsgDescHot, NO_DEADLINE, RpcError, Transport,
TransportError, control_method,
};
pub const DEFAULT_INITIAL_CREDITS: u32 = 65536;
const DEFAULT_MAX_TOMBSTONES: usize = 8192;
const DEFAULT_MAX_TRACKED_CHANNELS: usize = 4096;
fn max_tombstones() -> usize {
std::env::var("RAPACE_MAX_TOMBSTONES")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|v| *v > 0)
.unwrap_or(DEFAULT_MAX_TOMBSTONES)
}
fn max_tracked_channels() -> usize {
std::env::var("RAPACE_SESSION_MAX_TRACKED_CHANNELS")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|v| *v > 0)
.unwrap_or(DEFAULT_MAX_TRACKED_CHANNELS)
}
#[derive(Debug)]
struct Tombstones {
max: usize,
order: VecDeque<u32>,
map: HashMap<u32, TombstoneInfo>,
}
impl Tombstones {
fn new(max: usize) -> Self {
Self {
max,
order: VecDeque::new(),
map: HashMap::new(),
}
}
fn contains(&self, channel_id: u32) -> bool {
self.map.contains_key(&channel_id)
}
fn get(&self, channel_id: u32) -> Option<&TombstoneInfo> {
self.map.get(&channel_id)
}
fn insert(&mut self, channel_id: u32, info: TombstoneInfo) {
let existed = self.map.insert(channel_id, info).is_some();
if !existed {
self.order.push_back(channel_id);
}
while self.order.len() > self.max {
if let Some(evicted) = self.order.pop_front() {
self.map.remove(&evicted);
}
}
}
}
#[derive(Debug, Clone, Copy)]
struct TombstoneInfo {
lifecycle: ChannelLifecycle,
cancelled: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ChannelLifecycle {
#[default]
Open,
HalfClosedLocal,
HalfClosedRemote,
Closed,
}
#[derive(Debug, Clone)]
pub struct ChannelState {
pub lifecycle: ChannelLifecycle,
pub send_credits: u32,
pub cancelled: bool,
pub frames_sent: u64,
pub frames_received: u64,
}
impl Default for ChannelState {
fn default() -> Self {
Self {
lifecycle: ChannelLifecycle::Open,
send_credits: DEFAULT_INITIAL_CREDITS,
cancelled: false,
frames_sent: 0,
frames_received: 0,
}
}
}
impl ChannelState {
pub fn can_send(&self) -> bool {
!self.cancelled
&& matches!(
self.lifecycle,
ChannelLifecycle::Open | ChannelLifecycle::HalfClosedRemote
)
}
pub fn can_receive(&self) -> bool {
!self.cancelled
&& matches!(
self.lifecycle,
ChannelLifecycle::Open | ChannelLifecycle::HalfClosedLocal
)
}
pub fn mark_local_eos(&mut self) {
self.lifecycle = match self.lifecycle {
ChannelLifecycle::Open => ChannelLifecycle::HalfClosedLocal,
ChannelLifecycle::HalfClosedRemote => ChannelLifecycle::Closed,
other => other, };
}
pub fn mark_remote_eos(&mut self) {
self.lifecycle = match self.lifecycle {
ChannelLifecycle::Open => ChannelLifecycle::HalfClosedRemote,
ChannelLifecycle::HalfClosedLocal => ChannelLifecycle::Closed,
other => other, };
}
}
pub struct Session {
transport: Arc<Transport>,
channels: Mutex<HashMap<u32, ChannelState>>,
tombstones: Mutex<Tombstones>,
}
impl Session {
pub fn new(transport: Transport) -> Self {
Self {
transport: Arc::new(transport),
channels: Mutex::new(HashMap::new()),
tombstones: Mutex::new(Tombstones::new(max_tombstones())),
}
}
pub fn transport(&self) -> &Transport {
&self.transport
}
fn is_tombstoned(&self, channel_id: u32) -> bool {
self.tombstones.lock().contains(channel_id)
}
fn tombstone(&self, channel_id: u32) {
self.tombstones.lock().insert(
channel_id,
TombstoneInfo {
lifecycle: ChannelLifecycle::Closed,
cancelled: false,
},
);
}
fn tombstone_cancelled(&self, channel_id: u32) {
self.tombstones.lock().insert(
channel_id,
TombstoneInfo {
lifecycle: ChannelLifecycle::Closed,
cancelled: true,
},
);
}
pub async fn send_frame(&self, frame: Frame) -> Result<(), RpcError> {
let channel_id = frame.desc.channel_id;
let payload_len = frame.desc.payload_len;
let has_eos = frame.desc.flags.contains(FrameFlags::EOS);
if channel_id != 0 && frame.desc.flags.contains(FrameFlags::DATA) {
if self.is_tombstoned(channel_id) {
let mut channels = self.channels.lock();
channels.remove(&channel_id);
tracing::trace!(channel_id, "dropping send on tombstoned channel");
return Ok(());
}
let mut channels = self.channels.lock();
if !channels.contains_key(&channel_id) && channels.len() >= max_tracked_channels() {
tracing::warn!(
channel_id,
tracked_channels = channels.len(),
max_tracked_channels = max_tracked_channels(),
"too many tracked channels; refusing send"
);
return Err(RpcError::Status {
code: ErrorCode::ResourceExhausted,
message: "too many tracked channels".into(),
});
}
let state = channels.entry(channel_id).or_default();
if !state.can_send() {
return Ok(());
}
if payload_len > state.send_credits {
return Err(RpcError::Status {
code: ErrorCode::ResourceExhausted,
message: format!(
"insufficient credits: need {}, have {}",
payload_len, state.send_credits
),
});
}
state.send_credits -= payload_len;
state.frames_sent += 1;
if has_eos {
state.mark_local_eos();
}
let should_remove = state.lifecycle == ChannelLifecycle::Closed || state.cancelled;
let cancelled = state.cancelled;
if should_remove {
channels.remove(&channel_id);
}
drop(channels);
if should_remove {
if cancelled {
self.tombstone_cancelled(channel_id);
} else {
self.tombstone(channel_id);
}
}
}
self.transport
.send_frame(frame)
.await
.map_err(RpcError::Transport)
}
pub async fn recv_frame(&self) -> Result<Frame, TransportError> {
loop {
let frame = self.transport.recv_frame().await?;
if frame.desc.channel_id == 0 && frame.desc.flags.contains(FrameFlags::CONTROL) {
self.process_control_frame(&frame);
return Ok(frame);
}
let channel_id = frame.desc.channel_id;
let has_eos = frame.desc.flags.contains(FrameFlags::EOS);
if channel_id != 0 && self.is_tombstoned(channel_id) {
let mut channels = self.channels.lock();
channels.remove(&channel_id);
tracing::trace!(channel_id, "dropping recv on tombstoned channel");
continue;
}
let mut drop_frame = false;
let mut should_tombstone = false;
let mut should_tombstone_cancelled = false;
{
let mut channels = self.channels.lock();
if !channels.contains_key(&channel_id) && channels.len() >= max_tracked_channels() {
tracing::warn!(
channel_id,
tracked_channels = channels.len(),
max_tracked_channels = max_tracked_channels(),
"too many tracked channels; dropping incoming frame"
);
drop_frame = true;
should_tombstone = true;
} else {
let state = channels.entry(channel_id).or_default();
if !state.can_receive() {
drop_frame = true;
} else {
if frame.desc.flags.contains(FrameFlags::DATA) {
state.frames_received += 1;
}
if has_eos {
state.mark_remote_eos();
}
if state.lifecycle == ChannelLifecycle::Closed || state.cancelled {
should_tombstone_cancelled = state.cancelled;
channels.remove(&channel_id);
should_tombstone = true;
}
}
}
}
if should_tombstone {
if should_tombstone_cancelled {
self.tombstone_cancelled(channel_id);
} else {
self.tombstone(channel_id);
}
}
if drop_frame {
continue;
}
return Ok(frame);
}
}
pub fn is_deadline_exceeded(&self, desc: &MsgDescHot) -> bool {
if desc.deadline_ns == NO_DEADLINE {
return false;
}
let now = now_ns();
now > desc.deadline_ns
}
pub fn grant_credits(&self, channel_id: u32, bytes: u32) {
let mut channels = self.channels.lock();
let state = channels.entry(channel_id).or_default();
state.send_credits = state.send_credits.saturating_add(bytes);
}
pub fn cancel_channel(&self, channel_id: u32) {
if channel_id == 0 {
return;
}
{
let mut channels = self.channels.lock();
let _ = channels.remove(&channel_id);
}
self.tombstone_cancelled(channel_id);
}
pub fn is_cancelled(&self, channel_id: u32) -> bool {
let channels = self.channels.lock();
if let Some(state) = channels.get(&channel_id) {
return state.cancelled;
}
drop(channels);
self.tombstones
.lock()
.get(channel_id)
.map(|t| t.cancelled)
.unwrap_or(false)
}
pub fn get_credits(&self, channel_id: u32) -> u32 {
let channels = self.channels.lock();
if let Some(state) = channels.get(&channel_id) {
return state.send_credits;
}
drop(channels);
if channel_id != 0 && self.is_tombstoned(channel_id) {
return 0;
}
DEFAULT_INITIAL_CREDITS
}
pub fn get_lifecycle(&self, channel_id: u32) -> ChannelLifecycle {
let channels = self.channels.lock();
if let Some(state) = channels.get(&channel_id) {
return state.lifecycle;
}
drop(channels);
self.tombstones
.lock()
.get(channel_id)
.map(|t| t.lifecycle)
.unwrap_or(ChannelLifecycle::Open)
}
pub fn get_channel_state(&self, channel_id: u32) -> ChannelState {
let channels = self.channels.lock();
if let Some(state) = channels.get(&channel_id) {
return state.clone();
}
drop(channels);
if let Some(t) = self.tombstones.lock().get(channel_id).copied() {
return ChannelState {
lifecycle: t.lifecycle,
send_credits: 0,
cancelled: t.cancelled,
frames_sent: 0,
frames_received: 0,
};
}
ChannelState::default()
}
pub fn is_closed(&self, channel_id: u32) -> bool {
self.get_lifecycle(channel_id) == ChannelLifecycle::Closed
}
fn process_control_frame(&self, frame: &Frame) {
match frame.desc.method_id {
control_method::CANCEL_CHANNEL => {
if let Ok(ControlPayload::CancelChannel { channel_id, .. }) =
facet_postcard::from_slice::<ControlPayload>(frame.payload_bytes())
{
self.cancel_channel(channel_id);
}
}
control_method::GRANT_CREDITS => {
if let Ok(ControlPayload::GrantCredits { channel_id, bytes }) =
facet_postcard::from_slice::<ControlPayload>(frame.payload_bytes())
{
self.grant_credits(channel_id, bytes);
}
}
_ => {
}
}
}
pub fn close(&self) {
self.transport.close()
}
}
fn now_ns() -> u64 {
use web_time::Instant;
static START: std::sync::OnceLock<Instant> = std::sync::OnceLock::new();
let start = START.get_or_init(Instant::now);
start.elapsed().as_nanos() as u64
}
#[cfg(test)]
mod tests {
use super::*;
use rapace_core::Transport;
use rapace_core::control_method;
fn data_eos_frame(channel_id: u32) -> Frame {
let mut desc = MsgDescHot::new();
desc.channel_id = channel_id;
desc.flags = FrameFlags::DATA | FrameFlags::EOS;
Frame::with_inline_payload(desc, &[]).expect("empty payload should fit inline")
}
fn ping_frame() -> Frame {
let mut desc = MsgDescHot::new();
desc.channel_id = 0;
desc.method_id = control_method::PING;
desc.flags = FrameFlags::CONTROL;
Frame::with_inline_payload(desc, &[0u8; 8]).expect("ping payload should fit inline")
}
fn cancel_frame(channel_id: u32) -> Frame {
let payload = ControlPayload::CancelChannel {
channel_id,
reason: rapace_core::CancelReason::ClientCancel,
};
let bytes = facet_postcard::to_vec(&payload).expect("cancel payload should serialize");
let mut desc = MsgDescHot::new();
desc.channel_id = 0;
desc.method_id = control_method::CANCEL_CHANNEL;
desc.flags = FrameFlags::CONTROL;
Frame::with_inline_payload(desc, &bytes).expect("cancel payload should fit inline")
}
#[tokio::test]
async fn test_closed_channel_is_pruned_and_tombstoned() {
let (a, b) = Transport::mem_pair();
let session = Session::new(a);
session.send_frame(data_eos_frame(2)).await.unwrap();
assert_eq!(session.channels.lock().len(), 1);
b.send_frame(data_eos_frame(2)).await.unwrap();
let _ = session.recv_frame().await.unwrap();
assert_eq!(session.channels.lock().len(), 0);
assert!(session.tombstones.lock().contains(2));
}
#[tokio::test]
async fn test_late_frames_on_closed_channel_are_dropped() {
let (a, b) = Transport::mem_pair();
let session = Session::new(a);
session.send_frame(data_eos_frame(2)).await.unwrap();
b.send_frame(data_eos_frame(2)).await.unwrap();
let _ = session.recv_frame().await.unwrap();
assert!(session.tombstones.lock().contains(2));
let mut late_desc = MsgDescHot::new();
late_desc.channel_id = 2;
late_desc.flags = FrameFlags::DATA;
let late = Frame::with_inline_payload(late_desc, &[1, 2, 3]).unwrap();
b.send_frame(late).await.unwrap();
b.send_frame(ping_frame()).await.unwrap();
let got = session.recv_frame().await.unwrap();
assert_eq!(got.desc.channel_id, 0);
assert_eq!(got.desc.method_id, control_method::PING);
}
#[tokio::test]
async fn test_cancelled_channel_is_tombstoned_and_drops_late_frames() {
let (a, b) = Transport::mem_pair();
let session = Session::new(a);
b.send_frame(cancel_frame(2)).await.unwrap();
let got = session.recv_frame().await.unwrap();
assert_eq!(got.desc.channel_id, 0);
assert_eq!(got.desc.method_id, control_method::CANCEL_CHANNEL);
assert!(session.is_cancelled(2));
let mut late_desc = MsgDescHot::new();
late_desc.channel_id = 2;
late_desc.flags = FrameFlags::DATA;
let late = Frame::with_inline_payload(late_desc, &[1, 2, 3]).unwrap();
b.send_frame(late).await.unwrap();
b.send_frame(ping_frame()).await.unwrap();
let got = session.recv_frame().await.unwrap();
assert_eq!(got.desc.channel_id, 0);
assert_eq!(got.desc.method_id, control_method::PING);
}
}