use std::{fmt, str::FromStr};
use headers::authorization::Credentials;
use http::HeaderValue;
use http_auth::ChallengeParser;
use ruma_common::{
CanonicalJsonObject, IdParseError, OwnedServerName, OwnedServerSigningKeyId, ServerName,
api::auth_scheme::AuthScheme,
http_headers::quote_ascii_string_if_required,
serde::{Base64, Base64DecodeError},
};
use ruma_signatures::{Ed25519KeyPair, KeyPair, PublicKeyMap};
use thiserror::Error;
use tracing::debug;
#[derive(Debug, Clone, Copy, Default)]
#[allow(clippy::exhaustive_structs)]
pub struct ServerSignatures;
impl AuthScheme for ServerSignatures {
type Input<'a> = ServerSignaturesInput<'a>;
type AddAuthenticationError = XMatrixFromRequestError;
type Output = XMatrix;
type ExtractAuthenticationError = XMatrixExtractError;
fn add_authentication<T: AsRef<[u8]>>(
request: &mut http::Request<T>,
input: ServerSignaturesInput<'_>,
) -> Result<(), Self::AddAuthenticationError> {
let authorization = HeaderValue::from(&XMatrix::try_from_http_request(request, input)?);
request.headers_mut().insert(http::header::AUTHORIZATION, authorization);
Ok(())
}
fn extract_authentication<T: AsRef<[u8]>>(
request: &http::Request<T>,
) -> Result<Self::Output, Self::ExtractAuthenticationError> {
let value = request
.headers()
.get(http::header::AUTHORIZATION)
.ok_or(XMatrixExtractError::MissingAuthorizationHeader)?;
Ok(value.try_into()?)
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ServerSignaturesInput<'a> {
pub origin: OwnedServerName,
pub destination: OwnedServerName,
pub key_pair: &'a Ed25519KeyPair,
}
impl<'a> ServerSignaturesInput<'a> {
pub fn new(
origin: OwnedServerName,
destination: OwnedServerName,
key_pair: &'a Ed25519KeyPair,
) -> Self {
Self { origin, destination, key_pair }
}
}
#[derive(Clone)]
#[non_exhaustive]
pub struct XMatrix {
pub origin: OwnedServerName,
pub destination: Option<OwnedServerName>,
pub key: OwnedServerSigningKeyId,
pub sig: Base64,
}
impl XMatrix {
pub fn new(
origin: OwnedServerName,
destination: OwnedServerName,
key: OwnedServerSigningKeyId,
sig: Base64,
) -> Self {
Self { origin, destination: Some(destination), key, sig }
}
pub fn parse(s: impl AsRef<str>) -> Result<Self, XMatrixParseError> {
let parser = ChallengeParser::new(s.as_ref());
let mut xmatrix = None;
for challenge in parser {
let challenge = challenge?;
if challenge.scheme.eq_ignore_ascii_case(XMatrix::SCHEME) {
xmatrix = Some(challenge);
break;
}
}
let Some(xmatrix) = xmatrix else {
return Err(XMatrixParseError::NotFound);
};
let mut origin = None;
let mut destination = None;
let mut key = None;
let mut sig = None;
for (name, value) in xmatrix.params {
if name.eq_ignore_ascii_case("origin") {
if origin.is_some() {
return Err(XMatrixParseError::DuplicateParameter("origin".to_owned()));
} else {
origin = Some(OwnedServerName::try_from(value.to_unescaped())?);
}
} else if name.eq_ignore_ascii_case("destination") {
if destination.is_some() {
return Err(XMatrixParseError::DuplicateParameter("destination".to_owned()));
} else {
destination = Some(OwnedServerName::try_from(value.to_unescaped())?);
}
} else if name.eq_ignore_ascii_case("key") {
if key.is_some() {
return Err(XMatrixParseError::DuplicateParameter("key".to_owned()));
} else {
key = Some(OwnedServerSigningKeyId::try_from(value.to_unescaped())?);
}
} else if name.eq_ignore_ascii_case("sig") {
if sig.is_some() {
return Err(XMatrixParseError::DuplicateParameter("sig".to_owned()));
} else {
sig = Some(Base64::parse(value.to_unescaped())?);
}
} else {
debug!("Unknown parameter {name} in X-Matrix Authorization header");
}
}
Ok(Self {
origin: origin
.ok_or_else(|| XMatrixParseError::MissingParameter("origin".to_owned()))?,
destination,
key: key.ok_or_else(|| XMatrixParseError::MissingParameter("key".to_owned()))?,
sig: sig.ok_or_else(|| XMatrixParseError::MissingParameter("sig".to_owned()))?,
})
}
pub fn request_object<T: AsRef<[u8]>>(
request: &http::Request<T>,
origin: &ServerName,
destination: &ServerName,
) -> Result<CanonicalJsonObject, serde_json::Error> {
let body = request.body().as_ref();
let uri = request.uri().path_and_query().expect("http::Request should have a path");
let mut request_object = CanonicalJsonObject::from([
("destination".to_owned(), destination.as_str().into()),
("method".to_owned(), request.method().as_str().into()),
("origin".to_owned(), origin.as_str().into()),
("uri".to_owned(), uri.as_str().into()),
]);
if !body.is_empty() {
let content = serde_json::from_slice(body)?;
request_object.insert("content".to_owned(), content);
}
Ok(request_object)
}
pub fn try_from_http_request<T: AsRef<[u8]>>(
request: &http::Request<T>,
input: ServerSignaturesInput<'_>,
) -> Result<Self, XMatrixFromRequestError> {
let ServerSignaturesInput { origin, destination, key_pair } = input;
let request_object = Self::request_object(request, &origin, &destination)?;
let serialized_request_object = serde_json::to_vec(&request_object)?;
let (key_id, signature) = key_pair.sign(&serialized_request_object).into_parts();
let key = OwnedServerSigningKeyId::try_from(key_id.as_str())
.map_err(XMatrixFromRequestError::SigningKeyId)?;
let sig = Base64::new(signature);
Ok(Self { origin, destination: Some(destination), key, sig })
}
pub fn verify_request<T: AsRef<[u8]>>(
&self,
request: &http::Request<T>,
destination: &ServerName,
public_key_map: &PublicKeyMap,
) -> Result<(), XMatrixVerificationError> {
if self
.destination
.as_deref()
.is_some_and(|xmatrix_destination| xmatrix_destination != destination)
{
return Err(XMatrixVerificationError::DestinationMismatch);
}
let mut request_object = Self::request_object(request, &self.origin, destination)
.map_err(|error| ruma_signatures::VerificationError::Json(error.into()))?;
let entity_signature =
CanonicalJsonObject::from([(self.key.to_string(), self.sig.encode().into())]);
let signatures =
CanonicalJsonObject::from([(self.origin.to_string(), entity_signature.into())]);
request_object.insert("signatures".to_owned(), signatures.into());
Ok(ruma_signatures::verify_json(public_key_map, &request_object)?)
}
}
impl fmt::Debug for XMatrix {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("XMatrix")
.field("origin", &self.origin)
.field("destination", &self.destination)
.field("key", &self.key)
.finish_non_exhaustive()
}
}
impl fmt::Display for XMatrix {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { origin, destination, key, sig } = self;
let origin = quote_ascii_string_if_required(origin.as_str());
let key = quote_ascii_string_if_required(key.as_str());
let sig = sig.encode();
let sig = quote_ascii_string_if_required(&sig);
write!(f, r#"{} "#, Self::SCHEME)?;
if let Some(destination) = destination {
let destination = quote_ascii_string_if_required(destination.as_str());
write!(f, r#"destination={destination},"#)?;
}
write!(f, "key={key},origin={origin},sig={sig}")
}
}
impl FromStr for XMatrix {
type Err = XMatrixParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::parse(s)
}
}
impl TryFrom<&HeaderValue> for XMatrix {
type Error = XMatrixParseError;
fn try_from(value: &HeaderValue) -> Result<Self, Self::Error> {
Self::parse(value.to_str()?)
}
}
impl From<&XMatrix> for HeaderValue {
fn from(value: &XMatrix) -> Self {
value.to_string().try_into().expect("header format is static")
}
}
impl Credentials for XMatrix {
const SCHEME: &'static str = "X-Matrix";
fn decode(value: &HeaderValue) -> Option<Self> {
value.try_into().ok()
}
fn encode(&self) -> HeaderValue {
self.into()
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum XMatrixFromRequestError {
#[error("failed to construct request object to sign: {0}")]
IntoJson(#[from] serde_json::Error),
#[error("invalid signing key ID: {0}")]
SigningKeyId(IdParseError),
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum XMatrixParseError {
#[error(transparent)]
ToStr(#[from] http::header::ToStrError),
#[error("{0}")]
ParseStr(String),
#[error("X-Matrix credentials not found")]
NotFound,
#[error(transparent)]
ParseId(#[from] IdParseError),
#[error(transparent)]
ParseBase64(#[from] Base64DecodeError),
#[error("missing parameter '{0}'")]
MissingParameter(String),
#[error("duplicate parameter '{0}'")]
DuplicateParameter(String),
}
impl<'a> From<http_auth::parser::Error<'a>> for XMatrixParseError {
fn from(value: http_auth::parser::Error<'a>) -> Self {
Self::ParseStr(value.to_string())
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum XMatrixExtractError {
#[error("no Authorization HTTP header found, but this endpoint requires a server signature")]
MissingAuthorizationHeader,
#[error("failed to parse header value: {0}")]
Parse(#[from] XMatrixParseError),
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum XMatrixVerificationError {
#[error("destination in XMatrix doesn't match the one to verify")]
DestinationMismatch,
#[error("signature verification failed: {0}")]
Signature(#[from] ruma_signatures::VerificationError),
}
#[cfg(test)]
mod tests {
use headers::{HeaderValue, authorization::Credentials};
use ruma_common::{OwnedServerName, serde::Base64};
use super::XMatrix;
#[test]
fn xmatrix_auth_pre_1_3() {
let header = HeaderValue::from_static(
"X-Matrix origin=\"origin.hs.example.com\",key=\"ed25519:key1\",sig=\"dGVzdA==\"",
);
let origin = "origin.hs.example.com".try_into().unwrap();
let key = "ed25519:key1".try_into().unwrap();
let sig = Base64::new(b"test".to_vec());
let credentials = XMatrix::try_from(&header).unwrap();
assert_eq!(credentials.origin, origin);
assert_eq!(credentials.destination, None);
assert_eq!(credentials.key, key);
assert_eq!(credentials.sig, sig);
let credentials = XMatrix { origin, destination: None, key, sig };
assert_eq!(
credentials.encode(),
"X-Matrix key=\"ed25519:key1\",origin=origin.hs.example.com,sig=dGVzdA"
);
}
#[test]
fn xmatrix_auth_1_3() {
let header = HeaderValue::from_static(
"X-Matrix origin=\"origin.hs.example.com\",destination=\"destination.hs.example.com\",key=\"ed25519:key1\",sig=\"dGVzdA==\"",
);
let origin: OwnedServerName = "origin.hs.example.com".try_into().unwrap();
let destination: OwnedServerName = "destination.hs.example.com".try_into().unwrap();
let key = "ed25519:key1".try_into().unwrap();
let sig = Base64::new(b"test".to_vec());
let credentials = XMatrix::try_from(&header).unwrap();
assert_eq!(credentials.origin, origin);
assert_eq!(credentials.destination, Some(destination.clone()));
assert_eq!(credentials.key, key);
assert_eq!(credentials.sig, sig);
let credentials = XMatrix::new(origin, destination, key, sig);
assert_eq!(
credentials.encode(),
"X-Matrix destination=destination.hs.example.com,key=\"ed25519:key1\",origin=origin.hs.example.com,sig=dGVzdA"
);
}
#[test]
fn xmatrix_quoting() {
let header = HeaderValue::from_static(
r#"X-Matrix origin="example.com:1234",key="abc\"def\\:ghi",sig=dGVzdA,"#,
);
let origin: OwnedServerName = "example.com:1234".try_into().unwrap();
let key = r#"abc"def\:ghi"#.try_into().unwrap();
let sig = Base64::new(b"test".to_vec());
let credentials = XMatrix::try_from(&header).unwrap();
assert_eq!(credentials.origin, origin);
assert_eq!(credentials.destination, None);
assert_eq!(credentials.key, key);
assert_eq!(credentials.sig, sig);
let credentials = XMatrix { origin, destination: None, key, sig };
assert_eq!(
credentials.encode(),
r#"X-Matrix key="abc\"def\\:ghi",origin="example.com:1234",sig=dGVzdA"#
);
}
#[test]
fn xmatrix_auth_1_3_with_extra_spaces() {
let header = HeaderValue::from_static(
"X-Matrix origin=\"origin.hs.example.com\" , destination=\"destination.hs.example.com\",key=\"ed25519:key1\", sig=\"dGVzdA\"",
);
let credentials = XMatrix::try_from(&header).unwrap();
let sig = Base64::new(b"test".to_vec());
assert_eq!(credentials.origin, "origin.hs.example.com");
assert_eq!(credentials.destination.unwrap(), "destination.hs.example.com");
assert_eq!(credentials.key, "ed25519:key1");
assert_eq!(credentials.sig, sig);
}
}