1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
//! Base64 encoding support.
use crate::cvt_n;
use crate::error::ErrorStack;
use libc::c_int;

/// Encodes a slice of bytes to a base64 string.
///
/// This corresponds to [`EVP_EncodeBlock`].
///
/// # Panics
///
/// Panics if the input length or computed output length overflow a signed C integer.
///
/// [`EVP_EncodeBlock`]: https://www.openssl.org/docs/man1.1.1/man3/EVP_DecodeBlock.html
pub fn encode_block(src: &[u8]) -> String {
    assert!(src.len() <= c_int::max_value() as usize);
    let src_len = src.len() as c_int;

    let len = encoded_len(src_len).unwrap();
    let mut out = Vec::with_capacity(len as usize);

    // SAFETY: `encoded_len` ensures space for 4 output characters
    // for every 3 input bytes including padding and nul terminator.
    // `EVP_EncodeBlock` will write only single byte ASCII characters.
    // `EVP_EncodeBlock` will only write to not read from `out`.
    unsafe {
        let out_len = ffi::EVP_EncodeBlock(out.as_mut_ptr(), src.as_ptr(), src_len);
        out.set_len(out_len as usize);
        String::from_utf8_unchecked(out)
    }
}

/// Decodes a base64-encoded string to bytes.
///
/// This corresponds to [`EVP_DecodeBlock`].
///
/// # Panics
///
/// Panics if the input length or computed output length overflow a signed C integer.
///
/// [`EVP_DecodeBlock`]: https://www.openssl.org/docs/man1.1.1/man3/EVP_DecodeBlock.html
pub fn decode_block(src: &str) -> Result<Vec<u8>, ErrorStack> {
    let src = src.trim();

    // https://github.com/openssl/openssl/issues/12143
    if src.is_empty() {
        return Ok(vec![]);
    }

    assert!(src.len() <= c_int::max_value() as usize);
    let src_len = src.len() as c_int;

    let len = decoded_len(src_len).unwrap();
    let mut out = Vec::with_capacity(len as usize);

    // SAFETY: `decoded_len` ensures space for 3 output bytes
    // for every 4 input characters including padding.
    // `EVP_DecodeBlock` can write fewer bytes after stripping
    // leading and trailing whitespace, but never more.
    // `EVP_DecodeBlock` will only write to not read from `out`.
    unsafe {
        let out_len = cvt_n(ffi::EVP_DecodeBlock(
            out.as_mut_ptr(),
            src.as_ptr(),
            src_len,
        ))?;
        out.set_len(out_len as usize);
    }

    if src.ends_with('=') {
        out.pop();
        if src.ends_with("==") {
            out.pop();
        }
    }

    Ok(out)
}

fn encoded_len(src_len: c_int) -> Option<c_int> {
    let mut len = (src_len / 3).checked_mul(4)?;

    if src_len % 3 != 0 {
        len = len.checked_add(4)?;
    }

    len = len.checked_add(1)?;

    Some(len)
}

fn decoded_len(src_len: c_int) -> Option<c_int> {
    let mut len = (src_len / 4).checked_mul(3)?;

    if src_len % 4 != 0 {
        len = len.checked_add(3)?;
    }

    Some(len)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_encode_block() {
        assert_eq!("".to_string(), encode_block(b""));
        assert_eq!("Zg==".to_string(), encode_block(b"f"));
        assert_eq!("Zm8=".to_string(), encode_block(b"fo"));
        assert_eq!("Zm9v".to_string(), encode_block(b"foo"));
        assert_eq!("Zm9vYg==".to_string(), encode_block(b"foob"));
        assert_eq!("Zm9vYmE=".to_string(), encode_block(b"fooba"));
        assert_eq!("Zm9vYmFy".to_string(), encode_block(b"foobar"));
    }

    #[test]
    fn test_decode_block() {
        assert_eq!(b"".to_vec(), decode_block("").unwrap());
        assert_eq!(b"f".to_vec(), decode_block("Zg==").unwrap());
        assert_eq!(b"fo".to_vec(), decode_block("Zm8=").unwrap());
        assert_eq!(b"foo".to_vec(), decode_block("Zm9v").unwrap());
        assert_eq!(b"foob".to_vec(), decode_block("Zm9vYg==").unwrap());
        assert_eq!(b"fooba".to_vec(), decode_block("Zm9vYmE=").unwrap());
        assert_eq!(b"foobar".to_vec(), decode_block("Zm9vYmFy").unwrap());
    }

    #[test]
    fn test_strip_whitespace() {
        assert_eq!(b"foobar".to_vec(), decode_block(" Zm9vYmFy\n").unwrap());
        assert_eq!(b"foob".to_vec(), decode_block(" Zm9vYg==\n").unwrap());
    }
}