Skip to main content

cossh/auth/
secret.rs

1//! Sensitive in-memory string/buffer helpers.
2//!
3//! These wrappers reduce accidental secret exposure in logs and ensure buffers
4//! are zeroized when dropped.
5
6pub use secrecy::ExposeSecret;
7use secrecy::SecretString;
8use std::fmt;
9use std::str;
10use zeroize::Zeroize;
11
12#[derive(Default, Clone)]
13/// Redacted wrapper around a secret string.
14pub struct SensitiveString(SecretString);
15
16impl SensitiveString {
17    /// Create a new sensitive string from owned or borrowed input.
18    pub fn new(value: impl Into<String>) -> Self {
19        Self(SecretString::new(value.into().into_boxed_str()))
20    }
21
22    /// Build directly from an owned `String`.
23    pub fn from_owned_string(value: String) -> Self {
24        Self(SecretString::new(value.into_boxed_str()))
25    }
26
27    /// Decode UTF-8 bytes into a sensitive string.
28    pub fn from_utf8_bytes(value: Vec<u8>) -> Result<Self, std::string::FromUtf8Error> {
29        String::from_utf8(value).map(Self::from_owned_string)
30    }
31}
32
33impl ExposeSecret<str> for SensitiveString {
34    fn expose_secret(&self) -> &str {
35        self.0.expose_secret()
36    }
37}
38
39impl PartialEq for SensitiveString {
40    fn eq(&self, other: &Self) -> bool {
41        self.expose_secret() == other.expose_secret()
42    }
43}
44
45impl Eq for SensitiveString {}
46
47impl fmt::Debug for SensitiveString {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        self.0.fmt(f)
50    }
51}
52
53impl From<String> for SensitiveString {
54    fn from(value: String) -> Self {
55        Self::new(value)
56    }
57}
58
59impl From<&str> for SensitiveString {
60    fn from(value: &str) -> Self {
61        Self::new(value)
62    }
63}
64
65pub fn sensitive_string(value: impl Into<String>) -> SensitiveString {
66    SensitiveString::new(value)
67}
68
69#[derive(Default)]
70/// Editable secret buffer used by interactive prompts.
71pub struct SensitiveBuffer {
72    bytes: Vec<u8>,
73    len: usize,
74}
75
76impl SensitiveBuffer {
77    /// Create an empty secret buffer.
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    pub fn is_empty(&self) -> bool {
83        self.len == 0
84    }
85
86    pub fn char_len(&self) -> usize {
87        self.as_str().map(|value| value.chars().count()).unwrap_or(0)
88    }
89
90    pub fn clear(&mut self) {
91        self.bytes.zeroize();
92        self.len = 0;
93    }
94
95    pub fn insert_char(&mut self, cursor_chars: usize, ch: char) {
96        let mut encoded = [0u8; 4];
97        let encoded = ch.encode_utf8(&mut encoded).as_bytes();
98        let insert_at = self.byte_index_for_char(cursor_chars);
99        self.secure_reserve(encoded.len());
100        self.bytes.copy_within(insert_at..self.len, insert_at + encoded.len());
101        self.bytes[insert_at..insert_at + encoded.len()].copy_from_slice(encoded);
102        self.len += encoded.len();
103    }
104
105    pub fn backspace_char(&mut self, cursor_chars: usize) -> usize {
106        if cursor_chars == 0 {
107            return 0;
108        }
109        let end = self.byte_index_for_char(cursor_chars);
110        let start = self.byte_index_for_char(cursor_chars - 1);
111        self.remove_range(start, end);
112        cursor_chars - 1
113    }
114
115    pub fn delete_char(&mut self, cursor_chars: usize) -> usize {
116        let len = self.char_len();
117        if cursor_chars >= len {
118            return len;
119        }
120        let start = self.byte_index_for_char(cursor_chars);
121        let end = self.byte_index_for_char(cursor_chars + 1);
122        self.remove_range(start, end);
123        cursor_chars
124    }
125
126    pub fn masked(&self) -> String {
127        "*".repeat(self.char_len())
128    }
129
130    pub fn as_str(&self) -> Result<&str, str::Utf8Error> {
131        str::from_utf8(&self.bytes[..self.len])
132    }
133
134    pub fn into_sensitive_string(mut self) -> Result<SensitiveString, std::string::FromUtf8Error> {
135        let len = self.len;
136        let bytes = std::mem::take(&mut self.bytes);
137        self.len = 0;
138        let mut active = bytes;
139        active.truncate(len);
140        SensitiveString::from_utf8_bytes(active)
141    }
142
143    fn byte_index_for_char(&self, char_index: usize) -> usize {
144        let Ok(text) = self.as_str() else {
145            return self.len;
146        };
147        if char_index == 0 {
148            return 0;
149        }
150
151        let max = text.chars().count();
152        let clamped = char_index.min(max);
153        if clamped == max {
154            return self.len;
155        }
156
157        text.char_indices().nth(clamped).map_or(self.len, |(byte_index, _)| byte_index)
158    }
159
160    fn secure_reserve(&mut self, additional: usize) {
161        let required = self.len.saturating_add(additional);
162        if required <= self.bytes.len() {
163            return;
164        }
165
166        let doubled = self.bytes.len().saturating_mul(2).max(8);
167        let new_capacity = doubled.max(required);
168        let mut new_bytes = vec![0u8; new_capacity];
169        new_bytes[..self.len].copy_from_slice(&self.bytes[..self.len]);
170        self.bytes.zeroize();
171        self.bytes = new_bytes;
172    }
173
174    fn remove_range(&mut self, start: usize, end: usize) {
175        if start >= end || end > self.len {
176            return;
177        }
178        let removed = end - start;
179        self.bytes.copy_within(end..self.len, start);
180        let tail_start = self.len - removed;
181        self.bytes[tail_start..self.len].zeroize();
182        self.len -= removed;
183    }
184}
185
186impl fmt::Debug for SensitiveBuffer {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        f.write_str("[REDACTED]")
189    }
190}
191
192impl Drop for SensitiveBuffer {
193    fn drop(&mut self) {
194        self.bytes.zeroize();
195        self.len = 0;
196    }
197}
198
199pub mod serde_sensitive_string {
200    use super::{SensitiveString, sensitive_string};
201    use secrecy::ExposeSecret;
202    use serde::de::{self, Visitor};
203    use serde::{Deserializer, Serializer};
204    use std::fmt;
205
206    /// Serialize a sensitive string as plain text for protocol payloads.
207    pub fn serialize<S>(value: &SensitiveString, serializer: S) -> Result<S::Ok, S::Error>
208    where
209        S: Serializer,
210    {
211        serializer.serialize_str(value.expose_secret())
212    }
213
214    /// Deserialize a sensitive string from plain text.
215    pub fn deserialize<'de, D>(deserializer: D) -> Result<SensitiveString, D::Error>
216    where
217        D: Deserializer<'de>,
218    {
219        struct SensitiveStringVisitor;
220
221        impl<'de> Visitor<'de> for SensitiveStringVisitor {
222            type Value = SensitiveString;
223
224            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
225                formatter.write_str("a secret string")
226            }
227
228            fn visit_borrowed_str<E>(self, value: &'de str) -> Result<Self::Value, E>
229            where
230                E: de::Error,
231            {
232                Ok(sensitive_string(value))
233            }
234
235            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
236            where
237                E: de::Error,
238            {
239                Ok(sensitive_string(value))
240            }
241
242            fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
243            where
244                E: de::Error,
245            {
246                Ok(SensitiveString::from_owned_string(value))
247            }
248        }
249
250        deserializer.deserialize_string(SensitiveStringVisitor)
251    }
252}
253
254#[cfg(test)]
255#[path = "../test/auth/secret.rs"]
256mod tests;