hex_patch/app/ssh/
connection.rs1use std::{error::Error, fmt::Display, path::PathBuf, sync::Arc};
2
3use russh::client::{self, AuthResult, Handler};
4use russh::keys::key::PrivateKeyWithHashAlg;
5use russh_sftp::client::SftpSession;
6
7use crate::app::files::path;
8
9pub struct SSHClient;
10impl Handler for SSHClient {
11 type Error = russh::Error;
12
13 async fn check_server_key(
14 &mut self,
15 _server_public_key: &russh::keys::ssh_key::PublicKey,
16 ) -> Result<bool, Self::Error> {
17 Ok(true)
18 }
19}
20
21pub struct Connection {
22 runtime: tokio::runtime::Runtime,
23 sftp: SftpSession,
24 connection_str: String,
25}
26
27impl Connection {
28 fn get_key_files() -> Result<(PathBuf, PathBuf), String> {
29 let home_dir = dirs::home_dir().ok_or_else(|| "Home directory not found".to_string())?;
30
31 let ssh_dir = home_dir.join(".ssh");
32 if !ssh_dir.is_dir() {
33 return Err("SSH directory not found".into());
34 }
35 if ssh_dir.join("id_rsa").is_file() {
36 Ok((ssh_dir.join("id_rsa"), ssh_dir.join("id_rsa.pub")))
37 } else if ssh_dir.join("id_ed25519").is_file() {
38 Ok((ssh_dir.join("id_ed25519"), ssh_dir.join("id_ed25519.pub")))
39 } else if ssh_dir.join("id_ecdsa").is_file() {
40 Ok((ssh_dir.join("id_ecdsa"), ssh_dir.join("id_ecdsa.pub")))
41 } else if ssh_dir.join("id_dsa").is_file() {
42 Ok((ssh_dir.join("id_dsa"), ssh_dir.join("id_dsa.pub")))
43 } else {
44 Err("No private key found".into())
45 }
46 }
47
48 pub fn new(connection_str: &str, password: Option<&str>) -> Result<Self, Box<dyn Error>> {
49 let runtime = tokio::runtime::Builder::new_current_thread()
50 .enable_all()
51 .build()?;
52 let (username, host) = connection_str
53 .split_once('@')
54 .ok_or_else(|| Box::<dyn Error>::from("Invalid connection string"))?;
55
56 let (hostname, port) =
57 host.split_once(':')
58 .map_or(Ok((host, 22)), |(hostname, port)| {
59 port.parse::<u16>()
60 .map(|port| (hostname, port))
61 .map_err(|_| Box::<dyn Error>::from("Invalid port"))
62 })?;
63
64 let config = client::Config::default();
65
66 let mut session = runtime.block_on(client::connect(
67 config.into(),
68 (hostname, port),
69 SSHClient {},
70 ))?;
71 if let Some(password) = password {
72 if let AuthResult::Failure {
73 remaining_methods: _,
74 partial_success: _,
75 } = runtime.block_on(session.authenticate_password(username, password))?
76 {
77 return Err("Authentication failed".into());
78 }
79 } else {
80 let (private_key, _public_key) = Self::get_key_files()?;
81 let keypair = russh::keys::load_secret_key(private_key, None)?;
82 let keypair = PrivateKeyWithHashAlg::new(Arc::new(keypair), None);
83 if let AuthResult::Failure {
84 remaining_methods: _,
85 partial_success: _,
86 } = runtime.block_on(session.authenticate_publickey(username, keypair))?
87 {
88 return Err("Authentication failed".into());
89 }
90 }
91
92 let channel = runtime.block_on(session.channel_open_session())?;
93 runtime.block_on(channel.request_subsystem(true, "sftp"))?;
94
95 let sftp = runtime.block_on(SftpSession::new(channel.into_stream()))?;
96
97 Ok(Self {
98 runtime,
99 sftp,
100 connection_str: connection_str.to_string(),
101 })
102 }
103
104 pub fn separator(&self) -> char {
105 match self.runtime.block_on(self.sftp.canonicalize("/")) {
106 Ok(_) => '/',
107 Err(_) => '\\',
108 }
109 }
110
111 pub fn canonicalize(&self, path: &str) -> Result<String, Box<dyn Error>> {
112 Ok(self.runtime.block_on(self.sftp.canonicalize(path))?)
113 }
114
115 pub fn read(&self, path: &str) -> Result<Vec<u8>, Box<dyn Error>> {
116 let remote_file = self.runtime.block_on(self.sftp.read(path))?;
117 Ok(remote_file)
118 }
119
120 pub fn mkdirs(&self, path: &str) -> Result<(), Box<dyn Error>> {
121 self.runtime.block_on(async {
122 let mut paths = vec![path];
123 let mut current = path;
124 while let Some(parent) = path::parent(current) {
125 paths.push(parent);
126 current = parent;
127 }
128 paths.reverse();
129 for path in paths {
130 if self.sftp.read_dir(path).await.is_ok() {
131 continue;
132 };
133 self.sftp.create_dir(path).await?;
134 }
135 Ok::<(), Box<dyn Error>>(())
136 })?;
137 Ok(())
138 }
139
140 pub fn create(&self, path: &str) -> Result<(), Box<dyn Error>> {
141 self.runtime.block_on(self.sftp.create(path))?;
142 Ok(())
143 }
144
145 pub fn write(&self, path: &str, data: &[u8]) -> Result<(), Box<dyn Error>> {
146 self.runtime.block_on(self.sftp.write(path, data))?;
147 Ok(())
148 }
149
150 pub fn ls(&self, path: &str) -> Result<Vec<String>, Box<dyn Error>> {
151 let dir = self.runtime.block_on(self.sftp.read_dir(path))?;
152 dir.into_iter()
153 .map(|entry| Ok(path::join(path, &entry.file_name(), self.separator()).to_string()))
154 .collect()
155 }
156
157 pub fn is_file(&self, path: &str) -> bool {
158 self.runtime
159 .block_on(self.sftp.metadata(path))
160 .is_ok_and(|metadata| !metadata.is_dir())
161 }
162
163 pub fn is_dir(&self, path: &str) -> bool {
164 self.runtime
165 .block_on(self.sftp.metadata(path))
166 .is_ok_and(|metadata| metadata.is_dir())
167 }
168}
169
170impl Display for Connection {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 write!(f, "{}", self.connection_str)
173 }
174}