use std::fmt;
use std::sync::{atomic, Arc};
use crate::core::futures::sync::mpsc;
use crate::core::{self, futures};
use crate::server_utils::{session, tokio::runtime::TaskExecutor};
use crate::ws;
use crate::error;
use crate::Origin;
#[derive(Clone)]
pub struct Sender {
out: ws::Sender,
active: Arc<atomic::AtomicBool>,
}
impl Sender {
pub fn new(out: ws::Sender, active: Arc<atomic::AtomicBool>) -> Self {
Sender { out, active }
}
fn check_active(&self) -> error::Result<()> {
if self.active.load(atomic::Ordering::SeqCst) {
Ok(())
} else {
Err(error::Error::ConnectionClosed)
}
}
pub fn send<M>(&self, msg: M) -> error::Result<()>
where
M: Into<ws::Message>,
{
self.check_active()?;
self.out.send(msg)?;
Ok(())
}
pub fn broadcast<M>(&self, msg: M) -> error::Result<()>
where
M: Into<ws::Message>,
{
self.check_active()?;
self.out.broadcast(msg)?;
Ok(())
}
pub fn close(&self, code: ws::CloseCode) -> error::Result<()> {
self.check_active()?;
self.out.close(code)?;
Ok(())
}
}
pub struct RequestContext {
pub session_id: session::SessionId,
pub origin: Option<Origin>,
pub protocols: Vec<String>,
pub out: Sender,
pub executor: TaskExecutor,
}
impl RequestContext {
pub fn sender(&self) -> mpsc::Sender<String> {
let out = self.out.clone();
let (sender, receiver) = mpsc::channel(1);
self.executor.spawn(SenderFuture(out, receiver));
sender
}
}
impl fmt::Debug for RequestContext {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("RequestContext")
.field("session_id", &self.session_id)
.field("origin", &self.origin)
.field("protocols", &self.protocols)
.finish()
}
}
pub trait MetaExtractor<M: core::Metadata>: Send + Sync + 'static {
fn extract(&self, _context: &RequestContext) -> M;
}
impl<M, F> MetaExtractor<M> for F
where
M: core::Metadata,
F: Fn(&RequestContext) -> M + Send + Sync + 'static,
{
fn extract(&self, context: &RequestContext) -> M {
(*self)(context)
}
}
#[derive(Debug, Clone)]
pub struct NoopExtractor;
impl<M: core::Metadata + Default> MetaExtractor<M> for NoopExtractor {
fn extract(&self, _context: &RequestContext) -> M {
M::default()
}
}
struct SenderFuture(Sender, mpsc::Receiver<String>);
impl futures::Future for SenderFuture {
type Item = ();
type Error = ();
fn poll(&mut self) -> futures::Poll<Self::Item, Self::Error> {
use self::futures::Stream;
loop {
let item = self.1.poll()?;
match item {
futures::Async::NotReady => {
return Ok(futures::Async::NotReady);
}
futures::Async::Ready(None) => {
return Ok(futures::Async::Ready(()));
}
futures::Async::Ready(Some(val)) => {
if let Err(e) = self.0.send(val) {
warn!("Error sending a subscription update: {:?}", e);
return Ok(futures::Async::Ready(()));
}
}
}
}
}
}