mod format;
mod sign;
use crate::{
crypto::SigningKey,
header::{FieldBody, FieldName, HeaderField, HeaderFields},
message_hash::{BodyHasher, BodyHasherBuilder, BodyHasherStance},
parse,
signature::{
Canonicalization, CanonicalizationAlgorithm, DkimSignature, DomainName, Identity, Selector,
SigningAlgorithm, DKIM_SIGNATURE_NAME,
},
signer::format::LINE_WIDTH,
tag_list,
};
use std::{
cmp::Ordering,
collections::HashSet,
error::Error,
fmt::{self, Display, Formatter},
num::{NonZeroUsize, TryFromIntError},
time::Duration,
};
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub enum BodyLength {
#[default]
NoLimit,
MessageContent,
Exact(u64),
}
impl BodyLength {
fn to_usize(self) -> Result<Option<usize>, TryFromIntError> {
match self {
Self::NoLimit | Self::MessageContent => Ok(None),
Self::Exact(n) => n.try_into().map(Some),
}
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub enum Timestamp {
None,
#[default]
Now,
Exact(u64),
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub enum Expiration {
#[default]
Never,
After(Duration),
Exact(u64),
}
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
pub enum HeaderSelection {
#[default]
Auto,
Manual(Vec<FieldName>),
}
pub fn select_headers<'a, 'b: 'a>(
headers: &'a HeaderFields,
mut pred: impl FnMut(&FieldName) -> bool + 'b,
) -> impl DoubleEndedIterator<Item = &FieldName> + 'a {
headers
.as_ref()
.iter()
.rev()
.filter_map(move |(name, _)| pred(name).then_some(name))
}
pub fn default_signed_headers() -> Vec<FieldName> {
let names = [
"From",
"Reply-To",
"Subject",
"Date",
"To",
"Cc",
"Resent-Date",
"Resent-From",
"Resent-To",
"Resent-Cc",
"In-Reply-To",
"References",
"List-Id",
"List-Help",
"List-Unsubscribe",
"List-Subscribe",
"List-Post",
"List-Owner",
"List-Archive",
];
names
.into_iter()
.map(|n| FieldName::new(n).unwrap())
.collect()
}
pub fn default_unsigned_headers() -> Vec<FieldName> {
let names = [
"Return-Path",
"Received",
"Comments",
"Keywords",
];
names
.into_iter()
.map(|n| FieldName::new(n).unwrap())
.collect()
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum RequestError {
Overflow,
MissingFromHeader,
TooManyRequests,
EmptyRequests,
IncompatibleKeyType,
FromHeaderNotSigned,
InvalidSignedFieldName,
DomainMismatch,
ZeroExpirationDuration,
ExpirationNotAfterTimestamp,
InvalidExtTags,
InvalidDkimSignatureHeaderName,
InvalidIndentationWhitespace,
}
impl Display for RequestError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Overflow => write!(f, "integer too large"),
Self::MissingFromHeader => write!(f, "no From header"),
Self::TooManyRequests => write!(f, "too many sign requests"),
Self::EmptyRequests => write!(f, "no sign requests"),
Self::IncompatibleKeyType => write!(f, "incompatible key type"),
Self::FromHeaderNotSigned => write!(f, "From header not signed"),
Self::InvalidSignedFieldName => write!(f, "invalid signed header name"),
Self::DomainMismatch => write!(f, "domain mismatch"),
Self::ZeroExpirationDuration => write!(f, "zero expiration duration"),
Self::ExpirationNotAfterTimestamp => write!(f, "expiration not after timestamp"),
Self::InvalidExtTags => write!(f, "invalid extension tags"),
Self::InvalidDkimSignatureHeaderName => write!(f, "invalid DKIM-Signature header name"),
Self::InvalidIndentationWhitespace => write!(f, "invalid indentation whitespace"),
}
}
}
impl Error for RequestError {}
pub struct OutputFormat {
pub header_name: String,
pub line_width: NonZeroUsize,
pub indentation: String,
pub tag_order: Option<Box<dyn Fn(&str, &str) -> Ordering + Send + Sync>>,
pub ascii_only: bool,
}
impl Default for OutputFormat {
fn default() -> Self {
Self {
header_name: DKIM_SIGNATURE_NAME.into(),
line_width: LINE_WIDTH.try_into().unwrap(),
indentation: "\t".into(),
tag_order: None,
ascii_only: false,
}
}
}
struct ClosureDebug;
impl fmt::Debug for ClosureDebug {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "<closure>")
}
}
impl fmt::Debug for OutputFormat {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("OutputFormat")
.field("header_name", &self.header_name)
.field("line_width", &self.line_width)
.field("indentation", &self.indentation)
.field("tag_order", &self.tag_order.as_ref().map(|_| ClosureDebug))
.field("ascii_only", &self.ascii_only)
.finish()
}
}
#[derive(Debug)]
pub struct SignRequest<T> {
pub signing_key: T,
pub algorithm: SigningAlgorithm,
pub canonicalization: Canonicalization,
pub domain: DomainName,
pub header_selection: HeaderSelection,
pub identity: Option<Identity>,
pub body_length: BodyLength,
pub selector: Selector,
pub timestamp: Timestamp,
pub expiration: Expiration,
pub copy_headers: bool,
pub ext_tags: Vec<(String, String)>,
pub format: OutputFormat,
}
impl<T> SignRequest<T> {
pub fn new(
domain: DomainName,
selector: Selector,
algorithm: SigningAlgorithm,
signing_key: T,
) -> Self {
use CanonicalizationAlgorithm::*;
let canonicalization = Canonicalization::from((Relaxed, Simple));
let five_days = Duration::from_secs(60 * 60 * 24 * 5);
Self {
signing_key,
algorithm,
canonicalization,
domain,
header_selection: Default::default(),
identity: None,
body_length: BodyLength::NoLimit,
selector,
timestamp: Timestamp::Now,
expiration: Expiration::After(five_days),
copy_headers: false,
ext_tags: vec![],
format: Default::default(),
}
}
}
fn validate_request<T: AsRef<SigningKey>>(request: &SignRequest<T>) -> Result<(), RequestError> {
if request.signing_key.as_ref().key_type() != request.algorithm.key_type() {
return Err(RequestError::IncompatibleKeyType);
}
if let HeaderSelection::Manual(signed_headers) = &request.header_selection {
if !signed_headers.iter().any(|name| *name == "From") {
return Err(RequestError::FromHeaderNotSigned);
}
if signed_headers.iter().any(|name| name.as_ref().contains(';')) {
return Err(RequestError::InvalidSignedFieldName);
}
}
if let Some(identity) = &request.identity {
if !identity.domain.eq_or_subdomain_of(&request.domain) {
return Err(RequestError::DomainMismatch);
}
}
validate_timestamps(request.timestamp, request.expiration)?;
let mut tags_seen = HashSet::new();
if request.ext_tags.iter().any(|(name, value)| {
!tags_seen.insert(name)
|| !tag_list::is_tag_name(name)
|| !tag_list::is_tag_value(value)
|| format::is_output_tag(name)
}) {
return Err(RequestError::InvalidExtTags);
}
if !request.format.header_name.eq_ignore_ascii_case(DKIM_SIGNATURE_NAME) {
return Err(RequestError::InvalidDkimSignatureHeaderName);
}
let indent = &request.format.indentation;
if indent.is_empty() || indent.chars().any(|c| !parse::is_wsp(c)) {
return Err(RequestError::InvalidIndentationWhitespace);
}
Ok(())
}
fn validate_timestamps(timestamp: Timestamp, expiration: Expiration) -> Result<(), RequestError> {
if let Expiration::After(duration) = expiration {
if duration.as_secs() == 0 {
return Err(RequestError::ZeroExpirationDuration);
}
}
match (timestamp, expiration) {
(Timestamp::Exact(t), Expiration::Exact(x)) if t >= x => {
return Err(RequestError::ExpirationNotAfterTimestamp);
}
(Timestamp::Now, Expiration::Exact(0))
| (Timestamp::Exact(u64::MAX), Expiration::After(_)) => {
return Err(RequestError::ExpirationNotAfterTimestamp);
}
_ => {}
}
Ok(())
}
pub type SigningResult = Result<SigningOutput, SigningError>;
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum SigningError {
Overflow,
InsufficientContent,
SigningFailure,
}
impl Display for SigningError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Overflow => write!(f, "integer too large"),
Self::InsufficientContent => write!(f, "not enough message body content"),
Self::SigningFailure => write!(f, "signing failed"),
}
}
}
impl Error for SigningError {}
struct SigningOutputHeaderDisplay<'a> {
name: &'a str,
value: &'a str,
}
impl Display for SigningOutputHeaderDisplay<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.name, self.value)
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct SigningOutput {
pub header_name: String,
pub header_value: String,
pub signature: DkimSignature,
}
impl SigningOutput {
pub fn format_header(&self) -> impl Display + '_ {
SigningOutputHeaderDisplay {
name: &self.header_name,
value: &self.header_value,
}
}
pub fn to_header_field(&self) -> HeaderField {
(
FieldName::new(self.header_name.as_str()).unwrap(),
FieldBody::new(self.header_value.as_bytes()).unwrap(),
)
}
}
struct SignerTask<T> {
request: SignRequest<T>,
}
pub struct Signer<T> {
tasks: Vec<SignerTask<T>>, headers: HeaderFields,
body_hasher: BodyHasher,
}
impl<T> Signer<T>
where
T: AsRef<SigningKey>,
{
pub fn prepare_signing<I>(headers: HeaderFields, requests: I) -> Result<Self, RequestError>
where
I: IntoIterator<Item = SignRequest<T>>,
{
if !headers.as_ref().iter().any(|(name, _)| *name == "From") {
return Err(RequestError::MissingFromHeader);
}
let mut tasks = vec![];
let mut body_hasher = BodyHasherBuilder::new(false);
for (i, request) in requests.into_iter().enumerate() {
if i >= 10 {
return Err(RequestError::TooManyRequests);
}
validate_request(&request)?;
let body_len = request.body_length.to_usize().map_err(|_| RequestError::Overflow)?;
let hash_alg = request.algorithm.hash_algorithm();
let canon_alg = request.canonicalization.body;
body_hasher.register_canonicalization(body_len, hash_alg, canon_alg);
tasks.push(SignerTask { request });
}
if tasks.is_empty() {
return Err(RequestError::EmptyRequests);
}
let body_hasher = body_hasher.build();
Ok(Self {
tasks,
headers,
body_hasher,
})
}
pub fn process_body_chunk(&mut self, chunk: &[u8]) -> BodyHasherStance {
self.body_hasher.hash_chunk(chunk)
}
pub async fn sign(self) -> Vec<SigningResult> {
let bh_results = self.body_hasher.finish();
let mut result = vec![];
for task in self.tasks {
let request = task.request;
let signing_result = sign::perform_signing(request, &self.headers, &bh_results).await;
result.push(signing_result);
}
result
}
}
pub async fn sign<I, T>(
header: HeaderFields,
body: &[u8],
requests: I,
) -> Result<Vec<SigningResult>, RequestError>
where
I: IntoIterator<Item = SignRequest<T>>,
T: AsRef<SigningKey>,
{
let mut signer = Signer::prepare_signing(header, requests)?;
let _ = signer.process_body_chunk(body);
Ok(signer.sign().await)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::header::FieldBody;
use std::collections::HashSet;
#[test]
fn select_headers_ok() {
let headers = make_header_fields(["From", "Aa", "Bb", "Aa", "Dd"]);
let names = make_field_names(["from", "aa", "bb", "cc"]);
let selection = select_headers(&headers, move |name| names.contains(name));
assert!(selection.map(|n| n.as_ref()).eq(["Aa", "Bb", "Aa", "From"]));
}
fn make_header_fields(names: impl IntoIterator<Item = &'static str>) -> HeaderFields {
let names: Vec<_> = names
.into_iter()
.map(|name| (FieldName::new(name).unwrap(), FieldBody::new(*b"").unwrap()))
.collect();
HeaderFields::new(names).unwrap()
}
fn make_field_names(names: impl IntoIterator<Item = &'static str>) -> HashSet<FieldName> {
names
.into_iter()
.map(|name| FieldName::new(name).unwrap())
.collect()
}
#[test]
fn validate_timestamps_ok() {
use super::{
Expiration as X,
RequestError::{ExpirationNotAfterTimestamp as ENAT, ZeroExpirationDuration as ZED},
Timestamp as T,
};
let secs = Duration::from_secs;
assert_eq!(validate_timestamps(T::None, X::Never), Ok(()));
assert_eq!(validate_timestamps(T::None, X::After(secs(0))), Err(ZED));
assert_eq!(validate_timestamps(T::None, X::After(secs(3))), Ok(()));
assert_eq!(validate_timestamps(T::None, X::Exact(0)), Ok(()));
assert_eq!(validate_timestamps(T::None, X::Exact(3)), Ok(()));
assert_eq!(validate_timestamps(T::Now, X::Never), Ok(()));
assert_eq!(validate_timestamps(T::Now, X::After(secs(0))), Err(ZED));
assert_eq!(validate_timestamps(T::Now, X::After(secs(3))), Ok(()));
assert_eq!(validate_timestamps(T::Now, X::Exact(0)), Err(ENAT));
assert_eq!(validate_timestamps(T::Now, X::Exact(3)), Ok(()));
assert_eq!(validate_timestamps(T::Exact(3), X::Never), Ok(()));
assert_eq!(validate_timestamps(T::Exact(3), X::After(secs(0))), Err(ZED));
assert_eq!(validate_timestamps(T::Exact(3), X::After(secs(3))), Ok(()));
assert_eq!(validate_timestamps(T::Exact(u64::MAX), X::After(secs(3))), Err(ENAT));
}
}