use crate::codec::ConfabCodec;
use crate::errors::{InetError, InterfaceError, IoError};
use crate::events::Event;
use crate::input::{readline_stream, Input, StartupScript};
use crate::tls;
use crate::util::{now_hms, CharEncoding};
use futures_util::{SinkExt, Stream, StreamExt};
use rustyline_async::{Readline, SharedWriter};
use std::fs::File;
use std::io::{self, Write};
use std::num::NonZeroUsize;
use std::process::ExitCode;
use tokio::net::TcpStream;
use tokio_util::{codec::Framed, either::Either};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum ConnectState {
Open,
Closed,
}
pub(crate) struct Runner {
pub(crate) startup_script: Option<StartupScript>,
pub(crate) reporter: Reporter,
pub(crate) connector: Connector,
}
impl Runner {
pub(crate) async fn run(mut self) -> Result<ExitCode, InterfaceError> {
match self.try_run().await {
Ok(()) => Ok(ExitCode::SUCCESS),
Err(IoError::Interface(e)) => Err(e),
Err(IoError::Inet(e)) => {
self.reporter.report(Event::error(anyhow::Error::new(e)))?;
Ok(ExitCode::FAILURE)
}
}
}
async fn try_run(&mut self) -> Result<(), IoError> {
let mut frame = self.connector.connect(&mut self.reporter).await?;
if let Some(script) = self.startup_script.take() {
let r = ioloop(&mut frame, script, &mut self.reporter).await;
if let Err(e) = r {
let _ = frame.close().await;
return Err(e);
} else if r.is_ok_and(|cs| cs == ConnectState::Closed) {
frame.close().await?;
self.reporter.report(Event::disconnect())?;
return Ok(());
}
}
let (mut rl, shared) = init_readline()?;
self.reporter.set_writer(Box::new(shared));
let mut r = ioloop(&mut frame, readline_stream(&mut rl), &mut self.reporter)
.await
.map(|_| ());
let r2 = frame.close().await.map_err(IoError::from);
if r.is_ok() {
r = r2;
}
if r.is_ok() {
r = self
.reporter
.report(Event::disconnect())
.map_err(IoError::from);
}
let _ = rl.flush();
self.reporter.set_writer(Box::new(io::stdout()));
r
}
}
pub(crate) struct Reporter {
pub(crate) writer: Box<dyn Write + Send>,
pub(crate) transcript: Option<File>,
pub(crate) show_times: bool,
}
impl Reporter {
fn set_writer(&mut self, writer: Box<dyn Write + Send>) {
self.writer = writer;
}
fn report(&mut self, event: Event) -> Result<(), InterfaceError> {
self.report_inner(event).map_err(InterfaceError::Write)
}
fn report_inner(&mut self, event: Event) -> Result<(), io::Error> {
writeln!(self.writer, "{}", event.to_message(self.show_times))?;
if let Some(fp) = self.transcript.as_mut() {
if let Err(e) = writeln!(fp, "{}", event.to_json()) {
let _ = self.transcript.take();
if self.show_times {
write!(self.writer, "[{}] ", now_hms())?;
}
writeln!(self.writer, "! Error writing to transcript: {e}")?;
}
}
Ok(())
}
fn echo_ctrlc(&mut self) -> Result<(), InterfaceError> {
writeln!(self.writer, "^C").map_err(InterfaceError::Write)
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) struct Connector {
pub(crate) tls: bool,
pub(crate) host: String,
pub(crate) port: u16,
pub(crate) servername: Option<String>,
pub(crate) encoding: CharEncoding,
pub(crate) max_line_length: NonZeroUsize,
pub(crate) crlf: bool,
}
impl Connector {
async fn connect(&self, reporter: &mut Reporter) -> Result<Connection, IoError> {
reporter.report(Event::connect_start(&self.host, self.port))?;
let conn = TcpStream::connect((&*self.host, self.port))
.await
.map_err(InetError::Connect)?;
reporter.report(Event::connect_finish(
conn.peer_addr().map_err(InetError::PeerAddr)?,
))?;
let conn = if self.tls {
reporter.report(Event::tls_start())?;
let conn = tls::connect(conn, self.servername.as_ref().unwrap_or(&self.host))
.await
.map_err(InetError::Tls)?;
reporter.report(Event::tls_finish())?;
Either::Right(conn)
} else {
Either::Left(conn)
};
Ok(Connection(Framed::new(conn, self.codec())))
}
fn codec(&self) -> ConfabCodec {
ConfabCodec::new_with_max_length(self.max_line_length.get())
.encoding(self.encoding)
.crlf(self.crlf)
}
}
#[derive(Debug)]
struct Connection(Framed<Either<TcpStream, tls::TlsStream>, ConfabCodec>);
impl Connection {
async fn recv(&mut self) -> Option<Result<String, InetError>> {
self.0.next().await.map(|r| r.map_err(InetError::Recv))
}
async fn send(&mut self, line: String) -> Result<String, InetError> {
let line = self.0.codec().prepare_line(line);
self.0.send(&line).await.map_err(InetError::Send)?;
Ok(line)
}
async fn close(&mut self) -> Result<(), InetError> {
SinkExt::<&str>::close(&mut self.0)
.await
.map_err(InetError::Close)
}
}
async fn ioloop<S>(
frame: &mut Connection,
input: S,
reporter: &mut Reporter,
) -> Result<ConnectState, IoError>
where
S: Stream<Item = Result<Input, InterfaceError>> + Send,
{
tokio::pin!(input);
loop {
tokio::select! {
r = frame.recv() => match r {
Some(Ok(msg)) => reporter.report(Event::recv(msg))?,
Some(Err(e)) => return Err(e.into()),
None => return Ok(ConnectState::Closed),
},
r = input.next() => match r {
Some(Ok(Input::Line(line))) => {
let line = frame.send(line).await?;
reporter.report(Event::send(line))?;
}
Some(Ok(Input::CtrlC)) => reporter.echo_ctrlc()?,
Some(Err(e)) => return Err(e.into()),
None => return Ok(ConnectState::Open),
}
}
}
}
fn init_readline() -> Result<(Readline, SharedWriter), InterfaceError> {
let (mut rl, shared) = Readline::new(String::from("confab> ")).map_err(InterfaceError::Init)?;
rl.should_print_line_on(false, false);
Ok((rl, shared))
}