1use std::io::{Read, Write};
2use std::net::TcpStream;
3use std::path::{Path, PathBuf};
4use std::time::Duration;
5
6use base64::{Engine as _, engine::general_purpose::STANDARD};
7use miette::{Result, miette};
8use ssh2::{CheckResult, HostKeyType, KnownHostFileKind, Session};
9
10use crate::ui;
11
12#[derive(Clone)]
13pub struct SshSession {
14 pub session: Session,
15 pub user: String,
16 pub host: String,
17 pub port: u16,
18 pub password: Option<String>,
19 is_container: bool,
20}
21
22impl SshSession {
23 pub fn new(ssh_target: &str, port: Option<u16>, key_path: Option<&PathBuf>) -> Result<Self> {
28 let (user, host, parsed_port) = parse_ssh_target(ssh_target)?;
29 let port = port.or(parsed_port).unwrap_or(22);
31
32 let tcp = TcpStream::connect((host.as_str(), port))
33 .map_err(|e| miette!("Failed to connect to {host}:{port}: {e}"))?;
34
35 let mut session =
36 Session::new().map_err(|e| miette!("Failed to create SSH session: {e}"))?;
37
38 session.set_tcp_stream(tcp);
39 session
40 .handshake()
41 .map_err(|e| miette!("SSH handshake failed: {e}"))?;
42
43 verify_host_key(&session, &host, port)?;
47
48 if let Some(key_path) = key_path {
49 if session
50 .userauth_pubkey_file(&user, None, key_path, None)
51 .is_ok()
52 {
53 ui::info(&format!(
54 "Authenticated with SSH key: {}",
55 key_path.display()
56 ));
57 } else {
58 return Err(miette!(
59 "Failed to authenticate with provided key: {}",
60 key_path.display()
61 ));
62 }
63 } else if session.userauth_agent(&user).is_ok() {
64 ui::info("Authenticated via SSH agent");
65 } else {
66 let mut authenticated = false;
67 for key_path in find_ssh_keys() {
68 if session
69 .userauth_pubkey_file(&user, None, &key_path, None)
70 .is_ok()
71 {
72 ui::info(&format!(
73 "Authenticated with SSH key: {}",
74 key_path.display()
75 ));
76 authenticated = true;
77 break;
78 }
79 }
80
81 if !authenticated {
82 return Err(miette!(
83 "SSH authentication failed. Please ensure you have a valid SSH key configured"
84 ));
85 }
86 }
87
88 let mut ssh = Self {
89 session,
90 user,
91 host,
92 port,
93 password: None,
94 is_container: false,
95 };
96
97 ssh.password = ssh.test_sudo()?;
98
99 ssh.is_container = ssh
102 .execute_command_raw("[ -f /run/.containerenv ] || [ -f /.dockerenv ]", None)
103 .is_ok();
104
105 Ok(ssh)
106 }
107
108 pub fn exec(&self, command: &str) -> Result<String> {
116 if let Some(password) = &self.password
117 && command.starts_with("sudo ")
118 {
119 return self.execute_command_with_sudo(command, password, None);
120 }
121
122 self.execute_command_raw(command, None)
123 }
124
125 pub fn exec_stream(&self, command: &str) -> Result<i32> {
131 let (command, sudo_password) = if self.password.is_some() && command.starts_with("sudo ") {
135 (wrap_sudo_command(command), self.password.clone())
136 } else {
137 (command.to_string(), None)
138 };
139
140 let mut channel = self
141 .session
142 .channel_session()
143 .map_err(|e| miette!("Failed to open channel: {e}"))?;
144
145 channel
146 .request_pty("xterm", None, None)
147 .map_err(|e| miette!("Failed to request PTY: {e}"))?;
148
149 channel
150 .exec(&command)
151 .map_err(|e| miette!("Failed to execute command: {e}"))?;
152
153 if let Some(password) = sudo_password {
154 channel
155 .write_all(format!("{password}\n").as_bytes())
156 .map_err(|e| miette!("Failed to send sudo password: {e}"))?;
157 channel.flush().ok();
158 }
159
160 let _raw_guard = RawModeGuard::enter()
164 .map_err(|e| miette!("Failed to enable raw terminal mode: {e}"))?;
165
166 self.session.set_blocking(false);
168
169 let mut buf = [0u8; 4096];
170 let mut stdin_buf = [0u8; 256];
171 let mut stdout = std::io::stdout();
172 let stdin_fd = libc::STDIN_FILENO;
173
174 loop {
175 if channel.eof() {
176 break;
177 }
178
179 match channel.read(&mut buf) {
181 Ok(0) => break,
182 Ok(n) => {
183 stdout.write_all(&buf[..n]).ok();
184 stdout.flush().ok();
185 }
186 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
187 Err(e) => {
188 self.session.set_blocking(true);
189 return Err(miette!("Failed to read from channel: {e}"));
190 }
191 }
192
193 let mut pollfd = libc::pollfd {
195 fd: stdin_fd,
196 events: libc::POLLIN,
197 revents: 0,
198 };
199
200 let poll_result = unsafe { libc::poll(&raw mut pollfd, 1, 0) };
201
202 if poll_result > 0 && (pollfd.revents & libc::POLLIN) != 0 {
203 let n =
204 unsafe { libc::read(stdin_fd, stdin_buf.as_mut_ptr().cast(), stdin_buf.len()) };
205
206 if n > 0 {
207 #[allow(clippy::cast_sign_loss)]
208 channel.write_all(&stdin_buf[..n as usize]).ok();
209 channel.flush().ok();
210 }
211 }
212
213 std::thread::sleep(std::time::Duration::from_millis(10));
214 }
215
216 self.session.set_blocking(true);
217
218 channel.wait_close().ok();
219 let exit_status = channel.exit_status().unwrap_or(-1);
220
221 Ok(exit_status)
222 }
223
224 #[allow(dead_code)]
225 pub fn exec_timeout(&self, command: &str, timeout: Duration) -> Result<String> {
233 if let Some(password) = &self.password
234 && command.starts_with("sudo ")
235 {
236 return self.execute_command_with_sudo(command, password, Some(timeout));
237 }
238
239 self.execute_command_raw(command, Some(timeout))
240 }
241
242 pub fn upload_file(&self, local_path: &std::path::Path, remote_path: &str) -> Result<()> {
250 const CHUNK_SIZE: usize = 8192;
251
252 let file_size = std::fs::metadata(local_path)
253 .map_err(|e| miette!("Failed to get file metadata: {}", e))?
254 .len();
255
256 let pb = ui::progress_bar(
257 file_size,
258 &format!(
259 "Uploading {}",
260 local_path.file_name().unwrap().to_string_lossy()
261 ),
262 );
263
264 let mut file = std::fs::File::open(local_path)
265 .map_err(|e| miette!("Failed to open file {}: {}", local_path.display(), e))?;
266
267 let mut file_data = Vec::new();
268 file.read_to_end(&mut file_data)
269 .map_err(|e| miette!("Failed to read file: {e}"))?;
270
271 let mut channel = self
272 .session
273 .scp_send(
274 std::path::Path::new(remote_path),
275 0o755, file_data.len() as u64,
277 None,
278 )
279 .map_err(|e| miette!("Failed to create SCP channel: {e}"))?;
280
281 for chunk in file_data.chunks(CHUNK_SIZE) {
282 channel
283 .write_all(chunk)
284 .map_err(|e| miette!("Failed to write file data: {e}"))?;
285 pb.inc(chunk.len() as u64);
286 }
287
288 channel
289 .send_eof()
290 .map_err(|e| miette!("Failed to send EOF: {e}"))?;
291
292 channel
293 .wait_eof()
294 .map_err(|e| miette!("Failed to wait for EOF: {e}"))?;
295
296 channel
297 .close()
298 .map_err(|e| miette!("Failed to close SCP channel: {e}"))?;
299
300 channel
301 .wait_close()
302 .map_err(|e| miette!("Failed to wait for channel close: {e}"))?;
303
304 pb.finish_with_message(format!(
305 "✓ Uploaded {}",
306 local_path.file_name().unwrap().to_string_lossy()
307 ));
308
309 Ok(())
310 }
311
312 pub(crate) fn is_container(&self) -> bool {
313 self.is_container
314 }
315
316 fn test_sudo(&self) -> Result<Option<String>> {
317 if self.execute_command_raw("sudo -n true", None).is_ok() {
319 return Ok(None);
320 }
321
322 let password = ui::password(&format!("[sudo] password for {}", self.user))?;
323
324 match self.execute_command_with_sudo("true", &password, None) {
325 Ok(_) => Ok(Some(password)),
326 Err(_) => Err(miette!("Invalid sudo password")),
327 }
328 }
329
330 fn execute_command_raw(&self, command: &str, timeout: Option<Duration>) -> Result<String> {
331 self.execute_command_raw_with_stdin(command, None, timeout)
332 }
333
334 fn execute_command_raw_with_stdin(
335 &self,
336 command: &str,
337 stdin: Option<&str>,
338 timeout: Option<Duration>,
339 ) -> Result<String> {
340 let session = &self.session;
341
342 if let Some(timeout) = timeout {
343 session.set_timeout(
344 u32::try_from(timeout.as_millis()).map_err(|e| miette!("Invalid timeout: {e}"))?,
345 );
346 }
347
348 let mut channel = session
349 .channel_session()
350 .map_err(|e| miette!("Failed to open channel: {e}"))?;
351
352 channel
353 .exec(command)
354 .map_err(|e| miette!("Failed to execute command '{command}': {e}"))?;
355
356 if let Some(data) = stdin {
357 channel
358 .write_all(data.as_bytes())
359 .map_err(|e| miette!("Failed to write to command stdin: {e}"))?;
360 channel
361 .send_eof()
362 .map_err(|e| miette!("Failed to send EOF: {e}"))?;
363 }
364
365 let mut output = String::new();
366 channel
367 .read_to_string(&mut output)
368 .map_err(|e| miette!("Failed to read command output: {e}"))?;
369
370 let mut stderr = String::new();
371 channel
372 .stderr()
373 .read_to_string(&mut stderr)
374 .map_err(|e| miette!("Failed to read stderr: {e}"))?;
375
376 channel
377 .wait_close()
378 .map_err(|e| miette!("Failed to close channel: {e}"))?;
379
380 let exit_status = channel
381 .exit_status()
382 .map_err(|e| miette!("Failed to get exit status: {e}"))?;
383
384 if timeout.is_some() {
385 session.set_timeout(0);
386 }
387
388 if exit_status == 0 {
389 Ok(output)
390 } else {
391 let error_msg = if !stderr.is_empty() {
392 stderr.trim()
393 } else if !output.is_empty() {
394 output.trim()
395 } else {
396 "Command failed with no output"
397 };
398 Err(miette!(
399 "Command '{command}' failed with exit code {exit_status}: {error_msg}"
400 ))
401 }
402 }
403
404 fn execute_command_with_sudo(
405 &self,
406 command: &str,
407 password: &str,
408 timeout: Option<Duration>,
409 ) -> Result<String> {
410 let sudo_command = wrap_sudo_command(command);
413 self.execute_command_raw_with_stdin(&sudo_command, Some(&format!("{password}\n")), timeout)
414 }
415}
416
417pub(crate) fn parse_ssh_target(target: &str) -> Result<(String, String, Option<u16>)> {
418 let (user, host_port) = target
419 .split_once('@')
420 .ok_or_else(|| miette!("Invalid SSH target format. Expected user@host[:port]"))?;
421
422 if user.is_empty() {
423 return Err(miette!("User cannot be empty"));
424 }
425
426 let (host, port) = if let Some(rest) = host_port.strip_prefix('[') {
427 let (addr, after) = rest
429 .split_once(']')
430 .ok_or_else(|| miette!("Invalid IPv6 SSH target: missing ']'"))?;
431 let port = match after.strip_prefix(':') {
432 Some(p) => Some(p.parse::<u16>().map_err(|_| miette!("Invalid port: {p}"))?),
433 None if after.is_empty() => None,
434 None => {
435 return Err(miette!(
436 "Unexpected characters after IPv6 address: {after:?}"
437 ));
438 }
439 };
440 (addr.to_string(), port)
441 } else if let Some((h, p)) = host_port.rsplit_once(':')
442 && !h.contains(':')
443 {
444 (
447 h.to_string(),
448 Some(p.parse::<u16>().map_err(|_| miette!("Invalid port: {p}"))?),
449 )
450 } else {
451 (host_port.to_string(), None)
453 };
454
455 if host.is_empty() {
456 return Err(miette!("Host cannot be empty"));
457 }
458
459 Ok((user.to_string(), host, port))
460}
461
462fn known_hosts_path() -> Option<PathBuf> {
468 if let Ok(path) = std::env::var("MAKIATTO_KNOWN_HOSTS") {
469 return Some(PathBuf::from(path));
470 }
471 dirs::home_dir().map(|home| home.join(".ssh").join("known_hosts"))
472}
473
474fn verify_host_key(session: &Session, host: &str, port: u16) -> Result<()> {
478 let Some(kh_path) = known_hosts_path() else {
479 return Err(miette!(
480 "Cannot determine known_hosts location for host key verification"
481 ));
482 };
483
484 let mut known_hosts = session
485 .known_hosts()
486 .map_err(|e| miette!("Failed to initialise known_hosts: {e}"))?;
487
488 if kh_path.exists() {
489 known_hosts
490 .read_file(&kh_path, KnownHostFileKind::OpenSSH)
491 .map_err(|e| miette!("Failed to read {}: {e}", kh_path.display()))?;
492 }
493
494 let (key, key_type) = session
495 .host_key()
496 .ok_or_else(|| miette!("Server did not present a host key"))?;
497
498 match known_hosts.check_port(host, port, key) {
499 CheckResult::Match => Ok(()),
500 CheckResult::Mismatch => Err(miette!(
501 "SSH host key mismatch for {host}:{port} — possible machine-in-the-middle. \
502 If the host key legitimately changed, remove the stale entry from {}.",
503 kh_path.display()
504 )),
505 CheckResult::Failure => Err(miette!("Host key verification failed for {host}:{port}")),
506 CheckResult::NotFound => {
507 append_known_host(&kh_path, host, port, key, key_type)?;
508 ui::info(&format!(
509 "Trusting new host key for {host}:{port} (added to {})",
510 kh_path.display()
511 ));
512 Ok(())
513 }
514 }
515}
516
517fn wrap_sudo_command(command: &str) -> String {
524 let inner = command.strip_prefix("sudo ").unwrap_or(command);
525 format!("sudo -S -p '' {inner}")
526}
527
528fn known_host_line(host: &str, port: u16, key: &[u8], key_type: HostKeyType) -> Option<String> {
533 let key_type_str = match key_type {
534 HostKeyType::Rsa => "ssh-rsa",
535 HostKeyType::Dss => "ssh-dss",
536 HostKeyType::Ecdsa256 => "ecdsa-sha2-nistp256",
537 HostKeyType::Ecdsa384 => "ecdsa-sha2-nistp384",
538 HostKeyType::Ecdsa521 => "ecdsa-sha2-nistp521",
539 HostKeyType::Ed25519 => "ssh-ed25519",
540 HostKeyType::Unknown => return None,
541 };
542
543 let host_field = if port == 22 {
544 host.to_string()
545 } else {
546 format!("[{host}]:{port}")
547 };
548
549 Some(format!(
550 "{host_field} {key_type_str} {}",
551 STANDARD.encode(key)
552 ))
553}
554
555fn append_known_host(
557 path: &Path,
558 host: &str,
559 port: u16,
560 key: &[u8],
561 key_type: HostKeyType,
562) -> Result<()> {
563 let line = known_host_line(host, port, key, key_type)
564 .ok_or_else(|| miette!("Unknown host key type; refusing to record it"))?;
565
566 if let Some(parent) = path.parent() {
567 std::fs::create_dir_all(parent)
568 .map_err(|e| miette!("Failed to create {}: {e}", parent.display()))?;
569 }
570
571 let mut file = std::fs::OpenOptions::new()
572 .create(true)
573 .append(true)
574 .open(path)
575 .map_err(|e| miette!("Failed to open {}: {e}", path.display()))?;
576
577 file.write_all(format!("{line}\n").as_bytes())
578 .map_err(|e| miette!("Failed to write to {}: {e}", path.display()))?;
579
580 Ok(())
581}
582
583struct RawModeGuard {
585 original: libc::termios,
586}
587
588impl RawModeGuard {
589 fn enter() -> std::io::Result<Self> {
590 unsafe {
591 let mut original: libc::termios = std::mem::zeroed();
592 if libc::tcgetattr(libc::STDIN_FILENO, &raw mut original) != 0 {
593 return Err(std::io::Error::last_os_error());
594 }
595
596 let mut raw = original;
597 libc::cfmakeraw(&raw mut raw);
598 raw.c_lflag |= libc::ISIG;
600
601 if libc::tcsetattr(libc::STDIN_FILENO, libc::TCSANOW, &raw const raw) != 0 {
602 return Err(std::io::Error::last_os_error());
603 }
604
605 Ok(Self { original })
606 }
607 }
608}
609
610impl Drop for RawModeGuard {
611 fn drop(&mut self) {
612 unsafe {
613 libc::tcsetattr(libc::STDIN_FILENO, libc::TCSANOW, &raw const self.original);
614 }
615 }
616}
617
618fn find_ssh_keys() -> Vec<PathBuf> {
619 let mut keys = Vec::new();
620
621 if let Some(home_dir) = dirs::home_dir() {
622 let key_names = ["id_ed25519", "id_rsa", "id_dsa", "id_ecdsa"];
623
624 for key_name in &key_names {
625 let key_path = home_dir.join(".ssh").join(key_name);
626 if key_path.exists() {
627 keys.push(key_path);
628 }
629 }
630 }
631
632 keys
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638
639 #[test]
640 fn test_parse_ssh_target_without_port() {
641 let result = parse_ssh_target("root@192.168.1.1").unwrap();
642 assert_eq!(
643 result,
644 ("root".to_string(), "192.168.1.1".to_string(), None)
645 );
646 }
647
648 #[test]
649 fn test_parse_ssh_target_with_port() {
650 let result = parse_ssh_target("root@192.168.1.1:2222").unwrap();
651 assert_eq!(
652 result,
653 ("root".to_string(), "192.168.1.1".to_string(), Some(2222))
654 );
655 }
656
657 #[test]
658 fn test_parse_ssh_target_ipv6() {
659 let result = parse_ssh_target("root@[2001:db8::1]:2222").unwrap();
660 assert_eq!(
661 result,
662 ("root".to_string(), "2001:db8::1".to_string(), Some(2222))
663 );
664
665 let result = parse_ssh_target("root@[::1]").unwrap();
666 assert_eq!(result, ("root".to_string(), "::1".to_string(), None));
667
668 let result = parse_ssh_target("root@2001:db8::1").unwrap();
670 assert_eq!(
671 result,
672 ("root".to_string(), "2001:db8::1".to_string(), None)
673 );
674 }
675
676 #[test]
677 fn test_parse_ssh_target_invalid_format() {
678 assert!(parse_ssh_target("invalid").is_err());
679 assert!(parse_ssh_target("@host").is_err());
680 assert!(parse_ssh_target("user@").is_err());
681 assert!(parse_ssh_target("user@host:notaport").is_err());
682 }
683
684 #[test]
685 fn test_wrap_sudo_command_strips_leading_sudo() {
686 assert_eq!(
688 wrap_sudo_command("sudo systemctl restart makiatto"),
689 "sudo -S -p '' systemctl restart makiatto"
690 );
691 }
692
693 #[test]
694 fn test_wrap_sudo_command_wraps_bare_command() {
695 assert_eq!(wrap_sudo_command("true"), "sudo -S -p '' true");
697 }
698
699 #[test]
700 fn test_wrap_sudo_command_only_first_sudo_in_compound() {
701 assert_eq!(
704 wrap_sudo_command("sudo apt update && sudo apt install -y x"),
705 "sudo -S -p '' apt update && sudo apt install -y x"
706 );
707 }
708
709 #[test]
710 fn test_known_host_line_default_port() {
711 let line = known_host_line("example.com", 22, b"\x00\x01\x02", HostKeyType::Rsa).unwrap();
712 assert_eq!(line, "example.com ssh-rsa AAEC");
713 }
714
715 #[test]
716 fn test_known_host_line_custom_port_is_bracketed() {
717 let line =
718 known_host_line("10.0.0.1", 2222, b"\x00\x01\x02", HostKeyType::Ed25519).unwrap();
719 assert_eq!(line, "[10.0.0.1]:2222 ssh-ed25519 AAEC");
720 }
721
722 #[test]
723 fn test_known_host_line_unknown_type_is_rejected() {
724 assert!(known_host_line("h", 22, b"abc", HostKeyType::Unknown).is_none());
725 }
726
727 #[test]
728 fn test_known_hosts_path_honours_env_override() {
729 let prev = std::env::var("MAKIATTO_KNOWN_HOSTS").ok();
731 unsafe { std::env::set_var("MAKIATTO_KNOWN_HOSTS", "/tmp/custom_known_hosts") };
732 assert_eq!(
733 known_hosts_path(),
734 Some(PathBuf::from("/tmp/custom_known_hosts"))
735 );
736 match prev {
737 Some(v) => unsafe { std::env::set_var("MAKIATTO_KNOWN_HOSTS", v) },
738 None => unsafe { std::env::remove_var("MAKIATTO_KNOWN_HOSTS") },
739 }
740 }
741}