use std;
use std::collections::{HashMap, VecDeque};
use std::num::Wrapping;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::Bytes;
use client::GexParams;
use futures::future::Future;
use log::{debug, error, info, warn};
use msg::{is_kex_msg, validate_client_msg_strict_kex};
use russh_util::runtime::JoinHandle;
use russh_util::time::Instant;
use ssh_key::{Certificate, PrivateKey};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, ToSocketAddrs};
use tokio::pin;
use tokio::sync::{broadcast, mpsc};
use crate::cipher::{clear, OpeningKey};
use crate::kex::dh::groups::{DhGroup, BUILTIN_SAFE_DH_GROUPS, DH_GROUP14};
use crate::kex::{KexProgress, SessionKexState};
use crate::session::*;
use crate::ssh_read::*;
use crate::sshbuffer::*;
use crate::{*};
mod kex;
mod session;
pub use self::session::*;
mod encrypted;
pub struct Config {
pub server_id: SshId,
pub methods: auth::MethodSet,
pub auth_rejection_time: std::time::Duration,
pub auth_rejection_time_initial: Option<std::time::Duration>,
pub keys: Vec<PrivateKey>,
pub limits: Limits,
pub window_size: u32,
pub maximum_packet_size: u32,
pub channel_buffer_size: usize,
pub event_buffer_size: usize,
pub preferred: Preferred,
pub max_auth_attempts: usize,
pub inactivity_timeout: Option<std::time::Duration>,
pub keepalive_interval: Option<std::time::Duration>,
pub keepalive_max: usize,
pub nodelay: bool,
}
impl Default for Config {
fn default() -> Config {
Config {
server_id: SshId::Standard(Cow::Borrowed(concat!(
"SSH-2.0-",
env!("CARGO_PKG_NAME"),
"_",
env!("CARGO_PKG_VERSION")
))),
methods: auth::MethodSet::all(),
auth_rejection_time: std::time::Duration::from_secs(1),
auth_rejection_time_initial: None,
keys: Vec::new(),
window_size: 2097152,
maximum_packet_size: 32768,
channel_buffer_size: 100,
event_buffer_size: 10,
limits: Limits::default(),
preferred: Default::default(),
max_auth_attempts: 10,
inactivity_timeout: Some(std::time::Duration::from_secs(600)),
keepalive_interval: None,
keepalive_max: 3,
nodelay: false,
}
}
}
impl Debug for Config {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Config")
.field("server_id", &self.server_id)
.field("methods", &self.methods)
.field("auth_rejection_time", &self.auth_rejection_time)
.field(
"auth_rejection_time_initial",
&self.auth_rejection_time_initial,
)
.field("keys", &"***")
.field("window_size", &self.window_size)
.field("maximum_packet_size", &self.maximum_packet_size)
.field("channel_buffer_size", &self.channel_buffer_size)
.field("event_buffer_size", &self.event_buffer_size)
.field("limits", &self.limits)
.field("preferred", &self.preferred)
.field("max_auth_attempts", &self.max_auth_attempts)
.field("inactivity_timeout", &self.inactivity_timeout)
.field("keepalive_interval", &self.keepalive_interval)
.field("keepalive_max", &self.keepalive_max)
.finish()
}
}
pub struct Response<'a>(&'a mut (dyn Iterator<Item = Option<Bytes>> + Send));
impl Iterator for Response<'_> {
type Item = Bytes;
fn next(&mut self) -> Option<Self::Item> {
self.0.next().flatten()
}
}
use std::borrow::Cow;
#[derive(Debug, PartialEq, Eq)]
pub enum Auth {
Reject {
proceed_with_methods: Option<MethodSet>,
partial_success: bool,
},
Accept,
UnsupportedMethod,
Partial {
name: Cow<'static, str>,
instructions: Cow<'static, str>,
prompts: Cow<'static, [(Cow<'static, str>, bool)]>,
},
}
impl Auth {
pub fn reject() -> Self {
Auth::Reject {
proceed_with_methods: None,
partial_success: false,
}
}
}
#[cfg_attr(feature = "async-trait", async_trait::async_trait)]
pub trait Handler: Sized {
type Error: From<crate::Error> + Send;
#[allow(unused_variables)]
fn auth_none(&mut self, user: &str) -> impl Future<Output = Result<Auth, Self::Error>> + Send {
async { Ok(Auth::reject()) }
}
#[allow(unused_variables)]
fn auth_password(
&mut self,
user: &str,
password: &str,
) -> impl Future<Output = Result<Auth, Self::Error>> + Send {
async { Ok(Auth::reject()) }
}
#[allow(unused_variables)]
fn auth_publickey_offered(
&mut self,
user: &str,
public_key: &ssh_key::PublicKey,
) -> impl Future<Output = Result<Auth, Self::Error>> + Send {
async { Ok(Auth::Accept) }
}
#[allow(unused_variables)]
fn auth_publickey(
&mut self,
user: &str,
public_key: &ssh_key::PublicKey,
) -> impl Future<Output = Result<Auth, Self::Error>> + Send {
async { Ok(Auth::reject()) }
}
#[allow(unused_variables)]
fn auth_openssh_certificate(
&mut self,
user: &str,
certificate: &Certificate,
) -> impl Future<Output = Result<Auth, Self::Error>> + Send {
async { Ok(Auth::reject()) }
}
#[allow(unused_variables)]
fn auth_keyboard_interactive<'a>(
&'a mut self,
user: &str,
submethods: &str,
response: Option<Response<'a>>,
) -> impl Future<Output = Result<Auth, Self::Error>> + Send {
async { Ok(Auth::reject()) }
}
#[allow(unused_variables)]
fn auth_succeeded(
&mut self,
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn authentication_banner(
&mut self,
) -> impl Future<Output = Result<Option<String>, Self::Error>> + Send {
async { Ok(None) }
}
#[allow(unused_variables)]
fn channel_close(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn channel_eof(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn channel_open_session(
&mut self,
channel: Channel<Msg>,
session: &mut Session,
) -> impl Future<Output = Result<bool, Self::Error>> + Send {
async { Ok(false) }
}
#[allow(unused_variables)]
fn channel_open_x11(
&mut self,
channel: Channel<Msg>,
originator_address: &str,
originator_port: u32,
session: &mut Session,
) -> impl Future<Output = Result<bool, Self::Error>> + Send {
async { Ok(false) }
}
#[allow(unused_variables)]
fn channel_open_direct_tcpip(
&mut self,
channel: Channel<Msg>,
host_to_connect: &str,
port_to_connect: u32,
originator_address: &str,
originator_port: u32,
session: &mut Session,
) -> impl Future<Output = Result<bool, Self::Error>> + Send {
async { Ok(false) }
}
#[allow(unused_variables)]
fn channel_open_forwarded_tcpip(
&mut self,
channel: Channel<Msg>,
host_to_connect: &str,
port_to_connect: u32,
originator_address: &str,
originator_port: u32,
session: &mut Session,
) -> impl Future<Output = Result<bool, Self::Error>> + Send {
async { Ok(false) }
}
#[allow(unused_variables)]
fn channel_open_direct_streamlocal(
&mut self,
channel: Channel<Msg>,
socket_path: &str,
session: &mut Session,
) -> impl Future<Output = Result<bool, Self::Error>> + Send {
async { Ok(false) }
}
#[allow(unused_variables)]
fn channel_open_confirmation(
&mut self,
id: ChannelId,
max_packet_size: u32,
window_size: u32,
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn data(
&mut self,
channel: ChannelId,
data: &[u8],
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn extended_data(
&mut self,
channel: ChannelId,
code: u32,
data: &[u8],
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn window_adjusted(
&mut self,
channel: ChannelId,
new_size: u32,
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn adjust_window(&mut self, channel: ChannelId, current: u32) -> u32 {
current
}
#[allow(unused_variables, clippy::too_many_arguments)]
fn pty_request(
&mut self,
channel: ChannelId,
term: &str,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
modes: &[(Pty, u32)],
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn x11_request(
&mut self,
channel: ChannelId,
single_connection: bool,
x11_auth_protocol: &str,
x11_auth_cookie: &str,
x11_screen_number: u32,
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn env_request(
&mut self,
channel: ChannelId,
variable_name: &str,
variable_value: &str,
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn shell_request(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn exec_request(
&mut self,
channel: ChannelId,
data: &[u8],
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn subsystem_request(
&mut self,
channel: ChannelId,
name: &str,
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn window_change_request(
&mut self,
channel: ChannelId,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn agent_request(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> impl Future<Output = Result<bool, Self::Error>> + Send {
async { Ok(false) }
}
#[allow(unused_variables)]
fn signal(
&mut self,
channel: ChannelId,
signal: Sig,
session: &mut Session,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
#[allow(unused_variables)]
fn tcpip_forward(
&mut self,
address: &str,
port: &mut u32,
session: &mut Session,
) -> impl Future<Output = Result<bool, Self::Error>> + Send {
async { Ok(false) }
}
#[allow(unused_variables)]
fn cancel_tcpip_forward(
&mut self,
address: &str,
port: u32,
session: &mut Session,
) -> impl Future<Output = Result<bool, Self::Error>> + Send {
async { Ok(false) }
}
#[allow(unused_variables)]
fn streamlocal_forward(
&mut self,
socket_path: &str,
session: &mut Session,
) -> impl Future<Output = Result<bool, Self::Error>> + Send {
async { Ok(false) }
}
#[allow(unused_variables)]
fn cancel_streamlocal_forward(
&mut self,
socket_path: &str,
session: &mut Session,
) -> impl Future<Output = Result<bool, Self::Error>> + Send {
async { Ok(false) }
}
#[allow(unused_variables)]
fn lookup_dh_gex_group(
&mut self,
gex_params: &GexParams,
) -> impl Future<Output = Result<Option<DhGroup>, Self::Error>> + Send {
async {
let mut best_group = &DH_GROUP14;
for group in BUILTIN_SAFE_DH_GROUPS.iter() {
if group.bit_size() >= gex_params.min_group_size()
&& group.bit_size() <= gex_params.max_group_size()
{
best_group = *group;
break;
}
}
for group in BUILTIN_SAFE_DH_GROUPS.iter() {
if group.bit_size() > gex_params.preferred_group_size() {
best_group = *group;
break;
}
}
Ok(Some(best_group.clone()))
}
}
}
pub struct RunningServerHandle {
shutdown_tx: broadcast::Sender<String>,
}
impl RunningServerHandle {
pub fn shutdown(&self, reason: String) {
let _ = self.shutdown_tx.send(reason);
}
}
pub struct RunningServer<F: Future<Output = std::io::Result<()>> + Unpin + Send> {
inner: F,
shutdown_tx: broadcast::Sender<String>,
}
impl<F: Future<Output = std::io::Result<()>> + Unpin + Send> RunningServer<F> {
pub fn handle(&self) -> RunningServerHandle {
RunningServerHandle {
shutdown_tx: self.shutdown_tx.clone(),
}
}
}
impl<F: Future<Output = std::io::Result<()>> + Unpin + Send> Future for RunningServer<F> {
type Output = std::io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
Future::poll(Pin::new(&mut self.inner), cx)
}
}
#[cfg_attr(feature = "async-trait", async_trait::async_trait)]
pub trait Server {
type Handler: Handler + Send + 'static;
fn new_client(&mut self, peer_addr: Option<std::net::SocketAddr>) -> Self::Handler;
fn handle_session_error(&mut self, _error: <Self::Handler as Handler>::Error) {}
fn run_on_socket(
&mut self,
config: Arc<Config>,
socket: &TcpListener,
) -> RunningServer<impl Future<Output = std::io::Result<()>> + Unpin + Send>
where
Self: Send,
{
let (shutdown_tx, mut shutdown_rx) = broadcast::channel(1);
let shutdown_tx2 = shutdown_tx.clone();
let fut = async move {
if config.maximum_packet_size > 65535 {
error!(
"Maximum packet size ({:?}) should not larger than a TCP packet (65535)",
config.maximum_packet_size
);
}
let (error_tx, mut error_rx) = mpsc::unbounded_channel();
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
debug!("Server shutdown requested");
return Ok(());
},
accept_result = socket.accept() => {
match accept_result {
Ok((socket, peer_addr)) => {
let mut shutdown_rx = shutdown_tx2.subscribe();
let config = config.clone();
let handler = self.new_client(Some(peer_addr));
let error_tx = error_tx.clone();
russh_util::runtime::spawn(async move {
if config.nodelay {
if let Err(e) = socket.set_nodelay(true) {
warn!("set_nodelay() failed: {e:?}");
}
}
let session = match run_stream(config, socket, handler).await {
Ok(s) => s,
Err(e) => {
debug!("Connection setup failed");
let _ = error_tx.send(e);
return
}
};
let handle = session.handle();
tokio::select! {
reason = shutdown_rx.recv() => {
if handle.disconnect(
Disconnect::ByApplication,
reason.unwrap_or_else(|_| "".into()),
"".into()
).await.is_err() {
debug!("Failed to send disconnect message");
}
},
result = session => {
if let Err(e) = result {
debug!("Connection closed with error");
let _ = error_tx.send(e);
} else {
debug!("Connection closed");
}
}
}
});
}
Err(e) => {
return Err(e);
}
}
},
Some(error) = error_rx.recv() => {
self.handle_session_error(error);
}
}
}
};
RunningServer {
inner: Box::pin(fut),
shutdown_tx,
}
}
fn run_on_address<A: ToSocketAddrs + Send>(
&mut self,
config: Arc<Config>,
addrs: A,
) -> impl Future<Output = std::io::Result<()>> + Send
where
Self: Send,
{
async {
let socket = TcpListener::bind(addrs).await?;
self.run_on_socket(config, &socket).await?;
Ok(())
}
}
}
async fn start_reading<R: AsyncRead + Unpin>(
mut stream_read: R,
mut buffer: SSHBuffer,
mut cipher: Box<dyn OpeningKey + Send>,
) -> Result<(usize, R, SSHBuffer, Box<dyn OpeningKey + Send>), Error> {
buffer.buffer.clear();
let n = cipher::read(&mut stream_read, &mut buffer, &mut *cipher).await?;
Ok((n, stream_read, buffer, cipher))
}
pub struct RunningSession<H: Handler> {
handle: Handle,
join: JoinHandle<Result<(), H::Error>>,
}
impl<H: Handler> RunningSession<H> {
pub fn handle(&self) -> Handle {
self.handle.clone()
}
}
impl<H: Handler> Future for RunningSession<H> {
type Output = Result<(), H::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match Future::poll(Pin::new(&mut self.join), cx) {
Poll::Ready(r) => Poll::Ready(match r {
Ok(Ok(x)) => Ok(x),
Err(e) => Err(crate::Error::from(e).into()),
Ok(Err(e)) => Err(e),
}),
Poll::Pending => Poll::Pending,
}
}
}
pub async fn run_stream<H, R>(
config: Arc<Config>,
mut stream: R,
handler: H,
) -> Result<RunningSession<H>, H::Error>
where
H: Handler + Send + 'static,
R: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let mut write_buffer = SSHBuffer::new();
write_buffer.send_ssh_id(&config.as_ref().server_id);
map_err!(stream.write_all(&write_buffer.buffer[..]).await)?;
let mut stream = SshRead::new(stream);
let (sender, receiver) = tokio::sync::mpsc::channel(config.event_buffer_size);
let handle = server::session::Handle {
sender,
channel_buffer_size: config.channel_buffer_size,
};
let common = read_ssh_id(config, &mut stream).await?;
let mut session = Session {
target_window_size: common.config.window_size,
common,
receiver,
sender: handle.clone(),
pending_reads: Vec::new(),
pending_len: 0,
channels: HashMap::new(),
open_global_requests: VecDeque::new(),
kex: SessionKexState::Idle,
};
session.begin_rekey()?;
let join = russh_util::runtime::spawn(session.run(stream, handler));
Ok(RunningSession { handle, join })
}
async fn read_ssh_id<R: AsyncRead + Unpin>(
config: Arc<Config>,
read: &mut SshRead<R>,
) -> Result<CommonSession<Arc<Config>>, Error> {
let sshid = if let Some(t) = config.inactivity_timeout {
tokio::time::timeout(t, read.read_ssh_id()).await??
} else {
read.read_ssh_id().await?
};
let session = CommonSession {
packet_writer: PacketWriter::clear(),
auth_user: String::new(),
auth_method: None, auth_attempts: 0,
remote_to_local: Box::new(clear::Key),
encrypted: None,
config,
wants_reply: false,
disconnected: false,
buffer: Vec::new(),
strict_kex: false,
alive_timeouts: 0,
received_data: false,
remote_sshid: sshid.into(),
};
Ok(session)
}
async fn reply<H: Handler + Send>(
session: &mut Session,
handler: &mut H,
pkt: &mut IncomingSshPacket,
) -> Result<(), H::Error> {
if let Some(message_type) = pkt.buffer.first() {
debug!(
"< msg type {message_type:?}, seqn {:?}, len {}",
pkt.seqn.0,
pkt.buffer.len()
);
if session.common.strict_kex && session.common.encrypted.is_none() {
let seqno = pkt.seqn.0 - 1; validate_client_msg_strict_kex(*message_type, seqno as usize)?;
}
if [msg::IGNORE, msg::UNIMPLEMENTED, msg::DEBUG].contains(message_type) {
return Ok(());
}
}
if pkt.buffer.first() == Some(&msg::KEXINIT) && session.kex == SessionKexState::Idle {
info!("Client has initiated re-key");
session.begin_rekey()?;
}
let is_kex_msg = pkt.buffer.first().cloned().map(is_kex_msg).unwrap_or(false);
if is_kex_msg {
if let SessionKexState::InProgress(kex) = session.kex.take() {
let progress = kex
.step(Some(pkt), &mut session.common.packet_writer, handler)
.await?;
match progress {
KexProgress::NeedsReply { kex, reset_seqn } => {
debug!("kex impl continues: {kex:?}");
session.kex = SessionKexState::InProgress(kex);
if reset_seqn {
debug!("kex impl requests seqno reset");
session.common.reset_seqn();
}
}
KexProgress::Done { newkeys, .. } => {
debug!("kex impl has completed");
session.common.strict_kex =
session.common.strict_kex || newkeys.names.strict_kex();
if let Some(ref mut enc) = session.common.encrypted {
enc.last_rekey = Instant::now();
session.common.packet_writer.buffer().bytes = 0;
enc.flush_all_pending()?;
let mut pending = std::mem::take(&mut session.pending_reads);
for p in pending.drain(..) {
session.process_packet(handler, &p).await?;
}
session.pending_reads = pending;
session.pending_len = 0;
session.common.newkeys(newkeys);
session.flush()?;
} else {
session.common.encrypted(
EncryptedState::WaitingAuthServiceRequest {
sent: false,
accepted: false,
},
newkeys,
);
session.maybe_send_ext_info()?;
}
session.kex = SessionKexState::Idle;
if session.common.strict_kex {
pkt.seqn = Wrapping(0);
}
debug!("kex done");
}
}
session.flush()?;
return Ok(());
}
}
session.server_read_encrypted(handler, pkt).await
}