use crate::{Error, Result};
use core::str;
use encoding::{Base64Writer, Encode};
#[cfg(feature = "alloc")]
use {alloc::string::String, encoding::CheckedSum};
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) struct SshFormat<'a> {
pub(crate) algorithm_id: &'a str,
pub(crate) base64_data: &'a [u8],
#[cfg_attr(not(feature = "alloc"), allow(dead_code))]
pub(crate) comment: &'a str,
}
impl<'a> SshFormat<'a> {
pub(crate) fn decode(mut bytes: &'a [u8]) -> Result<Self> {
let algorithm_id = decode_segment_str(&mut bytes)?;
let base64_data = decode_segment(&mut bytes)?;
let comment = str::from_utf8(bytes)?.trim_end();
if algorithm_id.is_empty() || base64_data.is_empty() {
return Err(encoding::Error::Length.into());
}
Ok(Self {
algorithm_id,
base64_data,
comment,
})
}
pub(crate) fn encode<'o, K>(
algorithm_id: &str,
key: &K,
comment: &str,
out: &'o mut [u8],
) -> Result<&'o str>
where
K: Encode<Error = Error>,
{
let mut offset = 0;
encode_str(out, &mut offset, algorithm_id)?;
encode_str(out, &mut offset, " ")?;
let mut writer = Base64Writer::new(&mut out[offset..])?;
key.encode(&mut writer)?;
let base64_len = writer.finish()?.len();
offset = offset
.checked_add(base64_len)
.ok_or(encoding::Error::Length)?;
if !comment.is_empty() {
encode_str(out, &mut offset, " ")?;
encode_str(out, &mut offset, comment)?;
}
Ok(str::from_utf8(&out[..offset])?)
}
#[cfg(feature = "alloc")]
pub(crate) fn encode_string<K>(algorithm_id: &str, key: &K, comment: &str) -> Result<String>
where
K: Encode<Error = Error>,
{
let encoded_len = [
2, algorithm_id.len(),
base64_len_approx(key.encoded_len()?),
comment.len(),
]
.checked_sum()?;
let mut out = vec![0u8; encoded_len];
let actual_len = Self::encode(algorithm_id, key, comment, &mut out)?.len();
out.truncate(actual_len);
Ok(String::from_utf8(out)?)
}
}
#[cfg(feature = "alloc")]
fn base64_len_approx(input_len: usize) -> usize {
#[allow(clippy::integer_arithmetic)]
((((input_len * 4) / 3) + 3) & !3)
}
fn decode_segment<'a>(bytes: &mut &'a [u8]) -> Result<&'a [u8]> {
let start = *bytes;
let mut len = 0usize;
loop {
match *bytes {
[b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'+' | b'-' | b'/' | b'=' | b'@' | b'.', rest @ ..] =>
{
*bytes = rest;
len = len.checked_add(1).ok_or(encoding::Error::Length)?;
}
[b' ', rest @ ..] => {
*bytes = rest;
return start
.get(..len)
.ok_or_else(|| encoding::Error::Length.into());
}
[_, ..] => {
return Err(encoding::Error::CharacterEncoding.into());
}
[] => {
return start
.get(..len)
.ok_or_else(|| encoding::Error::Length.into());
}
}
}
}
fn decode_segment_str<'a>(bytes: &mut &'a [u8]) -> Result<&'a str> {
str::from_utf8(decode_segment(bytes)?).map_err(|_| encoding::Error::CharacterEncoding.into())
}
fn encode_str(out: &mut [u8], offset: &mut usize, s: &str) -> Result<()> {
let bytes = s.as_bytes();
if out.len()
< offset
.checked_add(bytes.len())
.ok_or(encoding::Error::Length)?
{
return Err(encoding::Error::Length.into());
}
out[*offset..][..bytes.len()].copy_from_slice(bytes);
*offset = offset
.checked_add(bytes.len())
.ok_or(encoding::Error::Length)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::SshFormat;
const EXAMPLE_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILM+rvN+ot98qgEN796jTiQfZfG1KaT0PtFDJ/XFSqti user@example.com";
#[test]
fn decode() {
let encapsulation = SshFormat::decode(EXAMPLE_KEY.as_bytes()).unwrap();
assert_eq!(encapsulation.algorithm_id, "ssh-ed25519");
assert_eq!(
encapsulation.base64_data,
b"AAAAC3NzaC1lZDI1NTE5AAAAILM+rvN+ot98qgEN796jTiQfZfG1KaT0PtFDJ/XFSqti"
);
assert_eq!(encapsulation.comment, "user@example.com");
}
}