1use crate::crypto::{CHECKSUM_LEN, KDFALG, KEYNUMLEN, PKALG, SALT_LEN};
16use crate::error::{Error, Result};
17use base64ct::{Base64, Encoding as _};
18use core::str;
19use memchr::memchr;
20use static_assertions::assert_eq_size;
21use std::fs::File;
22use std::io::stdin;
23use std::io::stdout;
24use std::io::{Read, Write};
25use std::path::Component;
26use std::path::Path;
27use zeroize::Zeroize;
28use zeroize::Zeroizing;
29
30pub const COMMENTHDR: &str = "untrusted comment: ";
32pub const MAX_COMMENT_LEN: usize = 1024;
34
35#[repr(C)]
39#[derive(Debug)]
40pub struct EncKey {
41 pub pkalg: [u8; 2],
43 pub kdfalg: [u8; 2],
45 pub kdfrounds: u32,
47 pub salt: [u8; SALT_LEN],
49 pub checksum: [u8; CHECKSUM_LEN],
51 pub keynum: [u8; KEYNUMLEN],
53 pub seckey: Zeroizing<[u8; 64]>,
55}
56
57impl Zeroize for EncKey {
58 fn zeroize(&mut self) {
59 self.pkalg.zeroize();
60 self.kdfalg.zeroize();
61 self.kdfrounds.zeroize();
62 self.salt.zeroize();
63 self.checksum.zeroize();
64 self.keynum.zeroize();
65 self.seckey.zeroize();
66 }
67}
68
69impl Drop for EncKey {
70 fn drop(&mut self) {
71 self.zeroize();
72 }
73}
74
75#[repr(C)]
77#[derive(Debug, Clone, Copy)]
78pub struct PubKey {
79 pub pkalg: [u8; 2],
81 pub keynum: [u8; KEYNUMLEN],
83 pub pubkey: [u8; 32],
85}
86
87#[derive(Debug, Clone, Copy)]
89pub struct Sig {
90 pub pkalg: [u8; 2],
92 pub keynum: [u8; KEYNUMLEN],
94 pub sig: [u8; 64],
96}
97
98impl EncKey {
99 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
107 let pkalg = bytes
108 .get(0..2)
109 .ok_or(Error::InvalidKeyLength)?
110 .try_into()
111 .map_err(|_e| Error::InvalidKeyLength)?;
112 let kdfalg = bytes
113 .get(2..4)
114 .ok_or(Error::InvalidKeyLength)?
115 .try_into()
116 .map_err(|_e| Error::InvalidKeyLength)?;
117 let kdfrounds_bytes = bytes.get(4..8).ok_or(Error::InvalidKeyLength)?;
118 let kdfrounds = u32::from_be_bytes(
119 kdfrounds_bytes
120 .try_into()
121 .map_err(|_e| Error::InvalidKeyLength)?,
122 );
123 let salt = bytes
124 .get(8..24)
125 .ok_or(Error::InvalidKeyLength)?
126 .try_into()
127 .map_err(|_e| Error::InvalidKeyLength)?;
128 let checksum = bytes
129 .get(24..32)
130 .ok_or(Error::InvalidKeyLength)?
131 .try_into()
132 .map_err(|_e| Error::InvalidKeyLength)?;
133 let keynum = bytes
134 .get(32..40)
135 .ok_or(Error::InvalidKeyLength)?
136 .try_into()
137 .map_err(|_e| Error::InvalidKeyLength)?;
138 let seckey = Zeroizing::new(
139 bytes
140 .get(40..104)
141 .ok_or(Error::InvalidKeyLength)?
142 .try_into()
143 .map_err(|_e| Error::InvalidKeyLength)?,
144 );
145
146 if bytes.len() != 104 {
148 return Err(Error::InvalidKeyLength);
149 }
150
151 if pkalg != PKALG {
152 return Err(Error::UnsupportedPkAlgo);
153 }
154 if kdfalg != KDFALG {
155 return Err(Error::UnsupportedKdfAlgo);
156 }
157
158 Ok(Self {
159 pkalg,
160 kdfalg,
161 kdfrounds,
162 salt,
163 checksum,
164 keynum,
165 seckey,
166 })
167 }
168
169 #[must_use]
171 pub fn to_bytes(&self) -> Zeroizing<Vec<u8>> {
172 let mut out = Zeroizing::new(Vec::with_capacity(104));
173 out.extend_from_slice(&self.pkalg);
174 out.extend_from_slice(&self.kdfalg);
175 out.extend_from_slice(&self.kdfrounds.to_be_bytes());
176 out.extend_from_slice(&self.salt);
177 out.extend_from_slice(&self.checksum);
178 out.extend_from_slice(&self.keynum);
179 out.extend_from_slice(self.seckey.as_ref());
180 out
181 }
182}
183
184impl PubKey {
185 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
192 let pkalg = bytes
193 .get(0..2)
194 .ok_or(Error::InvalidKeyLength)?
195 .try_into()
196 .map_err(|_e| Error::InvalidKeyLength)?;
197 let keynum = bytes
198 .get(2..10)
199 .ok_or(Error::InvalidKeyLength)?
200 .try_into()
201 .map_err(|_e| Error::InvalidKeyLength)?;
202 let pubkey = bytes
203 .get(10..42)
204 .ok_or(Error::InvalidKeyLength)?
205 .try_into()
206 .map_err(|_e| Error::InvalidKeyLength)?;
207
208 if bytes.len() != 42 {
209 return Err(Error::InvalidKeyLength);
210 }
211
212 if pkalg != PKALG {
213 return Err(Error::UnsupportedPkAlgo);
214 }
215
216 Ok(Self {
217 pkalg,
218 keynum,
219 pubkey,
220 })
221 }
222
223 #[must_use]
225 pub fn to_bytes(&self) -> Vec<u8> {
226 let mut out = Vec::with_capacity(42);
227 out.extend_from_slice(&self.pkalg);
228 out.extend_from_slice(&self.keynum);
229 out.extend_from_slice(&self.pubkey);
230 out
231 }
232}
233
234impl Sig {
235 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
242 let pkalg = bytes
243 .get(0..2)
244 .ok_or(Error::InvalidKeyLength)?
245 .try_into()
246 .map_err(|_e| Error::InvalidKeyLength)?;
247 let keynum = bytes
248 .get(2..10)
249 .ok_or(Error::InvalidKeyLength)?
250 .try_into()
251 .map_err(|_e| Error::InvalidKeyLength)?;
252 let sig = bytes
253 .get(10..74)
254 .ok_or(Error::InvalidKeyLength)?
255 .try_into()
256 .map_err(|_e| Error::InvalidKeyLength)?;
257
258 if bytes.len() != 74 {
259 return Err(Error::InvalidKeyLength);
260 }
261
262 if pkalg != PKALG {
263 return Err(Error::UnsupportedPkAlgo);
264 }
265
266 Ok(Self { pkalg, keynum, sig })
267 }
268
269 #[must_use]
271 pub fn to_bytes(&self) -> Vec<u8> {
272 let mut out = Vec::with_capacity(74);
273 out.extend_from_slice(&self.pkalg);
274 out.extend_from_slice(&self.keynum);
275 out.extend_from_slice(&self.sig);
276 out
277 }
278}
279
280pub fn parse_stream<F, R, T>(mut reader: R, parse_fn: F) -> Result<(T, Vec<u8>)>
293where
294 R: Read,
295 F: Fn(&[u8]) -> Result<T>,
296{
297 const HEADER_LIMIT: usize = 4096;
302 let mut header_buf = vec![0_u8; HEADER_LIMIT];
303
304 let mut total_read = 0;
306 while total_read < HEADER_LIMIT {
307 let n = reader
308 .read(&mut header_buf[total_read..])
309 .map_err(Error::Io)?;
310 if n == 0 {
311 break;
312 }
313 total_read = total_read.checked_add(n).ok_or(Error::Overflow)?;
314 }
315 header_buf.truncate(total_read);
316
317 let n1 = memchr(b'\n', &header_buf).ok_or(Error::InvalidCommentHeader)?;
319 let header_bytes = &header_buf[..n1];
320
321 let prefix = COMMENTHDR.as_bytes();
323 if !header_bytes.starts_with(prefix) {
324 return Err(Error::InvalidCommentHeader);
325 }
326 let comment = header_bytes[prefix.len()..].to_vec();
327
328 let n2_start = n1.checked_add(1).ok_or(Error::Overflow)?;
329 let n2 = memchr(b'\n', &header_buf[n2_start..])
330 .unwrap_or_else(|| header_buf.len().saturating_sub(n2_start));
331
332 let b64_start = n2_start;
333 let b64_end = b64_start.checked_add(n2).ok_or(Error::Overflow)?;
334
335 if b64_end > header_buf.len() {
336 return Err(Error::InvalidCommentHeader);
337 }
338
339 let b64_bytes = &header_buf[b64_start..b64_end];
340
341 let b64_str = str::from_utf8(b64_bytes).map_err(|_e| Error::InvalidSignatureUtf8)?;
343 let decoded = Base64::decode_vec(b64_str.trim()).map_err(Error::Base64Decode)?;
344
345 let obj = parse_fn(&decoded)?;
346 Ok((obj, comment))
347}
348
349pub fn parse<T, F>(path: &Path, parse_fn: F) -> Result<(T, Vec<u8>)>
362where
363 F: Fn(&[u8]) -> Result<T>,
364{
365 let reader: Box<dyn Read> = if path.to_str() == Some("-") {
366 Box::new(stdin())
367 } else {
368 Box::new(open(path, false)?)
369 };
370
371 parse_stream(reader, parse_fn)
372}
373
374pub fn write_stream(mut writer: impl Write, comment: &[u8], data: &[u8]) -> Result<()> {
382 let encoded = Base64::encode_string(data);
383
384 let mut content = Vec::new();
385 content.extend_from_slice(COMMENTHDR.as_bytes());
386 content.extend_from_slice(comment);
387 content.push(b'\n');
388 content.extend_from_slice(encoded.as_bytes());
389 content.push(b'\n');
390
391 writer.write_all(&content).map_err(Error::Io)?;
392 Ok(())
393}
394
395pub fn write(path: &Path, comment: &[u8], data: &[u8]) -> Result<()> {
403 let writer: Box<dyn Write> = if path.to_str() == Some("-") {
404 Box::new(stdout())
405 } else {
406 Box::new(open(path, true)?)
407 };
408
409 write_stream(writer, comment, data)
410}
411
412pub fn open(path: &Path, write: bool) -> Result<File> {
416 if path.components().any(|p| p == Component::ParentDir) {
417 return Err(Error::InvalidPath);
418 }
419 #[cfg(target_os = "linux")]
420 {
421 safe_open(path, write)
422 }
423 #[cfg(not(target_os = "linux"))]
424 {
425 use std::fs::OpenOptions;
426
427 let mut opts = OpenOptions::new();
428 if write {
429 opts.write(true).create_new(true);
430 #[cfg(unix)]
431 {
432 use std::os::unix::fs::OpenOptionsExt;
433 opts.mode(0o600);
434 }
435 } else {
436 opts.read(true);
437 #[cfg(unix)]
438 {
439 use nix::fcntl::OFlag;
440 use std::os::unix::fs::OpenOptionsExt;
441 opts.custom_flags(OFlag::O_NOFOLLOW.bits());
442 }
443 }
444
445 opts.open(path).map_err(Error::Io)
446 }
447}
448
449#[cfg(target_os = "linux")]
450fn safe_open(path: &Path, write: bool) -> Result<File> {
451 use nix::fcntl::{openat2, OFlag, OpenHow, ResolveFlag, AT_FDCWD};
452 use nix::sys::stat::Mode;
453 use std::os::fd::AsRawFd;
454
455 let mut how = OpenHow::new()
456 .resolve(ResolveFlag::RESOLVE_NO_SYMLINKS | ResolveFlag::RESOLVE_NO_MAGICLINKS);
457
458 if write {
459 how = how.flags(OFlag::O_WRONLY | OFlag::O_CREAT | OFlag::O_EXCL);
461 how = how.mode(Mode::from_bits_truncate(0o600));
462 return Ok(openat2(AT_FDCWD, path, how).map(File::from)?);
463 }
464
465 how = how.flags(OFlag::O_PATH | OFlag::O_NOFOLLOW);
467 let file = openat2(AT_FDCWD, path, how).map(File::from)?;
468 if !file.metadata()?.is_file() {
469 return Err(Error::InvalidPath);
470 }
471
472 how = how.flags(OFlag::O_RDONLY).resolve(ResolveFlag::empty());
474 let path = format!("/proc/thread-self/fd/{}", file.as_raw_fd());
475 let file = openat2(AT_FDCWD, path.as_str(), how).map(File::from)?;
476
477 Ok(file)
478}
479
480assert_eq_size!(EncKey, [u8; 104]);
481assert_eq_size!(PubKey, [u8; 42]);
482assert_eq_size!(Sig, [u8; 74]);
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use std::fs::OpenOptions;
488
489 #[test]
490 fn test_enckey_serialization() -> crate::error::Result<()> {
491 let enc = EncKey {
492 pkalg: PKALG,
493 kdfalg: KDFALG,
494 kdfrounds: 42,
495 salt: [1u8; SALT_LEN],
496 checksum: [2u8; CHECKSUM_LEN],
497 keynum: [3u8; KEYNUMLEN],
498 seckey: Zeroizing::new([4u8; 64]),
499 };
500 let bytes = enc.to_bytes();
501 let enc2 = EncKey::from_bytes(&bytes)?;
502 assert_eq!(enc.kdfrounds, enc2.kdfrounds);
503 assert_eq!(enc.salt, enc2.salt);
504
505 assert!(matches!(
507 EncKey::from_bytes(&bytes[..103]),
508 Err(Error::InvalidKeyLength)
509 ));
510
511 let mut long = bytes.clone();
512 long.push(0);
513 assert!(matches!(
514 EncKey::from_bytes(&long),
515 Err(Error::InvalidKeyLength)
516 ));
517
518 let mut bad_alg = bytes.clone();
520 bad_alg[0] = b'X';
521 assert!(matches!(
522 EncKey::from_bytes(&bad_alg),
523 Err(Error::UnsupportedPkAlgo)
524 ));
525
526 let mut bad_kdf = bytes.clone();
528 bad_kdf[2] = b'X';
529 assert!(matches!(
530 EncKey::from_bytes(&bad_kdf),
531 Err(Error::UnsupportedKdfAlgo)
532 ));
533
534 Ok(())
535 }
536
537 #[test]
538 fn test_pubkey_serialization() -> crate::error::Result<()> {
539 let pubk = PubKey {
540 pkalg: PKALG,
541 keynum: [1u8; KEYNUMLEN],
542 pubkey: [2u8; 32],
543 };
544 let bytes = pubk.to_bytes();
545 let pubk2 = PubKey::from_bytes(&bytes)?;
546 assert_eq!(pubk.keynum, pubk2.keynum);
547
548 assert!(matches!(
550 PubKey::from_bytes(&bytes[..41]),
551 Err(Error::InvalidKeyLength)
552 ));
553 let mut long = bytes.clone();
554 long.push(0);
555 assert!(matches!(
556 PubKey::from_bytes(&long),
557 Err(Error::InvalidKeyLength)
558 ));
559
560 let mut bad_alg = bytes.clone();
562 bad_alg[0] = b'X';
563 assert!(matches!(
564 PubKey::from_bytes(&bad_alg),
565 Err(Error::UnsupportedPkAlgo)
566 ));
567
568 Ok(())
569 }
570
571 #[test]
572 fn test_sig_serialization() -> crate::error::Result<()> {
573 let sig = Sig {
574 pkalg: PKALG,
575 keynum: [1u8; KEYNUMLEN],
576 sig: [0u8; 64],
577 };
578 let bytes = sig.to_bytes();
579 let sig2 = Sig::from_bytes(&bytes)?;
580 assert_eq!(sig.keynum, sig2.keynum);
581
582 assert!(matches!(
584 Sig::from_bytes(&bytes[..73]),
585 Err(Error::InvalidKeyLength)
586 ));
587
588 let mut long = bytes.clone();
589 long.push(0);
590 assert!(matches!(
591 Sig::from_bytes(&long),
592 Err(Error::InvalidKeyLength)
593 ));
594
595 let mut bad_alg = bytes.clone();
597 bad_alg[0] = b'X';
598 assert!(matches!(
599 Sig::from_bytes(&bad_alg),
600 Err(Error::UnsupportedPkAlgo)
601 ));
602
603 Ok(())
604 }
605
606 #[test]
607 #[cfg_attr(any(target_arch = "wasm32", target_arch = "wasm64"), ignore)]
608 fn test_file_io() -> std::result::Result<(), Box<dyn std::error::Error>> {
609 let dir = tempfile::tempdir()?;
610 let path = dir.path().join("secret.key");
611 let data = b"secret data";
612
613 write(&path, b"mycomment", data)?;
614
615 let (read_data, comment) = parse::<Vec<u8>, _>(&path, |b| Ok(b.to_vec()))?;
616 assert_eq!(read_data, data);
617 assert_eq!(comment, b"mycomment");
618
619 let missing = dir.path().join("missing");
621 let result = parse::<Vec<u8>, _>(&missing, |_| Ok(vec![]));
622 #[cfg(not(target_os = "linux"))]
623 assert!(matches!(result, Err(Error::Io(_))));
624 #[cfg(target_os = "linux")]
625 assert!(matches!(result, Err(Error::Nix(_))));
626
627 let bad_prefix = dir.path().join("bad_prefix");
629 let mut f = OpenOptions::new()
630 .write(true)
631 .create_new(true)
632 .open(&bad_prefix)?;
633 f.write_all(b"invalid header\n")?;
634 assert!(matches!(
635 parse::<Vec<u8>, _>(&bad_prefix, |_| Ok(vec![])),
636 Err(Error::InvalidCommentHeader)
637 ));
638
639 let no_newline = dir.path().join("no_newline");
641 let mut f = OpenOptions::new()
642 .write(true)
643 .create_new(true)
644 .open(&no_newline)?;
645 f.write_all(b"untrusted comment: foo")?;
646 assert!(matches!(
647 parse::<Vec<u8>, _>(&no_newline, |_| Ok(vec![])),
648 Err(Error::InvalidCommentHeader)
649 ));
650
651 let bad_utf8 = dir.path().join("bad_utf8");
655 write(&bad_utf8, b"comment", b"")?;
656 let mut f = OpenOptions::new().write(true).open(&bad_utf8)?;
657 f.write_all(b"untrusted comment: comment\n\xFF\xFF\n")?;
658 assert!(matches!(
659 parse::<Vec<u8>, _>(&bad_utf8, |_| Ok(vec![])),
660 Err(Error::InvalidSignatureUtf8)
661 ));
662
663 Ok(())
664 }
665
666 #[test]
667 #[cfg(unix)]
668 fn test_open_symlink_fail() -> std::result::Result<(), Box<dyn std::error::Error>> {
669 use std::os::unix::fs::symlink;
670 let dir = tempfile::tempdir()?;
671 let target = dir.path().join("target");
672 let link = dir.path().join("link");
673
674 std::fs::write(&target, b"target")?;
675 symlink(&target, &link)?;
676
677 assert!(open(&link, false).is_err());
679
680 Ok(())
681 }
682
683 #[test]
684 #[cfg(unix)]
685 fn test_open_write_mode() -> std::result::Result<(), Box<dyn std::error::Error>> {
686 use std::os::unix::fs::PermissionsExt;
687
688 let dir = tempfile::tempdir()?;
689 let path = dir.path().join("secret.key");
690
691 let _f = open(&path, true)?;
692
693 let metadata = std::fs::metadata(&path)?;
694 let mode = metadata.permissions().mode();
695
696 assert_eq!(mode & 0o777, 0o600);
698
699 Ok(())
700 }
701
702 #[test]
703 fn test_open_parent_dir_fail() {
704 let path = Path::new("foo/../bar");
705 assert!(matches!(open(path, false), Err(Error::InvalidPath)));
706 assert!(matches!(open(path, true), Err(Error::InvalidPath)));
707 }
708
709 #[test]
710 #[cfg(target_os = "linux")]
711 fn test_safe_open_not_file_fail() -> std::result::Result<(), Box<dyn std::error::Error>> {
712 let dir = tempfile::tempdir()?;
713 let path = dir.path();
714
715 assert!(matches!(open(path, false), Err(Error::InvalidPath)));
717
718 Ok(())
719 }
720
721 #[test]
722 #[cfg(target_os = "linux")]
723 fn test_open_magiclink_fail() {
724 let path = Path::new("/proc/self/root");
725 assert!(matches!(
726 open(path, false),
727 Err(Error::Nix(nix::errno::Errno::ELOOP))
728 ));
729 }
730
731 #[test]
732 #[cfg(target_os = "linux")]
733 fn test_open_char_device_fail() {
734 let path = Path::new("/dev/null");
735 assert!(matches!(open(path, false), Err(Error::InvalidPath)));
736 }
737
738 #[test]
739 #[cfg(target_os = "linux")]
740 fn test_open_fifo_fail() -> std::result::Result<(), Box<dyn std::error::Error>> {
741 use nix::sys::stat::Mode;
742 use nix::unistd::mkfifo;
743
744 let dir = tempfile::tempdir()?;
745 let path = dir.path().join("test.fifo");
746
747 mkfifo(&path, Mode::S_IRUSR | Mode::S_IWUSR)?;
748
749 assert!(matches!(open(&path, false), Err(Error::InvalidPath)));
751
752 Ok(())
753 }
754
755 #[test]
756 #[cfg(target_os = "linux")]
757 fn test_open_socket_fail() -> std::result::Result<(), Box<dyn std::error::Error>> {
758 use std::os::unix::net::UnixListener;
759
760 let dir = tempfile::tempdir()?;
761 let path = dir.path().join("test.sock");
762
763 let _listener = UnixListener::bind(&path)?;
764
765 assert!(matches!(open(&path, false), Err(Error::InvalidPath)));
767
768 Ok(())
769 }
770}