mod connect;
pub mod handler;
pub mod ipc;
#[cfg(test)]
mod tests;
mod transfer;
use std::fmt;
use std::net::SocketAddr;
use std::sync::Arc;
use russh::ChannelMsg;
use russh::client;
pub use self::connect::{Client, ClientOptions};
pub use crate::session::SessionState;
pub use crate::session::pty::PtyOptions;
use crate::error::{ClientError, Result};
use crate::session::pty::PtySize;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TransferProgress {
pub transferred: u64,
pub total: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ResolvedTarget {
Ticket(crate::transport::ticket::Ticket),
WormholeCode(String),
}
impl From<crate::transport::ticket::Ticket> for ResolvedTarget {
fn from(ticket: crate::transport::ticket::Ticket) -> Self {
Self::Ticket(ticket)
}
}
impl TransferProgress {
pub(crate) fn new(transferred: u64, total: u64) -> Self {
Self { transferred, total }
}
pub fn percent(&self) -> u8 {
self.transferred
.saturating_mul(100)
.checked_div(self.total)
.unwrap_or(100)
.min(100) as u8
}
}
pub struct Session {
pub(crate) handle: Arc<tokio::sync::RwLock<client::Handle<handler::ClientHandler>>>,
pub(super) handler: handler::ClientHandler,
channel: Option<russh::Channel<russh::client::Msg>>,
connection: Option<iroh::endpoint::Connection>,
endpoint: Option<iroh::Endpoint>,
remote_metadata: Option<crate::transport::metadata::PeerMetadata>,
state: SessionState,
}
impl fmt::Debug for Session {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Session")
.field("state", &self.state)
.field("has_metadata", &self.remote_metadata.is_some())
.finish()
}
}
#[derive(Debug, Clone, Default)]
pub struct ExecOutput {
pub stdout: Vec<u8>,
pub stderr: Vec<u8>,
pub exit_status: u32,
}
impl Session {
pub fn state(&self) -> SessionState {
self.state
}
pub fn remote_metadata(&self) -> Option<&crate::transport::metadata::PeerMetadata> {
self.remote_metadata.as_ref()
}
pub async fn request_pty(&mut self, options: PtyOptions) -> Result<()> {
let size = options.size();
let channel = self.ensure_channel().await?;
channel
.request_pty(
true,
options.term(),
size.cols as u32,
size.rows as u32,
size.pixel_width as u32,
size.pixel_height as u32,
options.modes_slice(),
)
.await
.map_err(|e| ClientError::PtyRequestFailed { source: e })?;
Ok(())
}
pub async fn start_shell(&mut self) -> Result<()> {
let channel = self.ensure_channel().await?;
channel
.request_shell(true)
.await
.map_err(|e| ClientError::ShellRequestFailed { source: e })?;
self.state = SessionState::ShellReady;
Ok(())
}
pub async fn exec(&mut self, command: &str) -> Result<()> {
let channel = self.ensure_channel().await?;
channel
.exec(true, command)
.await
.map_err(|e| ClientError::ExecFailed { source: e })?;
self.state = SessionState::ShellReady;
Ok(())
}
pub(crate) async fn ensure_channel(
&mut self,
) -> Result<&mut russh::Channel<russh::client::Msg>> {
if self.channel.is_none() {
let handle = self.handle.read().await;
let channel = handle
.channel_open_session()
.await
.map_err(|e| ClientError::ChannelOpenFailed { source: e })?;
self.channel = Some(channel);
}
self.channel.as_mut().ok_or_else(|| {
ClientError::ChannelOpenFailed {
source: russh::Error::ChannelOpenFailure(russh::ChannelOpenFailure::ConnectFailed),
}
.into()
})
}
pub async fn capture_exec(&mut self, command: &str) -> Result<ExecOutput> {
let handle = self.handle.read().await;
let mut channel = handle
.channel_open_session()
.await
.map_err(|e| ClientError::ChannelOpenFailed { source: e })?;
channel
.exec(true, command)
.await
.map_err(|e| ClientError::ExecFailed { source: e })?;
let mut output = ExecOutput::default();
loop {
match channel.wait().await {
Some(ChannelMsg::Data { data }) => {
output.stdout.extend_from_slice(&data);
}
Some(ChannelMsg::ExtendedData { data, ext: 1 }) => {
output.stderr.extend_from_slice(&data);
}
Some(ChannelMsg::ExitStatus { exit_status }) => {
output.exit_status = exit_status;
}
Some(ChannelMsg::Close) | None => break,
_ => {}
}
}
Ok(output)
}
pub async fn local_forward(
&self,
local_addr: impl tokio::net::ToSocketAddrs,
remote_host: String,
remote_port: u32,
) -> Result<(tokio::task::JoinHandle<()>, SocketAddr)> {
let listener = tokio::net::TcpListener::bind(local_addr)
.await
.map_err(|e| ClientError::TunnelFailed {
details: format!("failed to bind local listener: {}", e),
})?;
let bound_addr = listener
.local_addr()
.map_err(|e| ClientError::TunnelFailed {
details: format!("failed to resolve bound local address: {}", e),
})?;
let handle = self.handle.clone();
let join_handle = tokio::spawn(async move {
tracing::info!(
"Local port forwarding active on {:?}",
listener.local_addr()
);
loop {
let Ok((stream, addr)) = listener.accept().await else {
break;
};
tracing::debug!("Accepted local connection for tunnel from {:?}", addr);
let handle = handle.clone();
let remote_host = remote_host.clone();
tokio::spawn(async move {
let handle = handle.read().await;
let channel = match handle
.channel_open_direct_tcpip(
&remote_host,
remote_port,
&addr.ip().to_string(),
addr.port() as u32,
)
.await
{
Ok(c) => c,
Err(err) => {
tracing::warn!(
"Failed to open direct-tcpip channel for {}: {}: {}",
remote_host,
remote_port,
err
);
return;
}
};
let (mut reader, mut writer) = tokio::io::split(stream);
let (mut channel_reader, mut channel_writer) =
tokio::io::split(channel.into_stream());
let _ = tokio::select! {
res = tokio::io::copy(&mut reader, &mut channel_writer) => res,
res = tokio::io::copy(&mut channel_reader, &mut writer) => res,
};
});
}
});
Ok((join_handle, bound_addr))
}
pub async fn send(&mut self, data: &[u8]) -> Result<()> {
let channel = self.ensure_channel().await?;
channel
.data(data)
.await
.map_err(|e| ClientError::DataSendFailed { source: e }.into())
}
pub async fn eof(&mut self) -> Result<()> {
let channel = self.ensure_channel().await?;
channel
.eof()
.await
.map_err(|e| ClientError::EofSendFailed { source: e }.into())
}
pub async fn resize(&mut self, size: PtySize) -> Result<()> {
let channel = self.ensure_channel().await?;
channel
.window_change(
size.cols as u32,
size.rows as u32,
size.pixel_width as u32,
size.pixel_height as u32,
)
.await
.map_err(|e| ClientError::WindowChangeFailed { source: e }.into())
}
pub async fn next_event(&mut self) -> Result<Option<SessionEvent>> {
let Some(channel) = self.channel.as_mut() else {
return Ok(None);
};
match channel.wait().await {
Some(msg) => {
tracing::debug!("Received low-level SSH message: {:?}", msg);
Ok(Some(SessionEvent::from(msg)))
}
None => {
tracing::debug!("Low-level SSH event stream ended (None)");
self.state = SessionState::Closed;
Ok(None)
}
}
}
pub async fn disconnect(&mut self) -> Result<()> {
if let Some(channel) = self.channel.take() {
let _ = channel.close().await;
}
let handle = self.handle.read().await;
handle
.disconnect(russh::Disconnect::ByApplication, "", "en-US")
.await
.map_err(|e| ClientError::DisconnectFailed { source: e })?;
if let Some(conn) = self.connection.take() {
conn.close(0u32.into(), b"Session disconnected");
}
if let Some(endpoint) = self.endpoint.take() {
endpoint.close().await;
}
self.state = SessionState::Closed;
Ok(())
}
pub async fn remote_forward(
&self,
remote_host: String,
remote_port: u32,
local_host: String,
local_port: u16,
) -> Result<()> {
let mut handle = self.handle.write().await;
handle
.tcpip_forward(remote_host.clone(), remote_port)
.await
.map_err(|e| ClientError::TunnelFailed {
details: format!("server rejected remote forward request: {}", e),
})?;
self.handler
.register_remote_tunnel(remote_host, remote_port, local_host, local_port);
Ok(())
}
pub async fn remote_completion(&mut self, path: &str) -> Result<Vec<String>> {
let mut stream = self.open_transfer_stream("completion unavailable").await?;
crate::transport::transfer::write_completion_request(
&mut stream,
&crate::transport::transfer::CompletionRequest {
path: path.to_string(),
},
)
.await
.map_err(crate::error::TransportError::from)?;
match crate::transport::transfer::read_next_frame(&mut stream)
.await
.map_err(crate::error::TransportError::from)?
{
crate::transport::transfer::TransferFrame::CompletionResponse(res) => Ok(res.matches),
crate::transport::transfer::TransferFrame::Error(failure) => {
Err(ClientError::TransferRejected {
details: failure.to_string(),
}
.into())
}
other => Err(ClientError::DownloadFailed {
details: format!("unexpected completion frame: {other:?}"),
}
.into()),
}
}
pub async fn close(mut self) -> Result<()> {
self.disconnect().await
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SessionEvent {
Data(Vec<u8>),
ExtendedData(Vec<u8>, u32),
ExitStatus(u32),
ExitSignal {
signal: String,
core_dumped: bool,
error_message: String,
lang_tag: String,
},
Closed,
Ignore,
}
impl From<ChannelMsg> for SessionEvent {
fn from(msg: ChannelMsg) -> Self {
match msg {
ChannelMsg::Data { data } => Self::Data(data.to_vec()),
ChannelMsg::ExtendedData { data, ext } => Self::ExtendedData(data.to_vec(), ext),
ChannelMsg::ExitStatus { exit_status } => Self::ExitStatus(exit_status),
ChannelMsg::ExitSignal {
signal_name,
core_dumped,
error_message,
lang_tag,
} => Self::ExitSignal {
signal: format!("{:?}", signal_name),
core_dumped,
error_message: error_message.to_string(),
lang_tag: lang_tag.to_string(),
},
ChannelMsg::Eof => Self::Ignore,
ChannelMsg::Close => Self::Closed,
_ => Self::Ignore,
}
}
}