#![allow(clippy::map_entry)]
use std::collections::{BTreeMap, HashSet};
use std::fmt::Debug;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Instant;
use actix::prelude::io::WriteHandler;
use actix::prelude::*;
use futures::channel::oneshot;
use futures::future::LocalBoxFuture;
use futures::prelude::*;
use futures::FutureExt;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::{FramedRead, FramedWrite};
use ya_sb_proto::codec::{GsbMessage, GsbMessageDecoder, GsbMessageEncoder, ProtocolError};
use ya_sb_proto::*;
use ya_sb_util::writer;
use ya_sb_util::writer::EmptyBufferHandler;
use crate::connection::reader::InputHandler;
use crate::router::{IdBytes, InstanceConfig, RouterRef};
mod reader;
pub type StreamWriter<Output> = FramedWrite<Output, GsbMessageEncoder>;
#[derive(Message)]
#[rtype("()")]
pub struct DropConnection;
#[derive(Message)]
#[rtype("Result<(), oneshot::Canceled>")]
pub struct ForwardCallResponse {
call_reply: CallReply,
}
#[derive(Message)]
#[rtype("Result<(), oneshot::Canceled>")]
pub struct ForwardCallRequest {
call_request: CallRequest,
reply_to: Recipient<ForwardCallResponse>,
}
pub struct Connection<
W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
ConnInfo: Debug + Unpin + 'static,
> {
config: Arc<InstanceConfig>,
instance_id: Option<IdBytes>,
router: RouterRef<W, ConnInfo>,
services: HashSet<String>,
output: writer::SinkWrite<GsbMessage, W>,
reply_map: BTreeMap<String, Recipient<ForwardCallResponse>>,
hold_queue: Vec<(GsbMessage, oneshot::Sender<()>)>,
topic_map: BTreeMap<String, SpawnHandle>,
conn_info: ConnInfo,
last_packet: Instant,
}
impl<
W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
ConnInfo: Debug + Unpin + 'static,
> Actor for Connection<W, ConnInfo>
{
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
log::debug!("[{:?}] connection started", self.conn_info);
let _ = ctx.run_interval(self.config.ping_interval(), move |act, ctx| {
let since_last = Instant::now().duration_since(act.last_packet);
if since_last > act.config.ping_interval() / 2 {
if since_last > act.config.ping_timeout() {
log::warn!(
"[{:?}] no data for {:?} killing connection",
act.conn_info,
since_last
);
ctx.stop();
return;
}
log::debug!(
"[{:?}] no data for: {:?}, sending ping (buffer={})",
act.conn_info,
since_last,
act.output.buffer_len()
);
act.output.write(GsbMessage::Ping(Default::default()));
}
let dead_replies: Vec<_> = act
.reply_map
.iter()
.filter_map(|(request_id, replay_addr)| {
if replay_addr.connected() {
None
} else {
Some(request_id.clone())
}
})
.collect();
for request_id in dead_replies {
let _ = act.reply_map.remove(&request_id);
log::debug!(
"[{:?}] removing dead reply map for {}",
act.conn_info,
request_id
);
}
});
}
fn stopped(&mut self, ctx: &mut Self::Context) {
self.cleanup(ctx);
}
}
impl<W, ConnInfo> Connection<W, ConnInfo>
where
W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
ConnInfo: Debug + Unpin + 'static,
{
fn cleanup(&mut self, ctx: &mut <Self as Actor>::Context) {
if let Some(instance_id) = self.instance_id.take() {
log::trace!("[{:?}] cleanup connection", self.conn_info);
let addr = ctx.address();
let mut router = self.router.write();
for service_id in self.services.drain() {
router.unregister_service(&service_id, &addr);
}
router.remove_connection(instance_id, &addr);
}
}
fn send_reply(&mut self, reply: impl Into<GsbMessage>, _ctx: &mut <Self as Actor>::Context) {
self.output.write(reply.into());
log::trace!(
"[{:?}] reply queued. size={}",
self.conn_info,
self.output.buffer_len()
)
}
fn handle_call_request(
&mut self,
call_request: CallRequest,
ctx: &mut <Self as Actor>::Context,
) -> impl Future<Output = Result<(), CallReply>> + 'static {
let request_id = call_request.request_id.clone();
if let Some(dst) = { self.router.read().resolve_node(&call_request.address) } {
let reply_to = ctx.address().recipient();
let msg = ForwardCallRequest {
call_request,
reply_to,
};
dst.send(msg)
.timeout(self.config.forward_timeout())
.map(move |r| {
let error = match r {
Ok(Ok(())) => None,
Ok(Err(_)) => Some("request canceled"),
Err(MailboxError::Closed) => Some("lost connection"),
Err(MailboxError::Timeout) => Some("stalled connection"),
};
if let Some(err_msg) = error {
let mut reply = CallReply {
request_id,
data: err_msg.as_bytes().to_vec(),
..Default::default()
};
reply.set_code(CallReplyCode::ServiceFailure);
reply.set_reply_type(CallReplyType::Full);
Err(reply)
} else {
Ok(())
}
})
.left_future()
} else {
let mut reply = CallReply {
request_id,
..Default::default()
};
reply.set_code(CallReplyCode::CallReplyBadRequest);
reply.set_reply_type(CallReplyType::Full);
reply.data = "endpoint address not found".as_bytes().to_vec();
future::err(reply).right_future()
}
}
fn handle_push_request(
&mut self,
call_request: CallRequest,
ctx: &mut <Self as Actor>::Context,
) -> impl Future<Output = ()> + 'static {
match { self.router.read().resolve_node(&call_request.address) } {
Some(dst) => {
let reply_to = ctx.address().recipient();
let msg = ForwardCallRequest {
call_request,
reply_to,
};
dst.send(msg).then(|_| future::ready(())).left_future()
}
None => future::ready(()).right_future(),
}
}
fn handle_call_reply(
&mut self,
call_reply: CallReply,
_ctx: &mut <Self as Actor>::Context,
) -> impl Future<Output = Result<(), String>> + 'static {
if let Some(dst) = match call_reply.reply_type() {
CallReplyType::Full => self.reply_map.remove(&call_reply.request_id),
CallReplyType::Partial => self.reply_map.get(&call_reply.request_id).cloned(),
} {
let request_id = call_reply.request_id.clone();
dst.send(ForwardCallResponse { call_reply })
.map(move |v| match v {
Ok(Ok(())) => Ok(()),
Ok(Err(_)) => Err(format!("unable to send reply {}, canceled", request_id)),
Err(e) => Err(format!("unable to send reply {:?},", e)),
})
.left_future()
} else {
log::debug!("received unmatched reply {}", call_reply.request_id);
future::ok(()).right_future()
}
}
fn send_message(
&mut self,
msg: GsbMessage,
_ctx: &mut <Self as Actor>::Context,
) -> LocalBoxFuture<'static, Result<(), oneshot::Canceled>> {
if self.output.buffer_len() < self.config.high_buffer_mark() && self.hold_queue.is_empty() {
self.output.write(msg);
log::trace!("[{:?}] buffer {}", self.conn_info, self.output.buffer_len());
Box::pin(future::ok(()))
} else {
let (tx, rx) = oneshot::channel();
self.hold_queue.push((msg, tx));
log::trace!("[{:?}] queue {}", self.conn_info, self.hold_queue.len());
rx.boxed_local()
}
}
}
pub fn connection<
Input: AsyncRead + 'static,
Output: AsyncWrite + Unpin,
ConnInfo: Debug + Unpin + 'static,
>(
config: Arc<InstanceConfig>,
router: RouterRef<StreamWriter<Output>, ConnInfo>,
conn_info: ConnInfo,
input: Input,
output: Output,
) -> Addr<Connection<StreamWriter<Output>, ConnInfo>> {
let reader = FramedRead::new(input, GsbMessageDecoder::default());
let writer = FramedWrite::new(output, GsbMessageEncoder::default());
Connection::create(move |ctx| {
let output = writer::SinkWrite::new(writer, ctx);
let _ = Connection::add_stream(reader, ctx);
Connection {
instance_id: None,
router,
config,
services: Default::default(),
hold_queue: Default::default(),
reply_map: Default::default(),
topic_map: Default::default(),
conn_info,
output,
last_packet: Instant::now(),
}
})
}
impl<
W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
ConnInfo: Debug + Unpin + 'static,
> InputHandler<Result<GsbMessage, ProtocolError>> for Connection<W, ConnInfo>
{
fn handle(
&mut self,
item: Result<GsbMessage, ProtocolError>,
ctx: &mut Context<Self>,
) -> Pin<Box<dyn ActorFuture<Self, Output = ()>>> {
self.last_packet = Instant::now();
let msg = match item {
Err(ProtocolError::Io(e)) if e.kind() == std::io::ErrorKind::ConnectionReset => {
log::debug!("[{:?}] connection closed", self.conn_info);
ctx.stop();
return Box::pin(fut::ready(()));
}
Err(e) => {
log::error!("[{:?}] protocol error {:?}", self.conn_info, e);
ctx.stop();
return Box::pin(fut::ready(()));
}
Ok(msg) => msg,
};
match msg {
GsbMessage::CallRequest(call_request) => {
if call_request.no_reply {
return Box::pin(self.handle_push_request(call_request, ctx).into_actor(self));
}
return Box::pin(
self.handle_call_request(call_request, ctx)
.into_actor(self)
.then(|r, act, ctx| {
if let Err(error_reply) = r {
act.send_reply(error_reply, ctx);
}
fut::ready(())
}),
);
}
GsbMessage::CallReply(call_reply) => {
return Box::pin(
self.handle_call_reply(call_reply, ctx)
.into_actor(self)
.then(|r, act, _ctx| {
if let Err(msg) = r {
log::warn!("[{:?}] {}", act.conn_info, msg);
}
fut::ready(())
}),
)
}
GsbMessage::RegisterRequest(register_request) => {
let me = ctx.address();
let service_id = register_request.service_id;
let registered = { self.router.write().register_service(service_id.clone(), me) };
let mut reply = RegisterReply::default();
if registered {
self.services.insert(service_id);
} else {
reply.set_code(RegisterReplyCode::RegisterConflict);
}
self.send_reply(reply, ctx);
}
GsbMessage::UnregisterRequest(unregister_request) => {
let me = ctx.address();
let service_id = unregister_request.service_id;
let unregistered = { self.router.write().unregister_service(&service_id, &me) };
let mut reply = UnregisterReply::default();
if unregistered {
self.services.remove(&service_id);
} else {
reply.set_code(UnregisterReplyCode::NotRegistered);
}
self.send_reply(reply, ctx);
}
GsbMessage::SubscribeRequest(subscribe_request) => {
let topic_id = subscribe_request.topic;
let mut reply = SubscribeReply::default();
if self.topic_map.contains_key(&topic_id) {
reply.set_code(SubscribeReplyCode::SubscribeBadRequest);
reply.message = "topic already registered".to_string();
} else {
let rx = self.router.write().subscribe_topic(topic_id.clone());
let handle = ctx.spawn(fut::wrap_stream(rx).fold(
(),
|_, request, act: &mut Self, ctx| {
log::trace!("[{:?}] broadcast new item", act.conn_info);
match request {
Ok(broadcast_request) => act.send_message(
GsbMessage::BroadcastRequest(broadcast_request),
ctx,
),
Err(e) => {
log::debug!(
"[{:?}] failed to recv broadcast: {:?}",
act.conn_info,
e
);
Box::pin(future::ok(()))
}
}
.into_actor(act)
.then(|r, act, _ctx| {
if r.is_err() {
log::warn!("[{:?}] broadcast forward dropped", act.conn_info);
} else {
log::trace!("[{:?}] broadcast forwarded", act.conn_info);
}
fut::ready(())
})
},
));
self.topic_map.insert(topic_id, handle);
}
let _ = self.send_reply(GsbMessage::SubscribeReply(SubscribeReply::default()), ctx);
}
GsbMessage::UnsubscribeRequest(unsubscribe_request) => {
let mut reply = UnsubscribeReply::default();
log::debug!(
"[{:?}] unsubscribe {}",
self.conn_info,
unsubscribe_request.topic
);
if let Some(handle) = self.topic_map.remove(&unsubscribe_request.topic) {
ctx.cancel_future(handle);
} else {
reply.set_code(UnsubscribeReplyCode::NotSubscribed);
}
self.send_reply(GsbMessage::UnsubscribeReply(reply), ctx);
}
GsbMessage::BroadcastRequest(broadcast_request) => {
let reply = BroadcastReply::default();
if let Some(sender) = { self.router.read().find_topic(&broadcast_request.topic) } {
log::debug!(
"[{:?}] sending bcast to {} receivers",
self.conn_info,
sender.receiver_count()
);
let _ = sender.send(broadcast_request);
}
self.send_reply(GsbMessage::BroadcastReply(reply), ctx);
}
GsbMessage::Hello(hello_request) => {
if self.instance_id.is_some() {
log::error!("[{:?}] duplicate hello send", self.conn_info);
ctx.stop();
} else {
let instance_id: IdBytes = hello_request.instance_id.into();
self.instance_id = Some(instance_id.clone());
log::debug!(
"[{:?}] connection initialized peer {}/{}",
self.conn_info,
hello_request.name,
hello_request.version
);
return Box::pin(
self.router
.write()
.new_connection(instance_id, ctx.address())
.into_actor(self),
);
}
}
GsbMessage::Ping(_) => {
self.send_reply(GsbMessage::Ping(Default::default()), ctx);
}
GsbMessage::Pong(_) => {
log::trace!("[{:?}] pong recv", self.conn_info);
}
m => {
log::error!("[{:?}] unexpected gsb message: {:?}", self.conn_info, m);
ctx.stop();
}
}
Box::pin(fut::ready(()))
}
fn started(&mut self, _ctx: &mut Self::Context) {
let hello = self.config.hello();
let _ = self.output.write(GsbMessage::Hello(hello));
}
}
impl<
W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
ConnInfo: Debug + Unpin + 'static,
> Handler<DropConnection> for Connection<W, ConnInfo>
{
type Result = ();
fn handle(&mut self, _: DropConnection, ctx: &mut Self::Context) -> Self::Result {
log::debug!("[{:?}] forced connection drop", self.conn_info);
self.cleanup(ctx);
ctx.stop();
}
}
impl<
W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
ConnInfo: Debug + Unpin + 'static,
> WriteHandler<ProtocolError> for Connection<W, ConnInfo>
{
}
impl<
W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
ConnInfo: Debug + Unpin + 'static,
> EmptyBufferHandler for Connection<W, ConnInfo>
{
fn buffer_empty(&mut self, _ctx: &mut Self::Context) {
if self.hold_queue.is_empty() {
return;
}
log::trace!("[{:?}] empty buffer", self.conn_info);
for (msg, tx) in self
.hold_queue
.drain(..)
.filter(|(_msg, tx)| !tx.is_canceled())
.take(self.config.high_buffer_mark())
{
self.output.write(msg);
if tx.send(()).is_err() {
log::error!("[{:?}] failed to notify sender", self.conn_info);
}
}
log::trace!(
"[{:?}] on empty buffer, filled {}",
self.conn_info,
self.output.buffer_len()
);
}
}
impl<S, ConnInfo> Handler<ForwardCallResponse> for Connection<S, ConnInfo>
where
S: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
ConnInfo: Debug + Unpin + 'static,
{
type Result = ResponseFuture<Result<(), oneshot::Canceled>>;
fn handle(&mut self, msg: ForwardCallResponse, ctx: &mut Self::Context) -> Self::Result {
self.send_message(GsbMessage::CallReply(msg.call_reply), ctx)
}
}
impl<S, ConnInfo> Handler<ForwardCallRequest> for Connection<S, ConnInfo>
where
S: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
ConnInfo: Debug + Unpin + 'static,
{
type Result = ResponseFuture<Result<(), oneshot::Canceled>>;
fn handle(&mut self, msg: ForwardCallRequest, ctx: &mut Self::Context) -> Self::Result {
if !msg.call_request.no_reply {
if self
.reply_map
.insert(msg.call_request.request_id.clone(), msg.reply_to)
.is_some()
{
log::warn!(
"[{:?}] duplicate message request id forwarded {}",
self.conn_info,
msg.call_request.request_id
);
}
}
self.send_message(GsbMessage::CallRequest(msg.call_request), ctx)
}
}