1use anyhow::{Context, Result};
28use std::path::{Path, PathBuf};
29use std::time::Duration;
30use tokio::time::timeout;
31use zeroize::Zeroizing;
32
33use super::tokio_client::AuthMethod;
34
35const AUTH_PROMPT_TIMEOUT: Duration = Duration::from_secs(30);
37
38#[cfg(not(target_os = "windows"))]
41const AGENT_TIMEOUT: Duration = Duration::from_secs(5);
42
43#[cfg(not(target_os = "windows"))]
52async fn agent_has_identities() -> bool {
53 use russh::keys::agent::client::AgentClient;
54
55 let result = timeout(AGENT_TIMEOUT, async {
56 let mut agent = AgentClient::connect_env().await?;
57 agent.request_identities().await
58 })
59 .await;
60
61 match result {
62 Ok(Ok(identities)) => {
63 let has_keys = !identities.is_empty();
64 if has_keys {
65 tracing::debug!("SSH agent has {} loaded identities", identities.len());
66 } else {
67 tracing::debug!("SSH agent is running but has no loaded identities");
68 }
69 has_keys
70 }
71 Ok(Err(e)) => {
72 tracing::warn!("Failed to communicate with SSH agent: {e}");
73 false
74 }
75 Err(_) => {
76 tracing::warn!("SSH agent operation timed out after {:?}", AGENT_TIMEOUT);
77 false
78 }
79 }
80}
81
82const MAX_USERNAME_LENGTH: usize = 256;
84
85const MAX_HOSTNAME_LENGTH: usize = 253;
87
88#[derive(Debug, Clone)]
98pub struct AuthContext {
99 pub key_path: Option<PathBuf>,
101 pub use_agent: bool,
103 pub use_password: bool,
105 pub allow_password_fallback: bool,
107 #[cfg(target_os = "macos")]
109 pub use_keychain: bool,
110 pub username: String,
112 pub host: String,
114}
115
116impl AuthContext {
117 pub fn new(username: String, host: String) -> Result<Self> {
122 if username.is_empty() {
124 anyhow::bail!("Username cannot be empty");
125 }
126 if username.len() > MAX_USERNAME_LENGTH {
127 anyhow::bail!("Username too long (max {MAX_USERNAME_LENGTH} characters)");
128 }
129 if username.contains(['/', '\0', '\n', '\r']) {
130 anyhow::bail!("Username contains invalid characters");
131 }
132
133 if host.is_empty() {
135 anyhow::bail!("Hostname cannot be empty");
136 }
137 if host.len() > MAX_HOSTNAME_LENGTH {
138 anyhow::bail!("Hostname too long (max {MAX_HOSTNAME_LENGTH} characters)");
139 }
140 if host.contains(['\0', '\n', '\r']) {
141 anyhow::bail!("Hostname contains invalid characters");
142 }
143
144 Ok(Self {
145 key_path: None,
146 use_agent: false,
147 use_password: false,
148 allow_password_fallback: false,
149 #[cfg(target_os = "macos")]
150 use_keychain: false,
151 username,
152 host,
153 })
154 }
155
156 pub fn with_key_path(mut self, key_path: Option<PathBuf>) -> Result<Self> {
162 if let Some(path) = key_path {
163 let canonical_path = path
165 .canonicalize()
166 .with_context(|| format!("Failed to resolve SSH key path: {path:?}"))?;
167
168 if !canonical_path.is_file() {
170 anyhow::bail!("SSH key path is not a file: {canonical_path:?}");
171 }
172
173 self.key_path = Some(canonical_path);
174 } else {
175 self.key_path = None;
176 }
177 Ok(self)
178 }
179
180 pub fn with_agent(mut self, use_agent: bool) -> Self {
182 self.use_agent = use_agent;
183 self
184 }
185
186 pub fn with_password(mut self, use_password: bool) -> Self {
188 self.use_password = use_password;
189 self
190 }
191
192 pub fn with_password_fallback(mut self, allow: bool) -> Self {
197 self.allow_password_fallback = allow;
198 self
199 }
200
201 #[cfg(target_os = "macos")]
205 pub fn with_keychain(mut self, use_keychain: bool) -> Self {
206 self.use_keychain = use_keychain;
207 self
208 }
209
210 pub async fn determine_method(&self) -> Result<AuthMethod> {
238 let start_time = std::time::Instant::now();
240
241 let result = self.determine_method_internal().await;
242
243 let elapsed = start_time.elapsed();
245 if elapsed < Duration::from_millis(50) {
246 tokio::time::sleep(Duration::from_millis(50) - elapsed).await;
247 }
248
249 result
250 }
251
252 async fn determine_method_internal(&self) -> Result<AuthMethod> {
253 if self.use_password {
255 return self.password_auth().await;
256 }
257
258 if self.use_agent
260 && let Some(auth) = self.agent_auth()?
261 {
262 return Ok(auth);
263 }
264
265 if let Some(ref key_path) = self.key_path {
267 return self.key_file_auth(key_path).await;
268 }
269
270 #[cfg(not(target_os = "windows"))]
273 if !self.use_agent {
274 if let Some(auth) = self.agent_auth()? {
276 tracing::debug!(
277 "Using SSH agent (auto-detected) - agent will try all registered keys"
278 );
279 return Ok(auth);
280 }
281 }
282
283 match self.default_key_auth().await {
285 Ok(auth) => Ok(auth),
286 Err(_) => {
287 if atty::is(atty::Stream::Stdin) {
290 let should_attempt_password = if self.allow_password_fallback {
293 tracing::info!(
294 "SSH key authentication failed, falling back to password authentication"
295 );
296
297 const FALLBACK_DELAY: Duration = Duration::from_secs(1);
299 tokio::time::sleep(FALLBACK_DELAY).await;
300 true
301 } else {
302 self.prompt_password_fallback_consent().await?
303 };
304
305 if should_attempt_password {
306 tracing::debug!("Attempting password authentication fallback");
307
308 tracing::warn!(
310 "Password authentication fallback attempted for {}@{} after key auth failure",
311 self.username,
312 self.host
313 );
314
315 self.password_auth().await
316 } else {
317 anyhow::bail!(
319 "SSH authentication failed: All key-based methods failed.\n\
320 \n\
321 Tried:\n\
322 - SSH agent: {}\n\
323 - Default SSH keys: Not found or not authorized\n\
324 \n\
325 User declined password authentication fallback.\n\
326 \n\
327 Solutions:\n\
328 - Use --password flag to explicitly enable password authentication\n\
329 - Start SSH agent and add keys with 'ssh-add'\n\
330 - Specify a key file with -i/--identity\n\
331 - Ensure ~/.ssh/id_ed25519 or ~/.ssh/id_rsa exists and is authorized",
332 if cfg!(target_os = "windows") {
333 "Not supported on Windows"
334 } else if std::env::var_os("SSH_AUTH_SOCK").is_some() {
335 "Available but no identities authorized"
336 } else {
337 "Not available (SSH_AUTH_SOCK not set)"
338 }
339 )
340 }
341 } else {
342 anyhow::bail!(
344 "SSH authentication failed: No authentication method available.\n\
345 \n\
346 Tried:\n\
347 - SSH agent: {}\n\
348 - Default SSH keys: Not found or not authorized\n\
349 \n\
350 Solutions:\n\
351 - Use --password for password authentication\n\
352 - Start SSH agent and add keys with 'ssh-add'\n\
353 - Specify a key file with -i/--identity\n\
354 - Ensure ~/.ssh/id_ed25519 or ~/.ssh/id_rsa exists and is authorized",
355 if cfg!(target_os = "windows") {
356 "Not supported on Windows"
357 } else if std::env::var_os("SSH_AUTH_SOCK").is_some() {
358 "Available but no identities authorized"
359 } else {
360 "Not available (SSH_AUTH_SOCK not set)"
361 }
362 )
363 }
364 }
365 }
366 }
367
368 async fn prompt_password_fallback_consent(&self) -> Result<bool> {
372 use std::io::{self, Write};
373
374 tracing::info!(
375 "All SSH key-based authentication methods failed for {}@{}",
376 self.username,
377 self.host
378 );
379
380 const FALLBACK_DELAY: Duration = Duration::from_secs(1);
383 tokio::time::sleep(FALLBACK_DELAY).await;
384
385 let consent_future = tokio::task::spawn_blocking({
387 let username = self.username.clone();
388 let host = self.host.clone();
389 move || -> Result<bool> {
390 println!("\n⚠️ SSH key authentication failed for {username}@{host}");
391 println!("Would you like to try password authentication? (yes/no): ");
392 io::stdout().flush()?;
393
394 let mut response = String::new();
395 io::stdin().read_line(&mut response)?;
396 let response = response.trim().to_lowercase();
397
398 Ok(response == "yes" || response == "y")
399 }
400 });
401
402 const CONSENT_TIMEOUT: Duration = Duration::from_secs(30);
404 timeout(CONSENT_TIMEOUT, consent_future)
405 .await
406 .context("Consent prompt timed out after 30 seconds")?
407 .context("Consent prompt task failed")?
408 }
409
410 async fn password_auth(&self) -> Result<AuthMethod> {
412 tracing::debug!("Using password authentication");
413
414 let prompt_future = tokio::task::spawn_blocking({
416 let username = self.username.clone();
417 let host = self.host.clone();
418 move || -> Result<Zeroizing<String>> {
419 let password = Zeroizing::new(
421 rpassword::prompt_password(format!("Enter password for {username}@{host}: "))
422 .with_context(|| "Failed to read password")?,
423 );
424 Ok(password)
425 }
426 });
427
428 let password = timeout(AUTH_PROMPT_TIMEOUT, prompt_future)
429 .await
430 .context("Password prompt timed out")?
431 .context("Password prompt task failed")??;
432
433 Ok(AuthMethod::with_password(&password))
434 }
435
436 #[cfg(not(target_os = "windows"))]
445 fn agent_auth(&self) -> Result<Option<AuthMethod>> {
446 match std::env::var_os("SSH_AUTH_SOCK") {
448 Some(socket_path) => {
449 let path = std::path::Path::new(&socket_path);
451 if path.exists() {
452 let has_identities = std::thread::spawn(|| {
455 tokio::runtime::Builder::new_current_thread()
456 .enable_all()
457 .build()
458 .map(|rt| rt.block_on(agent_has_identities()))
459 .unwrap_or(false)
460 })
461 .join()
462 .unwrap_or(false);
463
464 if has_identities {
465 tracing::debug!("Using SSH agent for authentication");
466 Ok(Some(AuthMethod::Agent))
467 } else {
468 tracing::debug!(
469 "SSH agent is running but has no loaded identities, falling back to key files"
470 );
471 Ok(None)
472 }
473 } else {
474 tracing::warn!("SSH_AUTH_SOCK points to non-existent socket");
475 Ok(None)
476 }
477 }
478 None => {
479 tracing::warn!(
480 "SSH agent requested but SSH_AUTH_SOCK environment variable not set"
481 );
482 Ok(None)
483 }
484 }
485 }
486
487 #[cfg(target_os = "windows")]
489 fn agent_auth(&self) -> Result<Option<AuthMethod>> {
490 anyhow::bail!("SSH agent authentication is not supported on Windows");
491 }
492
493 fn is_key_encrypted(key_contents: &str) -> bool {
497 key_contents.contains("ENCRYPTED")
498 || key_contents.contains("Proc-Type: 4,ENCRYPTED")
499 || key_contents.contains("DEK-Info:") }
501
502 async fn key_file_auth(&self, key_path: &Path) -> Result<AuthMethod> {
504 tracing::debug!("Authenticating with key: {:?}", key_path);
505
506 let key_contents = tokio::fs::read_to_string(key_path)
508 .await
509 .with_context(|| format!("Failed to read SSH key file: {key_path:?}"))?;
510
511 let passphrase = if Self::is_key_encrypted(&key_contents) {
512 tracing::debug!("Detected encrypted SSH key");
513
514 #[cfg(target_os = "macos")]
516 let keychain_passphrase = if self.use_keychain {
517 tracing::debug!("Attempting to retrieve passphrase from Keychain");
518 match super::keychain_macos::retrieve_passphrase(key_path).await {
519 Ok(Some(pass)) => {
520 tracing::info!("Successfully retrieved passphrase from Keychain");
521 Some(pass)
522 }
523 Ok(None) => {
524 tracing::debug!("No passphrase found in Keychain");
525 None
526 }
527 Err(err) => {
528 tracing::warn!("Failed to retrieve passphrase from Keychain: {err}");
529 None
530 }
531 }
532 } else {
533 None
534 };
535
536 #[cfg(not(target_os = "macos"))]
537 let keychain_passphrase: Option<Zeroizing<String>> = None;
538
539 if let Some(pass) = keychain_passphrase {
541 Some(pass)
542 } else {
543 tracing::debug!("Prompting for passphrase");
544
545 let key_path_str = key_path.display().to_string();
547 let prompt_future =
548 tokio::task::spawn_blocking(move || -> Result<Zeroizing<String>> {
549 let pass = Zeroizing::new(
551 rpassword::prompt_password(format!(
552 "Enter passphrase for key {key_path_str}: "
553 ))
554 .with_context(|| "Failed to read passphrase")?,
555 );
556 Ok(pass)
557 });
558
559 let pass = timeout(AUTH_PROMPT_TIMEOUT, prompt_future)
560 .await
561 .context("Passphrase prompt timed out")?
562 .context("Passphrase prompt task failed")??;
563
564 #[cfg(target_os = "macos")]
566 if self.use_keychain {
567 tracing::debug!("Storing passphrase in Keychain");
568 if let Err(err) = super::keychain_macos::store_passphrase(key_path, &pass).await
569 {
570 tracing::warn!("Failed to store passphrase in Keychain: {err}");
571 } else {
573 tracing::info!("Successfully stored passphrase in Keychain");
574 }
575 }
576
577 Some(pass)
578 }
579 } else {
580 None
581 };
582
583 drop(key_contents);
585
586 Ok(AuthMethod::with_key_file(
587 key_path,
588 passphrase.as_ref().map(|p| p.as_str()),
589 ))
590 }
591
592 async fn default_key_auth(&self) -> Result<AuthMethod> {
594 let home_dir = dirs::home_dir()
596 .ok_or_else(|| anyhow::anyhow!("Could not determine home directory"))?;
597
598 let ssh_dir = home_dir.join(".ssh");
599
600 if !ssh_dir.is_dir() {
602 anyhow::bail!(
603 "SSH directory not found: {ssh_dir:?}\n\
604 Please ensure ~/.ssh directory exists with proper permissions."
605 );
606 }
607
608 let default_keys = [
610 ssh_dir.join("id_ed25519"),
611 ssh_dir.join("id_rsa"),
612 ssh_dir.join("id_ecdsa"),
613 ssh_dir.join("id_dsa"),
614 ];
615
616 for default_key in &default_keys {
617 if default_key.exists() && default_key.is_file() {
618 let canonical_key = default_key
620 .canonicalize()
621 .with_context(|| format!("Failed to resolve key path: {default_key:?}"))?;
622
623 tracing::debug!("Using default key: {:?}", canonical_key);
624
625 let key_contents = tokio::fs::read_to_string(&canonical_key)
627 .await
628 .with_context(|| format!("Failed to read SSH key file: {canonical_key:?}"))?;
629
630 let passphrase = if Self::is_key_encrypted(&key_contents) {
631 tracing::debug!("Detected encrypted SSH key");
632
633 #[cfg(target_os = "macos")]
635 let keychain_passphrase = if self.use_keychain {
636 tracing::debug!("Attempting to retrieve passphrase from Keychain");
637 match super::keychain_macos::retrieve_passphrase(&canonical_key).await {
638 Ok(Some(pass)) => {
639 tracing::info!("Successfully retrieved passphrase from Keychain");
640 Some(pass)
641 }
642 Ok(None) => {
643 tracing::debug!("No passphrase found in Keychain");
644 None
645 }
646 Err(err) => {
647 tracing::warn!(
648 "Failed to retrieve passphrase from Keychain: {err}"
649 );
650 None
651 }
652 }
653 } else {
654 None
655 };
656
657 #[cfg(not(target_os = "macos"))]
658 let keychain_passphrase: Option<Zeroizing<String>> = None;
659
660 if let Some(pass) = keychain_passphrase {
662 Some(pass)
663 } else {
664 tracing::debug!("Prompting for passphrase");
665
666 let key_path_str = canonical_key.display().to_string();
667 let prompt_future =
668 tokio::task::spawn_blocking(move || -> Result<Zeroizing<String>> {
669 let pass = Zeroizing::new(
670 rpassword::prompt_password(format!(
671 "Enter passphrase for key {key_path_str}: "
672 ))
673 .with_context(|| "Failed to read passphrase")?,
674 );
675 Ok(pass)
676 });
677
678 let pass = timeout(AUTH_PROMPT_TIMEOUT, prompt_future)
679 .await
680 .context("Passphrase prompt timed out")?
681 .context("Passphrase prompt task failed")??;
682
683 #[cfg(target_os = "macos")]
685 if self.use_keychain {
686 tracing::debug!("Storing passphrase in Keychain");
687 if let Err(err) =
688 super::keychain_macos::store_passphrase(&canonical_key, &pass).await
689 {
690 tracing::warn!("Failed to store passphrase in Keychain: {err}");
691 } else {
693 tracing::info!("Successfully stored passphrase in Keychain");
694 }
695 }
696
697 Some(pass)
698 }
699 } else {
700 None
701 };
702
703 drop(key_contents);
705
706 return Ok(AuthMethod::with_key_file(
707 &canonical_key,
708 passphrase.as_ref().map(|p| p.as_str()),
709 ));
710 }
711 }
712
713 anyhow::bail!(
715 "SSH authentication failed: No authentication method available.\n\
716 \n\
717 Tried:\n\
718 - SSH agent: {}\n\
719 - Default SSH keys: Not found\n\
720 \n\
721 Solutions:\n\
722 - Use --password for password authentication\n\
723 - Start SSH agent and add keys with 'ssh-add'\n\
724 - Specify a key file with -i/--identity\n\
725 - Create a default SSH key with 'ssh-keygen'",
726 if cfg!(target_os = "windows") {
727 "Not supported on Windows"
728 } else if std::env::var_os("SSH_AUTH_SOCK").is_some() {
729 "Available but no identities"
730 } else {
731 "Not available (SSH_AUTH_SOCK not set)"
732 }
733 )
734 }
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740 use crate::test_helpers::EnvGuard;
741 use serial_test::serial;
742 use tempfile::TempDir;
743
744 #[tokio::test]
745 async fn test_auth_context_creation() {
746 let ctx = AuthContext::new("testuser".to_string(), "testhost".to_string()).unwrap();
747 assert_eq!(ctx.username, "testuser");
748 assert_eq!(ctx.host, "testhost");
749 assert_eq!(ctx.key_path, None);
750 assert!(!ctx.use_agent);
751 assert!(!ctx.use_password);
752 }
753
754 #[tokio::test]
755 async fn test_auth_context_validation() {
756 let result = AuthContext::new("".to_string(), "host".to_string());
758 assert!(result.is_err());
759
760 let result = AuthContext::new("user/name".to_string(), "host".to_string());
762 assert!(result.is_err());
763
764 let result = AuthContext::new("user".to_string(), "".to_string());
766 assert!(result.is_err());
767
768 let long_username = "a".repeat(MAX_USERNAME_LENGTH + 1);
770 let result = AuthContext::new(long_username, "host".to_string());
771 assert!(result.is_err());
772 }
773
774 #[tokio::test]
775 async fn test_auth_context_with_key_path() {
776 let temp_dir = TempDir::new().unwrap();
777 let key_path = temp_dir.path().join("test_key");
778 std::fs::write(&key_path, "fake key content").unwrap();
779
780 let ctx = AuthContext::new("user".to_string(), "host".to_string())
781 .unwrap()
782 .with_key_path(Some(key_path.clone()))
783 .unwrap();
784
785 assert!(ctx.key_path.is_some());
787 assert!(ctx.key_path.unwrap().is_absolute());
788 }
789
790 #[tokio::test]
791 async fn test_auth_context_with_invalid_key_path() {
792 let temp_dir = TempDir::new().unwrap();
793
794 let result = AuthContext::new("user".to_string(), "host".to_string())
796 .unwrap()
797 .with_key_path(Some(temp_dir.path().to_path_buf()));
798
799 assert!(result.is_err());
800 }
801
802 #[tokio::test]
803 async fn test_auth_context_with_agent() {
804 let ctx = AuthContext::new("user".to_string(), "host".to_string())
805 .unwrap()
806 .with_agent(true);
807
808 assert!(ctx.use_agent);
809 }
810
811 #[tokio::test]
812 async fn test_auth_context_with_password() {
813 let ctx = AuthContext::new("user".to_string(), "host".to_string())
814 .unwrap()
815 .with_password(true);
816
817 assert!(ctx.use_password);
818 }
819
820 #[tokio::test]
821 async fn test_is_key_encrypted() {
822 assert!(AuthContext::is_key_encrypted(
823 "-----BEGIN ENCRYPTED PRIVATE KEY-----"
824 ));
825 assert!(AuthContext::is_key_encrypted("Proc-Type: 4,ENCRYPTED"));
826 assert!(AuthContext::is_key_encrypted("DEK-Info: AES-128-CBC"));
827 assert!(!AuthContext::is_key_encrypted(
828 "-----BEGIN PRIVATE KEY-----"
829 ));
830 assert!(!AuthContext::is_key_encrypted("ssh-rsa AAAAB3..."));
831 }
832
833 #[tokio::test]
834 async fn test_determine_method_with_key_file() {
835 let temp_dir = TempDir::new().unwrap();
836 let key_path = temp_dir.path().join("test_key");
837 std::fs::write(
838 &key_path,
839 "-----BEGIN PRIVATE KEY-----\nfake key content\n-----END PRIVATE KEY-----",
840 )
841 .unwrap();
842
843 let ctx = AuthContext::new("user".to_string(), "host".to_string())
844 .unwrap()
845 .with_key_path(Some(key_path.clone()))
846 .unwrap();
847
848 let auth = ctx.determine_method().await.unwrap();
849
850 match auth {
851 AuthMethod::PrivateKeyFile { key_file_path, .. } => {
852 assert!(key_file_path.is_absolute());
854 }
855 _ => panic!("Expected PrivateKeyFile auth method"),
856 }
857 }
858
859 #[cfg(not(target_os = "windows"))]
860 #[tokio::test]
861 #[serial]
862 async fn test_agent_auth_with_invalid_socket() {
863 let _sock = EnvGuard::set("SSH_AUTH_SOCK", "/tmp/nonexistent-ssh-agent.sock");
865
866 let ctx = AuthContext::new("user".to_string(), "host".to_string())
867 .unwrap()
868 .with_agent(true);
869
870 let auth = ctx.agent_auth().unwrap();
872 assert!(auth.is_none());
873 }
874
875 #[tokio::test]
876 async fn test_timing_attack_mitigation() {
877 let ctx = AuthContext::new("user".to_string(), "host".to_string()).unwrap();
878
879 let start = std::time::Instant::now();
881 let _ = ctx.determine_method().await;
882 let duration = start.elapsed();
883
884 assert!(duration >= Duration::from_millis(50));
886 }
887
888 #[tokio::test]
889 #[serial]
890 async fn test_password_fallback_in_non_interactive() {
891 let temp_dir = TempDir::new().unwrap();
893 let ssh_dir = temp_dir.path().join(".ssh");
894 std::fs::create_dir_all(&ssh_dir).unwrap();
895 let _home = EnvGuard::set("HOME", temp_dir.path().to_str().unwrap());
899 let _sock = EnvGuard::remove("SSH_AUTH_SOCK");
900
901 let ctx = AuthContext::new("user".to_string(), "host".to_string()).unwrap();
902
903 let result = ctx.determine_method().await;
905 assert!(result.is_err());
906
907 let error_msg = result.unwrap_err().to_string();
909 assert!(error_msg.contains("authentication"));
910 }
911}