korasi_cli/
ssh.rs

1use std::{fs::File, io::Read, path::Path, sync::Arc};
2
3use async_trait::async_trait;
4use russh::{
5    client::{self, Msg},
6    keys::{decode_secret_key, key},
7    Channel, ChannelId, ChannelMsg, Disconnect,
8};
9use russh_sftp::{client::SftpSession, protocol::OpenFlags};
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11
12use crate::util::{biject_paths, calc_prefix};
13
14pub const SSH_PORT: u16 = 22;
15
16pub struct ClientSSH;
17
18#[async_trait]
19impl client::Handler for ClientSSH {
20    type Error = anyhow::Error;
21
22    async fn check_server_key(
23        &mut self,
24        server_public_key: &key::PublicKey,
25    ) -> Result<bool, Self::Error> {
26        tracing::debug!("check_server_key: {:?}", server_public_key);
27        Ok(true)
28    }
29
30    async fn data(
31        &mut self,
32        channel: ChannelId,
33        data: &[u8],
34        _session: &mut client::Session,
35    ) -> Result<(), Self::Error> {
36        tracing::debug!("data on channel {:?}: {}", channel, data.len());
37        Ok(())
38    }
39}
40
41pub struct Session {
42    session: client::Handle<ClientSSH>,
43}
44
45impl Session {
46    /// Returns reusable remote channel that can used as a SSH/SFTP tunnel.
47    ///
48    /// This prevents direct access to private session.
49    pub async fn channel_open_session(&self) -> Result<Channel<Msg>, russh::Error> {
50        self.session.channel_open_session().await
51    }
52
53    /// Load a secret key, deciphering it with the supplied password if necessary.
54    pub fn load_secret_key<P: AsRef<Path>>(
55        secret_: P,
56        password: Option<&str>,
57    ) -> Result<key::KeyPair, anyhow::Error> {
58        let mut secret_file = std::fs::File::open(secret_)?;
59        let mut secret = String::new();
60        secret_file.read_to_string(&mut secret)?;
61        Ok(decode_secret_key(&secret, password)?)
62    }
63
64    /// Connect to remote instance via SSH.
65    ///
66    /// The public DNS name is the emphemeral host address generated when
67    /// an EC2 instance starts.
68    pub async fn connect(
69        user: &str,
70        public_dns_name: String,
71        ssh_key: String,
72    ) -> anyhow::Result<Self> {
73        let config = russh::client::Config {
74            inactivity_timeout: Some(std::time::Duration::from_secs(1200)), // 20 min.
75            ..<_>::default()
76        };
77        let mut session =
78            russh::client::connect(Arc::new(config), (public_dns_name, SSH_PORT), ClientSSH {})
79                .await
80                .expect("Failed to establish SSH connection with remote instance.");
81        let key_pair = Self::load_secret_key(ssh_key, None).unwrap();
82
83        session
84            .authenticate_publickey(user, Arc::new(key_pair))
85            .await?;
86
87        Ok(Self { session })
88    }
89
90    /// Executes a remote command using SSH.
91    pub async fn exec(&self, command: &str) -> anyhow::Result<u32> {
92        let mut channel = self.channel_open_session().await?;
93
94        // No terminal resizing after the connection is established.
95        let (w, h) = termion::terminal_size()?;
96        // Request an interactive PTY from the server.
97        channel
98            .request_pty(
99                false,
100                &std::env::var("TERM").unwrap_or("xterm".into()),
101                w as u32,
102                h as u32,
103                0,
104                0,
105                &[], // ideally you want to pass the actual terminal modes here
106            )
107            .await?;
108
109        channel.exec(true, command).await?;
110
111        let mut stdin = tokio_fd::AsyncFd::try_from(0)?;
112        let mut stdout = tokio_fd::AsyncFd::try_from(1)?;
113        let mut stderr = tokio_fd::AsyncFd::try_from(2)?;
114
115        let code;
116        let mut buf = vec![0; 1024];
117        let mut stdin_closed = false;
118
119        loop {
120            tokio::select! {
121                r = stdin.read(&mut buf), if !stdin_closed => {
122                    match r {
123                        Ok(0) => {
124                            stdin_closed = true;
125                            channel.eof().await?;
126                        },
127                        // Send it to the server
128                        Ok(n) => channel.data(&buf[..n]).await?,
129                        Err(e) => return Err(e.into()),
130                    };
131                },
132                Some(msg) = channel.wait() => {
133                    match msg {
134                        // Write data to the terminal
135                        ChannelMsg::Data { ref data } => {
136                            stdout.write_all(data).await?;
137                            stdout.flush().await?;
138                        }
139                        ChannelMsg::ExitStatus { exit_status } => {
140                            code = Some(exit_status);
141                            if !stdin_closed {
142                                channel.eof().await?;
143                            }
144                            break;
145                        }
146                        // Get std error from remote command.
147                        ChannelMsg::ExtendedData { ref data, ext: _ } => {
148                            stderr.write_all(data).await?;
149                            stderr.flush().await?;
150                        }
151                        _ => {}
152                    }
153                },
154            }
155        }
156
157        Ok(code.expect("program did not exit cleanly"))
158    }
159
160    async fn open_sftp_session(&self) -> Result<SftpSession, russh_sftp::client::error::Error> {
161        let channel = self.session.channel_open_session().await.unwrap();
162        channel.request_subsystem(true, "sftp").await.unwrap();
163
164        SftpSession::new(channel.into_stream()).await
165    }
166
167    /// Upload files within `src` to `dst` directory using SFTP.
168    /// If `dst` is not specified, files will uploaded to $HOME/{cwd}.
169    /// The {cwd} folder will be created by default in this use case.
170    ///
171    /// Panics if dst is not a directory.
172    pub async fn upload(&self, src: Option<String>, dst: Option<String>) -> anyhow::Result<()> {
173        let src_path = match std::fs::canonicalize(src.unwrap_or(".".into())) {
174            Ok(pth) => pth,
175            // Bail early if the src path is fked.
176            Err(err) => anyhow::bail!("Failed to canonicalize src = {err}"),
177        };
178
179        let sftp = self.open_sftp_session().await?;
180
181        if dst.is_some() {
182            match sftp.metadata(dst.as_ref().unwrap_or(&".".into())).await {
183                Ok(attr) => {
184                    if !attr.is_dir() {
185                        anyhow::bail!("Dst must be a dir!");
186                    }
187                }
188                Err(err) => {
189                    tracing::error!("Error remote metadata = {err}");
190                    return Ok(());
191                }
192            }
193        }
194
195        let prefix = calc_prefix(src_path.clone())?;
196        let dst_abs_path = sftp
197            .canonicalize(&dst.unwrap_or(".".into()))
198            .await
199            .expect("Failed to canonicalize remote dst.");
200
201        // The .gitignore at src_path will be respected.
202        for result in biject_paths(
203            src_path.to_str().unwrap(),
204            prefix.to_str().unwrap_or(""),
205            &dst_abs_path,
206        ) {
207            match result {
208                Ok((local_pth, combined, is_dir)) => {
209                    if is_dir {
210                        let _ = sftp.create_dir(combined.to_str().unwrap().to_owned()).await;
211                    } else {
212                        let open_remote_file = sftp
213                            .open_with_flags(
214                                combined.to_str().unwrap(),
215                                OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE,
216                            )
217                            .await;
218                        if open_remote_file.is_err() {
219                            tracing::warn!("Failed to open file = {:?}", combined,);
220                        }
221
222                        // Overwrite remote file contents with local file contents.
223                        if let Ok(mut remote_file) = open_remote_file {
224                            let mut local_file = File::open(local_pth).unwrap();
225                            let mut buffer = Vec::new();
226                            local_file.read_to_end(&mut buffer).unwrap();
227                            remote_file.write_all(buffer.as_slice()).await.unwrap();
228                            let _ = remote_file.sync_all().await;
229                            remote_file.shutdown().await.unwrap();
230                        }
231                    }
232                }
233                Err(err) => tracing::error!("ERROR: {}", err),
234            }
235        }
236
237        sftp.close().await?;
238
239        Ok(())
240    }
241
242    /// Closes SSH session.
243    pub async fn close(&mut self) -> anyhow::Result<()> {
244        self.session
245            .disconnect(Disconnect::ByApplication, "", "English")
246            .await?;
247        Ok(())
248    }
249}