async_ssh2_russh/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![warn(missing_docs)]
3#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/", env!("CARGO_PKG_README")))]
4
5use std::collections::HashMap;
6use std::ops::Deref;
7use std::path::Path;
8use std::sync::Arc;
9
10use russh::client::{connect, Config, Handle, Handler, Msg};
11use russh::keys::{load_secret_key, ssh_key, PrivateKeyWithHashAlg};
12use russh::{ChannelMsg, ChannelWriteHalf, CryptoVec};
13use tokio::io::AsyncWrite;
14use tokio::net::ToSocketAddrs;
15use tokio::sync::mpsc;
16use tokio::task::JoinHandle;
17
18// `pub` items
19#[cfg(feature = "sftp")]
20#[cfg_attr(docsrs, doc(cfg(feature = "sftp")))]
21pub mod sftp;
22use async_promise::Promise;
23#[doc(no_inline)]
24pub use russh::Error as SshError;
25#[cfg(feature = "sftp")]
26#[cfg_attr(docsrs, doc(cfg(feature = "sftp")))]
27pub use russh_sftp;
28use tracing::Instrument;
29pub use {async_promise, russh, tokio};
30mod read_stream;
31pub use read_stream::ReadStream;
32
33/// A handler that does NOT check the server's public key.
34///
35/// This should NOT be used unless you are certain that the SSH server is trusted and you are aware of the security
36/// implications of not verifying the server's public key, particularly the risk of man-in-the-middle (MITM) attacks.
37///
38/// This should only be used with public key authentication, as it provides
39/// [some protection against MITM attacks](https://security.stackexchange.com/questions/67242/does-public-key-auth-in-ssh-prevent-most-mitm-attacks).
40pub struct NoCheckHandler;
41impl Handler for NoCheckHandler {
42    type Error = SshError;
43
44    async fn check_server_key(&mut self, _server_public_key: &ssh_key::PublicKey) -> Result<bool, Self::Error> {
45        Ok(true)
46    }
47}
48
49/// An SSH session, which may open multiple [`AsyncChannel`]s.
50///
51/// This struct is a thin wrapper around [`russh::client::Handle`] which provides basic authentication and channel
52/// management for a SSH session. Implements [`Deref`] to allow access to the underlying [`russh::client::Handle`].
53pub struct AsyncSession<H: Handler> {
54    session: Handle<H>,
55}
56impl<H: 'static + Handler> AsyncSession<H> {
57    /// Connect to an SSH server using the provided configuration and handler, without beginning
58    /// authentication.
59    pub async fn connect_unauthenticated(
60        config: Arc<Config>,
61        addrs: impl ToSocketAddrs,
62        handler: H,
63    ) -> Result<Self, H::Error> {
64        let session = connect(config, addrs, handler).await?;
65        Ok(Self { session })
66    }
67
68    /// Opens an [`AsyncChannel`] in this session.
69    ///
70    /// [`AsyncChannel`] is the asnyc wrapper for [`russh::Channel`].
71    pub async fn open_channel(&self) -> Result<AsyncChannel, SshError> {
72        let russh_channel = self.session.channel_open_session().await?;
73        Ok(AsyncChannel::from(russh_channel))
74    }
75}
76
77impl AsyncSession<NoCheckHandler> {
78    /// Connect to an SSH server and authenticate with the given `user` and `key_path` via publickey
79    /// authentication.
80    ///
81    /// Uses [`NoCheckHandler`] to skip server public key verification, as publickey authentication provides protection
82    /// against MITM attacks.
83    pub async fn connect_publickey(
84        config: impl Into<Arc<Config>>,
85        addrs: impl ToSocketAddrs,
86        user: impl Into<String>,
87        key_path: impl AsRef<Path>,
88    ) -> Result<Self, SshError> {
89        let key_pair = load_secret_key(key_path, None)?;
90
91        let mut session = connect(config.into(), addrs, NoCheckHandler).await?;
92
93        // use publickey authentication.
94        let auth_res = session
95            .authenticate_publickey(
96                user,
97                PrivateKeyWithHashAlg::new(Arc::new(key_pair), session.best_supported_rsa_hash().await?.flatten()),
98            )
99            .await?;
100
101        if auth_res.success() {
102            Ok(Self { session })
103        } else {
104            Err(SshError::NotAuthenticated)
105        }
106    }
107}
108
109impl<H: Handler> Deref for AsyncSession<H> {
110    type Target = Handle<H>;
111    fn deref(&self) -> &Self::Target {
112        &self.session
113    }
114}
115
116/// An asynchronous SSH channel, one of possibly many within a single SSH [`AsyncSession`]. Each channel represents a
117/// separate command, shell, SFTP session, X11 forwarding, or other SSH subsystem.
118///
119/// This struct is a thin wrapper around [`russh::Channel`] which provides access to async read/write streams
120/// (stdout/stderr/stdin) and async event handling Implements [`Deref`] to allow access to the underlying
121/// [`russh::ChannelWriteHalf`].
122///
123/// # Shutdown Lifecycle
124///
125/// During shutdown, events _may_ be received in the following order. However this should not be relied upon, as the
126/// order may be different and none of these events are guaranteed to occur, except for [`Self::wait_close`] which will
127/// always happen last.
128///
129/// 1. [`Self::recv_success_failure`].
130/// 2. [`Self::recv_eof`] - Guarantees all stream data has been received, i.e. stdout/stderr will produce no more data.
131///    Channels may be closed without sending EOF; see [this StackOverflow answer](https://stackoverflow.com/a/23257958).
132/// 3. [`Self::recv_exit_status`] - The exit status of the command run, if applicable.
133/// 4. [`Self::wait_close`] - This channel is closed, no more events will occur.
134pub struct AsyncChannel {
135    write_half: ChannelWriteHalf<Msg>,
136    subscribe_send: mpsc::UnboundedSender<(Option<u32>, mpsc::UnboundedSender<CryptoVec>)>,
137    success_failure: Promise<bool>,
138    eof: Promise<()>,
139    exit_status: Promise<u32>,
140    reader: JoinHandle<()>,
141}
142
143impl From<russh::Channel<Msg>> for AsyncChannel {
144    fn from(inner: russh::Channel<Msg>) -> Self {
145        let (mut read_half, write_half) = inner.split();
146        let (mut resolve_success_failure, success_failure) = async_promise::channel();
147        let (mut resolve_eof, eof) = async_promise::channel();
148        let (mut resolve_exit_status, exit_status) = async_promise::channel();
149        let (subscribe_send, mut subscribe_recv) = mpsc::unbounded_channel();
150
151        let reader = async move {
152            // Map from `ext` to a sender for `CryptoVec`s of data.
153            type Subscribers = HashMap<Option<u32>, mpsc::UnboundedSender<CryptoVec>>;
154            let mut subscribers = Some(Subscribers::new());
155
156            #[tracing::instrument(level = "INFO", skip_all, fields(?ext))]
157            fn receive_data(subscribers: &Option<Subscribers>, ext: Option<u32>, data: CryptoVec) {
158                if let Some(subscribers) = &subscribers {
159                    if let Some(send) = subscribers.get(&ext) {
160                        if let Err(e) = send.send(data) {
161                            tracing::warn!("Failed to send data to subscriber: {e}");
162                        } else {
163                            tracing::debug!("Successfully sent data to subscriber.");
164                        }
165                    } else {
166                        tracing::debug!("No subscriber for ext, dropping data.");
167                    }
168                } else {
169                    tracing::warn!("Unexpectedly received data from server after receiving EOF.");
170                }
171            }
172
173            loop {
174                tokio::select! {
175                    biased;
176                    Some((ext, send)) = subscribe_recv.recv() => {
177                        if let Some(subscribers) = &mut subscribers {
178                            subscribers.insert(ext, send);
179                        } else {
180                            tracing::debug!(ext, "Received stream subscriber after EOF, ignoring.");
181                        }
182                    },
183                    opt_msg = read_half.wait() => {
184                        let Some(msg) = opt_msg else {
185                            // No more messages, exit!
186                            break;
187                        };
188
189                        tracing::info_span!("Message", ?msg).in_scope(|| {
190                            match msg {
191                                ChannelMsg::Data { data } => receive_data(&subscribers, None, data),
192                                ChannelMsg::ExtendedData { data, ext } => receive_data(&subscribers, Some(ext), data),
193                                ChannelMsg::Success | ChannelMsg::Failure => {
194                                    tracing::debug!("Resolving success/failure.");
195                                    let is_success = matches!(msg, ChannelMsg::Success);
196                                    if resolve_success_failure.resolve(is_success).is_err() {
197                                        tracing::warn!("Success/failure already resolved, ignoring.");
198                                    }
199                                }
200                                // The command has indicated no more `ChannelMsg::Data`/`ChannelMsg::ExtendedData` will be
201                                // sent.
202                                ChannelMsg::Eof => {
203                                    tracing::debug!("Resolving EOF and dropping stream subscribers.");
204                                    if resolve_eof.resolve(()).is_err() {
205                                        tracing::warn!("EOF already resolved, ignoring.");
206                                    }
207                                    // Disconnect all subscribers.
208                                    drop(std::mem::take(&mut subscribers));
209                                }
210                                // The command has returned an exit code
211                                ChannelMsg::ExitStatus { exit_status } => {
212                                    tracing::debug!(exit_status, "Resolving exit status.");
213                                    if resolve_exit_status.resolve(exit_status).is_err() {
214                                        tracing::warn!("Exit status already resolved, ignoring.");
215                                    }
216                                }
217                                // Other
218                                _ => {
219                                    tracing::trace!("Ignoring message.");
220                                }
221                            }
222                        });
223                    },
224                }
225            }
226            tracing::debug!("Channel read half finished, reader exiting.");
227            // Exiting causes the `self.reader` `JoinHandle` to close.
228        };
229        let reader = tokio::task::spawn(reader.instrument(tracing::info_span!("Reader")));
230
231        Self {
232            write_half,
233            subscribe_send,
234            success_failure,
235            eof,
236            exit_status,
237            reader,
238        }
239    }
240}
241
242impl AsyncChannel {
243    /// Returns the specified stream as a [`ReadStream`].
244    ///
245    /// Note that the returned stream will only receive data after this call, so call this before calling
246    /// [`exec`](ChannelWriteHalf::exec).
247    ///
248    /// When this is called for the same `ext` more than once, the later call will disconnect the
249    /// first.
250    pub fn read_stream(&self, ext: Option<u32>) -> ReadStream {
251        let (send, recv) = mpsc::unbounded_channel();
252        let _ = self.subscribe_send.send((ext, send));
253        ReadStream::from_recv(recv)
254    }
255
256    /// Returns stdout as a [`ReadStream`].
257    ///
258    /// Note that the returned stream will only receive data after this call, so call this before calling
259    /// [`exec`](ChannelWriteHalf::exec).
260    ///
261    /// When this is called more than once, the later call will disconnect the first.
262    pub fn stdout(&self) -> ReadStream {
263        self.read_stream(None)
264    }
265
266    /// Returns stderr as a [`ReadStream`].
267    ///
268    /// Note that the returned stream will only receive data after this call, so call this before calling
269    /// [`exec`](ChannelWriteHalf::exec).
270    ///
271    /// When this is called more than once, the later call will disconnect the first.
272    pub fn stderr(&self) -> ReadStream {
273        self.read_stream(Some(1))
274    }
275
276    /// Returns the specified stream as an [`impl AsyncWrite`](AsyncWrite).
277    ///
278    /// When this is called for the same `ext` more than once, writes to each may be interleaved.
279    #[expect(impl_trait_overcaptures, reason = "fix when upgrading to edition 2024.")]
280    pub fn write_stream(&self, ext: Option<u32>) -> impl AsyncWrite {
281        self.write_half.make_writer_ext(ext)
282    }
283
284    /// Returns stdin as an [`impl AsyncWrite`](AsyncWrite).
285    ///
286    /// When this is called more than once, writes to each may be interleaved.
287    #[expect(impl_trait_overcaptures, reason = "fix when upgrading to edition 2024.")]
288    pub fn stdin(&self) -> impl AsyncWrite {
289        self.write_stream(None)
290    }
291
292    /// Resolves when success or failure has been received, where `true` indicates success.
293    pub fn recv_success_failure(&self) -> &Promise<bool> {
294        &self.success_failure
295    }
296
297    /// Resolves when EOF has been received, indicating all stream data is complete.
298    ///
299    /// At that point, any streams from [`Self::stdout`]/[`Self::stderr`]/[`Self::read_stream`]
300    /// will return no additional data.
301    pub fn recv_eof(&self) -> &Promise<()> {
302        &self.eof
303    }
304
305    /// Resolves when the command exit status has been received.
306    pub fn recv_exit_status(&self) -> &Promise<u32> {
307        &self.exit_status
308    }
309
310    /// Returns when the channel has been closed.
311    ///
312    /// After this point, no more events will resolve.
313    pub async fn wait_close(&mut self) {
314        let _ = (&mut self.reader).await;
315    }
316
317    /// Returns if the channel has been closed. See [`Self::wait_close`].
318    pub fn is_closed(&self) -> bool {
319        self.reader.is_finished()
320    }
321}
322
323impl Deref for AsyncChannel {
324    type Target = ChannelWriteHalf<Msg>;
325    fn deref(&self) -> &Self::Target {
326        &self.write_half
327    }
328}