use crate::frame::{FrameDecodeError, FrameEncodeError};
use crate::rpc::rpc_internals::{
RpcHeader, RpcSession, RpcStreamEncoder, RpcStreamEvent,
rpc_trait::{RpcEmit, RpcResponseHandler},
};
use std::collections::HashMap;
impl<'a> Default for RpcRespondableSession<'a> {
fn default() -> Self {
Self::new()
}
}
pub struct RpcRespondableSession<'a> {
rpc_session: RpcSession,
pub(crate) response_handlers: HashMap<u32, Box<dyn FnMut(RpcStreamEvent) + Send + 'a>>,
catch_all_response_handler: Option<Box<dyn FnMut(RpcStreamEvent) + Send + 'a>>,
prebuffered_responses: HashMap<u32, Vec<u8>>, prebuffering_flags: HashMap<u32, bool>, }
impl<'a> RpcRespondableSession<'a> {
pub fn new() -> Self {
Self {
rpc_session: RpcSession::new(),
response_handlers: HashMap::new(),
catch_all_response_handler: None,
prebuffered_responses: HashMap::new(),
prebuffering_flags: HashMap::new(),
}
}
pub fn init_respondable_request<E, R>(
&mut self,
hdr: RpcHeader,
max_chunk_size: usize,
on_emit: E,
on_response: Option<R>,
prebuffer_response: bool,
) -> Result<RpcStreamEncoder<E>, FrameEncodeError>
where
E: RpcEmit,
R: RpcResponseHandler + 'a,
{
let rpc_request_id = hdr.rpc_request_id;
self.prebuffering_flags
.insert(rpc_request_id, prebuffer_response);
if let Some(on_response) = on_response {
self.response_handlers
.insert(rpc_request_id, Box::new(on_response));
}
self.rpc_session
.init_request(hdr, max_chunk_size, on_emit)
.map_err(|_| FrameEncodeError::CorruptFrame)
}
pub fn start_reply_stream<E>(
&mut self,
hdr: RpcHeader,
max_chunk_size: usize,
on_emit: E,
) -> Result<RpcStreamEncoder<E>, FrameEncodeError>
where
E: RpcEmit,
{
self.rpc_session
.init_request(hdr, max_chunk_size, on_emit)
.map_err(|_| FrameEncodeError::CorruptFrame)
}
pub fn set_catch_all_response_handler<R>(&mut self, handler: R)
where
R: RpcResponseHandler + 'a,
{
self.catch_all_response_handler = Some(Box::new(handler));
}
pub fn read_bytes(&mut self, bytes: &[u8]) -> Result<(), FrameDecodeError> {
self.rpc_session.read_bytes(bytes, |evt| {
let id = match &evt {
RpcStreamEvent::Header { rpc_request_id, .. } => Some(*rpc_request_id),
RpcStreamEvent::PayloadChunk { rpc_request_id, .. } => Some(*rpc_request_id),
RpcStreamEvent::End { rpc_request_id, .. } => Some(*rpc_request_id),
RpcStreamEvent::Error { rpc_request_id, .. } => *rpc_request_id,
};
let method_id = match &evt {
RpcStreamEvent::Header { rpc_method_id, .. } => Some(*rpc_method_id),
RpcStreamEvent::PayloadChunk { rpc_method_id, .. } => Some(*rpc_method_id),
RpcStreamEvent::End { rpc_method_id, .. } => Some(*rpc_method_id),
RpcStreamEvent::Error { rpc_method_id, .. } => *rpc_method_id,
};
let mut handled = false;
if let Some(rpc_id) = id {
let is_prebuffering_response =
self.prebuffering_flags.get(&rpc_id).unwrap_or(&false);
if *is_prebuffering_response {
let buffer = self.prebuffered_responses.entry(rpc_id).or_default();
match &evt {
RpcStreamEvent::Header { .. } => {
if let Some(cb) = self.response_handlers.get_mut(&rpc_id) {
cb(evt.clone());
}
}
RpcStreamEvent::PayloadChunk { bytes, .. } => {
buffer.extend_from_slice(bytes);
}
RpcStreamEvent::End { rpc_header, .. } => {
if let Some(cb) = self.response_handlers.get_mut(&rpc_id) {
let rpc_method_id =
method_id.ok_or(FrameDecodeError::CorruptFrame)?;
let rpc_payload_event = RpcStreamEvent::PayloadChunk {
rpc_request_id: rpc_id,
rpc_method_id,
bytes: buffer.clone(),
rpc_header: rpc_header.clone(),
};
cb(rpc_payload_event);
cb(evt.clone());
self.prebuffered_responses.remove(&rpc_id); }
}
_ => {
tracing::error!("Unknown `RpcStreamEvent`");
}
};
} else if let Some(cb) = self.response_handlers.get_mut(&rpc_id) {
cb(evt.clone());
handled = true;
}
if matches!(
evt,
RpcStreamEvent::End { .. } | RpcStreamEvent::Error { .. }
) {
self.response_handlers.remove(&rpc_id);
}
}
if !handled && let Some(cb) = self.catch_all_response_handler.as_mut() {
cb(evt);
}
Ok(())
})?;
Ok(())
}
pub fn get_remaining_response_handlers(&self) -> usize {
self.response_handlers.len()
}
}