use crate::channel::ChannelOptions;
use crate::message::OutgoingMessage;
use crate::metrics::{HandlerMetrics, ServerMetrics};
use crate::server_side_channel::ServerSideChannel;
use crate::server_side_handlers::{
Action, Assigner, CallHandlerFactory, CastHandlerFactory, HandleCall, HandleCast,
MessageHandlers,
};
use crate::{Call, Cast, Error, ProcedureId};
use bytecodec::marker::Never;
use factory::{DefaultFactory, Factory};
use fibers::net::futures::{Connected, TcpListenerBind};
use fibers::net::streams::Incoming;
use fibers::net::TcpListener;
use fibers::sync::mpsc;
use fibers::{self, BoxSpawn, Spawn};
use futures::future::{loop_fn, Either, Loop};
use futures::{self, Async, Future, Poll, Stream};
use prometrics::metrics::MetricBuilder;
use slog::{Discard, Logger};
use std::collections::HashMap;
use std::mem;
use std::net::SocketAddr;
#[derive(Debug)]
pub struct ServerBuilder {
bind_addr: SocketAddr,
logger: Logger,
handlers: MessageHandlers,
channel_options: ChannelOptions,
metrics: MetricBuilder,
handlers_metrics: HashMap<ProcedureId, HandlerMetrics>,
}
impl ServerBuilder {
pub fn new(bind_addr: SocketAddr) -> Self {
ServerBuilder {
bind_addr,
logger: Logger::root(Discard, o!()),
handlers: MessageHandlers(HashMap::new()),
channel_options: ChannelOptions::default(),
metrics: MetricBuilder::new(),
handlers_metrics: HashMap::new(),
}
}
pub fn logger(&mut self, logger: Logger) -> &mut Self {
self.logger = logger;
self
}
pub fn channel_options(&mut self, options: ChannelOptions) -> &mut Self {
self.channel_options = options;
self
}
pub fn metrics(&mut self, builder: MetricBuilder) -> &mut Self {
self.metrics = builder;
self
}
pub fn add_call_handler<T, H>(&mut self, handler: H) -> &mut Self
where
T: Call,
H: HandleCall<T>,
T::ReqDecoder: Default,
T::ResEncoder: Default,
{
self.add_call_handler_with_codec(handler, DefaultFactory::new(), DefaultFactory::new())
}
pub fn add_call_handler_with_decoder<T, H, D>(
&mut self,
handler: H,
decoder_factory: D,
) -> &mut Self
where
T: Call,
H: HandleCall<T>,
D: Factory<Item = T::ReqDecoder> + Send + Sync + 'static,
T::ResEncoder: Default,
{
self.add_call_handler_with_codec(handler, decoder_factory, DefaultFactory::new())
}
pub fn add_call_handler_with_encoder<T, H, E>(
&mut self,
handler: H,
encoder_factory: E,
) -> &mut Self
where
T: Call,
H: HandleCall<T>,
E: Factory<Item = T::ResEncoder> + Send + Sync + 'static,
T::ReqDecoder: Default,
{
self.add_call_handler_with_codec(handler, DefaultFactory::new(), encoder_factory)
}
pub fn add_call_handler_with_codec<T, H, D, E>(
&mut self,
handler: H,
decoder_factory: D,
encoder_factory: E,
) -> &mut Self
where
T: Call,
H: HandleCall<T>,
D: Factory<Item = T::ReqDecoder> + Send + Sync + 'static,
E: Factory<Item = T::ResEncoder> + Send + Sync + 'static,
{
assert!(
!self.handlers.0.contains_key(&T::ID),
"RPC registration conflicts: procedure={:?}, name={:?}",
T::ID,
T::NAME
);
let metrics = HandlerMetrics::new(self.metrics.clone(), T::ID, T::NAME, "call");
self.handlers_metrics.insert(T::ID, metrics.clone());
let handler = CallHandlerFactory::new(handler, decoder_factory, encoder_factory, metrics);
self.handlers.0.insert(T::ID, Box::new(handler));
self
}
pub fn add_cast_handler<T, H>(&mut self, handler: H) -> &mut Self
where
T: Cast,
H: HandleCast<T>,
T::Decoder: Default,
{
self.add_cast_handler_with_decoder(handler, DefaultFactory::new())
}
pub fn add_cast_handler_with_decoder<T, H, D>(
&mut self,
handler: H,
decoder_factory: D,
) -> &mut Self
where
T: Cast,
H: HandleCast<T>,
D: Factory<Item = T::Decoder> + Send + Sync + 'static,
{
assert!(
!self.handlers.0.contains_key(&T::ID),
"RPC registration conflicts: procedure={:?}, name={:?}",
T::ID,
T::NAME
);
let metrics = HandlerMetrics::new(self.metrics.clone(), T::ID, T::NAME, "cast");
self.handlers_metrics.insert(T::ID, metrics.clone());
let handler = CastHandlerFactory::new(handler, decoder_factory, metrics);
self.handlers.0.insert(T::ID, Box::new(handler));
self
}
pub fn finish<S>(mut self, spawner: S) -> Server<S>
where
S: Clone + Spawn + Send + 'static,
{
let logger = self.logger.new(o!("server" => self.bind_addr.to_string()));
info!(logger, "Starts RPC server");
let handlers = mem::replace(&mut self.handlers, MessageHandlers(HashMap::new()));
Server {
listener: Listener::Binding(TcpListener::bind(self.bind_addr)),
logger,
spawner,
assigner: Assigner::new(handlers),
channel_options: self.channel_options.clone(),
metrics: ServerMetrics::new(self.metrics.clone(), self.handlers_metrics.clone()),
}
}
}
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
pub struct Server<S> {
listener: Listener,
logger: Logger,
spawner: S,
assigner: Assigner,
channel_options: ChannelOptions,
metrics: ServerMetrics,
}
impl<S> Server<S> {
pub fn local_addr(self) -> impl Future<Item = (Self, SocketAddr), Error = Error> {
match self.listener {
Listener::Listening(_, addr) => Either::A(futures::finished((self, addr))),
Listener::Binding(_) => {
let future = loop_fn(self, |mut this| {
if fibers::fiber::with_current_context(|_| ()).is_none() {
return Ok(Loop::Continue(this));
}
track!(this.listener.poll())?;
if let Listener::Listening(_, addr) = this.listener {
Ok(Loop::Break((this, addr)))
} else {
Ok(Loop::Continue(this))
}
});
Either::B(future)
}
}
}
pub fn poll_local_addr(&mut self) -> Poll<SocketAddr, Error> {
match self.listener {
Listener::Listening(_, addr) => Ok(Async::Ready(addr)),
Listener::Binding(_) => {
track!(self.listener.poll())?;
if let Listener::Listening(_, addr) = self.listener {
Ok(Async::Ready(addr))
} else {
Ok(Async::NotReady)
}
}
}
}
pub fn metrics(&self) -> &ServerMetrics {
&self.metrics
}
}
impl<S> Future for Server<S>
where
S: Clone + Spawn + Send + 'static,
{
type Item = ();
type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
while let Async::Ready(item) = track!(self.listener.poll())? {
if let Some((client, addr)) = item {
let logger = self.logger.new(o!("client" => addr.to_string()));
info!(logger, "New TCP client");
let options = self.channel_options.clone();
let metrics = self.metrics.channels().create_channel_metrics(addr);
let channels = self.metrics.channels().clone();
let exit_logger = logger.clone();
let spawner = self.spawner.clone().boxed();
let assigner = self.assigner.clone();
let future = client
.map_err(|e| track!(Error::from(e)))
.and_then(move |stream| {
let channel = ServerSideChannel::new(
logger,
stream,
assigner.clone(),
options,
metrics,
);
ChannelHandler::new(spawner, channel)
});
self.spawner.spawn(future.then(move |result| {
channels.remove_channel_metrics(addr);
if let Err(e) = result {
error!(exit_logger, "TCP connection aborted: {}", e);
} else {
info!(exit_logger, "TCP connection was closed");
}
Ok(())
}));
} else {
info!(self.logger, "RPC server stopped");
return Ok(Async::Ready(()));
}
}
Ok(Async::NotReady)
}
}
struct ChannelHandler {
spawner: BoxSpawn,
channel: ServerSideChannel,
reply_tx: mpsc::Sender<OutgoingMessage>,
reply_rx: mpsc::Receiver<OutgoingMessage>,
}
impl ChannelHandler {
fn new(spawner: BoxSpawn, channel: ServerSideChannel) -> Self {
let (reply_tx, reply_rx) = mpsc::channel();
ChannelHandler {
spawner,
channel,
reply_tx,
reply_rx,
}
}
}
impl Future for ChannelHandler {
type Item = ();
type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
while let Async::Ready(action) = track!(self.channel.poll())? {
if let Some(action) = action {
match action {
Action::NoReply(noreply) => {
if let Some(future) = noreply.into_future() {
self.spawner.spawn(future.map_err(|_: Never| ()));
}
}
Action::Reply(mut reply) => {
if let Some(message) = reply.try_take() {
self.channel.reply(message);
} else {
let reply_tx = self.reply_tx.clone();
let future = reply.map(move |message| {
let _ = reply_tx.send(message);
});
self.spawner.spawn(future.map_err(|_: Never| ()));
}
}
}
} else {
return Ok(Async::Ready(()));
}
}
let mut do_break = true;
while let Async::Ready(item) = self.reply_rx.poll().expect("Never fails") {
let message = item.expect("Never fails");
self.channel.reply(message);
do_break = false;
}
if do_break {
break;
}
}
Ok(Async::NotReady)
}
}
#[derive(Debug)]
enum Listener {
Binding(TcpListenerBind),
Listening(Incoming, SocketAddr),
}
impl Stream for Listener {
type Item = (Connected, SocketAddr);
type Error = Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop {
let next = match self {
Listener::Binding(f) => {
if let Async::Ready(listener) = track!(f.poll().map_err(Error::from))? {
let addr = track!(listener.local_addr().map_err(Error::from))?;
Listener::Listening(listener.incoming(), addr)
} else {
break;
}
}
Listener::Listening(s, _) => return track!(s.poll().map_err(Error::from)),
};
*self = next;
}
Ok(Async::NotReady)
}
}