1use std::{fs::File, io::Read, path::Path, sync::Arc};
2
3use async_trait::async_trait;
4use russh::{
5 client::{self, Msg},
6 keys::{decode_secret_key, key},
7 Channel, ChannelId, ChannelMsg, Disconnect,
8};
9use russh_sftp::{client::SftpSession, protocol::OpenFlags};
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11
12use crate::util::{biject_paths, calc_prefix};
13
14pub const SSH_PORT: u16 = 22;
15
16pub struct ClientSSH;
17
18#[async_trait]
19impl client::Handler for ClientSSH {
20 type Error = anyhow::Error;
21
22 async fn check_server_key(
23 &mut self,
24 server_public_key: &key::PublicKey,
25 ) -> Result<bool, Self::Error> {
26 tracing::debug!("check_server_key: {:?}", server_public_key);
27 Ok(true)
28 }
29
30 async fn data(
31 &mut self,
32 channel: ChannelId,
33 data: &[u8],
34 _session: &mut client::Session,
35 ) -> Result<(), Self::Error> {
36 tracing::debug!("data on channel {:?}: {}", channel, data.len());
37 Ok(())
38 }
39}
40
41pub struct Session {
42 session: client::Handle<ClientSSH>,
43}
44
45impl Session {
46 pub async fn channel_open_session(&self) -> Result<Channel<Msg>, russh::Error> {
50 self.session.channel_open_session().await
51 }
52
53 pub fn load_secret_key<P: AsRef<Path>>(
55 secret_: P,
56 password: Option<&str>,
57 ) -> Result<key::KeyPair, anyhow::Error> {
58 let mut secret_file = std::fs::File::open(secret_)?;
59 let mut secret = String::new();
60 secret_file.read_to_string(&mut secret)?;
61 Ok(decode_secret_key(&secret, password)?)
62 }
63
64 pub async fn connect(
69 user: &str,
70 public_dns_name: String,
71 ssh_key: String,
72 ) -> anyhow::Result<Self> {
73 let config = russh::client::Config {
74 inactivity_timeout: Some(std::time::Duration::from_secs(1200)), ..<_>::default()
76 };
77 let mut session =
78 russh::client::connect(Arc::new(config), (public_dns_name, SSH_PORT), ClientSSH {})
79 .await
80 .expect("Failed to establish SSH connection with remote instance.");
81 let key_pair = Self::load_secret_key(ssh_key, None).unwrap();
82
83 session
84 .authenticate_publickey(user, Arc::new(key_pair))
85 .await?;
86
87 Ok(Self { session })
88 }
89
90 pub async fn exec(&self, command: &str) -> anyhow::Result<u32> {
92 let mut channel = self.channel_open_session().await?;
93
94 let (w, h) = termion::terminal_size()?;
96 channel
98 .request_pty(
99 false,
100 &std::env::var("TERM").unwrap_or("xterm".into()),
101 w as u32,
102 h as u32,
103 0,
104 0,
105 &[], )
107 .await?;
108
109 channel.exec(true, command).await?;
110
111 let mut stdin = tokio_fd::AsyncFd::try_from(0)?;
112 let mut stdout = tokio_fd::AsyncFd::try_from(1)?;
113 let mut stderr = tokio_fd::AsyncFd::try_from(2)?;
114
115 let code;
116 let mut buf = vec![0; 1024];
117 let mut stdin_closed = false;
118
119 loop {
120 tokio::select! {
121 r = stdin.read(&mut buf), if !stdin_closed => {
122 match r {
123 Ok(0) => {
124 stdin_closed = true;
125 channel.eof().await?;
126 },
127 Ok(n) => channel.data(&buf[..n]).await?,
129 Err(e) => return Err(e.into()),
130 };
131 },
132 Some(msg) = channel.wait() => {
133 match msg {
134 ChannelMsg::Data { ref data } => {
136 stdout.write_all(data).await?;
137 stdout.flush().await?;
138 }
139 ChannelMsg::ExitStatus { exit_status } => {
140 code = Some(exit_status);
141 if !stdin_closed {
142 channel.eof().await?;
143 }
144 break;
145 }
146 ChannelMsg::ExtendedData { ref data, ext: _ } => {
148 stderr.write_all(data).await?;
149 stderr.flush().await?;
150 }
151 _ => {}
152 }
153 },
154 }
155 }
156
157 Ok(code.expect("program did not exit cleanly"))
158 }
159
160 async fn open_sftp_session(&self) -> Result<SftpSession, russh_sftp::client::error::Error> {
161 let channel = self.session.channel_open_session().await.unwrap();
162 channel.request_subsystem(true, "sftp").await.unwrap();
163
164 SftpSession::new(channel.into_stream()).await
165 }
166
167 pub async fn upload(&self, src: Option<String>, dst: Option<String>) -> anyhow::Result<()> {
173 let src_path = match std::fs::canonicalize(src.unwrap_or(".".into())) {
174 Ok(pth) => pth,
175 Err(err) => anyhow::bail!("Failed to canonicalize src = {err}"),
177 };
178
179 let sftp = self.open_sftp_session().await?;
180
181 if dst.is_some() {
182 match sftp.metadata(dst.as_ref().unwrap_or(&".".into())).await {
183 Ok(attr) => {
184 if !attr.is_dir() {
185 anyhow::bail!("Dst must be a dir!");
186 }
187 }
188 Err(err) => {
189 tracing::error!("Error remote metadata = {err}");
190 return Ok(());
191 }
192 }
193 }
194
195 let prefix = calc_prefix(src_path.clone())?;
196 let dst_abs_path = sftp
197 .canonicalize(&dst.unwrap_or(".".into()))
198 .await
199 .expect("Failed to canonicalize remote dst.");
200
201 for result in biject_paths(
203 src_path.to_str().unwrap(),
204 prefix.to_str().unwrap_or(""),
205 &dst_abs_path,
206 ) {
207 match result {
208 Ok((local_pth, combined, is_dir)) => {
209 if is_dir {
210 let _ = sftp.create_dir(combined.to_str().unwrap().to_owned()).await;
211 } else {
212 let open_remote_file = sftp
213 .open_with_flags(
214 combined.to_str().unwrap(),
215 OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE,
216 )
217 .await;
218 if open_remote_file.is_err() {
219 tracing::warn!("Failed to open file = {:?}", combined,);
220 }
221
222 if let Ok(mut remote_file) = open_remote_file {
224 let mut local_file = File::open(local_pth).unwrap();
225 let mut buffer = Vec::new();
226 local_file.read_to_end(&mut buffer).unwrap();
227 remote_file.write_all(buffer.as_slice()).await.unwrap();
228 let _ = remote_file.sync_all().await;
229 remote_file.shutdown().await.unwrap();
230 }
231 }
232 }
233 Err(err) => tracing::error!("ERROR: {}", err),
234 }
235 }
236
237 sftp.close().await?;
238
239 Ok(())
240 }
241
242 pub async fn close(&mut self) -> anyhow::Result<()> {
244 self.session
245 .disconnect(Disconnect::ByApplication, "", "English")
246 .await?;
247 Ok(())
248 }
249}