use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use arc_swap::ArcSwap;
use futures::future::Either;
use tokio::sync::Notify;
use super::{Session, config::SessionManagerConfig};
use crate::{
consts::StateType,
error::Error,
stream::Stream,
transport::{TransportConnect, TransportStream},
};
const SESSION_ROUND_ROBIN_THRESHOLD: usize = 32;
pub struct SessionManager<C: TransportConnect> {
inner: Arc<SessionManagerInner<C>>,
}
impl<C: TransportConnect> Clone for SessionManager<C> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[allow(dead_code)]
struct SessionManagerInner<C: TransportConnect> {
sm_config: SessionManagerConfig,
connect: C,
addr: C::Address,
sessions: Vec<ArcSwap<Session>>,
count: AtomicUsize,
epoch: u64,
rand_id: u64,
state: StateType,
shutdown_tx: Arc<Notify>,
}
impl<C: TransportConnect> SessionManager<C>
where
C: TransportConnect + Send + Sync + 'static,
C::Stream: Send,
<C::Stream as TransportStream>::ReadHalf: Send + 'static,
<C::Stream as TransportStream>::WriteHalf: Send + 'static,
C::Address: Clone + Send + Sync,
{
pub async fn new(
mut sm_config: SessionManagerConfig,
connect: C,
addr: C::Address,
) -> Result<Self, Error> {
let mut sessions = Vec::with_capacity(sm_config.session_num());
for i in 0..sm_config.session_num() {
let session = Session::client(i, 0, 0, &mut sm_config, &connect, addr.clone()).await?;
sessions.push(ArcSwap::from_pointee(session));
}
let shutdown = Arc::new(Notify::new());
let sm = Self {
inner: Arc::new(SessionManagerInner {
sm_config: sm_config.clone(),
connect,
addr,
sessions,
count: AtomicUsize::new(0),
epoch: 0,
rand_id: 0,
state: StateType::DefaultState,
shutdown_tx: shutdown.clone(),
}),
};
for i in 0..sm_config.session_num() {
tokio::spawn({
let shutdown_rx = shutdown.clone();
let sm = sm.clone();
async move { sm.rebuild_session(shutdown_rx, i).await }
});
}
Ok(sm)
}
pub fn get_stream(&self) -> Result<Stream, Error> {
let sessions = &self.inner.sessions;
if sessions.is_empty() {
return Err(Error::SessionUnhealthy);
}
let i = ((self.inner.count.fetch_add(1, Ordering::SeqCst) + 1)
/ SESSION_ROUND_ROBIN_THRESHOLD)
% sessions.len();
sessions[i].load().get_or_open_stream(i)
}
pub async fn put_back(&self, stream: Stream) {
self.inner.sessions[stream.session_id()]
.load()
.put_or_close_stream(stream)
.await;
}
pub async fn close(&self) {
self.inner.shutdown_tx.notify_waiters();
for s in &self.inner.sessions {
s.load().close().await;
}
}
async fn rebuild_session(self, shutdown_rx: Arc<Notify>, i: usize) {
let mut shutdown_rx = Box::pin(shutdown_rx.notified());
loop {
let session = self.inner.sessions[i].load();
let notified = std::pin::pin!(session.shared.shutdown_notify.notified());
match futures::future::select(notified, &mut shutdown_rx).await {
Either::Left(_) => {}
Either::Right(_) => return,
}
loop {
tokio::time::sleep(self.inner.sm_config.config().rebuild_interval).await;
match Session::client(
i,
0,
0,
&mut self.inner.sm_config.clone(),
&self.inner.connect,
self.inner.addr.clone(),
)
.await
{
Ok(new_session) => {
self.inner.sessions[i].store(Arc::new(new_session));
tracing::info!("rebuild session {i} success");
break;
}
Err(e) => {
tracing::error!(
"rebuild session {i} error: {:?}, retry after {:?}",
e,
self.inner.sm_config.config().rebuild_interval
);
}
}
}
}
}
}