use std::str::FromStr;
use std::task::{Context as StdContext, Poll};
use futures::FutureExt;
use futures::future::BoxFuture;
use http::header::{ACCEPT, ToStrError};
use mediatype::{Name, ReadParams};
use miden_node_utils::{ErrorReport, FlattenResult};
use miden_protocol::{Word, WordError};
use semver::{Comparator, Version, VersionReq};
use tower::{Layer, Service};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum GenesisNegotiation {
Optional,
Mandatory,
}
#[derive(Clone)]
pub struct AcceptHeaderLayer {
supported_versions: VersionReq,
genesis_commitment: Word,
require_genesis_methods: Vec<&'static str>,
}
#[derive(Debug, thiserror::Error)]
enum AcceptHeaderError {
#[error("header value could not be parsed as a UTF8 string")]
InvalidUtf8(#[source] ToStrError),
#[error("accept header's media type could not be parsed")]
InvalidMediaType(#[source] mediatype::MediaTypeError),
#[error("a Q value was invalid")]
InvalidQValue(#[source] QParsingError),
#[error("version value failed to parse")]
InvalidVersion(#[source] semver::Error),
#[error("genesis value failed to parse")]
InvalidGenesis(#[source] WordError),
#[error("server does not support any of the specified application/vnd.miden content types")]
NoSupportedMediaRange,
}
impl AcceptHeaderLayer {
pub fn new(rpc_version: &Version, genesis_commitment: Word) -> Self {
let supported_versions = VersionReq {
comparators: vec![Comparator {
op: semver::Op::Exact,
major: rpc_version.major,
minor: rpc_version.minor.into(),
patch: None,
pre: semver::Prerelease::default(),
}],
};
AcceptHeaderLayer {
supported_versions,
genesis_commitment,
require_genesis_methods: Vec::new(),
}
}
pub fn with_genesis_enforced_method(mut self, method: &'static str) -> Self {
self.require_genesis_methods.push(method);
self
}
}
impl<S> Layer<S> for AcceptHeaderLayer {
type Service = AcceptHeaderService<S>;
fn layer(&self, inner: S) -> Self::Service {
AcceptHeaderService { inner, verifier: self.clone() }
}
}
impl AcceptHeaderLayer {
const VERSION: Name<'static> = Name::new_unchecked("version");
const GENESIS: Name<'static> = Name::new_unchecked("genesis");
const GRPC: Name<'static> = Name::new_unchecked("grpc");
fn negotiate(
&self,
accept: &str,
genesis_mode: GenesisNegotiation,
) -> Result<(), AcceptHeaderError> {
let mut media_types = mediatype::MediaTypeList::new(accept).peekable();
if media_types.peek().is_none() {
if matches!(genesis_mode, GenesisNegotiation::Mandatory) {
return Err(AcceptHeaderError::NoSupportedMediaRange);
}
return Ok(());
}
for media_type in media_types {
let media_type = media_type.map_err(AcceptHeaderError::InvalidMediaType)?;
match (media_type.ty.as_str(), media_type.subty.as_str()) {
("*", "*") | ("*" | "application", "vnd.miden") => {},
_ => continue,
}
if let Some(suffix) = media_type.suffix
&& suffix != Self::GRPC
{
continue;
}
let quality = media_type
.get_param(mediatype::names::Q)
.map(|value| QValue::from_str(value.unquoted_str().as_ref()))
.transpose()
.map_err(AcceptHeaderError::InvalidQValue)?
.unwrap_or_default();
if quality.is_zero() {
continue;
}
let version = media_type
.get_param(Self::VERSION)
.map(|value| Version::parse(value.unquoted_str().as_ref()))
.transpose()
.map_err(AcceptHeaderError::InvalidVersion)?;
if let Some(version) = version
&& !self.supported_versions.matches(&version)
{
continue;
}
let genesis = media_type
.get_param(Self::GENESIS)
.map(|value| Word::try_from(value.unquoted_str().as_ref()))
.transpose()
.map_err(AcceptHeaderError::InvalidGenesis)?;
match (genesis_mode, genesis) {
(_, Some(value)) if value != self.genesis_commitment => continue,
(GenesisNegotiation::Mandatory, None) => continue,
_ => {},
}
return Ok(());
}
Err(AcceptHeaderError::NoSupportedMediaRange)
}
}
#[derive(Clone)]
pub struct AcceptHeaderService<S> {
inner: S,
verifier: AcceptHeaderLayer,
}
impl<S, B> Service<http::Request<B>> for AcceptHeaderService<S>
where
S: Service<http::Request<B>, Response = http::Response<B>> + Clone + Send + 'static,
S::Error: Send + 'static,
S::Future: Send + 'static,
B: Default + Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut StdContext<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: http::Request<B>) -> Self::Future {
if request.method() == http::Method::OPTIONS {
return self.inner.call(request).boxed();
}
let path = request.uri().path();
let method_name = path.rsplit('/').next().unwrap_or_default();
let requires_genesis = self.verifier.require_genesis_methods.contains(&method_name);
let Some(header) = request.headers().get(ACCEPT) else {
if requires_genesis {
let response = tonic::Status::invalid_argument(
"Accept header with 'genesis' parameter is required for write RPC methods",
)
.into_http();
return futures::future::ready(Ok(response)).boxed();
}
return self.inner.call(request).boxed();
};
let result = header
.to_str()
.map_err(AcceptHeaderError::InvalidUtf8)
.map(|header| {
let mode = if requires_genesis {
GenesisNegotiation::Mandatory
} else {
GenesisNegotiation::Optional
};
self.verifier.negotiate(header, mode)
})
.flatten_result();
match result {
Ok(()) => self.inner.call(request).boxed(),
Err(err) => {
let response = tonic::Status::invalid_argument(err.as_report()).into_http();
futures::future::ready(Ok(response)).boxed()
},
}
}
}
#[derive(Debug, PartialEq, thiserror::Error)]
enum QParsingError {
#[error("Q value contained too many decimal digits")]
TooManyDigits,
#[error("invalid format")]
BadFormat,
#[error("invalid decimal digits")]
InvalidDecimalDigits,
}
#[derive(Debug, PartialEq)]
struct QValue {
kilo: u16,
}
impl Default for QValue {
fn default() -> Self {
Self { kilo: 1000 }
}
}
impl QValue {
#[cfg(test)]
const fn new(kilo: u16) -> Self {
Self { kilo }
}
fn is_zero(&self) -> bool {
self.kilo == 0
}
}
impl FromStr for QValue {
type Err = QParsingError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let kilo = match s.as_bytes() {
[b'1'] => 1000,
[b'1', b'.', rest @ ..] if rest.iter().all(|&c| c == b'0') => 1000,
[b'0'] => 0,
[b'0', b'.', rest @ ..] => {
let digits = match rest {
[] => [b'0', b'0', b'0'],
[a] => [*a, b'0', b'0'],
[a, b] => [*a, *b, b'0'],
[a, b, c] => [*a, *b, *c],
_ => return Err(QParsingError::TooManyDigits),
};
let digits = str::from_utf8(&digits).unwrap();
u16::from_str(digits).map_err(|_| QParsingError::InvalidDecimalDigits)?
},
_ => return Err(Self::Err::BadFormat),
};
Ok(Self { kilo })
}
}
#[cfg(test)]
mod tests {
use miden_protocol::Word;
use semver::Version;
use super::{AcceptHeaderLayer, QParsingError};
use crate::server::accept::QValue;
const TEST_GENESIS_COMMITMENT: &str =
"0x00000000000000000000000000000000000000000000000000000000deadbeef";
const TEST_RPC_VERSION: Version = Version::new(0, 2, 3);
impl AcceptHeaderLayer {
fn for_tests() -> Self {
Self::new(&TEST_RPC_VERSION, Word::try_from(TEST_GENESIS_COMMITMENT).unwrap())
}
}
#[rstest::rstest]
#[case::empty("")]
#[case::wildcard("*/*")]
#[case::media_type_only("application/vnd.miden")]
#[case::with_grpc_suffix("application/vnd.miden+grpc")]
#[case::with_quality("application/vnd.miden; q=0.3")]
#[case::version_exact("application/vnd.miden; version=0.2.3")]
#[case::version_patch_bump("application/vnd.miden; version=0.2.4")]
#[case::version_patch_down("application/vnd.miden; version=0.2.2")]
#[case::matching_network(
"application/vnd.miden; genesis=0x00000000000000000000000000000000000000000000000000000000deadbeef"
)]
#[case::matching_network_and_version(
"application/vnd.miden; genesis=0x00000000000000000000000000000000000000000000000000000000deadbeef; version=0.2.3"
)]
#[case::parameter_order_swopped(
"application/vnd.miden; version=0.2.3; genesis=0x00000000000000000000000000000000000000000000000000000000deadbeef;"
)]
#[case::trailing_semi_comma("application/vnd.miden; ")]
#[case::trailing_comma("application/vnd.miden, ")]
#[case::multiple_types("application/vnd.miden; version=2.0.0, application/vnd.miden")]
#[case::quoted_quality(r#"application/vnd.miden; q="1""#)]
#[case::quoted_version(r#"application/vnd.miden; version="0.2.3""#)]
#[case::quoted_network(r#"application/vnd.miden; genesis="0x00000000000000000000000000000000000000000000000000000000deadbeef""#)]
#[test]
fn request_should_pass(#[case] accept: &'static str) {
AcceptHeaderLayer::for_tests()
.negotiate(accept, super::GenesisNegotiation::Optional)
.unwrap();
}
#[rstest::rstest]
#[case::obsolete_format("application/vnd.miden+grpc.0.2.3")]
#[case::with_non_grpc_suffix("application/vnd.miden+not")]
#[case::invalid_version("application/vnd.miden; version=0x123")]
#[case::invalid_genesis("application/vnd.miden; genesis=aaa")]
#[case::version_too_old("application/vnd.miden; version=0.1.0")]
#[case::version_too_new("application/vnd.miden; version=0.3.0")]
#[case::zero_weighting("application/vnd.miden; q=0.0")]
#[case::wildcard_subtype("application/*")]
#[test]
fn request_should_be_rejected(#[case] accept: &'static str) {
AcceptHeaderLayer::for_tests()
.negotiate(accept, super::GenesisNegotiation::Optional)
.unwrap_err();
}
#[test]
fn write_requires_genesis_param_missing_or_empty_or_mismatch() {
let layer = AcceptHeaderLayer::for_tests();
assert!(
layer
.negotiate("application/vnd.miden", super::GenesisNegotiation::Mandatory)
.is_err()
);
assert!(layer.negotiate("", super::GenesisNegotiation::Mandatory).is_err());
let mismatched = "application/vnd.miden; genesis=0x00000000000000000000000000000000000000000000000000000000deadbeee";
assert!(layer.negotiate(mismatched, super::GenesisNegotiation::Mandatory).is_err());
}
#[rstest::rstest]
#[case::matching_network(
"application/vnd.miden; genesis=0x00000000000000000000000000000000000000000000000000000000deadbeef"
)]
#[case::matching_network_and_version(
"application/vnd.miden; genesis=0x00000000000000000000000000000000000000000000000000000000deadbeef; version=0.2.3"
)]
#[test]
fn request_with_mandadory_genesis_should_pass(#[case] accept: &'static str) {
AcceptHeaderLayer::for_tests()
.negotiate(accept, super::GenesisNegotiation::Mandatory)
.unwrap();
}
#[rstest::rstest]
#[case::missing_network("application/vnd.miden;")]
#[case::missing_network_wildcard("*/*")]
#[test]
fn request_with_mandadory_genesis_should_be_rejected(#[case] accept: &'static str) {
AcceptHeaderLayer::for_tests()
.negotiate(accept, super::GenesisNegotiation::Mandatory)
.unwrap_err();
}
#[rstest::rstest]
#[case::one("1", Ok(QValue::new(1_000)))]
#[case::one_period("1.", Ok(QValue::new(1_000)))]
#[case::one_full("1.000", Ok(QValue::new(1_000)))]
#[case::zero("0", Ok(QValue::new(0)))]
#[case::zero_period("0.", Ok(QValue::new(0)))]
#[case::zeros("0.000", Ok(QValue::new(0)))]
#[case::first_decimal("0.1", Ok(QValue::new(100)))]
#[case::second_decimal("0.01", Ok(QValue::new(10)))]
#[case::third_decimal("0.001", Ok(QValue::new(1)))]
#[case::digits_123("0.123", Ok(QValue::new(123)))]
#[case::digits_456("0.456", Ok(QValue::new(456)))]
#[case::digits_789("0.789", Ok(QValue::new(789)))]
#[case::too_many_digits("0.1234", Err(QParsingError::TooManyDigits))]
#[case::invalid_digit("0.a", Err(QParsingError::InvalidDecimalDigits))]
#[case::extra_period("0..0", Err(QParsingError::InvalidDecimalDigits))]
#[case::leading_period(".0", Err(QParsingError::BadFormat))]
#[case::missing_period("0123", Err(QParsingError::BadFormat))]
#[case::barely_too_large("1.001", Err(QParsingError::BadFormat))]
#[case::too_large_by_far("2.0", Err(QParsingError::BadFormat))]
#[test]
fn qvalue_parsing(#[case] s: &'static str, #[case] expected: Result<QValue, QParsingError>) {
use std::str::FromStr;
assert_eq!(QValue::from_str(s), expected);
}
#[test]
fn qvalue_default_is_one() {
assert_eq!(QValue::default(), QValue::new(1_000));
}
}