use super::pager::{BufferedOutput, PagerConfig, PagerExitStatus};
use std::io::{self, Write};
enum StdoutBackend {
Direct(Box<dyn Write + Send>),
Pager(BufferedOutput),
}
pub struct OutputStreams {
stdout: StdoutBackend,
stderr: Box<dyn Write + Send>,
}
impl OutputStreams {
#[must_use]
pub fn new() -> Self {
Self {
stdout: StdoutBackend::Direct(Box::new(io::stdout())),
stderr: Box::new(io::stderr()),
}
}
#[must_use]
pub fn with_pager(config: PagerConfig) -> Self {
Self {
stdout: StdoutBackend::Pager(BufferedOutput::new(config)),
stderr: Box::new(io::stderr()),
}
}
#[cfg(test)]
#[allow(dead_code)] pub fn with_writers<W1, W2>(stdout: W1, stderr: W2) -> Self
where
W1: Write + Send + 'static,
W2: Write + Send + 'static,
{
Self {
stdout: StdoutBackend::Direct(Box::new(stdout)),
stderr: Box::new(stderr),
}
}
pub fn write_result(&mut self, content: &str) -> io::Result<()> {
match &mut self.stdout {
StdoutBackend::Direct(writer) => writeln!(writer, "{content}"),
StdoutBackend::Pager(buffer) => {
buffer.write(content)?;
buffer.write("\n")
}
}
}
pub fn write_diagnostic(&mut self, content: &str) -> io::Result<()> {
writeln!(self.stderr, "{content}")
}
#[allow(dead_code)]
pub fn flush_stderr(&mut self) -> io::Result<()> {
self.stderr.flush()
}
pub fn finish(self) -> io::Result<PagerExitStatus> {
match self.stdout {
StdoutBackend::Direct(_) => Ok(PagerExitStatus::Success),
StdoutBackend::Pager(buffer) => buffer.finish(),
}
}
pub fn finish_checked(self) -> anyhow::Result<()> {
let status = self.finish()?;
if let Some(code) = status.exit_code() {
return Err(crate::error::CliError::pager_exit(code).into());
}
Ok(())
}
}
impl Default for OutputStreams {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
pub struct TestOutputStreams {
pub stdout: std::sync::Arc<std::sync::Mutex<Vec<u8>>>,
pub stderr: std::sync::Arc<std::sync::Mutex<Vec<u8>>>,
}
#[cfg(test)]
impl TestOutputStreams {
#[must_use]
pub fn new() -> (Self, OutputStreams) {
let stdout = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let stderr = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let test = Self {
stdout: std::sync::Arc::clone(&stdout),
stderr: std::sync::Arc::clone(&stderr),
};
let streams = OutputStreams {
stdout: StdoutBackend::Direct(Box::new(SharedWriter(stdout))),
stderr: Box::new(SharedWriter(stderr)),
};
(test, streams)
}
#[must_use]
pub fn stdout_string(&self) -> String {
let guard = self.stdout.lock().unwrap();
String::from_utf8_lossy(&guard).to_string()
}
#[must_use]
pub fn stderr_string(&self) -> String {
let guard = self.stderr.lock().unwrap();
String::from_utf8_lossy(&guard).to_string()
}
}
#[cfg(test)]
struct SharedWriter(std::sync::Arc<std::sync::Mutex<Vec<u8>>>);
#[cfg(test)]
impl Write for SharedWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut guard = self.0.lock().unwrap();
guard.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_output_streams_creation() {
let _streams = OutputStreams::new();
}
#[test]
fn test_default() {
let _streams = OutputStreams::default();
}
#[test]
fn test_output_streams_capture() {
let (test, mut streams) = TestOutputStreams::new();
streams.write_result("hello").unwrap();
streams.write_diagnostic("world").unwrap();
assert_eq!(test.stdout_string(), "hello\n");
assert_eq!(test.stderr_string(), "world\n");
}
#[test]
fn test_finish_non_pager_returns_success() {
let streams = OutputStreams::new();
let status = streams.finish().unwrap();
assert!(status.is_success());
}
}