1use std::collections::HashMap;
7use std::io;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use parking_lot::RwLock;
13use tracing::{debug, info, warn};
14
15use super::handler::{AuthContext, AuthHandler, AuthMethod, AuthResult};
16use crate::PublicKey;
17
18#[derive(Debug, Clone)]
20pub struct AuthorizedKey {
21 pub key_type: String,
23 pub key_data: String,
25 pub key_bytes: Vec<u8>,
27 pub comment: Option<String>,
29 pub options: Vec<String>,
31}
32
33impl AuthorizedKey {
34 pub fn to_public_key(&self) -> PublicKey {
36 PublicKey::new(&self.key_type, self.key_bytes.clone())
37 .with_comment(self.comment.clone().unwrap_or_default())
38 }
39
40 pub fn matches(&self, key: &PublicKey) -> bool {
42 self.key_type == key.key_type && self.key_bytes == key.data
43 }
44}
45
46pub fn parse_authorized_keys(content: &str) -> Vec<AuthorizedKey> {
69 let mut keys = Vec::new();
70
71 for line in content.lines() {
72 let line = line.trim();
73
74 if line.is_empty() || line.starts_with('#') {
76 continue;
77 }
78
79 if let Some(key) = parse_authorized_key_line(line) {
80 keys.push(key);
81 } else {
82 debug!(line = %line, "Failed to parse authorized_keys line");
83 }
84 }
85
86 keys
87}
88
89fn parse_authorized_key_line(line: &str) -> Option<AuthorizedKey> {
91 let line = line.trim();
92 if line.is_empty() || line.starts_with('#') {
93 return None;
94 }
95
96 const KEY_TYPES: &[&str] = &[
98 "ssh-ed25519",
99 "ssh-rsa",
100 "ecdsa-sha2-nistp256",
101 "ecdsa-sha2-nistp384",
102 "ecdsa-sha2-nistp521",
103 "ssh-dss",
104 "sk-ssh-ed25519@openssh.com",
105 "sk-ecdsa-sha2-nistp256@openssh.com",
106 ];
107
108 let (first, rest) = split_unquoted_whitespace(line)?;
110
111 if KEY_TYPES.contains(&first) {
113 parse_key_parts(first, rest, &[])
115 } else {
116 let options = split_options(first);
118
119 let (key_type, key_rest) = split_unquoted_whitespace(rest)?;
120
121 if !KEY_TYPES.contains(&key_type) {
122 return None;
123 }
124
125 parse_key_parts(key_type, key_rest, &options)
126 }
127}
128
129fn split_unquoted_whitespace(input: &str) -> Option<(&str, &str)> {
131 let mut in_quotes = false;
132 let mut escaped = false;
133
134 for (idx, ch) in input.char_indices() {
135 if escaped {
136 escaped = false;
137 continue;
138 }
139 if ch == '\\' {
140 escaped = true;
141 continue;
142 }
143 if ch == '"' {
144 in_quotes = !in_quotes;
145 continue;
146 }
147 if ch.is_whitespace() && !in_quotes {
148 let first = &input[..idx];
149 let rest = input[idx..].trim_start();
150 return Some((first, rest));
151 }
152 }
153
154 if input.is_empty() {
155 None
156 } else {
157 Some((input, ""))
158 }
159}
160
161fn split_options(input: &str) -> Vec<String> {
163 let mut options = Vec::new();
164 let mut current = String::new();
165 let mut in_quotes = false;
166 let mut escaped = false;
167
168 for ch in input.chars() {
169 if escaped {
170 current.push(ch);
171 escaped = false;
172 continue;
173 }
174 if ch == '\\' {
175 current.push(ch);
176 escaped = true;
177 continue;
178 }
179 if ch == '"' {
180 in_quotes = !in_quotes;
181 current.push(ch);
182 continue;
183 }
184 if ch == ',' && !in_quotes {
185 let trimmed = current.trim();
186 if !trimmed.is_empty() {
187 options.push(trimmed.to_string());
188 }
189 current.clear();
190 continue;
191 }
192 current.push(ch);
193 }
194
195 let trimmed = current.trim();
196 if !trimmed.is_empty() {
197 options.push(trimmed.to_string());
198 }
199
200 options
201}
202
203fn parse_key_parts(key_type: &str, rest: &str, options: &[String]) -> Option<AuthorizedKey> {
205 let mut parts = rest.trim().splitn(2, |c: char| c.is_whitespace());
206 let key_data = parts.next()?;
207
208 if key_data.is_empty() {
209 return None;
210 }
211
212 let key_bytes = match base64_decode(key_data) {
214 Ok(bytes) => bytes,
215 Err(e) => {
216 debug!(error = %e, "Failed to decode base64 key data");
217 return None;
218 }
219 };
220
221 let comment = parts.next().map(|s| s.trim().to_string());
222
223 Some(AuthorizedKey {
224 key_type: key_type.to_string(),
225 key_data: key_data.to_string(),
226 key_bytes,
227 comment,
228 options: options.to_vec(),
229 })
230}
231
232fn base64_decode(input: &str) -> Result<Vec<u8>, &'static str> {
234 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
235
236 fn decode_char(c: u8) -> Result<u8, &'static str> {
237 ALPHABET
238 .iter()
239 .position(|&x| x == c)
240 .map(|p| p as u8)
241 .ok_or("invalid base64 character")
242 }
243
244 let input = input.as_bytes();
245 let mut output = Vec::with_capacity(input.len() * 3 / 4);
246
247 let mut iter = input.iter().filter(|&&c| c != b'\n' && c != b'\r');
248 #[allow(clippy::while_let_loop, clippy::redundant_guards)]
251 loop {
252 let a = match iter.next() {
253 Some(&c) => decode_char(c)?,
254 None => break,
255 };
256 let b = match iter.next() {
257 Some(&c) => decode_char(c)?,
258 None => return Err("invalid base64 length"),
259 };
260 let c = match iter.next() {
261 Some(&c) if c == b'=' => {
262 output.push((a << 2) | (b >> 4));
263 match iter.next() {
264 Some(&d) if d == b'=' => {}
265 _ => return Err("invalid base64 padding"),
266 }
267 if iter.next().is_some() {
268 return Err("invalid base64 padding");
269 }
270 break;
271 }
272 Some(&c) => decode_char(c)?,
273 None => return Err("invalid base64 length"),
274 };
275 let d = match iter.next() {
276 Some(&ch) if ch == b'=' => {
277 output.push((a << 2) | (b >> 4));
278 output.push((b << 4) | (c >> 2));
279 if iter.next().is_some() {
280 return Err("invalid base64 padding");
281 }
282 break;
283 }
284 Some(&ch) => decode_char(ch)?,
285 None => return Err("invalid base64 length"),
286 };
287
288 output.push((a << 2) | (b >> 4));
289 output.push((b << 4) | (c >> 2));
290 output.push((c << 6) | d);
291 }
292
293 Ok(output)
294}
295
296pub struct AuthorizedKeysAuth {
313 keys_path: PathBuf,
315 per_user: bool,
318 cache: Arc<RwLock<HashMap<String, Vec<AuthorizedKey>>>>,
320}
321
322impl AuthorizedKeysAuth {
323 pub fn new(keys_path: impl AsRef<Path>) -> io::Result<Self> {
333 let path = expand_tilde(keys_path.as_ref());
334
335 let auth = Self {
336 keys_path: path.clone(),
337 per_user: false,
338 cache: Arc::new(RwLock::new(HashMap::new())),
339 };
340
341 auth.load_keys_sync()?;
343
344 Ok(auth)
345 }
346
347 pub fn per_user(keys_path: impl AsRef<Path>) -> Self {
352 let path = keys_path.as_ref().to_path_buf();
353
354 Self {
355 keys_path: path,
356 per_user: true,
357 cache: Arc::new(RwLock::new(HashMap::new())),
358 }
359 }
360
361 pub async fn reload(&self) -> io::Result<()> {
363 self.load_keys_sync()
364 }
365
366 fn load_keys_sync(&self) -> io::Result<()> {
368 if self.per_user {
369 self.cache.write().clear();
371 return Ok(());
372 }
373
374 let content = std::fs::read_to_string(&self.keys_path)?;
376 let keys = parse_authorized_keys(&content);
377
378 info!(
379 path = %self.keys_path.display(),
380 count = keys.len(),
381 "Loaded authorized keys"
382 );
383
384 let mut cache = self.cache.write();
385 cache.clear();
386 cache.insert(String::new(), keys); Ok(())
389 }
390
391 fn get_keys_for_user(&self, username: &str) -> Vec<AuthorizedKey> {
393 if self.per_user {
394 if username.contains('/')
397 || username.contains('\\')
398 || username.contains("..")
399 || username.contains('\0')
400 {
401 warn!(
402 username = %username,
403 "Rejected username with path traversal characters"
404 );
405 return Vec::new();
406 }
407
408 if let Some(keys) = self.cache.read().get(username) {
410 return keys.clone();
411 }
412
413 let path = self.keys_path.to_string_lossy().replace("%u", username);
415 let path = expand_tilde(Path::new(&path));
416
417 match std::fs::read_to_string(&path) {
418 Ok(content) => {
419 let keys = parse_authorized_keys(&content);
420 debug!(
421 username = %username,
422 path = %path.display(),
423 count = keys.len(),
424 "Loaded user authorized keys"
425 );
426 self.cache
427 .write()
428 .insert(username.to_string(), keys.clone());
429 keys
430 }
431 Err(e) => {
432 debug!(
433 username = %username,
434 path = %path.display(),
435 error = %e,
436 "Failed to load user authorized keys"
437 );
438 Vec::new()
439 }
440 }
441 } else {
442 self.cache.read().get("").cloned().unwrap_or_default()
444 }
445 }
446
447 pub fn cached_key_count(&self) -> usize {
449 self.cache.read().values().map(|v| v.len()).sum()
450 }
451
452 pub fn keys_path(&self) -> &Path {
454 &self.keys_path
455 }
456}
457
458#[async_trait]
459impl AuthHandler for AuthorizedKeysAuth {
460 async fn auth_publickey(&self, ctx: &AuthContext, key: &PublicKey) -> AuthResult {
461 debug!(
462 username = %ctx.username(),
463 remote_addr = %ctx.remote_addr(),
464 key_type = %key.key_type,
465 "AuthorizedKeysAuth: auth attempt"
466 );
467
468 let authorized_keys = self.get_keys_for_user(ctx.username());
469
470 for ak in &authorized_keys {
471 if ak.matches(key) {
472 info!(
473 username = %ctx.username(),
474 comment = ak.comment.as_deref().unwrap_or("<none>"),
475 "AuthorizedKeysAuth: accepted"
476 );
477 return AuthResult::Accept;
478 }
479 }
480
481 debug!(
482 username = %ctx.username(),
483 key_count = authorized_keys.len(),
484 "AuthorizedKeysAuth: no matching key"
485 );
486 AuthResult::Reject
487 }
488
489 fn supported_methods(&self) -> Vec<AuthMethod> {
490 vec![AuthMethod::PublicKey]
491 }
492}
493
494fn expand_tilde(path: &Path) -> PathBuf {
496 let home = std::env::var_os("HOME").map(PathBuf::from);
497 expand_tilde_with_home(path, home.as_deref())
498}
499
500fn expand_tilde_with_home(path: &Path, home: Option<&Path>) -> PathBuf {
502 let path_str = path.to_string_lossy();
503 if let Some(stripped) = path_str.strip_prefix("~/")
504 && let Some(home_dir) = home
505 {
506 return home_dir.join(stripped);
507 }
508 path.to_path_buf()
509}
510
511#[cfg(test)]
512mod tests {
513 use super::super::SessionId;
514 use super::*;
515 use std::collections::HashMap;
516 use std::net::SocketAddr;
517 use std::sync::Arc;
518
519 fn make_context(username: &str) -> AuthContext {
520 let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
521 AuthContext::new(username, addr, SessionId(1))
522 }
523
524 #[test]
525 fn test_parse_simple_key() {
526 let line = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIG1cILnhxkg+kMsGsVJP7hQnfKSPPIP/8GSXTE2n/8SE user@example.com";
527 let keys = parse_authorized_keys(line);
528
529 assert_eq!(keys.len(), 1);
530 let key = &keys[0];
531 assert_eq!(key.key_type, "ssh-ed25519");
532 assert_eq!(key.comment, Some("user@example.com".to_string()));
533 assert!(key.options.is_empty());
534 }
535
536 #[test]
537 fn test_parse_key_with_options() {
538 let line = "no-pty,command=\"/bin/git-shell\" ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIG1cILnhxkg+kMsGsVJP7hQnfKSPPIP/8GSXTE2n/8SE git@server";
539 let keys = parse_authorized_keys(line);
540
541 assert_eq!(keys.len(), 1);
542 let key = &keys[0];
543 assert_eq!(key.key_type, "ssh-ed25519");
544 assert!(key.options.contains(&"no-pty".to_string()));
545 assert!(key.options.iter().any(|o| o.starts_with("command=")));
546 }
547
548 #[test]
549 fn test_parse_key_with_quoted_option_spaces() {
550 let line = "command=\"echo hello world\",no-pty ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIG1cILnhxkg+kMsGsVJP7hQnfKSPPIP/8GSXTE2n/8SE user@example.com";
551 let keys = parse_authorized_keys(line);
552
553 assert_eq!(keys.len(), 1);
554 let key = &keys[0];
555 assert_eq!(key.key_type, "ssh-ed25519");
556 assert!(key.options.contains(&"no-pty".to_string()));
557 assert!(
558 key.options
559 .iter()
560 .any(|o| o == "command=\"echo hello world\"")
561 );
562 }
563
564 #[test]
565 fn test_parse_key_with_quoted_option_commas() {
566 let line = "command=\"echo a,b\",no-pty ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIG1cILnhxkg+kMsGsVJP7hQnfKSPPIP/8GSXTE2n/8SE user@example.com";
567 let keys = parse_authorized_keys(line);
568
569 assert_eq!(keys.len(), 1);
570 let key = &keys[0];
571 assert_eq!(key.key_type, "ssh-ed25519");
572 assert!(key.options.contains(&"no-pty".to_string()));
573 assert!(key.options.iter().any(|o| o == "command=\"echo a,b\""));
574 }
575
576 #[test]
577 fn test_parse_multiple_keys() {
578 let content = r#"
579# Comment line
580ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIG1cILnhxkg+kMsGsVJP7hQnfKSPPIP/8GSXTE2n/8SE user1@example.com
581ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAgQC1 user2@example.com
582
583# Another comment
584ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHUFrQ== user3@example.com
585 "#;
586
587 let keys = parse_authorized_keys(content);
588 assert_eq!(keys.len(), 3);
589 }
590
591 #[test]
592 fn test_parse_empty_and_comments() {
593 let content = r#"
594# Only comments
595
596# and empty lines
597
598 "#;
599
600 let keys = parse_authorized_keys(content);
601 assert!(keys.is_empty());
602 }
603
604 #[test]
605 fn test_base64_decode() {
606 let decoded = base64_decode("SGVsbG8=").unwrap();
608 assert_eq!(decoded, b"Hello");
609
610 let decoded = base64_decode("SGk=").unwrap();
612 assert_eq!(decoded, b"Hi");
613
614 let decoded = base64_decode("QQ==").unwrap();
616 assert_eq!(decoded, b"A");
617 }
618
619 #[test]
620 fn test_base64_decode_rejects_malformed_padding() {
621 assert!(base64_decode("=Q==").is_err());
622 assert!(base64_decode("QQ=").is_err());
623 assert!(base64_decode("QQ=A").is_err());
624 assert!(base64_decode("SGk=A").is_err());
625 assert!(base64_decode("SGk=Zg==").is_err());
626 }
627
628 #[test]
629 fn test_authorized_key_matches() {
630 let ak = AuthorizedKey {
631 key_type: "ssh-ed25519".to_string(),
632 key_data: "AAAA".to_string(),
633 key_bytes: vec![1, 2, 3, 4],
634 comment: None,
635 options: vec![],
636 };
637
638 let matching_key = PublicKey::new("ssh-ed25519", vec![1, 2, 3, 4]);
639 assert!(ak.matches(&matching_key));
640
641 let wrong_type = PublicKey::new("ssh-rsa", vec![1, 2, 3, 4]);
642 assert!(!ak.matches(&wrong_type));
643
644 let wrong_data = PublicKey::new("ssh-ed25519", vec![5, 6, 7, 8]);
645 assert!(!ak.matches(&wrong_data));
646 }
647
648 #[test]
649 fn test_expand_tilde() {
650 let home = Path::new("/home/testuser");
651
652 let expanded = expand_tilde_with_home(Path::new("~/.ssh/authorized_keys"), Some(home));
654 assert_eq!(
655 expanded,
656 PathBuf::from("/home/testuser/.ssh/authorized_keys")
657 );
658
659 let expanded = expand_tilde_with_home(Path::new("/etc/ssh/keys"), Some(home));
661 assert_eq!(expanded, PathBuf::from("/etc/ssh/keys"));
662
663 let expanded = expand_tilde_with_home(Path::new("~/.ssh/authorized_keys"), None);
665 assert_eq!(expanded, PathBuf::from("~/.ssh/authorized_keys"));
666 }
667
668 #[test]
669 fn test_authorized_key_to_public_key() {
670 let ak = AuthorizedKey {
671 key_type: "ssh-ed25519".to_string(),
672 key_data: "AAAA".to_string(),
673 key_bytes: vec![1, 2, 3],
674 comment: Some("user@example.com".to_string()),
675 options: vec!["no-pty".to_string()],
676 };
677
678 let pk = ak.to_public_key();
679 assert_eq!(pk.key_type, "ssh-ed25519");
680 assert_eq!(pk.data, vec![1, 2, 3]);
681 assert_eq!(pk.comment, Some("user@example.com".to_string()));
682 }
683
684 #[tokio::test]
685 async fn test_authorized_keys_auth_uses_cached_keys() {
686 let ak = AuthorizedKey {
687 key_type: "ssh-ed25519".to_string(),
688 key_data: "AAAA".to_string(),
689 key_bytes: vec![1, 2, 3],
690 comment: None,
691 options: vec![],
692 };
693
694 let cache = HashMap::from([(String::new(), vec![ak.clone()])]);
695 let auth = AuthorizedKeysAuth {
696 keys_path: PathBuf::from("/ignored"),
697 per_user: false,
698 cache: Arc::new(RwLock::new(cache)),
699 };
700
701 let ctx = make_context("alice");
702 let key = PublicKey::new("ssh-ed25519", vec![1, 2, 3]);
703 assert!(matches!(
704 auth.auth_publickey(&ctx, &key).await,
705 AuthResult::Accept
706 ));
707
708 let wrong_key = PublicKey::new("ssh-ed25519", vec![9, 9, 9]);
709 assert!(matches!(
710 auth.auth_publickey(&ctx, &wrong_key).await,
711 AuthResult::Reject
712 ));
713 assert_eq!(auth.cached_key_count(), 1);
714 }
715
716 #[tokio::test]
717 async fn test_authorized_keys_auth_per_user_cache() {
718 let ak = AuthorizedKey {
719 key_type: "ssh-ed25519".to_string(),
720 key_data: "AAAA".to_string(),
721 key_bytes: vec![4, 5, 6],
722 comment: None,
723 options: vec![],
724 };
725
726 let cache = HashMap::from([("alice".to_string(), vec![ak.clone()])]);
727 let auth = AuthorizedKeysAuth {
728 keys_path: PathBuf::from("/ignored/%u"),
729 per_user: true,
730 cache: Arc::new(RwLock::new(cache)),
731 };
732
733 let ctx = make_context("alice");
734 let key = PublicKey::new("ssh-ed25519", vec![4, 5, 6]);
735 assert!(matches!(
736 auth.auth_publickey(&ctx, &key).await,
737 AuthResult::Accept
738 ));
739
740 let ctx = make_context("bob");
741 assert!(matches!(
742 auth.auth_publickey(&ctx, &key).await,
743 AuthResult::Reject
744 ));
745 assert_eq!(auth.cached_key_count(), 1);
746 }
747}