use crate::node::Node;
use crate::ssh::tokio_client::CommandOutput;
use tokio::sync::mpsc;
const MAX_BUFFER_SIZE: usize = 10 * 1024 * 1024;
#[derive(Debug)]
struct RollingBuffer {
data: Vec<u8>,
total_bytes_received: usize,
bytes_dropped: usize,
}
impl RollingBuffer {
fn new() -> Self {
Self {
data: Vec::new(),
total_bytes_received: 0,
bytes_dropped: 0,
}
}
fn append(&mut self, new_data: &[u8]) {
self.total_bytes_received += new_data.len();
self.data.extend_from_slice(new_data);
if self.data.len() > MAX_BUFFER_SIZE {
let overflow = self.data.len() - MAX_BUFFER_SIZE;
self.bytes_dropped += overflow;
self.data.drain(0..overflow);
tracing::warn!(
"Buffer overflow: dropped {} bytes (total dropped: {})",
overflow,
self.bytes_dropped
);
}
}
fn as_slice(&self) -> &[u8] {
&self.data
}
fn take(&mut self) -> Vec<u8> {
std::mem::take(&mut self.data)
}
fn has_overflow(&self) -> bool {
self.bytes_dropped > 0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExecutionStatus {
Pending,
Running,
Completed,
Failed(String),
}
pub struct NodeStream {
pub node: Node,
receiver: mpsc::Receiver<CommandOutput>,
stdout_buffer: RollingBuffer,
stderr_buffer: RollingBuffer,
status: ExecutionStatus,
exit_code: Option<u32>,
closed: bool,
}
impl NodeStream {
pub fn new(node: Node, receiver: mpsc::Receiver<CommandOutput>) -> Self {
Self {
node,
receiver,
stdout_buffer: RollingBuffer::new(),
stderr_buffer: RollingBuffer::new(),
status: ExecutionStatus::Pending,
exit_code: None,
closed: false,
}
}
pub fn poll(&mut self) -> bool {
let mut received_data = false;
if self.status == ExecutionStatus::Pending {
self.status = ExecutionStatus::Running;
}
loop {
match self.receiver.try_recv() {
Ok(output) => {
received_data = true;
match output {
CommandOutput::StdOut(data) => {
self.stdout_buffer.append(&data);
if self.stdout_buffer.has_overflow() {
tracing::warn!(
"Node {} stdout buffer overflow - old data discarded",
self.node.host
);
}
}
CommandOutput::StdErr(data) => {
self.stderr_buffer.append(&data);
if self.stderr_buffer.has_overflow() {
tracing::warn!(
"Node {} stderr buffer overflow - old data discarded",
self.node.host
);
}
}
CommandOutput::ExitCode(code) => {
self.exit_code = Some(code);
tracing::debug!("Node {} received exit code: {}", self.node.host, code);
}
}
}
Err(mpsc::error::TryRecvError::Empty) => {
break;
}
Err(mpsc::error::TryRecvError::Disconnected) => {
self.closed = true;
if !matches!(self.status, ExecutionStatus::Failed(_)) {
if let Some(code) = self.exit_code {
if code != 0 {
self.status =
ExecutionStatus::Failed(format!("Exit code: {}", code));
} else {
self.status = ExecutionStatus::Completed;
}
} else {
self.status = ExecutionStatus::Completed;
}
}
tracing::debug!("Channel disconnected for node {}", self.node.host);
break;
}
}
}
received_data
}
pub fn stdout(&self) -> &[u8] {
self.stdout_buffer.as_slice()
}
pub fn stderr(&self) -> &[u8] {
self.stderr_buffer.as_slice()
}
pub fn take_stdout(&mut self) -> Vec<u8> {
self.stdout_buffer.take()
}
pub fn take_stderr(&mut self) -> Vec<u8> {
self.stderr_buffer.take()
}
pub fn status(&self) -> &ExecutionStatus {
&self.status
}
pub fn set_status(&mut self, status: ExecutionStatus) {
self.status = status;
}
pub fn exit_code(&self) -> Option<u32> {
self.exit_code
}
pub fn set_exit_code(&mut self, code: u32) {
self.exit_code = Some(code);
}
pub fn is_closed(&self) -> bool {
self.closed
}
pub fn is_complete(&self) -> bool {
matches!(
self.status,
ExecutionStatus::Completed | ExecutionStatus::Failed(_)
) && self.closed
}
}
pub struct MultiNodeStreamManager {
streams: Vec<NodeStream>,
}
impl MultiNodeStreamManager {
pub fn new() -> Self {
Self {
streams: Vec::new(),
}
}
pub fn add_stream(&mut self, node: Node, receiver: mpsc::Receiver<CommandOutput>) {
self.streams.push(NodeStream::new(node, receiver));
}
pub fn poll_all(&mut self) -> bool {
let mut any_received = false;
for stream in &mut self.streams {
if stream.poll() {
any_received = true;
}
}
any_received
}
pub fn streams(&self) -> &[NodeStream] {
&self.streams
}
pub fn streams_mut(&mut self) -> &mut [NodeStream] {
&mut self.streams
}
pub fn all_complete(&self) -> bool {
!self.streams.is_empty() && self.streams.iter().all(|s| s.is_complete())
}
pub fn completed_count(&self) -> usize {
self.streams
.iter()
.filter(|s| matches!(s.status(), ExecutionStatus::Completed) && s.is_closed())
.count()
}
pub fn failed_count(&self) -> usize {
self.streams
.iter()
.filter(|s| matches!(s.status(), ExecutionStatus::Failed(_)))
.count()
}
pub fn total_count(&self) -> usize {
self.streams.len()
}
}
impl Default for MultiNodeStreamManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use russh::CryptoVec;
#[test]
fn test_node_stream_creation() {
let node = Node::new("localhost".to_string(), 22, "test".to_string());
let (_tx, rx) = mpsc::channel(100);
let stream = NodeStream::new(node, rx);
assert_eq!(stream.status(), &ExecutionStatus::Pending);
assert_eq!(stream.exit_code(), None);
assert!(!stream.is_closed());
assert!(!stream.is_complete());
}
#[tokio::test]
async fn test_node_stream_polling() {
let node = Node::new("localhost".to_string(), 22, "test".to_string());
let (tx, rx) = mpsc::channel(100);
let mut stream = NodeStream::new(node, rx);
let data = CryptoVec::from(b"test output".to_vec());
tx.send(CommandOutput::StdOut(data)).await.unwrap();
assert!(stream.poll());
assert_eq!(stream.stdout(), b"test output");
assert_eq!(stream.status(), &ExecutionStatus::Running);
}
#[tokio::test]
async fn test_node_stream_take_buffers() {
let node = Node::new("localhost".to_string(), 22, "test".to_string());
let (tx, rx) = mpsc::channel(100);
let mut stream = NodeStream::new(node, rx);
let data = CryptoVec::from(b"test".to_vec());
tx.send(CommandOutput::StdOut(data)).await.unwrap();
stream.poll();
let stdout = stream.take_stdout();
assert_eq!(stdout, b"test");
assert!(stream.stdout().is_empty());
}
#[tokio::test]
async fn test_node_stream_completion() {
let node = Node::new("localhost".to_string(), 22, "test".to_string());
let (tx, rx) = mpsc::channel(100);
let mut stream = NodeStream::new(node, rx);
drop(tx);
stream.poll();
assert!(stream.is_closed());
assert!(stream.is_complete());
assert_eq!(stream.status(), &ExecutionStatus::Completed);
}
#[tokio::test]
async fn test_multi_node_stream_manager() {
let mut manager = MultiNodeStreamManager::new();
let node1 = Node::new("host1".to_string(), 22, "node1".to_string());
let (_tx1, rx1) = mpsc::channel(100);
manager.add_stream(node1, rx1);
let node2 = Node::new("host2".to_string(), 22, "node2".to_string());
let (_tx2, rx2) = mpsc::channel(100);
manager.add_stream(node2, rx2);
assert_eq!(manager.total_count(), 2);
assert_eq!(manager.completed_count(), 0);
}
#[tokio::test]
async fn test_multi_node_stream_poll_all() {
let mut manager = MultiNodeStreamManager::new();
let node1 = Node::new("host1".to_string(), 22, "node1".to_string());
let (tx1, rx1) = mpsc::channel(100);
manager.add_stream(node1, rx1);
let data = CryptoVec::from(b"output1".to_vec());
tx1.send(CommandOutput::StdOut(data)).await.unwrap();
assert!(manager.poll_all());
assert_eq!(manager.streams()[0].stdout(), b"output1");
}
#[tokio::test]
async fn test_multi_node_stream_all_complete() {
let mut manager = MultiNodeStreamManager::new();
let node1 = Node::new("host1".to_string(), 22, "node1".to_string());
let (tx1, rx1) = mpsc::channel(100);
manager.add_stream(node1, rx1);
let node2 = Node::new("host2".to_string(), 22, "node2".to_string());
let (tx2, rx2) = mpsc::channel(100);
manager.add_stream(node2, rx2);
drop(tx1);
drop(tx2);
manager.poll_all();
assert!(manager.all_complete());
assert_eq!(manager.completed_count(), 2);
}
}