1mod tty;
15
16#[cfg(target_family = "windows")]
17mod win32;
18
19pub use crate::tty::Stream;
20use std::error::Error;
21use std::io::Read;
22
23#[cfg(target_family = "windows")]
24pub use crate::windows::prompt_password_stdin;
25
26#[cfg(target_family = "windows")]
27pub use crate::windows::prompt_password_tty;
28
29#[cfg(target_family = "windows")]
30pub use crate::tty::isatty;
31
32#[cfg(target_family = "unix")]
33pub use crate::unix::prompt_password_stdin;
34
35#[cfg(target_family = "unix")]
36pub use crate::unix::prompt_password_tty;
37
38#[cfg(target_family = "unix")]
39pub use crate::tty::isatty;
40
41#[derive(Debug)]
48pub enum PromptError {
49 EnableFailed(std::io::Error),
50 IOError(std::io::Error),
51 InvalidArgument,
52}
53
54impl std::fmt::Display for PromptError {
55 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
56 match self {
57 PromptError::EnableFailed(e) => write!(f, "Could not re-enable echo: {}", e),
58 PromptError::IOError(e) => e.fmt(f),
59 PromptError::InvalidArgument => write!(f, "Invalid arugment Stdin"),
60 }
61 }
62}
63
64impl From<std::io::Error> for PromptError {
65 fn from(e: std::io::Error) -> PromptError {
66 PromptError::IOError(e)
67 }
68}
69
70impl Error for PromptError {
71 fn source(&self) -> Option<&(dyn Error + 'static)> {
72 match self {
73 PromptError::EnableFailed(e) => Some(e),
74 PromptError::IOError(e) => Some(e),
75 PromptError::InvalidArgument => None,
76 }
77 }
78}
79
80fn print_stream(prompt: &str, stream: Stream) -> Result<(), PromptError> {
82 use std::io::Write;
83
84 if stream == Stream::Stdout {
85 print!("{}", prompt);
86 std::io::stdout().flush()?;
87 } else {
88 eprint!("{}", prompt);
89 std::io::stderr().flush()?;
90 }
91
92 Ok(())
93}
94
95#[allow(dead_code)]
97fn strip_newline(input: &str) -> &str {
98 input
99 .strip_suffix("\r\n")
100 .or(input.strip_suffix('\n'))
101 .unwrap_or(input)
102}
103
104#[allow(dead_code)]
107fn find_crlf(input: &[u16]) -> Option<usize> {
108 let cr: u16 = 0x000d;
109 let lf: u16 = 0x000a;
110 let mut prev: Option<u16> = None;
111 for (i, c) in input.iter().enumerate() {
112 if *c == lf {
113 if prev.is_some_and(|p| p == cr) {
114 return Some(i - 1);
115 } else {
116 return Some(i);
117 }
118 }
119
120 prev = Some(*c)
121 }
122
123 None
124}
125
126#[allow(dead_code)]
132fn read_line<T: Read>(mut source: T) -> Result<String, std::io::Error> {
133 #[cfg(feature = "secure_zero")]
134 let mut data_read = zeroize::Zeroizing::new(Vec::<u8>::new());
135 #[cfg(feature = "secure_zero")]
136 let mut buffer = zeroize::Zeroizing::new([0u8; 64]);
137
138 #[cfg(not(feature = "secure_zero"))]
139 let mut data_read = Vec::<u8>::new();
140 #[cfg(not(feature = "secure_zero"))]
141 let mut buffer: [u8; 64] = [0; 64];
142
143 loop {
144 let n = match source.read(buffer.as_mut()) {
145 Ok(n) => n,
146 Err(e) => match e.kind() {
147 std::io::ErrorKind::Interrupted => continue,
148 _ => {
149 return Err(e);
150 }
151 },
152 };
153
154 if let Some(pos) = find_lf(&buffer[..n]) {
155 data_read.extend_from_slice(&buffer[..pos + 1]);
156 break;
157 } else {
158 data_read.extend_from_slice(&buffer[..n]);
159 }
160 }
161
162 let password = match std::str::from_utf8(&data_read) {
163 Ok(p) => p.to_string(),
164 Err(_) => {
165 return Err(std::io::Error::new(
166 std::io::ErrorKind::InvalidData,
167 "Found invalid UTF-8",
168 ));
169 }
170 };
171
172 Ok(password)
173}
174
175#[allow(dead_code)]
178fn find_lf(input: &[u8]) -> Option<usize> {
179 let lf: u8 = 0x0a;
180 for (i, b) in input.iter().enumerate() {
181 if *b == lf {
182 return Some(i);
183 }
184 }
185
186 None
187}
188
189#[cfg(target_family = "windows")]
190mod windows {
191 use crate::win32::{BOOL, ENABLE_ECHO_INPUT, FALSE, INVALID_HANDLE_VALUE, STD_INPUT_HANDLE};
192 use crate::win32::{
193 ENABLE_LINE_INPUT, ENABLE_PROCESSED_INPUT, GetConsoleMode, GetStdHandle, ReadConsoleW,
194 SetConsoleMode, WriteConsoleW,
195 };
196 use crate::{PromptError, Stream, print_stream};
197
198 use std::fs::OpenOptions;
199 use std::os::windows::io::AsRawHandle;
200 use std::os::windows::raw::HANDLE;
201
202 fn disable_echo(handle: HANDLE) -> Result<u32, PromptError> {
205 let mut mode: u32 = 0;
206 unsafe {
207 if GetConsoleMode(handle, &mut mode) == FALSE {
208 return Err(PromptError::IOError(std::io::Error::last_os_error()));
209 }
210 }
211 let original_mode = mode;
212
213 mode &= !ENABLE_ECHO_INPUT;
214 mode &= !ENABLE_LINE_INPUT;
215 mode |= ENABLE_PROCESSED_INPUT;
216
217 unsafe {
218 if SetConsoleMode(handle, mode) == FALSE {
219 let err = std::io::Error::last_os_error();
220 return Err(PromptError::IOError(err));
221 }
222 }
223
224 Ok(original_mode)
225 }
226
227 fn enable_echo(orig: u32, handle: HANDLE) -> Result<(), PromptError> {
230 unsafe {
231 if SetConsoleMode(handle, orig) == FALSE {
232 let err = std::io::Error::last_os_error();
233 return Err(PromptError::EnableFailed(err));
234 }
235 }
236
237 Ok(())
238 }
239
240 pub fn prompt_password_stdin(
260 prompt: Option<&str>,
261 stream: Stream,
262 ) -> Result<String, PromptError> {
263 if stream == Stream::Stdin {
264 return Err(PromptError::InvalidArgument);
265 }
266
267 let handle: HANDLE = unsafe {
268 let handle = GetStdHandle(STD_INPUT_HANDLE);
269 if handle.is_null() || handle == INVALID_HANDLE_VALUE {
270 let err = std::io::Error::last_os_error();
271 return Err(PromptError::IOError(err));
272 }
273
274 handle
275 };
276
277 let restore = disable_echo(handle)?;
280
281 if let Some(p) = prompt {
282 print_stream(p, stream)?;
283 }
284
285 let password = match read_console(handle) {
286 Ok(p) => p,
287 Err(e) => {
288 enable_echo(restore, handle)?;
289 print_stream("\n", stream)?;
290 return Err(e);
291 }
292 };
293
294 enable_echo(restore, handle)?;
295 print_stream("\n", stream)?;
296
297 Ok(password)
298 }
299
300 pub fn prompt_password_tty(prompt: Option<&str>) -> Result<String, PromptError> {
303 let console_in = OpenOptions::new().read(true).write(true).open("CONIN$")?;
304 let console_out = OpenOptions::new().write(true).open("CONOUT$")?;
305
306 if let Some(p) = prompt {
307 write_console(console_out.as_raw_handle(), p)?;
308 }
309
310 let restore = disable_echo(console_in.as_raw_handle())?;
311 let password = match read_console(console_in.as_raw_handle()) {
312 Ok(p) => p,
313 Err(e) => {
314 enable_echo(restore, console_in.as_raw_handle())?;
315 write_console(console_out.as_raw_handle(), "\r\n")?;
316 return Err(e);
317 }
318 };
319
320 enable_echo(restore, console_in.as_raw_handle())?;
321 write_console(console_out.as_raw_handle(), "\r\n")?;
322
323 Ok(password)
324 }
325
326 fn write_console(console_out: HANDLE, prompt: &str) -> Result<(), PromptError> {
328 let converted_prompt: Vec<u16> = prompt.encode_utf16().collect();
330 let res: BOOL = unsafe {
331 WriteConsoleW(
332 console_out,
333 converted_prompt.as_ptr() as *const core::ffi::c_void,
334 converted_prompt.len() as u32,
335 std::ptr::null_mut(),
336 std::ptr::null(),
337 )
338 };
339
340 if res == FALSE {
341 let err = std::io::Error::last_os_error();
342 return Err(PromptError::IOError(err));
343 }
344
345 Ok(())
346 }
347
348 fn contains_crlf(input: &[u16]) -> bool {
349 let cr = 0x000d;
350 let lf = 0x000a;
351 for i in input {
352 if *i == cr || *i == lf {
353 return true;
354 }
355 }
356 false
357 }
358
359 fn ignore_ctrl_chars(input: &[u16]) -> Vec<u16> {
361 let cr = 0x000d;
362 let lf = 0x000a;
363 let bs = 0x0008;
364 let mut res: Vec<u16> = Vec::with_capacity(input.len());
365 for i in input {
367 let val = *i;
368 if val == cr || val == lf {
369 return res;
370 }
371 if val == bs {
372 res.pop();
373 } else {
374 res.push(val);
375 }
376 }
377
378 res
379 }
380
381 fn read_console(console_in: HANDLE) -> Result<String, PromptError> {
383 #[cfg(feature = "secure_zero")]
384 use zeroize::Zeroize;
385
386 #[cfg(feature = "secure_zero")]
387 let mut input = zeroize::Zeroizing::new(Vec::<u16>::new());
388 #[cfg(feature = "secure_zero")]
389 let mut buffer = zeroize::Zeroizing::new([0u16; 64]);
390
391 #[cfg(not(feature = "secure_zero"))]
392 let mut input: Vec<u16> = Vec::new();
393 #[cfg(not(feature = "secure_zero"))]
394 let mut buffer: [u16; 1] = [0; 1];
395
396 loop {
397 let mut num_read: u32 = 0;
398 let num_read_ptr: *mut u32 = &mut num_read;
399 let res: BOOL = unsafe {
400 ReadConsoleW(
401 console_in,
402 buffer.as_mut_ptr() as *mut std::ffi::c_void,
403 buffer.len() as u32,
404 num_read_ptr,
405 std::ptr::null(),
406 )
407 };
408
409 if res == FALSE {
410 let err = std::io::Error::last_os_error();
411 return Err(PromptError::IOError(err));
412 }
413
414 let max_len = std::cmp::min(num_read, buffer.len() as u32) as usize;
415
416 let chars = &buffer[..max_len];
417 input.extend_from_slice(chars);
418 if contains_crlf(chars) {
419 break;
420 }
421 }
422
423 #[cfg(feature = "secure_zero")]
424 let mut cleaned_input = ignore_ctrl_chars(input.as_slice());
425
426 #[cfg(not(feature = "secure_zero"))]
427 let cleaned_input = ignore_ctrl_chars(input.as_slice());
428
429 let password = match String::from_utf16(&cleaned_input) {
430 Ok(s) => s,
431 Err(_) => {
432 let err =
433 std::io::Error::new(std::io::ErrorKind::InvalidData, "Found invalid UTF-16");
434 return Err(PromptError::IOError(err));
435 }
436 };
437
438 #[cfg(feature = "secure_zero")]
439 cleaned_input.zeroize();
440
441 Ok(password)
442 }
443}
444
445#[cfg(target_family = "unix")]
446mod unix {
447 use crate::{PromptError, Stream, print_stream, read_line, strip_newline};
448
449 use libc::{ECHO, STDIN_FILENO, TCSANOW, tcgetattr, tcsetattr, termios};
450 use std::ffi::CStr;
451 use std::fs::File;
452 use std::io::Write;
453 use std::mem::MaybeUninit;
454 use std::os::fd::{AsRawFd, FromRawFd};
455
456 fn set_echo(echo: bool, fd: i32) -> Result<(), PromptError> {
457 let mut tty = MaybeUninit::<termios>::uninit();
458 unsafe {
459 if tcgetattr(fd, tty.as_mut_ptr()) != 0 {
460 return Err(PromptError::IOError(std::io::Error::last_os_error()));
461 }
462 }
463
464 let mut tty = unsafe { tty.assume_init() };
465
466 if !echo {
467 tty.c_lflag &= !ECHO;
468 } else {
469 tty.c_lflag |= ECHO;
470 }
471
472 unsafe {
473 let tty_ptr: *const termios = &tty;
474 if tcsetattr(fd, TCSANOW, tty_ptr) != 0 {
475 let err = std::io::Error::last_os_error();
476 if echo {
477 return Err(PromptError::EnableFailed(err));
478 } else {
479 return Err(PromptError::IOError(err));
480 }
481 }
482 }
483
484 Ok(())
485 }
486
487 pub fn prompt_password_stdin(
502 prompt: Option<&str>,
503 stream: Stream,
504 ) -> Result<String, PromptError> {
505 if stream == Stream::Stdin {
506 return Err(PromptError::InvalidArgument);
507 }
508
509 set_echo(false, STDIN_FILENO)?;
511
512 if let Some(p) = prompt {
513 print_stream(p, stream)?;
514 }
515
516 let mut pass = String::new();
517 let stdin = std::io::stdin();
518 match stdin.read_line(&mut pass) {
519 Ok(_) => {}
520 Err(e) => {
521 if prompt.is_some() {
522 print_stream("\n", stream)?;
523 }
524
525 set_echo(true, STDIN_FILENO)?;
526 return Err(PromptError::IOError(e));
527 }
528 };
529
530 if prompt.is_some() {
531 print_stream("\n", stream)?;
532 }
533
534 set_echo(true, STDIN_FILENO)?;
536
537 let pass = strip_newline(&pass).to_string();
538
539 Ok(pass)
540 }
541
542 pub fn prompt_password_tty(prompt: Option<&str>) -> Result<String, PromptError> {
545 let flags = if prompt.is_some() {
546 libc::O_RDWR | libc::O_NOCTTY
547 } else {
548 libc::O_RDONLY | libc::O_NOCTTY
549 };
550
551 let raw_tty = unsafe {
552 libc::open(
553 CStr::from_bytes_with_nul_unchecked(b"/dev/tty\0").as_ptr(),
554 flags,
555 )
556 };
557
558 if raw_tty == -1 {
559 let err = std::io::Error::last_os_error();
560 return Err(PromptError::IOError(err));
561 }
562
563 let mut tty = unsafe { File::from_raw_fd(raw_tty) };
564
565 if let Some(p) = prompt {
566 write_tty(p, &mut tty)?;
567 }
568
569 let tty_fd = tty.as_raw_fd();
570 set_echo(false, tty_fd)?;
571 let password = match read_line(&mut tty) {
572 Ok(p) => p,
573 Err(e) => {
574 if prompt.is_some() {
575 if let Err(e) = write_tty("\n", &mut tty) {
576 set_echo(true, tty_fd)?;
577 return Err(e.into());
578 }
579 }
580 set_echo(true, tty_fd)?;
581 return Err(e.into());
582 }
583 };
584
585 #[cfg(feature = "secure_zero")]
586 let password = zeroize::Zeroizing::new(password);
587
588 if prompt.is_some() {
589 if let Err(e) = write_tty("\n", &mut tty) {
590 set_echo(true, tty_fd)?;
591 return Err(e.into());
592 }
593 }
594
595 set_echo(true, tty_fd)?;
596
597 let password = strip_newline(&password).to_string();
598
599 Ok(password)
600 }
601
602 fn write_tty<T: Write>(prompt: &str, tty: &mut T) -> Result<(), std::io::Error> {
603 tty.write_all(prompt.as_bytes())?;
604 tty.flush()?;
605
606 Ok(())
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::{find_lf, read_line, strip_newline};
613
614 #[test]
615 fn test_strip_newline() {
616 assert_eq!(strip_newline("hello\r\n"), "hello");
617 assert_eq!(strip_newline("hello\n"), "hello");
618 assert_eq!(strip_newline("hello"), "hello");
619 }
620
621 #[test]
622 fn test_find_lf() {
623 let input = [0x41, 0x42, 0x43, 0x0a];
624 let input2 = [0x41, 0x42, 0x43];
625 assert_eq!(find_lf(&input), Some(3));
626 assert_eq!(find_lf(&input2), None);
627 }
628
629 #[test]
630 fn test_read_line() -> Result<(), String> {
631 let line = "Hello\n".to_string();
632 let pass = match read_line(line.as_bytes()) {
633 Ok(p) => p,
634 Err(e) => return Err(e.to_string()),
635 };
636 assert_eq!(pass, line);
637
638 Ok(())
639 }
640
641 #[test]
642 #[cfg_attr(not(feature = "secure_zero"), ignore)]
643 fn test_read_line_secure_zero() -> Result<(), String> {
644 let line = "Hello\n".to_string();
645 let pass = match read_line(line.as_bytes()) {
646 Ok(p) => p,
647 Err(e) => return Err(e.to_string()),
648 };
649 assert_eq!(pass, line);
650
651 Ok(())
652 }
653}