hex_patch/app/ssh/
connection.rs

1use std::{error::Error, fmt::Display, path::PathBuf, sync::Arc};
2
3use russh::client::{self, AuthResult, Handler};
4use russh::keys::key::PrivateKeyWithHashAlg;
5use russh_sftp::client::SftpSession;
6
7use crate::app::files::path;
8
9pub struct SSHClient;
10impl Handler for SSHClient {
11    type Error = russh::Error;
12
13    async fn check_server_key(
14        &mut self,
15        _server_public_key: &russh::keys::ssh_key::PublicKey,
16    ) -> Result<bool, Self::Error> {
17        Ok(true)
18    }
19}
20
21pub struct Connection {
22    runtime: tokio::runtime::Runtime,
23    sftp: SftpSession,
24    connection_str: String,
25}
26
27impl Connection {
28    fn get_key_files() -> Result<(PathBuf, PathBuf), String> {
29        let home_dir = dirs::home_dir().ok_or_else(|| "Home directory not found".to_string())?;
30
31        let ssh_dir = home_dir.join(".ssh");
32        if !ssh_dir.is_dir() {
33            return Err("SSH directory not found".into());
34        }
35        if ssh_dir.join("id_rsa").is_file() {
36            Ok((ssh_dir.join("id_rsa"), ssh_dir.join("id_rsa.pub")))
37        } else if ssh_dir.join("id_ed25519").is_file() {
38            Ok((ssh_dir.join("id_ed25519"), ssh_dir.join("id_ed25519.pub")))
39        } else if ssh_dir.join("id_ecdsa").is_file() {
40            Ok((ssh_dir.join("id_ecdsa"), ssh_dir.join("id_ecdsa.pub")))
41        } else if ssh_dir.join("id_dsa").is_file() {
42            Ok((ssh_dir.join("id_dsa"), ssh_dir.join("id_dsa.pub")))
43        } else {
44            Err("No private key found".into())
45        }
46    }
47
48    pub fn new(connection_str: &str, password: Option<&str>) -> Result<Self, Box<dyn Error>> {
49        let runtime = tokio::runtime::Builder::new_current_thread()
50            .enable_all()
51            .build()?;
52        let (username, host) = connection_str
53            .split_once('@')
54            .ok_or_else(|| Box::<dyn Error>::from("Invalid connection string"))?;
55
56        let (hostname, port) =
57            host.split_once(':')
58                .map_or(Ok((host, 22)), |(hostname, port)| {
59                    port.parse::<u16>()
60                        .map(|port| (hostname, port))
61                        .map_err(|_| Box::<dyn Error>::from("Invalid port"))
62                })?;
63
64        let config = client::Config::default();
65
66        let mut session = runtime.block_on(client::connect(
67            config.into(),
68            (hostname, port),
69            SSHClient {},
70        ))?;
71        if let Some(password) = password {
72            if let AuthResult::Failure {
73                remaining_methods: _,
74            } = runtime.block_on(session.authenticate_password(username, password))?
75            {
76                return Err("Authentication failed".into());
77            }
78        } else {
79            let (private_key, _public_key) = Self::get_key_files()?;
80            let keypair = russh::keys::load_secret_key(private_key, None)?;
81            let keypair = PrivateKeyWithHashAlg::new(Arc::new(keypair), None);
82            if let AuthResult::Failure {
83                remaining_methods: _,
84            } = runtime.block_on(session.authenticate_publickey(username, keypair))?
85            {
86                return Err("Authentication failed".into());
87            }
88        }
89
90        let channel = runtime.block_on(session.channel_open_session())?;
91        runtime.block_on(channel.request_subsystem(true, "sftp"))?;
92
93        let sftp = runtime.block_on(SftpSession::new(channel.into_stream()))?;
94
95        Ok(Self {
96            runtime,
97            sftp,
98            connection_str: connection_str.to_string(),
99        })
100    }
101
102    pub fn separator(&self) -> char {
103        match self.runtime.block_on(self.sftp.canonicalize("/")) {
104            Ok(_) => '/',
105            Err(_) => '\\',
106        }
107    }
108
109    pub fn canonicalize(&self, path: &str) -> Result<String, Box<dyn Error>> {
110        Ok(self.runtime.block_on(self.sftp.canonicalize(path))?)
111    }
112
113    pub fn read(&self, path: &str) -> Result<Vec<u8>, Box<dyn Error>> {
114        let remote_file = self.runtime.block_on(self.sftp.read(path))?;
115        Ok(remote_file)
116    }
117
118    pub fn mkdirs(&self, path: &str) -> Result<(), Box<dyn Error>> {
119        self.runtime.block_on(async {
120            let mut paths = vec![path];
121            let mut current = path;
122            while let Some(parent) = path::parent(current) {
123                paths.push(parent);
124                current = parent;
125            }
126            paths.reverse();
127            for path in paths {
128                if self.sftp.read_dir(path).await.is_ok() {
129                    continue;
130                };
131                self.sftp.create_dir(path).await?;
132            }
133            Ok::<(), Box<dyn Error>>(())
134        })?;
135        Ok(())
136    }
137
138    pub fn create(&self, path: &str) -> Result<(), Box<dyn Error>> {
139        self.runtime.block_on(self.sftp.create(path))?;
140        Ok(())
141    }
142
143    pub fn write(&self, path: &str, data: &[u8]) -> Result<(), Box<dyn Error>> {
144        self.runtime.block_on(self.sftp.write(path, data))?;
145        Ok(())
146    }
147
148    pub fn ls(&self, path: &str) -> Result<Vec<String>, Box<dyn Error>> {
149        let dir = self.runtime.block_on(self.sftp.read_dir(path))?;
150        dir.into_iter()
151            .map(|entry| Ok(path::join(path, &entry.file_name(), self.separator()).to_string()))
152            .collect()
153    }
154
155    pub fn is_file(&self, path: &str) -> bool {
156        self.runtime
157            .block_on(self.sftp.metadata(path))
158            .is_ok_and(|metadata| !metadata.is_dir())
159    }
160
161    pub fn is_dir(&self, path: &str) -> bool {
162        self.runtime
163            .block_on(self.sftp.metadata(path))
164            .is_ok_and(|metadata| metadata.is_dir())
165    }
166}
167
168impl Display for Connection {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        write!(f, "{}", self.connection_str)
171    }
172}