use super::ConnectionState;
use crate::api::BotApi;
use futures_util::stream::{SplitSink, SplitStream};
use serde_json::Value;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::time::{Duration, sleep};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use tracing::{debug, error, info};
type ConnectFn = Box<
dyn Fn(
Session,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<(), crate::error::BotError>> + Send>,
> + Send
+ Sync,
>;
type DispatchFn = Box<dyn Fn(&str, Value) + Send + Sync>;
pub type WsSink =
SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Message>;
pub type WsStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
#[derive(Debug, Clone)]
pub struct Session {
pub session_id: String,
pub shard: (u32, u32),
pub url: String,
pub needs_reconnect: bool,
}
impl Session {
pub fn new(session_id: String, shard: (u32, u32), url: String) -> Self {
Self {
session_id,
shard,
url,
needs_reconnect: false,
}
}
pub fn mark_for_reconnect(&mut self) {
self.session_id = String::new();
self.needs_reconnect = true;
}
}
#[allow(unused)]
pub struct ConnectionSession {
max_async: usize,
connect_fn: ConnectFn,
dispatch_fn: DispatchFn,
sessions: Vec<Session>,
state: Arc<Mutex<ConnectionState>>,
}
impl ConnectionSession {
pub fn new<F, D>(max_async: usize, connect_fn: F, dispatch_fn: D, api: BotApi) -> Self
where
F: Fn(
Session,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<(), crate::error::BotError>> + Send>,
> + Send
+ Sync
+ 'static,
D: Fn(&str, Value) + Send + Sync + 'static,
{
Self {
max_async,
connect_fn: Box::new(connect_fn),
dispatch_fn: Box::new(dispatch_fn),
sessions: Vec::new(),
state: Arc::new(Mutex::new(ConnectionState::new(api))),
}
}
pub fn add_session(&mut self, session: Session) {
self.sessions.push(session);
}
pub async fn multi_run(mut self, session_interval: u64) -> Result<(), crate::error::BotError> {
if self.sessions.is_empty() {
return Ok(());
}
let mut index = 0;
let mut tasks = Vec::new();
while !self.sessions.is_empty() {
debug!("Session list loop running");
let time_interval = session_interval * (index + 1);
info!(
"Max concurrent connections: {}, Starting sessions: {}",
self.max_async,
self.sessions.len()
);
for _ in 0..self.max_async {
if self.sessions.is_empty() {
break;
}
let session = self.sessions.remove(0);
tasks.push(tokio::spawn(async move {
debug!("Would connect session: {:?}", session);
sleep(Duration::from_secs(time_interval)).await;
}));
}
index += self.max_async as u64;
}
for task in tasks {
if let Err(e) = task.await {
error!("Task execution failed: {:?}", e);
}
}
Ok(())
}
pub fn state(&self) -> Arc<Mutex<ConnectionState>> {
self.state.clone()
}
}