use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use serde_json::{Map, Value};
use tokio::{
io::{AsyncRead, AsyncWrite},
select,
sync::{
mpsc::{self, Receiver},
oneshot::Sender,
},
time,
};
use tokio_tungstenite::{
WebSocketStream,
tungstenite::{Bytes, Message},
};
use crate::{
apperror::AppError,
channel_handler::ChannelHandler,
command_channel::CommandChannelWriteHalf,
model::{packet::Packet, site_info},
rules::Rule,
};
pub(crate) struct NotLoggedInSession {
command_channel_write: CommandChannelWriteHalf,
}
#[derive(Debug, PartialEq)]
pub struct SiteInfo {
name: String,
domain: String,
description: Option<String>,
enabled: bool,
}
impl SiteInfo {
pub fn get_domain_name(&self) -> String {
String::new() + &self.name + "." + &self.domain
}
pub fn get_description(&self) -> Option<String> {
self.description.clone()
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
}
impl From<site_info::SiteInfo> for SiteInfo {
fn from(value: site_info::SiteInfo) -> Self {
Self {
name: value.name,
domain: value.domain,
description: value.description,
enabled: value.enabled,
}
}
}
pub struct Session {
command_channel_write: CommandChannelWriteHalf,
}
pub(crate) struct LoginInfo {
pub token: String,
pub client_name: Option<String>,
pub client_id: Option<String>,
pub mac: Option<String>,
}
impl LoginInfo {
pub fn new(
token: String,
client_name: Option<String>,
client_id: Option<String>,
mac: Option<String>,
) -> Self {
Self {
token,
client_name,
client_id,
mac,
}
}
}
pub(crate) struct Handler {
to_server_rx: Receiver<Vec<u8>>,
channel_handler: ChannelHandler,
}
impl Handler {
fn new(to_server_rx: Receiver<Vec<u8>>, channel_handler: ChannelHandler) -> Self {
Handler {
to_server_rx,
channel_handler,
}
}
pub async fn run_loop<T: AsyncRead + AsyncWrite + Unpin>(
mut self,
mut socket: WebSocketStream<T>,
event_poll_interval: Duration,
) -> Result<(), AppError> {
let ping_interval: u64 = 30;
let mut ping = time::interval(Duration::from_secs(ping_interval));
let mut pong_timeout = time::interval(Duration::from_secs(ping_interval * 2));
let mut pong_received = true;
let mut event_poll = time::interval(event_poll_interval);
loop {
select! {
_ = event_poll.tick() => {
}
_ = ping.tick() => {
match socket.send(Message::Ping(Bytes::new())).await {
Ok(_) => {
},
Err(e) => log::warn!("ping failed {:}", e),
}
}
_ = pong_timeout.tick() => {
if pong_received {
pong_received = false;
continue;
}
return Result::Err(AppError::new("ping timeout"));
}
packet = self.to_server_rx.recv() => {
let Some(packet) = packet else {
return Result::Err(AppError::new("handler terminated"))
};
match socket.send(Message::Binary(Bytes::from_owner(packet))).await {
Ok(_) => continue,
Err(_) => return Result::Err(AppError::new("failed to send packet")),
}
}
packet = socket.next() => {
let Some(packet) = packet else {
return Result::Err(AppError::new("connection closed"))
};
let packet = packet.map_err(|e| AppError::new(format!("connection error: {}", e)))?;
match packet {
Message::Binary(packet) => {
let packet = Packet::deserialize(&packet)?;
self.channel_handler.route(packet).await;
},
Message::Pong(_) => {
pong_received = true;
log::trace!("pong received");
}
_ => {
return Result::Err(AppError::new("received a packet that was not of type binary"))
}
}
}
}
}
}
}
impl NotLoggedInSession {
pub(crate) fn new(rule: Rule) -> (Handler, NotLoggedInSession) {
let (to_server_tx, to_server_rx) = mpsc::channel::<Vec<u8>>(8);
let (channel_handler, command_channel) = ChannelHandler::new(to_server_tx, rule);
let (command_channel_read, command_channel_write) = command_channel.split();
command_channel_read.spawn();
(
Handler::new(to_server_rx, channel_handler),
NotLoggedInSession {
command_channel_write,
},
)
}
pub(crate) async fn login(self, login_info: LoginInfo) -> Result<Session, AppError> {
let value = self.command_channel_write.send_info(login_info).await?;
log::trace!("response: {:?}", value);
Ok(Session {
command_channel_write: self.command_channel_write,
})
}
}
impl Session {
pub async fn post_upload(&self, data: Map<String, Value>) -> Result<Value, AppError> {
self.command_channel_write.post_upload(data).await
}
pub async fn get_download(&self) -> Result<Value, AppError> {
self.command_channel_write.get_download().await
}
pub async fn get_site_info(&self) -> Result<SiteInfo, AppError> {
self.command_channel_write
.get_site_info()
.await
.map(|site_info| site_info.into())
}
}
pub(crate) async fn start<T: AsyncRead + AsyncWrite + Unpin>(
login_info: LoginInfo,
rule: Rule,
socket: WebSocketStream<T>,
event_poll_interval: Duration,
on_session: Option<Sender<Session>>,
) -> Result<(), AppError> {
let (handler, session) = NotLoggedInSession::new(rule);
tokio::spawn(async move {
match session.login(login_info).await {
Err(e) => {
log::warn!("login failed {e}");
}
Ok(session) => {
log::info!("logged in");
if let Some(on_session) = on_session
&& on_session.send(session).is_err()
{
log::warn!("failed to inform on session handler")
}
}
}
});
handler.run_loop(socket, event_poll_interval).await
}