1use anyhow::{Context, Result};
28use std::path::{Path, PathBuf};
29use std::time::Duration;
30use tokio::time::timeout;
31use zeroize::Zeroizing;
32
33use super::tokio_client::AuthMethod;
34
35const AUTH_PROMPT_TIMEOUT: Duration = Duration::from_secs(30);
37
38const MAX_USERNAME_LENGTH: usize = 256;
40
41const MAX_HOSTNAME_LENGTH: usize = 253;
43
44#[derive(Debug, Clone)]
54pub struct AuthContext {
55 pub key_path: Option<PathBuf>,
57 pub use_agent: bool,
59 pub use_password: bool,
61 pub allow_password_fallback: bool,
63 #[cfg(target_os = "macos")]
65 pub use_keychain: bool,
66 pub username: String,
68 pub host: String,
70}
71
72impl AuthContext {
73 pub fn new(username: String, host: String) -> Result<Self> {
78 if username.is_empty() {
80 anyhow::bail!("Username cannot be empty");
81 }
82 if username.len() > MAX_USERNAME_LENGTH {
83 anyhow::bail!("Username too long (max {MAX_USERNAME_LENGTH} characters)");
84 }
85 if username.contains(['/', '\0', '\n', '\r']) {
86 anyhow::bail!("Username contains invalid characters");
87 }
88
89 if host.is_empty() {
91 anyhow::bail!("Hostname cannot be empty");
92 }
93 if host.len() > MAX_HOSTNAME_LENGTH {
94 anyhow::bail!("Hostname too long (max {MAX_HOSTNAME_LENGTH} characters)");
95 }
96 if host.contains(['\0', '\n', '\r']) {
97 anyhow::bail!("Hostname contains invalid characters");
98 }
99
100 Ok(Self {
101 key_path: None,
102 use_agent: false,
103 use_password: false,
104 allow_password_fallback: false,
105 #[cfg(target_os = "macos")]
106 use_keychain: false,
107 username,
108 host,
109 })
110 }
111
112 pub fn with_key_path(mut self, key_path: Option<PathBuf>) -> Result<Self> {
118 if let Some(path) = key_path {
119 let canonical_path = path
121 .canonicalize()
122 .with_context(|| format!("Failed to resolve SSH key path: {path:?}"))?;
123
124 if !canonical_path.is_file() {
126 anyhow::bail!("SSH key path is not a file: {canonical_path:?}");
127 }
128
129 self.key_path = Some(canonical_path);
130 } else {
131 self.key_path = None;
132 }
133 Ok(self)
134 }
135
136 pub fn with_agent(mut self, use_agent: bool) -> Self {
138 self.use_agent = use_agent;
139 self
140 }
141
142 pub fn with_password(mut self, use_password: bool) -> Self {
144 self.use_password = use_password;
145 self
146 }
147
148 pub fn with_password_fallback(mut self, allow: bool) -> Self {
153 self.allow_password_fallback = allow;
154 self
155 }
156
157 #[cfg(target_os = "macos")]
161 pub fn with_keychain(mut self, use_keychain: bool) -> Self {
162 self.use_keychain = use_keychain;
163 self
164 }
165
166 pub async fn determine_method(&self) -> Result<AuthMethod> {
194 let start_time = std::time::Instant::now();
196
197 let result = self.determine_method_internal().await;
198
199 let elapsed = start_time.elapsed();
201 if elapsed < Duration::from_millis(50) {
202 tokio::time::sleep(Duration::from_millis(50) - elapsed).await;
203 }
204
205 result
206 }
207
208 async fn determine_method_internal(&self) -> Result<AuthMethod> {
209 if self.use_password {
211 return self.password_auth().await;
212 }
213
214 if self.use_agent {
216 if let Some(auth) = self.agent_auth()? {
217 return Ok(auth);
218 }
219 }
220
221 if let Some(ref key_path) = self.key_path {
223 return self.key_file_auth(key_path).await;
224 }
225
226 #[cfg(not(target_os = "windows"))]
228 if self.use_agent {
229 if let Some(auth) = self.agent_auth()? {
230 return Ok(auth);
231 }
232 }
233
234 match self.default_key_auth().await {
236 Ok(auth) => Ok(auth),
237 Err(_) => {
238 if atty::is(atty::Stream::Stdin) {
241 let should_attempt_password = if self.allow_password_fallback {
244 tracing::info!("SSH key authentication failed, falling back to password authentication");
245
246 const FALLBACK_DELAY: Duration = Duration::from_secs(1);
248 tokio::time::sleep(FALLBACK_DELAY).await;
249 true
250 } else {
251 self.prompt_password_fallback_consent().await?
252 };
253
254 if should_attempt_password {
255 tracing::debug!("Attempting password authentication fallback");
256
257 tracing::warn!(
259 "Password authentication fallback attempted for {}@{} after key auth failure",
260 self.username,
261 self.host
262 );
263
264 self.password_auth().await
265 } else {
266 anyhow::bail!(
268 "SSH authentication failed: All key-based methods failed.\n\
269 \n\
270 Tried:\n\
271 - SSH agent: {}\n\
272 - Default SSH keys: Not found or not authorized\n\
273 \n\
274 User declined password authentication fallback.\n\
275 \n\
276 Solutions:\n\
277 - Use --password flag to explicitly enable password authentication\n\
278 - Start SSH agent and add keys with 'ssh-add'\n\
279 - Specify a key file with -i/--identity\n\
280 - Ensure ~/.ssh/id_ed25519 or ~/.ssh/id_rsa exists and is authorized",
281 if cfg!(target_os = "windows") {
282 "Not supported on Windows"
283 } else if std::env::var_os("SSH_AUTH_SOCK").is_some() {
284 "Available but no identities authorized"
285 } else {
286 "Not available (SSH_AUTH_SOCK not set)"
287 }
288 )
289 }
290 } else {
291 anyhow::bail!(
293 "SSH authentication failed: No authentication method available.\n\
294 \n\
295 Tried:\n\
296 - SSH agent: {}\n\
297 - Default SSH keys: Not found or not authorized\n\
298 \n\
299 Solutions:\n\
300 - Use --password for password authentication\n\
301 - Start SSH agent and add keys with 'ssh-add'\n\
302 - Specify a key file with -i/--identity\n\
303 - Ensure ~/.ssh/id_ed25519 or ~/.ssh/id_rsa exists and is authorized",
304 if cfg!(target_os = "windows") {
305 "Not supported on Windows"
306 } else if std::env::var_os("SSH_AUTH_SOCK").is_some() {
307 "Available but no identities authorized"
308 } else {
309 "Not available (SSH_AUTH_SOCK not set)"
310 }
311 )
312 }
313 }
314 }
315 }
316
317 async fn prompt_password_fallback_consent(&self) -> Result<bool> {
321 use std::io::{self, Write};
322
323 tracing::info!(
324 "All SSH key-based authentication methods failed for {}@{}",
325 self.username,
326 self.host
327 );
328
329 const FALLBACK_DELAY: Duration = Duration::from_secs(1);
332 tokio::time::sleep(FALLBACK_DELAY).await;
333
334 let consent_future = tokio::task::spawn_blocking({
336 let username = self.username.clone();
337 let host = self.host.clone();
338 move || -> Result<bool> {
339 println!("\n⚠️ SSH key authentication failed for {username}@{host}");
340 println!("Would you like to try password authentication? (yes/no): ");
341 io::stdout().flush()?;
342
343 let mut response = String::new();
344 io::stdin().read_line(&mut response)?;
345 let response = response.trim().to_lowercase();
346
347 Ok(response == "yes" || response == "y")
348 }
349 });
350
351 const CONSENT_TIMEOUT: Duration = Duration::from_secs(30);
353 timeout(CONSENT_TIMEOUT, consent_future)
354 .await
355 .context("Consent prompt timed out after 30 seconds")?
356 .context("Consent prompt task failed")?
357 }
358
359 async fn password_auth(&self) -> Result<AuthMethod> {
361 tracing::debug!("Using password authentication");
362
363 let prompt_future = tokio::task::spawn_blocking({
365 let username = self.username.clone();
366 let host = self.host.clone();
367 move || -> Result<Zeroizing<String>> {
368 let password = Zeroizing::new(
370 rpassword::prompt_password(format!("Enter password for {username}@{host}: "))
371 .with_context(|| "Failed to read password")?,
372 );
373 Ok(password)
374 }
375 });
376
377 let password = timeout(AUTH_PROMPT_TIMEOUT, prompt_future)
378 .await
379 .context("Password prompt timed out")?
380 .context("Password prompt task failed")??;
381
382 Ok(AuthMethod::with_password(&password))
383 }
384
385 #[cfg(not(target_os = "windows"))]
387 fn agent_auth(&self) -> Result<Option<AuthMethod>> {
388 match std::env::var_os("SSH_AUTH_SOCK") {
390 Some(socket_path) => {
391 let path = std::path::Path::new(&socket_path);
393 if path.exists() {
394 tracing::debug!("Using SSH agent for authentication");
395 Ok(Some(AuthMethod::Agent))
396 } else {
397 tracing::warn!("SSH_AUTH_SOCK points to non-existent socket");
398 Ok(None)
399 }
400 }
401 None => {
402 tracing::warn!(
403 "SSH agent requested but SSH_AUTH_SOCK environment variable not set"
404 );
405 Ok(None)
406 }
407 }
408 }
409
410 #[cfg(target_os = "windows")]
412 fn agent_auth(&self) -> Result<Option<AuthMethod>> {
413 anyhow::bail!("SSH agent authentication is not supported on Windows");
414 }
415
416 fn is_key_encrypted(key_contents: &str) -> bool {
420 key_contents.contains("ENCRYPTED")
421 || key_contents.contains("Proc-Type: 4,ENCRYPTED")
422 || key_contents.contains("DEK-Info:") }
424
425 async fn key_file_auth(&self, key_path: &Path) -> Result<AuthMethod> {
427 tracing::debug!("Authenticating with key: {:?}", key_path);
428
429 let key_contents = tokio::fs::read_to_string(key_path)
431 .await
432 .with_context(|| format!("Failed to read SSH key file: {key_path:?}"))?;
433
434 let passphrase = if Self::is_key_encrypted(&key_contents) {
435 tracing::debug!("Detected encrypted SSH key");
436
437 #[cfg(target_os = "macos")]
439 let keychain_passphrase = if self.use_keychain {
440 tracing::debug!("Attempting to retrieve passphrase from Keychain");
441 match super::keychain_macos::retrieve_passphrase(key_path).await {
442 Ok(Some(pass)) => {
443 tracing::info!("Successfully retrieved passphrase from Keychain");
444 Some(pass)
445 }
446 Ok(None) => {
447 tracing::debug!("No passphrase found in Keychain");
448 None
449 }
450 Err(err) => {
451 tracing::warn!("Failed to retrieve passphrase from Keychain: {err}");
452 None
453 }
454 }
455 } else {
456 None
457 };
458
459 #[cfg(not(target_os = "macos"))]
460 let keychain_passphrase: Option<Zeroizing<String>> = None;
461
462 if let Some(pass) = keychain_passphrase {
464 Some(pass)
465 } else {
466 tracing::debug!("Prompting for passphrase");
467
468 let key_path_str = key_path.display().to_string();
470 let prompt_future =
471 tokio::task::spawn_blocking(move || -> Result<Zeroizing<String>> {
472 let pass = Zeroizing::new(
474 rpassword::prompt_password(format!(
475 "Enter passphrase for key {key_path_str}: "
476 ))
477 .with_context(|| "Failed to read passphrase")?,
478 );
479 Ok(pass)
480 });
481
482 let pass = timeout(AUTH_PROMPT_TIMEOUT, prompt_future)
483 .await
484 .context("Passphrase prompt timed out")?
485 .context("Passphrase prompt task failed")??;
486
487 #[cfg(target_os = "macos")]
489 if self.use_keychain {
490 tracing::debug!("Storing passphrase in Keychain");
491 if let Err(err) = super::keychain_macos::store_passphrase(key_path, &pass).await
492 {
493 tracing::warn!("Failed to store passphrase in Keychain: {err}");
494 } else {
496 tracing::info!("Successfully stored passphrase in Keychain");
497 }
498 }
499
500 Some(pass)
501 }
502 } else {
503 None
504 };
505
506 drop(key_contents);
508
509 Ok(AuthMethod::with_key_file(
510 key_path,
511 passphrase.as_ref().map(|p| p.as_str()),
512 ))
513 }
514
515 async fn default_key_auth(&self) -> Result<AuthMethod> {
517 let home_dir = dirs::home_dir()
519 .ok_or_else(|| anyhow::anyhow!("Could not determine home directory"))?;
520
521 let ssh_dir = home_dir.join(".ssh");
522
523 if !ssh_dir.is_dir() {
525 anyhow::bail!(
526 "SSH directory not found: {ssh_dir:?}\n\
527 Please ensure ~/.ssh directory exists with proper permissions."
528 );
529 }
530
531 let default_keys = [
533 ssh_dir.join("id_ed25519"),
534 ssh_dir.join("id_rsa"),
535 ssh_dir.join("id_ecdsa"),
536 ssh_dir.join("id_dsa"),
537 ];
538
539 for default_key in &default_keys {
540 if default_key.exists() && default_key.is_file() {
541 let canonical_key = default_key
543 .canonicalize()
544 .with_context(|| format!("Failed to resolve key path: {default_key:?}"))?;
545
546 tracing::debug!("Using default key: {:?}", canonical_key);
547
548 let key_contents = tokio::fs::read_to_string(&canonical_key)
550 .await
551 .with_context(|| format!("Failed to read SSH key file: {canonical_key:?}"))?;
552
553 let passphrase = if Self::is_key_encrypted(&key_contents) {
554 tracing::debug!("Detected encrypted SSH key");
555
556 #[cfg(target_os = "macos")]
558 let keychain_passphrase = if self.use_keychain {
559 tracing::debug!("Attempting to retrieve passphrase from Keychain");
560 match super::keychain_macos::retrieve_passphrase(&canonical_key).await {
561 Ok(Some(pass)) => {
562 tracing::info!("Successfully retrieved passphrase from Keychain");
563 Some(pass)
564 }
565 Ok(None) => {
566 tracing::debug!("No passphrase found in Keychain");
567 None
568 }
569 Err(err) => {
570 tracing::warn!(
571 "Failed to retrieve passphrase from Keychain: {err}"
572 );
573 None
574 }
575 }
576 } else {
577 None
578 };
579
580 #[cfg(not(target_os = "macos"))]
581 let keychain_passphrase: Option<Zeroizing<String>> = None;
582
583 if let Some(pass) = keychain_passphrase {
585 Some(pass)
586 } else {
587 tracing::debug!("Prompting for passphrase");
588
589 let key_path_str = canonical_key.display().to_string();
590 let prompt_future =
591 tokio::task::spawn_blocking(move || -> Result<Zeroizing<String>> {
592 let pass = Zeroizing::new(
593 rpassword::prompt_password(format!(
594 "Enter passphrase for key {key_path_str}: "
595 ))
596 .with_context(|| "Failed to read passphrase")?,
597 );
598 Ok(pass)
599 });
600
601 let pass = timeout(AUTH_PROMPT_TIMEOUT, prompt_future)
602 .await
603 .context("Passphrase prompt timed out")?
604 .context("Passphrase prompt task failed")??;
605
606 #[cfg(target_os = "macos")]
608 if self.use_keychain {
609 tracing::debug!("Storing passphrase in Keychain");
610 if let Err(err) =
611 super::keychain_macos::store_passphrase(&canonical_key, &pass).await
612 {
613 tracing::warn!("Failed to store passphrase in Keychain: {err}");
614 } else {
616 tracing::info!("Successfully stored passphrase in Keychain");
617 }
618 }
619
620 Some(pass)
621 }
622 } else {
623 None
624 };
625
626 drop(key_contents);
628
629 return Ok(AuthMethod::with_key_file(
630 &canonical_key,
631 passphrase.as_ref().map(|p| p.as_str()),
632 ));
633 }
634 }
635
636 anyhow::bail!(
638 "SSH authentication failed: No authentication method available.\n\
639 \n\
640 Tried:\n\
641 - SSH agent: {}\n\
642 - Default SSH keys: Not found\n\
643 \n\
644 Solutions:\n\
645 - Use --password for password authentication\n\
646 - Start SSH agent and add keys with 'ssh-add'\n\
647 - Specify a key file with -i/--identity\n\
648 - Create a default SSH key with 'ssh-keygen'",
649 if cfg!(target_os = "windows") {
650 "Not supported on Windows"
651 } else if std::env::var_os("SSH_AUTH_SOCK").is_some() {
652 "Available but no identities"
653 } else {
654 "Not available (SSH_AUTH_SOCK not set)"
655 }
656 )
657 }
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663 use tempfile::TempDir;
664
665 #[tokio::test]
666 async fn test_auth_context_creation() {
667 let ctx = AuthContext::new("testuser".to_string(), "testhost".to_string()).unwrap();
668 assert_eq!(ctx.username, "testuser");
669 assert_eq!(ctx.host, "testhost");
670 assert_eq!(ctx.key_path, None);
671 assert!(!ctx.use_agent);
672 assert!(!ctx.use_password);
673 }
674
675 #[tokio::test]
676 async fn test_auth_context_validation() {
677 let result = AuthContext::new("".to_string(), "host".to_string());
679 assert!(result.is_err());
680
681 let result = AuthContext::new("user/name".to_string(), "host".to_string());
683 assert!(result.is_err());
684
685 let result = AuthContext::new("user".to_string(), "".to_string());
687 assert!(result.is_err());
688
689 let long_username = "a".repeat(MAX_USERNAME_LENGTH + 1);
691 let result = AuthContext::new(long_username, "host".to_string());
692 assert!(result.is_err());
693 }
694
695 #[tokio::test]
696 async fn test_auth_context_with_key_path() {
697 let temp_dir = TempDir::new().unwrap();
698 let key_path = temp_dir.path().join("test_key");
699 std::fs::write(&key_path, "fake key content").unwrap();
700
701 let ctx = AuthContext::new("user".to_string(), "host".to_string())
702 .unwrap()
703 .with_key_path(Some(key_path.clone()))
704 .unwrap();
705
706 assert!(ctx.key_path.is_some());
708 assert!(ctx.key_path.unwrap().is_absolute());
709 }
710
711 #[tokio::test]
712 async fn test_auth_context_with_invalid_key_path() {
713 let temp_dir = TempDir::new().unwrap();
714
715 let result = AuthContext::new("user".to_string(), "host".to_string())
717 .unwrap()
718 .with_key_path(Some(temp_dir.path().to_path_buf()));
719
720 assert!(result.is_err());
721 }
722
723 #[tokio::test]
724 async fn test_auth_context_with_agent() {
725 let ctx = AuthContext::new("user".to_string(), "host".to_string())
726 .unwrap()
727 .with_agent(true);
728
729 assert!(ctx.use_agent);
730 }
731
732 #[tokio::test]
733 async fn test_auth_context_with_password() {
734 let ctx = AuthContext::new("user".to_string(), "host".to_string())
735 .unwrap()
736 .with_password(true);
737
738 assert!(ctx.use_password);
739 }
740
741 #[tokio::test]
742 async fn test_is_key_encrypted() {
743 assert!(AuthContext::is_key_encrypted(
744 "-----BEGIN ENCRYPTED PRIVATE KEY-----"
745 ));
746 assert!(AuthContext::is_key_encrypted("Proc-Type: 4,ENCRYPTED"));
747 assert!(AuthContext::is_key_encrypted("DEK-Info: AES-128-CBC"));
748 assert!(!AuthContext::is_key_encrypted(
749 "-----BEGIN PRIVATE KEY-----"
750 ));
751 assert!(!AuthContext::is_key_encrypted("ssh-rsa AAAAB3..."));
752 }
753
754 #[tokio::test]
755 async fn test_determine_method_with_key_file() {
756 let temp_dir = TempDir::new().unwrap();
757 let key_path = temp_dir.path().join("test_key");
758 std::fs::write(
759 &key_path,
760 "-----BEGIN PRIVATE KEY-----\nfake key content\n-----END PRIVATE KEY-----",
761 )
762 .unwrap();
763
764 let ctx = AuthContext::new("user".to_string(), "host".to_string())
765 .unwrap()
766 .with_key_path(Some(key_path.clone()))
767 .unwrap();
768
769 let auth = ctx.determine_method().await.unwrap();
770
771 match auth {
772 AuthMethod::PrivateKeyFile { key_file_path, .. } => {
773 assert!(key_file_path.is_absolute());
775 }
776 _ => panic!("Expected PrivateKeyFile auth method"),
777 }
778 }
779
780 #[cfg(not(target_os = "windows"))]
781 #[tokio::test]
782 async fn test_agent_auth_with_invalid_socket() {
783 std::env::set_var("SSH_AUTH_SOCK", "/tmp/nonexistent-ssh-agent.sock");
785
786 let ctx = AuthContext::new("user".to_string(), "host".to_string())
787 .unwrap()
788 .with_agent(true);
789
790 let auth = ctx.agent_auth().unwrap();
792 assert!(auth.is_none());
793
794 std::env::remove_var("SSH_AUTH_SOCK");
796 }
797
798 #[tokio::test]
799 async fn test_timing_attack_mitigation() {
800 let ctx = AuthContext::new("user".to_string(), "host".to_string()).unwrap();
801
802 let start = std::time::Instant::now();
804 let _ = ctx.determine_method().await;
805 let duration = start.elapsed();
806
807 assert!(duration >= Duration::from_millis(50));
809 }
810
811 #[tokio::test]
812 async fn test_password_fallback_in_non_interactive() {
813 let original_home = std::env::var("HOME").ok();
815 let original_ssh_auth_sock = std::env::var("SSH_AUTH_SOCK").ok();
816
817 let temp_dir = TempDir::new().unwrap();
819 let ssh_dir = temp_dir.path().join(".ssh");
820 std::fs::create_dir_all(&ssh_dir).unwrap();
821 std::env::set_var("HOME", temp_dir.path().to_str().unwrap());
825 std::env::remove_var("SSH_AUTH_SOCK");
826
827 let ctx = AuthContext::new("user".to_string(), "host".to_string()).unwrap();
828
829 let result = ctx.determine_method().await;
831 assert!(result.is_err());
832
833 let error_msg = result.unwrap_err().to_string();
835 assert!(error_msg.contains("authentication"));
836
837 if let Some(home) = original_home {
839 std::env::set_var("HOME", home);
840 } else {
841 std::env::remove_var("HOME");
842 }
843 if let Some(sock) = original_ssh_auth_sock {
844 std::env::set_var("SSH_AUTH_SOCK", sock);
845 }
846 }
847}