use {
super::{
super::{
types::{Ref, RefMut, State},
utils::parse_header,
},
Accessor, OneshotSender,
},
hyper::{
header::HeaderMap,
http::{StatusCode, Uri},
},
std::{collections::HashMap, fmt::Debug, io::Error as IoError, ops::Deref, sync::Arc},
tokio::sync::watch::Sender as WatchSender,
tracing::{error, info},
};
pub(super) struct WsConnGuard<S = ()>
where
S: Debug,
{
pub close_tx: Option<(OneshotSender<HttpAccessor<S>>, WsAccessor<S>)>,
}
impl<S> WsConnGuard<S>
where
S: Debug,
{
pub(super) fn new() -> Self {
Self { close_tx: None }
}
}
impl<S> Drop for WsConnGuard<S>
where
S: Debug,
{
fn drop(&mut self) {
if let Some((close_tx, accessor)) = self.close_tx.take() {
info!("WebSocket connection closed: {}", accessor.as_uri());
if let Err(e) = close_tx.send(accessor) {
error!(?e, "Failed to send close signal.");
}
}
}
}
#[derive(Debug)]
pub struct BaseConn<S = ()> {
state: State<S>,
uri: Arc<Uri>,
}
impl<S> Clone for BaseConn<S> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
uri: self.uri.clone(),
}
}
}
impl<S> BaseConn<S> {
pub async fn borrow(&self) -> Ref<'_, S> {
Ref::new(self.state.read().await)
}
pub async fn borrow_mut(&self) -> RefMut<'_, S> {
RefMut::new(self.state.write().await)
}
}
impl<S> From<(Uri, State<S>)> for BaseConn<S> {
fn from((uri, state): (Uri, State<S>)) -> Self {
Self {
uri: uri.into(),
state,
}
}
}
impl<S> Deref for BaseConn<S> {
type Target = Uri;
fn deref(&self) -> &Self::Target {
&self.uri
}
}
#[derive(Debug)]
pub struct ResponseConn<S = ()> {
base: BaseConn<S>,
status: StatusCode,
response_headers: HeaderMap,
}
impl<S> Clone for ResponseConn<S> {
fn clone(&self) -> Self {
Self {
base: self.base.clone(),
status: self.status,
response_headers: self.response_headers.clone(),
}
}
}
impl<S> ResponseConn<S> {
fn as_uri(&self) -> &Uri {
&self.base.uri
}
pub fn status(&self) -> StatusCode {
self.status
}
pub fn get_header(&self, name: &str) -> Option<&str> {
self.response_headers
.get(name)
.and_then(|v| v.to_str().ok())
}
pub fn get_headers(&self) -> &HeaderMap {
&self.response_headers
}
}
impl<S> Deref for ResponseConn<S> {
type Target = BaseConn<S>;
fn deref(&self) -> &Self::Target {
&self.base
}
}
impl<S> From<(Uri, StatusCode, HeaderMap, State<S>)> for ResponseConn<S> {
fn from(
(uri, status, response_headers, state): (Uri, StatusCode, HeaderMap, State<S>),
) -> Self {
Self {
base: (uri, state).into(),
status,
response_headers,
}
}
}
#[derive(Debug)]
pub struct RequestConn<S = ()> {
base: BaseConn<S>,
request_headers: WatchSender<HeaderMap>,
query_params: WatchSender<HashMap<String, String>>,
}
impl<S> Clone for RequestConn<S> {
fn clone(&self) -> Self {
Self {
base: self.base.clone(),
request_headers: self.request_headers.clone(),
query_params: self.query_params.clone(),
}
}
}
impl<S> Deref for RequestConn<S> {
type Target = BaseConn<S>;
fn deref(&self) -> &Self::Target {
&self.base
}
}
impl<S>
From<(
Uri,
WatchSender<HeaderMap>,
WatchSender<HashMap<String, String>>,
State<S>,
)> for RequestConn<S>
{
fn from(
(uri, request_headers, query_params, state): (
Uri,
WatchSender<HeaderMap>,
WatchSender<HashMap<String, String>>,
State<S>,
),
) -> Self {
Self {
base: (uri, state).into(),
request_headers,
query_params,
}
}
}
impl<S> RequestConn<S> {
pub fn set_header(&self, name: &str, value: &str) -> Result<bool, IoError> {
let (name, value) = parse_header(name, value)?;
Ok(self.request_headers.send_if_modified(|map| {
if let Some(v) = map.get(&name)
&& v == value
{
false
} else {
map.insert(name, value);
true
}
}))
}
pub fn add_header(&self, name: &str, value: &str) -> Result<(), IoError> {
let (name, value) = parse_header(name, value)?;
self.request_headers.send_modify(|map| {
map.append(name, value);
});
Ok(())
}
pub fn set_argument(&self, name: &str, value: &str) -> bool {
self.query_params.send_if_modified(|params| {
if let Some(v) = params.get(name)
&& v == value
{
false
} else {
params.insert(name.to_owned(), value.to_owned());
true
}
})
}
}
pub type WsAccessor<S = ()> = Accessor<ResponseConn<S>>;
pub type HttpAccessor<S = ()> = Accessor<ResponseConn<S>>;
pub type RequestAccessor<S = ()> = Accessor<RequestConn<S>>;