use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::{atomic, Arc};
use std::task::{Context, Poll};
use crate::core;
use crate::core::futures::channel::mpsc;
use crate::server_utils::{reactor::TaskExecutor, session};
use crate::ws;
use crate::error;
use crate::Origin;
#[derive(Clone)]
pub struct Sender {
out: ws::Sender,
active: Arc<atomic::AtomicBool>,
}
impl Sender {
pub fn new(out: ws::Sender, active: Arc<atomic::AtomicBool>) -> Self {
Sender { out, active }
}
fn check_active(&self) -> error::Result<()> {
if self.active.load(atomic::Ordering::SeqCst) {
Ok(())
} else {
Err(error::Error::ConnectionClosed)
}
}
pub fn send<M>(&self, msg: M) -> error::Result<()>
where
M: Into<ws::Message>,
{
self.check_active()?;
self.out.send(msg)?;
Ok(())
}
pub fn broadcast<M>(&self, msg: M) -> error::Result<()>
where
M: Into<ws::Message>,
{
self.check_active()?;
self.out.broadcast(msg)?;
Ok(())
}
pub fn close(&self, code: ws::CloseCode) -> error::Result<()> {
self.check_active()?;
self.out.close(code)?;
Ok(())
}
}
pub struct RequestContext {
pub session_id: session::SessionId,
pub origin: Option<Origin>,
pub protocols: Vec<String>,
pub out: Sender,
pub executor: TaskExecutor,
}
impl RequestContext {
pub fn sender(&self) -> mpsc::UnboundedSender<String> {
let out = self.out.clone();
let (sender, receiver) = mpsc::unbounded();
self.executor.spawn(SenderFuture(out, Box::new(receiver)));
sender
}
}
impl fmt::Debug for RequestContext {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("RequestContext")
.field("session_id", &self.session_id)
.field("origin", &self.origin)
.field("protocols", &self.protocols)
.finish()
}
}
pub trait MetaExtractor<M: core::Metadata>: Send + Sync + 'static {
fn extract(&self, _context: &RequestContext) -> M;
}
impl<M, F> MetaExtractor<M> for F
where
M: core::Metadata,
F: Fn(&RequestContext) -> M + Send + Sync + 'static,
{
fn extract(&self, context: &RequestContext) -> M {
(*self)(context)
}
}
#[derive(Debug, Clone)]
pub struct NoopExtractor;
impl<M: core::Metadata + Default> MetaExtractor<M> for NoopExtractor {
fn extract(&self, _context: &RequestContext) -> M {
M::default()
}
}
struct SenderFuture(Sender, Box<dyn futures::Stream<Item = String> + Send + Unpin>);
impl Future for SenderFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
use futures::Stream;
let this = Pin::into_inner(self);
loop {
match Pin::new(&mut this.1).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(()),
Poll::Ready(Some(val)) => {
if let Err(e) = this.0.send(val) {
warn!("Error sending a subscription update: {:?}", e);
return Poll::Ready(());
}
}
}
}
}
}