audb_core/tools/
ssh.rs

1use russh::client::Handle;
2use russh::client::{self};
3use russh::keys::ssh_key;
4use russh::keys::PrivateKeyWithHashAlg;
5use russh::{ChannelMsg, Preferred};
6use russh_sftp::client::SftpSession;
7use russh_sftp::protocol::OpenFlags;
8use std::borrow::Cow;
9use std::fs::File;
10use std::fs;
11use std::path::Path;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::io::AsyncWriteExt;
15use anyhow::{anyhow, Result};
16
17use crate::tools::shell_escape::escape_single_quote;
18
19const DEFAULT_USER: &str = "defaultuser";
20
21pub struct SshClient {}
22
23impl client::Handler for SshClient {
24    type Error = russh::Error;
25
26    async fn check_server_key(&mut self, _server_public_key: &ssh_key::PublicKey) -> Result<bool, Self::Error> {
27        Ok(true)
28    }
29}
30
31impl SshClient {
32    pub fn connect(
33        host: &str,
34        port: u16,
35        key_path: &Path,
36    ) -> Result<Handle<SshClient>> {
37        tokio::task::block_in_place(|| {
38            tokio::runtime::Handle::current().block_on(Self::_connect(host, port, key_path))
39        })
40    }
41
42    pub fn exec(
43        session: &mut Handle<SshClient>,
44        command: &str,
45    ) -> Result<Vec<String>> {
46        tokio::task::block_in_place(|| {
47            tokio::runtime::Handle::current().block_on(Self::_exec(session, command))
48        })
49    }
50
51    /// Execute command as root using devel-su (Aurora OS)
52    ///
53    /// Uses the `echo 'password' | devel-su sh -c 'command'` pattern to automate
54    /// devel-su password input through SSH exec channels.
55    ///
56    /// # Security
57    /// This function properly escapes the password and command to prevent shell injection.
58    /// Both parameters are escaped for use in single-quote contexts.
59    pub fn exec_as_devel_su(
60        session: &mut Handle<SshClient>,
61        command: &str,
62        password: &str,
63    ) -> Result<Vec<String>> {
64        if password.is_empty() {
65            return Err(anyhow!(
66                "Root password not configured. Use 'audb device add' to set the root password."
67            ));
68        }
69
70        // Escape password and command for single-quote context to prevent shell injection
71        let password_escaped = escape_single_quote(password);
72        let command_escaped = escape_single_quote(command);
73
74        // Use echo pipe pattern: echo 'password' | devel-su sh -c 'command'
75        let devel_su_command = format!(
76            "echo '{}' | devel-su sh -c '{}'",
77            password_escaped, command_escaped
78        );
79
80        Self::exec(session, &devel_su_command)
81    }
82
83    /// Read file contents as base64 string via SSH exec
84    /// Useful for reading files owned by root when used with exec_as_devel_su
85    pub fn read_file_base64(
86        session: &mut Handle<SshClient>,
87        remote_path: &Path,
88        password: &str,
89    ) -> Result<String> {
90        let command = format!("base64 {}", remote_path.display());
91        let output = Self::exec_as_devel_su(session, &command, password)?;
92
93        if output.is_empty() {
94            return Err(anyhow!("File is empty or could not be read"));
95        }
96
97        // Join lines and remove whitespace
98        Ok(output.join("").replace(['\n', '\r'], ""))
99    }
100
101    pub fn upload(
102        session: &mut Handle<SshClient>,
103        local_path: &Path,
104        remote_path: &Path,
105    ) -> Result<()> {
106        tokio::task::block_in_place(|| {
107            tokio::runtime::Handle::current().block_on(Self::_upload(session, local_path, remote_path))
108        })
109    }
110
111    pub fn download(
112        session: &mut Handle<SshClient>,
113        remote_path: &Path,
114        local_path: &Path,
115    ) -> Result<()> {
116        tokio::task::block_in_place(|| {
117            tokio::runtime::Handle::current().block_on(Self::_download(session, remote_path, local_path))
118        })
119    }
120
121    pub fn test_connection(
122        host: &str,
123        port: u16,
124        key_path: &Path,
125    ) -> bool {
126        match Self::connect(host, port, key_path) {
127            Ok(mut session) => {
128                Self::exec(&mut session, "echo test").is_ok()
129            }
130            Err(_) => false,
131        }
132    }
133
134    async fn _connect(
135        host: &str,
136        port: u16,
137        key_path: &Path,
138    ) -> Result<Handle<SshClient>> {
139        Self::_connect_with_user(DEFAULT_USER, host, port, key_path).await
140    }
141
142    async fn _connect_with_user(
143        user: &str,
144        host: &str,
145        port: u16,
146        key_path: &Path,
147    ) -> Result<Handle<SshClient>> {
148        let timeout_session = Duration::from_secs(30);
149        let timeout_connect = Duration::from_secs(5);
150        let config = client::Config {
151            inactivity_timeout: Some(timeout_session),
152            preferred: Preferred {
153                kex: Cow::Owned(vec![
154                    russh::kex::CURVE25519_PRE_RFC_8731,
155                    russh::kex::EXTENSION_SUPPORT_AS_CLIENT,
156                ]),
157                ..Default::default()
158            },
159            ..<_>::default()
160        };
161        let config = Arc::new(config);
162        let sh = SshClient {};
163        let mut session = match tokio::time::timeout(timeout_connect, client::connect(config, (host, port), sh)).await?
164        {
165            Ok(session) => session,
166            Err(err) => return Err(anyhow!("Connection error: {}", err)),
167        };
168        let secret_key = Arc::new(russh::keys::load_secret_key(key_path, None)?);
169        let key_pair = PrivateKeyWithHashAlg::new(secret_key, session.best_supported_rsa_hash().await?.flatten());
170        let result = session.authenticate_publickey(user, key_pair).await?;
171        if !result.success() {
172            return Err(anyhow!("Failed to authenticate via SSH as {}", user));
173        }
174        Ok(session)
175    }
176
177    async fn _exec(
178        session: &mut Handle<SshClient>,
179        command: &str,
180    ) -> Result<Vec<String>> {
181        let mut code = None;
182        let mut stdout: Vec<String> = vec![];
183        let mut stderr: Vec<String> = vec![];
184        let mut channel = session.channel_open_session().await?;
185        channel.exec(true, command).await?;
186        loop {
187            let Some(msg) = channel.wait().await else {
188                break;
189            };
190            match msg {
191                ChannelMsg::Data { ref data } => {
192                    match str::from_utf8(data.as_ref()) {
193                        Ok(out_line) => {
194                            let line = out_line.trim().to_string();
195                            stdout.push(line)
196                        },
197                        Err(_) => return Err(anyhow!("Failed to process SSH connection data")),
198                    };
199                }
200                ChannelMsg::ExtendedData { ref data, ext } => {
201                    // ext == 1 means stderr
202                    if ext == 1 {
203                        match str::from_utf8(data.as_ref()) {
204                            Ok(err_line) => {
205                                let line = err_line.trim().to_string();
206                                stderr.push(line)
207                            },
208                            Err(_) => return Err(anyhow!("Failed to process SSH stderr data")),
209                        };
210                    }
211                }
212                ChannelMsg::ExitStatus { exit_status } => {
213                    code = Some(exit_status);
214                }
215                _ => {}
216            }
217        }
218        if let Some(code) = code {
219            if code != 0 {
220                let error_msg = if !stderr.is_empty() {
221                    stderr.join("\n")
222                } else if !stdout.is_empty() {
223                    stdout.join("\n")
224                } else {
225                    format!("Command failed with exit code {}", code)
226                };
227                return Err(anyhow!("{}", error_msg));
228            }
229        }
230        Ok(stdout)
231    }
232
233    async fn _upload(
234        session: &mut Handle<SshClient>,
235        local_path: &Path,
236        remote_path: &Path,
237    ) -> Result<()> {
238        let sftp_session = Self::_sftp_session(session).await?;
239
240        let file = File::open(local_path)?;
241        let size = file.metadata()?.len();
242        if size == 0 {
243            return Err(anyhow!("File is empty"));
244        }
245
246        let mut sftp_file = sftp_session
247            .open_with_flags(
248                remote_path.to_string_lossy().to_string(),
249                OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE | OpenFlags::READ,
250            )
251            .await?;
252
253        let data = fs::read(local_path)?;
254        sftp_file.write_all(&data).await?;
255
256        Ok(())
257    }
258
259
260    async fn _download(
261        session: &mut Handle<SshClient>,
262        remote_path: &Path,
263        local_path: &Path,
264    ) -> Result<()> {
265        let sftp_session = Self::_sftp_session(session).await?;
266
267        let mut sftp_file = sftp_session
268            .open_with_flags(
269                remote_path.to_string_lossy().to_string(),
270                OpenFlags::READ,
271            )
272            .await
273            .map_err(|e| anyhow!("Failed to open remote file {}: {}", remote_path.display(), e))?;
274
275        // Read file contents
276        use tokio::io::AsyncReadExt;
277        let mut data = Vec::new();
278        sftp_file.read_to_end(&mut data).await
279            .map_err(|e| anyhow!("Failed to read remote file: {}", e))?;
280
281        // Write to local file
282        fs::write(local_path, &data)
283            .map_err(|e| anyhow!("Failed to write local file {}: {}", local_path.display(), e))?;
284
285        Ok(())
286    }
287
288    async fn _sftp_session(session: &mut Handle<SshClient>) -> Result<SftpSession> {
289        let channel = session.channel_open_session().await?;
290        channel.request_subsystem(true, "sftp").await
291            .map_err(|e| anyhow!("Failed to request SFTP subsystem: {}", e))?;
292        Ok(SftpSession::new(channel.into_stream()).await?)
293    }
294}