async_ssh2_russh/lib.rs
1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![warn(missing_docs)]
3#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/", env!("CARGO_PKG_README")))]
4
5use std::collections::HashMap;
6use std::ops::Deref;
7use std::path::Path;
8use std::sync::Arc;
9
10use russh::client::{connect, Config, Handle, Handler, Msg};
11use russh::keys::{load_secret_key, ssh_key, PrivateKeyWithHashAlg};
12use russh::{ChannelMsg, ChannelWriteHalf, CryptoVec};
13use tokio::io::AsyncWrite;
14use tokio::net::ToSocketAddrs;
15use tokio::sync::mpsc;
16use tokio::task::JoinHandle;
17
18// `pub` items
19#[cfg(feature = "sftp")]
20#[cfg_attr(docsrs, doc(cfg(feature = "sftp")))]
21pub mod sftp;
22use async_promise::Promise;
23#[doc(no_inline)]
24pub use russh::Error as SshError;
25#[cfg(feature = "sftp")]
26#[cfg_attr(docsrs, doc(cfg(feature = "sftp")))]
27pub use russh_sftp;
28use tracing::Instrument;
29pub use {async_promise, russh, tokio};
30mod read_stream;
31pub use read_stream::ReadStream;
32
33/// A handler that does NOT check the server's public key.
34///
35/// This should NOT be used unless you are certain that the SSH server is trusted and you are aware of the security
36/// implications of not verifying the server's public key, particularly the risk of man-in-the-middle (MITM) attacks.
37///
38/// This should only be used with public key authentication, as it provides
39/// [some protection against MITM attacks](https://security.stackexchange.com/questions/67242/does-public-key-auth-in-ssh-prevent-most-mitm-attacks).
40pub struct NoCheckHandler;
41impl Handler for NoCheckHandler {
42 type Error = SshError;
43
44 async fn check_server_key(&mut self, _server_public_key: &ssh_key::PublicKey) -> Result<bool, Self::Error> {
45 Ok(true)
46 }
47}
48
49/// An SSH session, which may open multiple [`AsyncChannel`]s.
50///
51/// This struct is a thin wrapper around [`russh::client::Handle`] which provides basic authentication and channel
52/// management for a SSH session. Implements [`Deref`] to allow access to the underlying [`russh::client::Handle`].
53pub struct AsyncSession<H: Handler> {
54 session: Handle<H>,
55}
56impl<H: 'static + Handler> AsyncSession<H> {
57 /// Connect to an SSH server using the provided configuration and handler, without beginning
58 /// authentication.
59 pub async fn connect_unauthenticated(
60 config: Arc<Config>,
61 addrs: impl ToSocketAddrs,
62 handler: H,
63 ) -> Result<Self, H::Error> {
64 let session = connect(config, addrs, handler).await?;
65 Ok(Self { session })
66 }
67
68 /// Opens an [`AsyncChannel`] in this session.
69 ///
70 /// [`AsyncChannel`] is the asnyc wrapper for [`russh::Channel`].
71 pub async fn open_channel(&self) -> Result<AsyncChannel, SshError> {
72 let russh_channel = self.session.channel_open_session().await?;
73 Ok(AsyncChannel::from(russh_channel))
74 }
75}
76
77impl AsyncSession<NoCheckHandler> {
78 /// Connect to an SSH server and authenticate with the given `user` and `key_path` via publickey
79 /// authentication.
80 ///
81 /// Uses [`NoCheckHandler`] to skip server public key verification, as publickey authentication provides protection
82 /// against MITM attacks.
83 pub async fn connect_publickey(
84 config: impl Into<Arc<Config>>,
85 addrs: impl ToSocketAddrs,
86 user: impl Into<String>,
87 key_path: impl AsRef<Path>,
88 ) -> Result<Self, SshError> {
89 let key_pair = load_secret_key(key_path, None)?;
90
91 let mut session = connect(config.into(), addrs, NoCheckHandler).await?;
92
93 // use publickey authentication.
94 let auth_res = session
95 .authenticate_publickey(
96 user,
97 PrivateKeyWithHashAlg::new(Arc::new(key_pair), session.best_supported_rsa_hash().await?.flatten()),
98 )
99 .await?;
100
101 if auth_res.success() {
102 Ok(Self { session })
103 } else {
104 Err(SshError::NotAuthenticated)
105 }
106 }
107}
108
109impl<H: Handler> Deref for AsyncSession<H> {
110 type Target = Handle<H>;
111 fn deref(&self) -> &Self::Target {
112 &self.session
113 }
114}
115
116/// An asynchronous SSH channel, one of possibly many within a single SSH [`AsyncSession`]. Each channel represents a
117/// separate command, shell, SFTP session, X11 forwarding, or other SSH subsystem.
118///
119/// This struct is a thin wrapper around [`russh::Channel`] which provides access to async read/write streams
120/// (stdout/stderr/stdin) and async event handling Implements [`Deref`] to allow access to the underlying
121/// [`russh::ChannelWriteHalf`].
122///
123/// # Shutdown Lifecycle
124///
125/// During shutdown, events _may_ be received in the following order. However this should not be relied upon, as the
126/// order may be different and none of these events are guaranteed to occur, except for [`Self::wait_close`] which will
127/// always happen last.
128///
129/// 1. [`Self::recv_success_failure`].
130/// 2. [`Self::recv_eof`] - Guarantees all stream data has been received, i.e. stdout/stderr will produce no more data.
131/// Channels may be closed without sending EOF; see [this StackOverflow answer](https://stackoverflow.com/a/23257958).
132/// 3. [`Self::recv_exit_status`] - The exit status of the command run, if applicable.
133/// 4. [`Self::wait_close`] - This channel is closed, no more events will occur.
134pub struct AsyncChannel {
135 write_half: ChannelWriteHalf<Msg>,
136 subscribe_send: mpsc::UnboundedSender<(Option<u32>, mpsc::UnboundedSender<CryptoVec>)>,
137 success_failure: Promise<bool>,
138 eof: Promise<()>,
139 exit_status: Promise<u32>,
140 reader: JoinHandle<()>,
141}
142
143impl From<russh::Channel<Msg>> for AsyncChannel {
144 fn from(inner: russh::Channel<Msg>) -> Self {
145 let (mut read_half, write_half) = inner.split();
146 let (mut resolve_success_failure, success_failure) = async_promise::channel();
147 let (mut resolve_eof, eof) = async_promise::channel();
148 let (mut resolve_exit_status, exit_status) = async_promise::channel();
149 let (subscribe_send, mut subscribe_recv) = mpsc::unbounded_channel();
150
151 let reader = async move {
152 // Map from `ext` to a sender for `CryptoVec`s of data.
153 type Subscribers = HashMap<Option<u32>, mpsc::UnboundedSender<CryptoVec>>;
154 let mut subscribers = Some(Subscribers::new());
155
156 #[tracing::instrument(level = "INFO", skip_all, fields(?ext))]
157 fn receive_data(subscribers: &Option<Subscribers>, ext: Option<u32>, data: CryptoVec) {
158 if let Some(subscribers) = &subscribers {
159 if let Some(send) = subscribers.get(&ext) {
160 if let Err(e) = send.send(data) {
161 tracing::warn!("Failed to send data to subscriber: {e}");
162 } else {
163 tracing::debug!("Successfully sent data to subscriber.");
164 }
165 } else {
166 tracing::debug!("No subscriber for ext, dropping data.");
167 }
168 } else {
169 tracing::warn!("Unexpectedly received data from server after receiving EOF.");
170 }
171 }
172
173 loop {
174 tokio::select! {
175 biased;
176 Some((ext, send)) = subscribe_recv.recv() => {
177 if let Some(subscribers) = &mut subscribers {
178 subscribers.insert(ext, send);
179 } else {
180 tracing::debug!(ext, "Received stream subscriber after EOF, ignoring.");
181 }
182 },
183 opt_msg = read_half.wait() => {
184 let Some(msg) = opt_msg else {
185 // No more messages, exit!
186 break;
187 };
188
189 tracing::info_span!("Message", ?msg).in_scope(|| {
190 match msg {
191 ChannelMsg::Data { data } => receive_data(&subscribers, None, data),
192 ChannelMsg::ExtendedData { data, ext } => receive_data(&subscribers, Some(ext), data),
193 ChannelMsg::Success | ChannelMsg::Failure => {
194 tracing::debug!("Resolving success/failure.");
195 let is_success = matches!(msg, ChannelMsg::Success);
196 if resolve_success_failure.resolve(is_success).is_err() {
197 tracing::warn!("Success/failure already resolved, ignoring.");
198 }
199 }
200 // The command has indicated no more `ChannelMsg::Data`/`ChannelMsg::ExtendedData` will be
201 // sent.
202 ChannelMsg::Eof => {
203 tracing::debug!("Resolving EOF and dropping stream subscribers.");
204 if resolve_eof.resolve(()).is_err() {
205 tracing::warn!("EOF already resolved, ignoring.");
206 }
207 // Disconnect all subscribers.
208 drop(std::mem::take(&mut subscribers));
209 }
210 // The command has returned an exit code
211 ChannelMsg::ExitStatus { exit_status } => {
212 tracing::debug!(exit_status, "Resolving exit status.");
213 if resolve_exit_status.resolve(exit_status).is_err() {
214 tracing::warn!("Exit status already resolved, ignoring.");
215 }
216 }
217 // Other
218 _ => {
219 tracing::trace!("Ignoring message.");
220 }
221 }
222 });
223 },
224 }
225 }
226 tracing::debug!("Channel read half finished, reader exiting.");
227 // Exiting causes the `self.reader` `JoinHandle` to close.
228 };
229 let reader = tokio::task::spawn(reader.instrument(tracing::info_span!("Reader")));
230
231 Self {
232 write_half,
233 subscribe_send,
234 success_failure,
235 eof,
236 exit_status,
237 reader,
238 }
239 }
240}
241
242impl AsyncChannel {
243 /// Returns the specified stream as a [`ReadStream`].
244 ///
245 /// Note that the returned stream will only receive data after this call, so call this before calling
246 /// [`exec`](ChannelWriteHalf::exec).
247 ///
248 /// When this is called for the same `ext` more than once, the later call will disconnect the
249 /// first.
250 pub fn read_stream(&self, ext: Option<u32>) -> ReadStream {
251 let (send, recv) = mpsc::unbounded_channel();
252 let _ = self.subscribe_send.send((ext, send));
253 ReadStream::from_recv(recv)
254 }
255
256 /// Returns stdout as a [`ReadStream`].
257 ///
258 /// Note that the returned stream will only receive data after this call, so call this before calling
259 /// [`exec`](ChannelWriteHalf::exec).
260 ///
261 /// When this is called more than once, the later call will disconnect the first.
262 pub fn stdout(&self) -> ReadStream {
263 self.read_stream(None)
264 }
265
266 /// Returns stderr as a [`ReadStream`].
267 ///
268 /// Note that the returned stream will only receive data after this call, so call this before calling
269 /// [`exec`](ChannelWriteHalf::exec).
270 ///
271 /// When this is called more than once, the later call will disconnect the first.
272 pub fn stderr(&self) -> ReadStream {
273 self.read_stream(Some(1))
274 }
275
276 /// Returns the specified stream as an [`impl AsyncWrite`](AsyncWrite).
277 ///
278 /// When this is called for the same `ext` more than once, writes to each may be interleaved.
279 #[expect(impl_trait_overcaptures, reason = "fix when upgrading to edition 2024.")]
280 pub fn write_stream(&self, ext: Option<u32>) -> impl AsyncWrite {
281 self.write_half.make_writer_ext(ext)
282 }
283
284 /// Returns stdin as an [`impl AsyncWrite`](AsyncWrite).
285 ///
286 /// When this is called more than once, writes to each may be interleaved.
287 #[expect(impl_trait_overcaptures, reason = "fix when upgrading to edition 2024.")]
288 pub fn stdin(&self) -> impl AsyncWrite {
289 self.write_stream(None)
290 }
291
292 /// Resolves when success or failure has been received, where `true` indicates success.
293 pub fn recv_success_failure(&self) -> &Promise<bool> {
294 &self.success_failure
295 }
296
297 /// Resolves when EOF has been received, indicating all stream data is complete.
298 ///
299 /// At that point, any streams from [`Self::stdout`]/[`Self::stderr`]/[`Self::read_stream`]
300 /// will return no additional data.
301 pub fn recv_eof(&self) -> &Promise<()> {
302 &self.eof
303 }
304
305 /// Resolves when the command exit status has been received.
306 pub fn recv_exit_status(&self) -> &Promise<u32> {
307 &self.exit_status
308 }
309
310 /// Returns when the channel has been closed.
311 ///
312 /// After this point, no more events will resolve.
313 pub async fn wait_close(&mut self) {
314 let _ = (&mut self.reader).await;
315 }
316
317 /// Returns if the channel has been closed. See [`Self::wait_close`].
318 pub fn is_closed(&self) -> bool {
319 self.reader.is_finished()
320 }
321}
322
323impl Deref for AsyncChannel {
324 type Target = ChannelWriteHalf<Msg>;
325 fn deref(&self) -> &Self::Target {
326 &self.write_half
327 }
328}