use {
super::{
super::{
types::{Ref, RefMut, State, WsAsyncFn},
utils::{get_path, parse_header, url_decode},
},
Accessor, Command,
},
hyper::{StatusCode, header::HeaderMap, http::Uri},
std::{
collections::HashMap,
io::Error as IoError,
mem::take,
net::SocketAddr,
ops::Deref,
sync::{Arc, LazyLock, Weak},
},
tokio::{
runtime::Runtime,
sync::{
Mutex,
mpsc::{Sender as MpscSender, WeakSender},
watch::{Sender as WatchSender, channel as watch_channel},
},
},
tokio_tungstenite::tungstenite::Message,
tracing::{error, warn},
};
pub(super) struct ConnInfo {
pub uri: Uri,
pub headers: HeaderMap,
pub sender: MpscSender<Message>,
}
pub(super) static WS_CONNS: LazyLock<Mutex<HashMap<(String, SocketAddr), ConnInfo>>> =
LazyLock::new(Default::default);
async fn find_ws_connection<F, Args, Ret, Acc, S, P>(
_target: F,
state: State<S>,
mut predicated: P,
) -> Option<WsAccessor<S>>
where
F: WsAsyncFn<Args, Ret, Acc, S>,
P: FnMut(&mut WsAccessor<S>) -> bool,
{
let target_path = get_path::<F>();
let conns = WS_CONNS.lock().await;
for ((path, addr), info) in conns.iter() {
if path == &target_path {
let accessor: WsAccessor<S> =
WsConn::from((info.uri.clone(), *addr, info.headers.clone(), state.clone())).into();
let mut accessor = accessor;
if predicated(&mut accessor) {
return Some(accessor);
}
}
}
None
}
#[derive(Debug)]
pub struct BaseConn<S = ()> {
uri: Arc<Uri>,
socket_addr: SocketAddr,
headers: Arc<HeaderMap>,
state: State<S>,
}
impl<S> BaseConn<S> {
pub fn get_addr(&self) -> SocketAddr {
self.socket_addr
}
pub fn get_argument(&self, name: &str) -> Option<String> {
let query = self.query()?;
for pair in query.split('&') {
if let Some(eq_pos) = pair.find('=') {
let key = &pair[..eq_pos];
if let Ok(decoded_key) = url_decode(key)
&& decoded_key == name
{
let value = &pair[eq_pos + 1..];
return url_decode(value).ok();
}
} else if let Ok(decoded_key) = url_decode(pair)
&& decoded_key == name
{
return Some(String::new());
}
}
None
}
pub fn get_arguments<'a>(&'a self, name: &'a str) -> impl Iterator<Item = String> + 'a {
self.query()
.into_iter()
.flat_map(|query| query.split('&'))
.filter_map(move |pair| {
if let Some(eq_pos) = pair.find('=') {
let key = &pair[..eq_pos];
if let Ok(decoded_key) = url_decode(key)
&& decoded_key == name
{
url_decode(&pair[eq_pos + 1..]).ok()
} else {
None
}
} else if let Ok(decoded_key) = url_decode(pair)
&& decoded_key == name
{
Some(String::new())
} else {
None
}
})
}
pub fn get_all_arguments(&self) -> HashMap<String, String> {
let mut result = HashMap::new();
if let Some(query) = self.query() {
for pair in query.split('&') {
if let Some(eq_pos) = pair.find('=') {
let key = &pair[..eq_pos];
let value = &pair[eq_pos + 1..];
if let (Ok(key), Ok(value)) = (url_decode(key), url_decode(value)) {
result.entry(key).or_insert(value);
}
} else if !pair.is_empty()
&& let Ok(key) = url_decode(pair)
{
result.entry(key).or_insert_with(String::new);
}
}
}
result
}
pub fn get_header(&self, name: &str) -> Option<&str> {
self.headers.get(name).and_then(|v| v.to_str().ok())
}
pub fn get_headers(&self) -> &HeaderMap {
&self.headers
}
pub async fn borrow(&self) -> Ref<'_, S> {
Ref::new(self.state.read().await)
}
pub async fn borrow_mut(&self) -> RefMut<'_, S> {
RefMut::new(self.state.write().await)
}
}
impl<S> Clone for BaseConn<S> {
fn clone(&self) -> Self {
Self {
headers: self.headers.clone(),
uri: self.uri.clone(),
state: self.state.clone(),
socket_addr: self.socket_addr,
}
}
}
impl<S> From<(Uri, SocketAddr, HeaderMap, State<S>)> for BaseConn<S> {
fn from((uri, socket_addr, headers, state): (Uri, SocketAddr, HeaderMap, State<S>)) -> Self {
Self {
uri: uri.into(),
socket_addr,
headers: headers.into(),
state,
}
}
}
impl<S> Deref for BaseConn<S> {
type Target = Uri;
fn deref(&self) -> &Self::Target {
&self.uri
}
}
pub type WsAccessor<S = ()> = Accessor<WsConn<S>>;
impl<S> AsRef<WsAccessor<S>> for WsAccessor<S> {
fn as_ref(&self) -> &Self {
self
}
}
#[derive(Debug)]
pub struct WsConn<S = ()> {
inner: BaseConn<S>,
}
impl<S> Clone for WsConn<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<S> WsConn<S> {
pub async fn find_conn<F, Args, Ret, Acc, P>(
&self,
target: F,
predicated: P,
) -> Option<WsAccessor<S>>
where
F: WsAsyncFn<Args, Ret, Acc, S>,
P: FnMut(&mut WsAccessor<S>) -> bool,
{
find_ws_connection(target, self.inner.state.clone(), predicated).await
}
pub async fn get_other_conns(&self) -> impl Iterator<Item = WsAccessor<S>> {
let target_path = self.uri.path().to_owned();
let self_addr = self.socket_addr;
let conn_data = {
let conns = WS_CONNS.lock().await;
conns
.iter()
.filter(|((path, addr), _)| path == &target_path && *addr != self_addr)
.map(|((_, addr), info)| (info.uri.clone(), *addr, info.headers.clone()))
.collect::<Vec<_>>()
};
conn_data.into_iter().map(|(uri, addr, headers)| {
WsConn::from((uri, addr, headers, self.state.clone())).into()
})
}
}
impl<S> From<(Uri, SocketAddr, HeaderMap, State<S>)> for WsConn<S> {
fn from((uri, socket_addr, headers, state): (Uri, SocketAddr, HeaderMap, State<S>)) -> Self {
Self {
inner: (uri, socket_addr, headers, state).into(),
}
}
}
impl<S> Deref for WsConn<S> {
type Target = BaseConn<S>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
pub(super) struct WsConnGuard {
pub uri: Uri,
pub socket_addr: SocketAddr,
pub headers: HeaderMap,
pub rt: Weak<Runtime>,
pub command: WeakSender<Command>,
}
impl Drop for WsConnGuard {
fn drop(&mut self) {
let Some(rt) = self.rt.upgrade() else {
warn!("Runtime already dropped.");
return;
};
let Some(command) = self.command.upgrade() else {
warn!("Command sender already dropped.");
return;
};
let socket_addr = self.socket_addr;
let headers = take(&mut self.headers);
let uri = take(&mut self.uri);
rt.spawn(async move {
let path = uri.path().to_owned();
if let Err(e) = command
.send(Command::WsClose {
uri,
socket_addr,
headers,
})
.await
{
error!(?e, "Failed to send close event.");
}
WS_CONNS.lock().await.remove(&(path, socket_addr));
});
}
}
pub type HttpAccessor<S = ()> = Accessor<HttpConn<S>>;
impl<S> AsRef<HttpAccessor<S>> for HttpAccessor<S> {
fn as_ref(&self) -> &Self {
self
}
}
#[derive(Debug)]
pub struct HttpConn<S = ()> {
inner: BaseConn<S>,
response_headers: WatchSender<HeaderMap>,
response_status: WatchSender<StatusCode>,
}
impl<S> Clone for HttpConn<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
response_headers: self.response_headers.clone(),
response_status: self.response_status.clone(),
}
}
}
impl<S> HttpConn<S> {
pub fn set_header(&self, name: &str, value: &str) -> Result<bool, IoError> {
let (name, value) = parse_header(name, value)?;
Ok(self.response_headers.send_if_modified(|map| {
if let Some(v) = map.get(&name)
&& v == value
{
false
} else {
map.insert(name, value);
true
}
}))
}
pub fn add_header(&self, name: &str, value: &str) -> Result<(), IoError> {
let (name, value) = parse_header(name, value)?;
self.response_headers.send_modify(|map| {
map.append(name, value);
});
Ok(())
}
pub fn set_status(&self, status: StatusCode) {
self.response_status.send_modify(|i| *i = status);
}
pub async fn find_ws_conn<F, Args, Ret, Acc, P>(
&self,
target: F,
predicated: P,
) -> Option<WsAccessor<S>>
where
F: WsAsyncFn<Args, Ret, Acc, S>,
P: FnMut(&mut WsAccessor<S>) -> bool,
{
find_ws_connection(target, self.inner.state.clone(), predicated).await
}
pub async fn get_all_ws_conns<F, Args, Ret, Acc>(
&self,
_target: F,
) -> impl Iterator<Item = WsAccessor<S>>
where
F: WsAsyncFn<Args, Ret, Acc, S>,
{
let target_path = get_path::<F>();
let conn_data = {
let conns = WS_CONNS.lock().await;
conns
.iter()
.filter(|((path, _addr), _info)| path == &target_path)
.map(|(key, info)| (info.uri.clone(), key.1, info.headers.clone()))
.collect::<Vec<_>>()
};
conn_data.into_iter().map(|(uri, addr, headers)| {
WsConn::from((uri, addr, headers, self.state.clone())).into()
})
}
}
impl<S>
From<(
Uri,
SocketAddr,
HeaderMap,
State<S>,
WatchSender<HeaderMap>,
WatchSender<StatusCode>,
)> for HttpConn<S>
{
fn from(
(uri, socket_addr, request_headers, state, response_headers, response_status): (
Uri,
SocketAddr,
HeaderMap,
State<S>,
WatchSender<HeaderMap>,
WatchSender<StatusCode>,
),
) -> Self {
Self {
inner: (uri, socket_addr, request_headers, state).into(),
response_headers,
response_status,
}
}
}
impl<S> From<(Uri, SocketAddr, HeaderMap, State<S>)> for HttpConn<S> {
fn from((uri, socket_addr, headers, state): (Uri, SocketAddr, HeaderMap, State<S>)) -> Self {
let (headers_tx, _) = watch_channel(Default::default());
let (status_tx, _) = watch_channel(StatusCode::OK);
Self {
inner: (uri, socket_addr, headers, state).into(),
response_headers: headers_tx,
response_status: status_tx,
}
}
}
impl<S> Deref for HttpConn<S> {
type Target = BaseConn<S>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::sync::RwLock;
fn create_base_conn(uri: &str) -> BaseConn {
BaseConn::from((
uri.parse().unwrap(),
"127.0.0.1:8080".parse().unwrap(),
HeaderMap::new(),
Arc::new(RwLock::new(())),
))
}
#[test]
fn test_path() {
let conn = create_base_conn("/api/user");
assert_eq!(conn.path(), "/api/user");
let conn = create_base_conn("/api/user?id=123");
assert_eq!(conn.path(), "/api/user");
let conn = create_base_conn("/api/user?id=123&name=test");
assert_eq!(conn.path(), "/api/user");
}
#[test]
fn test_query() {
let conn = create_base_conn("/api/user");
assert!(conn.query().is_none());
let conn = create_base_conn("/api/user?id=123");
assert_eq!(conn.query(), Some("id=123"));
let conn = create_base_conn("/api/user?id=123&name=test");
assert_eq!(conn.query(), Some("id=123&name=test"));
}
#[test]
fn test_get_argument() {
let conn = create_base_conn("/api/user?id=123&name=test");
assert_eq!(conn.get_argument("id"), Some("123".to_string()));
assert_eq!(conn.get_argument("name"), Some("test".to_string()));
assert_eq!(conn.get_argument("age"), None);
}
#[test]
fn test_get_argument_url_encoded() {
let conn = create_base_conn("/search?q=hello%20world&tag=%E4%B8%AD%E6%96%87");
assert_eq!(conn.get_argument("q"), Some("hello world".to_string()));
assert_eq!(conn.get_argument("tag"), Some("中文".to_string()));
}
#[test]
fn test_get_argument_plus_as_space() {
let conn = create_base_conn("/search?q=hello+world");
assert_eq!(conn.get_argument("q"), Some("hello world".to_string()));
}
#[test]
fn test_get_arguments_multiple_values() {
let conn = create_base_conn("/api?tag=foo&tag=bar&tag=baz");
let values: Vec<_> = conn.get_arguments("tag").collect();
assert_eq!(values, vec!["foo", "bar", "baz"]);
}
#[test]
fn test_get_all_arguments() {
let conn = create_base_conn("/api?id=123&name=hello%20world&flag");
let args = conn.get_all_arguments();
assert_eq!(args.get("id"), Some(&"123".to_string()));
assert_eq!(args.get("name"), Some(&"hello world".to_string()));
assert_eq!(args.get("flag"), Some(&String::new()));
}
#[test]
fn test_empty_values() {
let conn = create_base_conn("/api?empty=&flag");
assert_eq!(conn.get_argument("empty"), Some(String::new()));
assert_eq!(conn.get_argument("flag"), Some(String::new()));
}
#[test]
fn test_special_characters_in_key() {
let conn = create_base_conn("/api?%24key=value");
assert_eq!(conn.get_argument("$key"), Some("value".to_string()));
}
}