use std::collections::HashMap;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use futures::stream::{FuturesUnordered, StreamExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::Mutex;
use crate::config::SessionConfig;
use crate::error::{ExpectError, Result};
use crate::expect::{Pattern, PatternSet};
use crate::types::Match;
pub type SessionId = usize;
#[derive(Debug, Clone)]
pub struct SelectResult {
pub session_id: SessionId,
pub matched: Match,
pub pattern_index: usize,
}
#[derive(Debug, Clone)]
pub struct SendResult {
pub session_id: SessionId,
pub success: bool,
pub error: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReadyType {
Matched,
Readable,
Writable,
Closed,
Error,
}
struct ManagedSession<T: AsyncReadExt + AsyncWriteExt + Unpin + Send> {
session: crate::session::Session<T>,
label: String,
active: bool,
}
impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send> fmt::Debug for ManagedSession<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ManagedSession")
.field("label", &self.label)
.field("active", &self.active)
.finish_non_exhaustive()
}
}
pub struct MultiSessionManager<T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static> {
sessions: HashMap<SessionId, Arc<Mutex<ManagedSession<T>>>>,
next_id: SessionId,
default_timeout: Duration,
default_config: SessionConfig,
}
impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static> fmt::Debug
for MultiSessionManager<T>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MultiSessionManager")
.field("session_count", &self.sessions.len())
.field("next_id", &self.next_id)
.field("default_timeout", &self.default_timeout)
.finish()
}
}
impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static> Default for MultiSessionManager<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static> MultiSessionManager<T> {
#[must_use]
pub fn new() -> Self {
Self {
sessions: HashMap::new(),
next_id: 0,
default_timeout: Duration::from_secs(30),
default_config: SessionConfig::default(),
}
}
#[must_use]
pub const fn with_timeout(mut self, timeout: Duration) -> Self {
self.default_timeout = timeout;
self
}
#[must_use]
pub fn with_config(mut self, config: SessionConfig) -> Self {
self.default_config = config;
self
}
pub fn add(
&mut self,
session: crate::session::Session<T>,
label: impl Into<String>,
) -> SessionId {
let id = self.next_id;
self.next_id += 1;
let managed = ManagedSession {
session,
label: label.into(),
active: true,
};
self.sessions.insert(id, Arc::new(Mutex::new(managed)));
id
}
#[allow(clippy::unused_async)]
pub async fn remove(&mut self, id: SessionId) -> Option<crate::session::Session<T>> {
if let Some(arc) = self.sessions.remove(&id) {
match Arc::try_unwrap(arc) {
Ok(mutex) => Some(mutex.into_inner().session),
Err(arc) => {
self.sessions.insert(id, arc);
None
}
}
} else {
None
}
}
#[must_use]
pub fn len(&self) -> usize {
self.sessions.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.sessions.is_empty()
}
#[must_use]
pub fn session_ids(&self) -> Vec<SessionId> {
self.sessions.keys().copied().collect()
}
pub async fn label(&self, id: SessionId) -> Option<String> {
if let Some(arc) = self.sessions.get(&id) {
let guard = arc.lock().await;
Some(guard.label.clone())
} else {
None
}
}
pub async fn is_active(&self, id: SessionId) -> bool {
if let Some(arc) = self.sessions.get(&id) {
let guard = arc.lock().await;
guard.active
} else {
false
}
}
pub async fn set_active(&self, id: SessionId, active: bool) {
if let Some(arc) = self.sessions.get(&id) {
let mut guard = arc.lock().await;
guard.active = active;
}
}
pub async fn active_ids(&self) -> Vec<SessionId> {
let mut active = Vec::new();
for &id in self.sessions.keys() {
if self.is_active(id).await {
active.push(id);
}
}
active
}
pub async fn send(&self, id: SessionId, data: &[u8]) -> Result<()> {
let arc = self
.sessions
.get(&id)
.ok_or(ExpectError::SessionNotFound { id })?;
let mut guard = arc.lock().await;
guard.session.send(data).await
}
pub async fn send_line(&self, id: SessionId, line: &str) -> Result<()> {
let arc = self
.sessions
.get(&id)
.ok_or(ExpectError::SessionNotFound { id })?;
let mut guard = arc.lock().await;
guard.session.send_line(line).await
}
pub async fn send_all(&self, data: &[u8]) -> Vec<SendResult> {
let mut futures = FuturesUnordered::new();
for (&id, arc) in &self.sessions {
let arc = Arc::clone(arc);
let data = data.to_vec();
futures.push(async move {
let mut guard = arc.lock().await;
if !guard.active {
return SendResult {
session_id: id,
success: false,
error: Some("Session not active".to_string()),
};
}
match guard.session.send(&data).await {
Ok(()) => SendResult {
session_id: id,
success: true,
error: None,
},
Err(e) => SendResult {
session_id: id,
success: false,
error: Some(e.to_string()),
},
}
});
}
let mut results = Vec::new();
while let Some(result) = futures.next().await {
results.push(result);
}
results
}
pub async fn expect(&self, id: SessionId, pattern: impl Into<Pattern>) -> Result<Match> {
let arc = self
.sessions
.get(&id)
.ok_or(ExpectError::SessionNotFound { id })?;
let mut guard = arc.lock().await;
guard.session.expect(pattern).await
}
#[allow(clippy::type_complexity)]
pub async fn expect_any(&self, pattern: impl Into<Pattern>) -> Result<SelectResult> {
let pattern = pattern.into();
self.expect_any_of(&[pattern]).await
}
#[allow(clippy::type_complexity)]
pub async fn expect_any_of(&self, patterns: &[Pattern]) -> Result<SelectResult> {
if self.sessions.is_empty() {
return Err(ExpectError::NoSessions);
}
let pattern_set = PatternSet::from_patterns(patterns.to_vec());
let mut futures: FuturesUnordered<
Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>>,
> = FuturesUnordered::new();
for (&id, arc) in &self.sessions {
let arc = Arc::clone(arc);
let patterns = pattern_set.clone();
let future: Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>> =
Box::pin(async move {
let mut guard = arc.lock().await;
if !guard.active {
return (id, Err(ExpectError::SessionClosed));
}
match guard.session.expect_any(&patterns).await {
Ok(m) => (id, Ok((m, 0))), Err(e) => (id, Err(e)),
}
});
futures.push(future);
}
let mut last_error: Option<ExpectError> = None;
while let Some((session_id, result)) = futures.next().await {
match result {
Ok((matched, pattern_index)) => {
return Ok(SelectResult {
session_id,
matched,
pattern_index,
});
}
Err(e) => {
if !matches!(e, ExpectError::Timeout { .. }) {
last_error = Some(e);
}
}
}
}
Err(last_error.unwrap_or_else(|| ExpectError::Timeout {
duration: self.default_timeout,
pattern: "multi-session expect".to_string(),
buffer: String::new(),
}))
}
pub async fn expect_all(&self, pattern: impl Into<Pattern>) -> Result<Vec<SelectResult>> {
let pattern = pattern.into();
self.expect_all_of(&[pattern]).await
}
#[allow(clippy::type_complexity)]
pub async fn expect_all_of(&self, patterns: &[Pattern]) -> Result<Vec<SelectResult>> {
if self.sessions.is_empty() {
return Err(ExpectError::NoSessions);
}
let pattern_set = PatternSet::from_patterns(patterns.to_vec());
let mut futures: FuturesUnordered<
Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>>,
> = FuturesUnordered::new();
for (&id, arc) in &self.sessions {
let arc = Arc::clone(arc);
let patterns = pattern_set.clone();
let future: Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>> =
Box::pin(async move {
let mut guard = arc.lock().await;
if !guard.active {
return (id, Err(ExpectError::SessionClosed));
}
match guard.session.expect_any(&patterns).await {
Ok(m) => (id, Ok((m, 0))),
Err(e) => (id, Err(e)),
}
});
futures.push(future);
}
let mut results = Vec::new();
let mut errors = Vec::new();
while let Some((session_id, result)) = futures.next().await {
match result {
Ok((matched, pattern_index)) => {
results.push(SelectResult {
session_id,
matched,
pattern_index,
});
}
Err(e) => {
errors.push((session_id, e));
}
}
}
if let Some((id, error)) = errors.into_iter().next() {
return Err(ExpectError::MultiSessionError {
session_id: id,
error: Box::new(error),
});
}
Ok(results)
}
pub async fn with_session<F, R>(&self, id: SessionId, f: F) -> Result<R>
where
F: FnOnce(&mut crate::session::Session<T>) -> R,
{
let arc = self
.sessions
.get(&id)
.ok_or(ExpectError::SessionNotFound { id })?;
let mut guard = arc.lock().await;
Ok(f(&mut guard.session))
}
pub async fn with_session_async<F, Fut, R>(&self, id: SessionId, f: F) -> Result<R>
where
F: FnOnce(&mut crate::session::Session<T>) -> Fut,
Fut: Future<Output = R>,
{
let arc = self
.sessions
.get(&id)
.ok_or(ExpectError::SessionNotFound { id })?;
let mut guard = arc.lock().await;
Ok(f(&mut guard.session).await)
}
}
#[derive(Debug, Default)]
pub struct PatternSelector {
patterns: HashMap<SessionId, Vec<Pattern>>,
default_patterns: Vec<Pattern>,
timeout: Option<Duration>,
}
impl PatternSelector {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn session(mut self, id: SessionId, pattern: impl Into<Pattern>) -> Self {
self.patterns.entry(id).or_default().push(pattern.into());
self
}
#[must_use]
pub fn session_patterns(mut self, id: SessionId, patterns: Vec<Pattern>) -> Self {
self.patterns.entry(id).or_default().extend(patterns);
self
}
#[must_use]
pub fn default_pattern(mut self, pattern: impl Into<Pattern>) -> Self {
self.default_patterns.push(pattern.into());
self
}
#[must_use]
pub const fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub fn patterns_for(&self, id: SessionId) -> &[Pattern] {
self.patterns
.get(&id)
.map_or(&self.default_patterns, Vec::as_slice)
}
#[allow(clippy::type_complexity)]
pub async fn select<T>(&self, manager: &MultiSessionManager<T>) -> Result<SelectResult>
where
T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
{
if manager.is_empty() {
return Err(ExpectError::NoSessions);
}
let timeout = self.timeout.unwrap_or(manager.default_timeout);
let mut futures: FuturesUnordered<
Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>>,
> = FuturesUnordered::new();
for &id in &manager.session_ids() {
let patterns = self.patterns_for(id);
if patterns.is_empty() {
continue;
}
let arc = match manager.sessions.get(&id) {
Some(arc) => Arc::clone(arc),
None => continue,
};
let pattern_set = PatternSet::from_patterns(patterns.to_vec());
let future: Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>> =
Box::pin(async move {
let mut guard = arc.lock().await;
if !guard.active {
return (id, Err(ExpectError::SessionClosed));
}
match guard.session.expect_any(&pattern_set).await {
Ok(m) => (id, Ok((m, 0))),
Err(e) => (id, Err(e)),
}
});
futures.push(future);
}
let select_future = async {
while let Some((session_id, result)) = futures.next().await {
if let Ok((matched, pattern_index)) = result {
return Ok(SelectResult {
session_id,
matched,
pattern_index,
});
}
}
Err(ExpectError::Timeout {
duration: timeout,
pattern: "pattern selector".to_string(),
buffer: String::new(),
})
};
tokio::time::timeout(timeout, select_future)
.await
.map_err(|_| ExpectError::Timeout {
duration: timeout,
pattern: "pattern selector".to_string(),
buffer: String::new(),
})?
}
}
#[cfg(test)]
mod tests {
use tokio::io::DuplexStream;
use super::*;
fn create_mock_transport() -> (DuplexStream, DuplexStream) {
tokio::io::duplex(1024)
}
#[tokio::test]
async fn manager_add_remove() {
let mut manager: MultiSessionManager<DuplexStream> = MultiSessionManager::new();
let (client, _server) = create_mock_transport();
let session = crate::session::Session::new(client, SessionConfig::default());
let id = manager.add(session, "test");
assert_eq!(manager.len(), 1);
assert_eq!(manager.label(id).await, Some("test".to_string()));
let removed = manager.remove(id).await;
assert!(removed.is_some());
assert!(manager.is_empty());
}
#[tokio::test]
async fn manager_active_state() {
let mut manager: MultiSessionManager<DuplexStream> = MultiSessionManager::new();
let (client, _server) = create_mock_transport();
let session = crate::session::Session::new(client, SessionConfig::default());
let id = manager.add(session, "test");
assert!(manager.is_active(id).await);
manager.set_active(id, false).await;
assert!(!manager.is_active(id).await);
let active = manager.active_ids().await;
assert!(active.is_empty());
}
#[tokio::test]
async fn pattern_selector_build() {
let selector = PatternSelector::new()
.session(0, "login:")
.session(0, "password:")
.session(1, "prompt>")
.default_pattern("$");
assert_eq!(selector.patterns_for(0).len(), 2);
assert_eq!(selector.patterns_for(1).len(), 1);
assert_eq!(selector.patterns_for(99).len(), 1); }
#[tokio::test]
async fn expect_any_no_sessions() {
let manager: MultiSessionManager<DuplexStream> = MultiSessionManager::new();
let result = manager.expect_any("test").await;
assert!(matches!(result, Err(ExpectError::NoSessions)));
}
}