use crate::DialOutFunc;
use crate::core::PaddingFactory;
use crate::proxy::session::Session;
use crate::runtime::new_client_session;
use indexmap::IndexMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, RwLock};
use tokio::time::{interval, timeout};
const IDLE_POOL_WAIT_TIMEOUT: Duration = Duration::from_millis(100);
pub struct Client {
dial_out: DialOutFunc,
sessions: Arc<Mutex<IndexMap<u64, Arc<Session>>>>,
#[allow(clippy::type_complexity)]
idle_sessions: Arc<Mutex<IndexMap<u64, (Arc<Session>, Instant)>>>,
idle_pool_notify: Arc<tokio::sync::Notify>,
session_seq_number: AtomicU64,
padding: Arc<RwLock<PaddingFactory>>,
idle_session_timeout: Duration,
min_idle_sessions: usize,
}
impl Client {
pub fn new(
dial_out: DialOutFunc,
padding: Arc<RwLock<PaddingFactory>>,
idle_session_check_interval: Duration,
idle_session_timeout: Duration,
min_idle_sessions: usize,
) -> Self {
let client = Self {
dial_out,
sessions: Arc::new(Mutex::new(IndexMap::new())),
idle_sessions: Arc::new(Mutex::new(IndexMap::new())),
idle_pool_notify: Arc::new(tokio::sync::Notify::new()),
session_seq_number: AtomicU64::new(0),
padding,
idle_session_timeout,
min_idle_sessions,
};
let idle_sessions = client.idle_sessions.clone();
let idle_timeout = client.idle_session_timeout;
let min_idle = client.min_idle_sessions;
tokio::spawn(async move {
let mut interval = interval(idle_session_check_interval);
loop {
interval.tick().await;
Self::idle_cleanup(&idle_sessions, idle_timeout, min_idle).await;
}
});
client
}
pub async fn create_stream(&self) -> Result<Arc<Session>, std::io::Error> {
let mut last_error = None;
for _ in 0..3 {
let (session, seq) = self.find_or_create_session().await?;
match session.open_stream().await {
Ok(stream) => {
self.spawn_idle_waiter(session.clone(), seq);
return Ok(stream);
}
Err(error) => {
log::warn!("Failed to open stream on session {seq}: {error}, retrying...");
let _ = session.terminate().await;
last_error = Some(error);
}
}
}
Err(last_error.unwrap_or_else(|| std::io::Error::other("Failed to create stream")))
}
async fn find_or_create_session(&self) -> Result<(Arc<Session>, u64), std::io::Error> {
if let Some((session, seq)) = self.pick_session_from_idle_pool().await {
return Ok((session, seq));
}
let has_live_sessions = {
let sessions = self.sessions.lock().await;
!sessions.is_empty()
};
if has_live_sessions {
log::trace!("Client: idle pool empty; waiting briefly for a session to return");
if timeout(IDLE_POOL_WAIT_TIMEOUT, self.idle_pool_notify.notified()).await.is_err() {
log::trace!(
"Client: idle pool wait timed out after {:?}; creating a new session",
IDLE_POOL_WAIT_TIMEOUT
);
}
if let Some((session, seq)) = self.pick_session_from_idle_pool().await {
return Ok((session, seq));
}
}
let (session, seq) = self.create_session().await?;
Ok((session, seq))
}
fn spawn_idle_waiter(&self, session: Arc<Session>, seq: u64) {
let idle_sessions = self.idle_sessions.clone();
let idle_pool_notify = self.idle_pool_notify.clone();
tokio::spawn(async move {
let ptr = Arc::as_ptr(&session) as usize;
if session.is_terminated().await {
log::trace!("Client: idle waiter sees terminated session seq={} ptr=0x{:x}", seq, ptr);
return;
}
if !session.is_stream_open().await {
let mut idles = idle_sessions.lock().await;
if idles.contains_key(&seq) {
log::trace!("Client: idle waiter found session already pooled seq={} ptr=0x{:x}", seq, ptr);
return;
}
log::trace!("Client: idle waiter pooled session immediately seq={} ptr=0x{:x}", seq, ptr);
idles.insert(seq, (session.clone(), Instant::now()));
idle_pool_notify.notify_waiters();
return;
}
log::trace!("Client: idle waiter waiting for session seq={} ptr=0x{:x}", seq, ptr);
session.wait_for_idle().await;
log::trace!("Client: idle waiter woke for session seq={} ptr=0x{:x}", seq, ptr);
if session.is_terminated().await {
log::trace!("Client: idle waiter woke to terminated session seq={} ptr=0x{:x}", seq, ptr);
return;
}
if session.is_stream_open().await {
log::trace!("Client: idle waiter woke but stream reopened seq={} ptr=0x{:x}", seq, ptr);
return;
}
let mut idles = idle_sessions.lock().await;
if idles.contains_key(&seq) {
log::trace!("Client: idle waiter found session pooled after wake seq={} ptr=0x{:x}", seq, ptr);
return;
}
log::trace!("Client: idle waiter returning session to pool seq={} ptr=0x{:x}", seq, ptr);
idles.insert(seq, (session, Instant::now()));
idle_pool_notify.notify_waiters();
});
}
async fn pick_session_from_idle_pool(&self) -> Option<(Arc<Session>, u64)> {
let mut idle_sessions = self.idle_sessions.lock().await;
while !idle_sessions.is_empty() {
let last_index = idle_sessions.len() - 1;
if let Some((seq, (session, idle_since))) = idle_sessions.swap_remove_index(last_index) {
if session.is_terminated().await {
continue;
}
if idle_since.elapsed() >= self.idle_session_timeout {
log::trace!("Dropping stale idle session {seq} before reuse");
let _ = session.terminate().await;
continue;
}
let ptr = Arc::as_ptr(&session) as usize;
log::trace!("Client: reusing idle session seq={} ptr=0x{:x}", seq, ptr);
return Some((session, seq));
} else {
break;
}
}
None
}
async fn create_session(&self) -> Result<(Arc<Session>, u64), std::io::Error> {
log::info!("Client: creating new session (dial out)");
let conn = match (self.dial_out)().await {
Ok(c) => {
log::debug!("Client: dial out succeeded");
c
}
Err(e) => {
log::warn!("Client: dial out failed: {e}");
return Err(e);
}
};
let session = Arc::new(new_client_session(conn, self.padding.clone()).await);
session.ensure_started().await?;
let seq = { self.session_seq_number.fetch_add(1, Ordering::SeqCst) };
self.sessions.lock().await.insert(seq, session.clone());
let ptr = Arc::as_ptr(&session) as usize;
log::trace!("Client: created session seq={} ptr=0x{:x}", seq, ptr);
let session_clone = session.clone();
let sessions = self.sessions.clone();
let idle_pool_notify = self.idle_pool_notify.clone();
tokio::spawn(async move {
let result = session_clone.run().await;
log::debug!("Session {seq} ended: {result:?}");
sessions.lock().await.swap_remove(&seq);
idle_pool_notify.notify_waiters();
});
Ok((session, seq))
}
#[allow(clippy::type_complexity)]
async fn idle_cleanup(idle_sessions: &Arc<Mutex<IndexMap<u64, (Arc<Session>, Instant)>>>, timeout: Duration, min_idle: usize) {
let mut idles = idle_sessions.lock().await;
let now = Instant::now();
if idles.len() <= min_idle {
return;
}
let mut timed_out_indices: Vec<usize> = Vec::new();
for index in 0..idles.len() {
if let Some((_seq, (_session, idle_since))) = idles.get_index(index)
&& now.duration_since(*idle_since) >= timeout
{
timed_out_indices.push(index);
}
}
if timed_out_indices.is_empty() {
return;
}
let max_removable = idles.len().saturating_sub(min_idle);
let remove_count = std::cmp::min(max_removable, timed_out_indices.len());
let to_remove = &timed_out_indices[..remove_count];
for &index in to_remove.iter().rev() {
if let Some((_seq, (session, _))) = idles.swap_remove_index(index) {
let _ = session.terminate().await;
}
}
}
pub async fn close(&self) -> Result<(), std::io::Error> {
let sessions = self.sessions.lock().await;
for session in sessions.values() {
let _ = session.terminate().await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::Client;
use crate::core::{Command, Frame};
use crate::proxy::session::DEFAULT_SID;
use crate::runtime::{DefaultPaddingFactory, ProtocolHost};
use crate::{AsyncReadWrite, DialOutFunc};
use std::sync::{Arc, Mutex as StdMutex};
use std::time::Duration;
use tokio::io::duplex;
use tokio::task::yield_now;
use tokio::time::timeout;
#[tokio::test]
async fn local_fin_does_not_return_session_to_idle_pool_before_remote_fin() {
let peers = Arc::new(StdMutex::new(Vec::new()));
let dial_out: DialOutFunc = {
let peers = peers.clone();
Box::new(move || {
let peers = peers.clone();
Box::pin(async move {
let (client_io, peer_io) = duplex(1024);
peers.lock().expect("peer store lock poisoned").push(peer_io);
Ok(Box::new(client_io) as Box<dyn AsyncReadWrite>)
})
})
};
let client = Client::new(
dial_out,
DefaultPaddingFactory::load(),
Duration::from_secs(60),
Duration::from_secs(60),
0,
);
let stream = client.create_stream().await.expect("stream should be created");
stream
.write_frame(Frame::new(Command::Fin, DEFAULT_SID))
.await
.expect("local FIN should be sent");
stream
.mark_local_stream_closed(DEFAULT_SID)
.await
.expect("local FIN should close only the local half");
yield_now().await;
assert!(
client.idle_sessions.lock().await.is_empty(),
"local FIN alone must not make the session reusable"
);
stream
.close_logical_stream(DEFAULT_SID)
.await
.expect("remote FIN should close the logical stream");
let (reused, reused_seq) = timeout(Duration::from_secs(1), async {
loop {
if let Some(session) = client.pick_session_from_idle_pool().await {
break session;
}
yield_now().await;
}
})
.await
.expect("session should become idle after remote FIN");
assert_eq!(reused_seq, 0, "the first session should be returned to the idle pool");
assert!(
!reused.is_stream_open().await,
"reused session should still be idle when removed from the idle pool"
);
}
#[tokio::test]
async fn remote_fin_does_not_return_session_to_idle_pool_before_local_fin() {
let peers = Arc::new(StdMutex::new(Vec::new()));
let dial_out: DialOutFunc = {
let peers = peers.clone();
Box::new(move || {
let peers = peers.clone();
Box::pin(async move {
let (client_io, peer_io) = duplex(1024);
peers.lock().expect("peer store lock poisoned").push(peer_io);
Ok(Box::new(client_io) as Box<dyn AsyncReadWrite>)
})
})
};
let client = Client::new(
dial_out,
DefaultPaddingFactory::load(),
Duration::from_secs(60),
Duration::from_secs(60),
0,
);
let stream = client.create_stream().await.expect("stream should be created");
stream
.close_logical_stream(DEFAULT_SID)
.await
.expect("remote FIN should close only the remote half");
let mut buf = [0u8; 1];
let eof_len = timeout(Duration::from_secs(1), stream.read(&mut buf))
.await
.expect("remote FIN should wake the reader")
.expect("reader should observe EOF after remote FIN");
assert_eq!(eof_len, 0, "remote FIN should surface as EOF to the local reader");
yield_now().await;
assert!(
client.idle_sessions.lock().await.is_empty(),
"remote FIN alone must not make the session reusable"
);
stream
.write_frame(Frame::new(Command::Fin, DEFAULT_SID))
.await
.expect("local FIN should be sent after observing remote EOF");
stream
.mark_local_stream_closed(DEFAULT_SID)
.await
.expect("local FIN should close the remaining local half");
let (reused, reused_seq) = timeout(Duration::from_secs(1), async {
loop {
if let Some(session) = client.pick_session_from_idle_pool().await {
break session;
}
yield_now().await;
}
})
.await
.expect("session should become idle after both halves close");
assert_eq!(reused_seq, 0, "the first session should be returned to the idle pool");
assert!(
!reused.is_stream_open().await,
"reused session should still be idle when removed from the idle pool"
);
}
}