1#![doc = include_str!("../README.md")]
2
3#[doc = include_str!("../README.md")]
4#[cfg(doctest)]
5pub struct ReadmeDoctests;
6
7#[cfg(not(any(feature = "libssh", feature = "ssh2")))]
8compile_error!("Either feature \"libssh\" or \"ssh2\" must be enabled for this crate.");
9
10use std::collections::BTreeMap;
11use std::fmt;
12use std::io::{self, Write};
13use std::net::{IpAddr, SocketAddr};
14use std::path::PathBuf;
15use std::str::FromStr;
16use std::time::Duration;
17
18use async_compat::CompatExt;
19use async_trait::async_trait;
20use distant_core::net::auth::{AuthHandlerMap, DummyAuthHandler, Verifier};
21use distant_core::net::client::{Client, ClientConfig};
22use distant_core::net::common::{Host, InmemoryTransport, OneshotListener, Version};
23use distant_core::net::server::{Server, ServerRef};
24use distant_core::protocol::PROTOCOL_VERSION;
25use distant_core::{DistantApiServerHandler, DistantClient, DistantSingleKeyCredentials};
26use log::*;
27use smol::channel::Receiver as SmolReceiver;
28use tokio::sync::Mutex;
29use wezterm_ssh::{
30 ChildKiller, Config as WezConfig, MasterPty, PtySize, Session as WezSession,
31 SessionEvent as WezSessionEvent,
32};
33
34mod api;
35mod process;
36mod utils;
37
38use api::SshDistantApi;
39
40#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
44pub enum SshFamily {
45 Unix,
47
48 Windows,
50}
51
52impl SshFamily {
53 pub const fn as_static_str(&self) -> &'static str {
54 match self {
55 Self::Unix => "unix",
56 Self::Windows => "windows",
57 }
58 }
59}
60
61#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
63#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
64#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
65pub enum SshBackend {
66 #[cfg(feature = "libssh")]
68 LibSsh,
69
70 #[cfg(feature = "ssh2")]
72 Ssh2,
73}
74
75impl SshBackend {
76 pub const fn as_static_str(&self) -> &'static str {
77 match self {
78 #[cfg(feature = "libssh")]
79 Self::LibSsh => "libssh",
80
81 #[cfg(feature = "ssh2")]
82 Self::Ssh2 => "ssh2",
83 }
84 }
85}
86
87impl Default for SshBackend {
88 fn default() -> Self {
93 #[cfg(feature = "ssh2")]
94 {
95 Self::Ssh2
96 }
97
98 #[cfg(not(feature = "ssh2"))]
99 {
100 Self::LibSsh
101 }
102 }
103}
104
105impl FromStr for SshBackend {
106 type Err = &'static str;
107
108 fn from_str(s: &str) -> Result<Self, Self::Err> {
109 match s {
110 #[cfg(feature = "ssh2")]
111 s if s.trim().eq_ignore_ascii_case("ssh2") => Ok(Self::Ssh2),
112
113 #[cfg(feature = "libssh")]
114 s if s.trim().eq_ignore_ascii_case("libssh") => Ok(Self::LibSsh),
115
116 _ => Err("SSH backend must be \"libssh\" or \"ssh2\""),
117 }
118 }
119}
120
121impl fmt::Display for SshBackend {
122 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
123 match self {
124 #[cfg(feature = "libssh")]
125 Self::LibSsh => write!(f, "libssh"),
126
127 #[cfg(feature = "ssh2")]
128 Self::Ssh2 => write!(f, "ssh2"),
129 }
130 }
131}
132
133#[derive(Debug)]
135#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
136pub struct SshAuthPrompt {
137 pub prompt: String,
139
140 pub echo: bool,
143}
144
145#[derive(Debug)]
148#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
149pub struct SshAuthEvent {
150 pub username: String,
152
153 pub instructions: String,
155
156 pub prompts: Vec<SshAuthPrompt>,
158}
159
160#[derive(Clone, Debug, Default)]
162#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
163#[cfg_attr(feature = "serde", serde(default))]
164pub struct SshOpts {
165 pub backend: SshBackend,
167
168 pub identity_files: Vec<PathBuf>,
176
177 pub identities_only: Option<bool>,
182
183 pub port: Option<u16>,
185
186 pub proxy_command: Option<String>,
188
189 pub user: Option<String>,
191
192 pub user_known_hosts_files: Vec<PathBuf>,
197
198 pub verbose: bool,
200
201 pub other: BTreeMap<String, String>,
203}
204
205#[derive(Clone, Debug)]
207pub struct DistantLaunchOpts {
208 pub binary: String,
210
211 pub args: String,
213
214 pub timeout: Duration,
216}
217
218impl Default for DistantLaunchOpts {
219 fn default() -> Self {
220 Self {
221 binary: String::from("distant"),
222 args: String::new(),
223 timeout: Duration::from_secs(15),
224 }
225 }
226}
227
228#[async_trait]
230pub trait SshAuthHandler {
231 async fn on_authenticate(&self, event: SshAuthEvent) -> io::Result<Vec<String>>;
235
236 async fn on_verify_host(&self, host: &str) -> io::Result<bool>;
240
241 async fn on_banner(&self, text: &str);
244
245 async fn on_error(&self, text: &str);
247}
248
249pub struct LocalSshAuthHandler;
252
253#[async_trait]
254impl SshAuthHandler for LocalSshAuthHandler {
255 async fn on_authenticate(&self, event: SshAuthEvent) -> io::Result<Vec<String>> {
256 trace!("[local] on_authenticate({event:?})");
257 let task = tokio::task::spawn_blocking(move || {
258 if !event.username.is_empty() {
259 eprintln!("Authentication for {}", event.username);
260 }
261
262 if !event.instructions.is_empty() {
263 eprintln!("{}", event.instructions);
264 }
265
266 let mut answers = Vec::new();
267 for prompt in &event.prompts {
268 let mut prompt_lines = prompt.prompt.split('\n').collect::<Vec<_>>();
270
271 let prompt_line = prompt_lines.pop().unwrap();
273
274 for line in prompt_lines.into_iter() {
276 eprintln!("{line}");
277 }
278
279 let answer = if prompt.echo {
280 eprint!("{prompt_line}");
281 std::io::stderr().lock().flush()?;
282
283 let mut answer = String::new();
284 std::io::stdin().read_line(&mut answer)?;
285 answer
286 } else {
287 rpassword::prompt_password(prompt_line)?
288 };
289
290 answers.push(answer);
291 }
292 Ok(answers)
293 });
294
295 task.await
296 .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?
297 }
298
299 async fn on_verify_host(&self, host: &str) -> io::Result<bool> {
300 trace!("[local] on_verify_host({host})");
301 eprintln!("{host}");
302 let task = tokio::task::spawn_blocking(|| {
303 eprint!("Enter [y/N]> ");
304 std::io::stderr().lock().flush()?;
305
306 let mut answer = String::new();
307 std::io::stdin().read_line(&mut answer)?;
308
309 trace!("Verify? Answer = '{answer}'");
310 match answer.as_str().trim() {
311 "y" | "Y" | "yes" | "YES" => Ok(true),
312 _ => Ok(false),
313 }
314 });
315
316 task.await
317 .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?
318 }
319
320 async fn on_banner(&self, _text: &str) {
321 trace!("[local] on_banner({_text})");
322 }
323
324 async fn on_error(&self, _text: &str) {
325 trace!("[local] on_error({_text})");
326 }
327}
328
329pub struct Ssh {
331 session: WezSession,
332 events: SmolReceiver<WezSessionEvent>,
333 host: String,
334 port: u16,
335 authenticated: bool,
336
337 cached_family: Mutex<Option<SshFamily>>,
339}
340
341impl Ssh {
342 pub fn connect(host: impl AsRef<str>, opts: SshOpts) -> io::Result<Self> {
344 debug!(
345 "Establishing ssh connection to {} using {:?}",
346 host.as_ref(),
347 opts
348 );
349 let mut config = WezConfig::new();
350 config.add_default_config_files();
351
352 let mut config = config.for_host(host.as_ref());
354
355 if let Some(port) = opts.port.as_ref() {
357 config.insert("port".to_string(), port.to_string());
358 }
359 if let Some(user) = opts.user.as_ref() {
360 config.insert("user".to_string(), user.to_string());
361 }
362 if !opts.identity_files.is_empty() {
363 config.insert(
364 "identityfile".to_string(),
365 opts.identity_files
366 .iter()
367 .filter_map(|p| p.to_str())
368 .map(ToString::to_string)
369 .collect::<Vec<String>>()
370 .join(" "),
371 );
372 }
373 if let Some(yes) = opts.identities_only.as_ref() {
374 let value = if *yes {
375 "yes".to_string()
376 } else {
377 "no".to_string()
378 };
379 config.insert("identitiesonly".to_string(), value);
380 }
381 if let Some(cmd) = opts.proxy_command.as_ref() {
382 config.insert("proxycommand".to_string(), cmd.to_string());
383 }
384 if !opts.user_known_hosts_files.is_empty() {
385 config.insert(
386 "userknownhostsfile".to_string(),
387 opts.user_known_hosts_files
388 .iter()
389 .filter_map(|p| p.to_str())
390 .map(ToString::to_string)
391 .collect::<Vec<String>>()
392 .join(" "),
393 );
394 }
395
396 config.insert("wezterm_ssh_verbose".to_string(), opts.verbose.to_string());
398
399 config.insert("wezterm_ssh_backend".to_string(), opts.backend.to_string());
401
402 config.extend(opts.other);
404
405 let port = config
407 .get("port")
408 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Missing port"))?
409 .parse::<u16>()
410 .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?;
411
412 trace!("WezSession::connect({:?})", config);
414 let (session, events) =
415 WezSession::connect(config).map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
416
417 Ok(Self {
418 session,
419 events,
420 host: host.as_ref().to_string(),
421 port,
422 authenticated: false,
423 cached_family: Mutex::new(None),
424 })
425 }
426
427 pub fn host(&self) -> &str {
429 &self.host
430 }
431
432 pub fn port(&self) -> u16 {
434 self.port
435 }
436
437 #[inline]
438 pub fn is_authenticated(&self) -> bool {
439 self.authenticated
440 }
441
442 pub async fn authenticate(&mut self, handler: impl SshAuthHandler) -> io::Result<()> {
444 if self.authenticated {
446 return Ok(());
447 }
448
449 while let Ok(event) = self.events.recv().await {
452 match event {
453 WezSessionEvent::Banner(banner) => {
454 trace!("ssh banner: {banner:?}");
455 if let Some(banner) = banner {
456 handler.on_banner(banner.as_ref()).await;
457 }
458 }
459 WezSessionEvent::HostVerify(verify) => {
460 trace!("ssh host verify: {verify:?}");
461 let verified = handler.on_verify_host(verify.message.as_str()).await?;
462 verify
463 .answer(verified)
464 .compat()
465 .await
466 .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
467 }
468 WezSessionEvent::Authenticate(mut auth) => {
469 trace!("ssh authenticate: {auth:?}");
470 let ev = SshAuthEvent {
471 username: auth.username.clone(),
472 instructions: auth.instructions.clone(),
473 prompts: auth
474 .prompts
475 .drain(..)
476 .map(|p| SshAuthPrompt {
477 prompt: p.prompt,
478 echo: p.echo,
479 })
480 .collect(),
481 };
482
483 let answers = handler.on_authenticate(ev).await?;
484 auth.answer(answers)
485 .compat()
486 .await
487 .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
488 }
489 WezSessionEvent::Error(err) => {
490 trace!("ssh error: {err:?}");
491 handler.on_error(&err).await;
492 return Err(io::Error::new(io::ErrorKind::PermissionDenied, err));
493 }
494 WezSessionEvent::Authenticated => {
495 trace!("ssh authenticated");
496 break;
497 }
498 }
499 }
500
501 self.authenticated = true;
503
504 Ok(())
505 }
506
507 pub async fn detect_family(&self) -> io::Result<SshFamily> {
511 if !self.authenticated {
513 return Err(io::Error::new(
514 io::ErrorKind::PermissionDenied,
515 "Not authenticated",
516 ));
517 }
518
519 let mut family = self.cached_family.lock().await;
520
521 if family.is_none() {
523 let is_windows = utils::is_windows(&self.session).await?;
526
527 *family = Some(if is_windows {
528 SshFamily::Windows
529 } else {
530 SshFamily::Unix
531 });
532 }
533
534 Ok(family.unwrap())
536 }
537
538 pub async fn launch_and_connect(self, opts: DistantLaunchOpts) -> io::Result<DistantClient> {
541 trace!("ssh::launch_and_colnnnect({:?})", opts);
542
543 if !self.authenticated {
545 return Err(io::Error::new(
546 io::ErrorKind::PermissionDenied,
547 "Not authenticated",
548 ));
549 }
550
551 let timeout = opts.timeout;
552
553 debug!("Looking up host {} @ port {}", self.host, self.port);
562 let mut candidate_ips = tokio::net::lookup_host(format!("{}:{}", self.host, self.port))
563 .await
564 .map_err(|x| {
565 io::Error::new(
566 x.kind(),
567 format!("{} needs to be resolvable outside of ssh: {}", self.host, x),
568 )
569 })?
570 .map(|addr| addr.ip())
571 .collect::<Vec<IpAddr>>();
572 candidate_ips.sort_unstable();
573 candidate_ips.dedup();
574 if candidate_ips.is_empty() {
575 return Err(io::Error::new(
576 io::ErrorKind::AddrNotAvailable,
577 format!("Unable to resolve {}:{}", self.host, self.port),
578 ));
579 }
580
581 let credentials = self.launch(opts).await?;
582 let key = credentials.key;
583
584 let mut err = None;
586 for ip in candidate_ips {
587 let addr = SocketAddr::new(ip, credentials.port);
588 debug!("Attempting to connect to distant server @ {}", addr);
589 match Client::tcp(addr)
590 .auth_handler(AuthHandlerMap::new().with_static_key(key.clone()))
591 .connect_timeout(timeout)
592 .version(Version::new(
593 PROTOCOL_VERSION.major,
594 PROTOCOL_VERSION.minor,
595 PROTOCOL_VERSION.patch,
596 ))
597 .connect()
598 .await
599 {
600 Ok(client) => return Ok(client),
601 Err(x) => err = Some(x),
602 }
603 }
604
605 Err(err.expect("Err set above"))
607 }
608
609 pub async fn launch(self, opts: DistantLaunchOpts) -> io::Result<DistantSingleKeyCredentials> {
612 trace!("ssh::launch({:?})", opts);
613
614 if !self.authenticated {
616 return Err(io::Error::new(
617 io::ErrorKind::PermissionDenied,
618 "Not authenticated",
619 ));
620 }
621
622 let family = self.detect_family().await?;
623 trace!("Detected family: {}", family.as_static_str());
624
625 let host = self
626 .host()
627 .parse::<Host>()
628 .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?;
629
630 let (mut pty, mut child) = self
631 .session
632 .request_pty("xterm-256color", PtySize::default(), None, None)
633 .compat()
634 .await
635 .map_err(utils::to_other_error)?;
636
637 let mut args = vec![
639 String::from("server"),
640 String::from("listen"),
641 String::from("--daemon"),
642 String::from("--host"),
643 String::from("ssh"),
644 ];
645 args.extend(match family {
646 SshFamily::Windows => winsplit::split(&opts.args),
647 SshFamily::Unix => shell_words::split(&opts.args)
648 .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?,
649 });
650
651 let cmd = format!("{} {}", opts.binary, args.join(" "));
653 debug!("Executing {cmd}");
654 pty.write_all(format!("{cmd}\r\n").as_bytes())?;
655
656 let credentials = {
658 let mut reader = pty.try_clone_reader().map_err(utils::to_other_error)?;
660 let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1);
661 let read_task = tokio::task::spawn_blocking(move || {
662 let mut buf = [0u8; 1024];
663 while let Ok(n) = reader.read(&mut buf) {
664 if n == 0 {
665 break;
666 }
667 let _ = tx.blocking_send(buf[..n].to_vec());
668 }
669 });
670
671 let start_instant = std::time::Instant::now();
679 let timeout = opts.timeout;
680 tokio::spawn(async move {
681 let mut stdout = Vec::new();
682 loop {
683 while let Ok(bytes) = rx.try_recv() {
685 trace!("Received {} more bytes over stdout", bytes.len());
686 stdout.extend_from_slice(&bytes);
687
688 if let Some(mut credentials) =
689 DistantSingleKeyCredentials::find_lax(&String::from_utf8_lossy(&stdout))
690 {
691 credentials.host = host;
692 read_task.abort();
693 return Ok(credentials);
694 }
695 }
696
697 if start_instant.elapsed() >= timeout {
699 stdout.retain(|b| {
702 b.is_ascii() && (b.is_ascii_whitespace() || !b.is_ascii_control())
703 });
704
705 read_task.abort();
706 return Err(io::Error::new(
707 io::ErrorKind::BrokenPipe,
708 format!(
709 "Failed to spawn server: '{}'",
710 shell_words::quote(&String::from_utf8_lossy(&stdout))
711 ),
712 ));
713 }
714
715 tokio::time::sleep(Duration::from_millis(50)).await;
717 }
718 })
719 };
720
721 trace!("Waiting for credentials to appear");
723 let credentials = credentials.await??;
724 debug!("Got credentials");
725
726 drop(pty);
728 let _ = child.kill();
729
730 Ok(credentials)
731 }
732
733 pub async fn into_distant_client(self) -> io::Result<DistantClient> {
736 Ok(self.into_distant_pair().await?.0)
737 }
738
739 pub async fn into_distant_pair(self) -> io::Result<(DistantClient, ServerRef)> {
741 if !self.authenticated {
743 return Err(io::Error::new(
744 io::ErrorKind::PermissionDenied,
745 "Not authenticated",
746 ));
747 }
748
749 let Self {
750 session: wez_session,
751 ..
752 } = self;
753
754 let (t1, t2) = InmemoryTransport::pair(1);
755 let server = Server::new()
756 .handler(DistantApiServerHandler::new(SshDistantApi::new(
757 wez_session,
758 )))
759 .verifier(Verifier::none())
760 .start(OneshotListener::from_value(t2))?;
761 let client = Client::build()
762 .auth_handler(DummyAuthHandler)
763 .config(ClientConfig::default().with_maximum_silence_duration())
764 .connector(t1)
765 .connect()
766 .await?;
767 Ok((client, server))
768 }
769}