redact_crypto/
source.rs

1//! Sources provide some source material for creating a type. Currently, the only
2//! implementations available are sources of bytes. A source provides an interface
3//! for read/write operations on the set of bytes it covers.
4
5use crate::CryptoError;
6use base64::DecodeError;
7use once_cell::sync::OnceCell;
8use serde::{
9    de::{self, Deserializer},
10    Deserialize, Serialize, Serializer,
11};
12use std::{
13    convert::Into,
14    error::Error,
15    fmt::{self, Display, Formatter},
16    io::{self, ErrorKind},
17    path::PathBuf as StdPathBuf,
18    str::FromStr,
19};
20
21#[derive(Debug)]
22pub enum NotFoundKind {
23    File(String),
24    Vector,
25}
26
27#[derive(Debug)]
28pub enum SourceError {
29    /// Error occurred while performing IO on the filesystem
30    FsIoError { source: io::Error },
31
32    /// Requested bytes were not found
33    NotFound { kind: NotFoundKind },
34
35    /// File path given has an invalid file name with no stem
36    FilePathHasNoFileStem { path: String },
37
38    /// File path given was invalid UTF-8
39    FilePathIsInvalidUTF8,
40
41    /// Error happened when decoding base64 string
42    Base64Decode { source: DecodeError },
43}
44
45impl Error for SourceError {
46    fn source(&self) -> Option<&(dyn Error + 'static)> {
47        match *self {
48            SourceError::FsIoError { ref source } => Some(source),
49            SourceError::NotFound { .. } => None,
50            SourceError::FilePathHasNoFileStem { .. } => None,
51            SourceError::FilePathIsInvalidUTF8 => None,
52            SourceError::Base64Decode { ref source } => Some(source),
53        }
54    }
55}
56
57impl Display for SourceError {
58    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
59        match *self {
60            SourceError::FsIoError { .. } => {
61                write!(f, "Error occured during file system I/O")
62            }
63            SourceError::NotFound { ref kind } => match kind {
64                NotFoundKind::File(path) => write!(f, "Path \"{}\" not found", path),
65                NotFoundKind::Vector => write!(f, "Vector byte source contains no bytes"),
66            },
67            SourceError::FilePathHasNoFileStem { ref path } => {
68                write!(
69                    f,
70                    "File path \"{}\" was invalid as the file name has no stem",
71                    path
72                )
73            }
74            SourceError::FilePathIsInvalidUTF8 => {
75                write!(f, "Given file path was not valid UTF-8")
76            }
77            SourceError::Base64Decode { .. } => {
78                write!(f, "Error occurred while decoding string from base64")
79            }
80        }
81    }
82}
83
84impl From<SourceError> for CryptoError {
85    fn from(mse: SourceError) -> Self {
86        match mse {
87            SourceError::NotFound { .. } => CryptoError::NotFound {
88                source: Box::new(mse),
89            },
90            _ => CryptoError::InternalError {
91                source: Box::new(mse),
92            },
93        }
94    }
95}
96
97pub trait HasByteSource {
98    fn byte_source(&self) -> ByteSource;
99}
100
101/// Enumerates all the different types of sources.
102/// Currently supported:
103/// - Bytes: sources that are represented as a byte array
104#[derive(Serialize, Deserialize, Debug, Clone)]
105#[serde(tag = "t", content = "c")]
106pub enum Source {
107    Byte(ByteSource),
108}
109
110/// Enumerates all the different types of byte-type sources.
111/// Currently supported:
112/// - Fs: data stored on the filesystem
113/// - Vector: data stored in a vector of bytes
114#[derive(Serialize, Deserialize, Debug, Clone)]
115#[serde(tag = "t", content = "c")]
116pub enum ByteSource {
117    Fs(FsByteSource),
118    Vector(VectorByteSource),
119}
120
121impl ByteSource {
122    /// Sets the bytes of the source to the given value
123    pub fn set(&mut self, key: &[u8]) -> Result<(), SourceError> {
124        match self {
125            ByteSource::Fs(fsbks) => fsbks.set(key),
126            ByteSource::Vector(vbks) => vbks.set(key),
127        }
128    }
129
130    /// Gets the bytes stored by the source
131    pub fn get(&self) -> Result<&[u8], SourceError> {
132        match self {
133            ByteSource::Fs(fsbks) => fsbks.get(),
134            ByteSource::Vector(vbks) => vbks.get(),
135        }
136    }
137}
138
139impl From<&[u8]> for ByteSource {
140    fn from(value: &[u8]) -> Self {
141        ByteSource::Vector(value.into())
142    }
143}
144
145impl From<&str> for ByteSource {
146    fn from(value: &str) -> Self {
147        ByteSource::Vector(value.into())
148    }
149}
150
151/// Represents a valid path
152#[derive(Serialize, Deserialize, Debug, Clone)]
153pub struct Path {
154    path: StdPathBuf,
155    stem: String,
156}
157
158impl Path {
159    pub fn file_stem(&self) -> &str {
160        &self.stem
161    }
162}
163
164impl<'a> From<&'a Path> for &'a StdPathBuf {
165    fn from(path: &'a Path) -> Self {
166        &path.path
167    }
168}
169
170impl FromStr for Path {
171    type Err = SourceError;
172
173    fn from_str(path: &str) -> Result<Self, Self::Err> {
174        let path: StdPathBuf = path.into();
175        let stem = path
176            .file_stem()
177            .ok_or(SourceError::FilePathHasNoFileStem {
178                path: path
179                    .clone()
180                    .into_os_string()
181                    .into_string()
182                    .unwrap_or_else(|_| "<Invalid UTF8>".to_owned()),
183            })?
184            .to_str()
185            .ok_or(SourceError::FilePathIsInvalidUTF8)?
186            .to_owned();
187
188        Ok(Self { path, stem })
189    }
190}
191
192/// A source that is a path to a file on the filesystem. The contents
193/// of the file are cached on the first call to get(), and can be refreshed
194/// by calling the reload() method.
195#[derive(Serialize, Deserialize, Debug, Clone)]
196pub struct FsByteSource {
197    path: Path,
198    #[serde(skip)]
199    cached: OnceCell<VectorByteSource>,
200}
201
202impl FromStr for FsByteSource {
203    type Err = SourceError;
204
205    fn from_str(s: &str) -> Result<Self, Self::Err> {
206        let path = Path::from_str(s)?;
207        Ok(FsByteSource::new(path))
208    }
209}
210
211impl FsByteSource {
212    /// Creates an `FsBytesSource` from a path on the filesystem
213    pub fn new(path: Path) -> Self {
214        let cached = OnceCell::new();
215        FsByteSource { path, cached }
216    }
217
218    /// Reads a `VectorBytesSource` from a path on the filesystem
219    fn read_from_path(path: &Path) -> Result<VectorByteSource, SourceError> {
220        let path_ref: &StdPathBuf = path.into();
221        let path_str = path
222            .path
223            .clone()
224            .into_os_string()
225            .into_string()
226            .unwrap_or_else(|_| "<Invalid UTF8>".to_owned());
227
228        // Mock this
229        let read_bytes = std::fs::read(path_ref).map_err(|e| match e.kind() {
230            ErrorKind::NotFound => SourceError::NotFound {
231                kind: NotFoundKind::File(path_str),
232            },
233            _ => SourceError::FsIoError { source: e },
234        })?;
235        let bytes =
236            base64::decode(read_bytes).map_err(|e| SourceError::Base64Decode { source: e })?;
237        Ok(VectorByteSource { value: Some(bytes) })
238    }
239
240    /// Empties the cache, triggering a reload of the file on the next
241    /// call to get. Note that this function does not perform any file
242    /// I/O.
243    pub fn reload(&mut self) {
244        self.cached.take();
245    }
246
247    /// Re-writes the file at the path to the given bytes
248    pub fn set(&mut self, value: &[u8]) -> Result<(), SourceError> {
249        let path_ref: &StdPathBuf = (&self.path).into();
250        let path_str = self
251            .path
252            .path
253            .clone()
254            .into_os_string()
255            .into_string()
256            .unwrap_or_else(|_| "<Invalid UTF8>".to_owned());
257        let path_parent = path_ref.parent();
258        let bytes = base64::encode(value);
259
260        // If the path contains parent directories, try to create the chain of
261        // directories first before making the file
262        if let Some(path) = path_parent {
263            std::fs::create_dir_all(path).map_err(|source| SourceError::FsIoError { source })?;
264        }
265
266        // Write the given bytes to the file at the given path
267        std::fs::write(path_ref, bytes)
268            .map(|_| self.reload())
269            .map_err(|source| match source.kind() {
270                std::io::ErrorKind::NotFound => SourceError::NotFound {
271                    kind: NotFoundKind::File(path_str),
272                },
273                _ => SourceError::FsIoError { source },
274            })
275    }
276
277    /// Returns the bytes stored at the path
278    pub fn get(&self) -> Result<&[u8], SourceError> {
279        self.cached
280            .get_or_try_init(|| Self::read_from_path(&self.path))?
281            .get()
282    }
283
284    /// Returns the path where the key is stored
285    pub fn path(&self) -> &Path {
286        &self.path
287    }
288}
289
290/// A source that is an array of bytes in memory
291#[derive(Serialize, Deserialize, Debug, Clone)]
292pub struct VectorByteSource {
293    #[serde(
294        serialize_with = "byte_vector_serialize",
295        deserialize_with = "byte_vector_deserialize"
296    )]
297    value: Option<Vec<u8>>,
298}
299
300/// Custom serialization function base64-encodes the bytes before storage
301fn byte_vector_serialize<S>(bytes: &Option<Vec<u8>>, s: S) -> Result<S::Ok, S::Error>
302where
303    S: Serializer,
304{
305    match bytes {
306        Some(bytes) => {
307            let b64_encoded = base64::encode(bytes);
308            s.serialize_some(&Some(b64_encoded))
309        }
310        None => s.serialize_none(),
311    }
312}
313
314/// Custom deserialization function base64-decodes the bytes before passing them back
315fn byte_vector_deserialize<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
316where
317    D: Deserializer<'de>,
318{
319    let b64_encoded: Option<String> = de::Deserialize::deserialize(deserializer)?;
320    match b64_encoded {
321        Some(bytes) => Ok(Some(base64::decode(bytes).map_err(de::Error::custom)?)),
322        None => Ok(None),
323    }
324}
325
326impl VectorByteSource {
327    /// Creates a new `VectorBytesSource` from the given byte array
328    pub fn new(value: Option<&[u8]>) -> Self {
329        match value {
330            Some(value) => VectorByteSource {
331                value: Some(value.to_vec()),
332            },
333            None => VectorByteSource { value: None },
334        }
335    }
336
337    /// Re-writes the source to the given bytes
338    pub fn set(&mut self, key: &[u8]) -> Result<(), SourceError> {
339        self.value = Some(key.to_owned());
340        Ok(())
341    }
342
343    /// Returns the stored bytes
344    pub fn get(&self) -> Result<&[u8], SourceError> {
345        match self.value {
346            Some(ref bytes) => Ok(bytes.as_ref()),
347            None => Err(SourceError::NotFound {
348                kind: NotFoundKind::Vector,
349            }),
350        }
351    }
352}
353
354impl From<&[u8]> for VectorByteSource {
355    fn from(value: &[u8]) -> Self {
356        Self::new(Some(value))
357    }
358}
359
360impl From<&str> for VectorByteSource {
361    fn from(value: &str) -> Self {
362        Self::new(Some(value.as_ref()))
363    }
364}