use std::{collections::HashSet, str::FromStr};
use urlpattern::UrlPatternOptions;
pub use urlpattern::{UrlPatternInit, UrlPatternMatchInput, UrlPatternResult};
use {std::env, thiserror::Error, urlpattern::UrlPattern};
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum IncludeExclude<T>
where
T: std::hash::Hash + Eq,
{
Include(HashSet<T>),
Exclude(HashSet<T>),
}
impl<T> IncludeExclude<T>
where
T: std::hash::Hash + Eq,
{
#[must_use]
#[allow(dead_code)]
pub fn include_all() -> Self {
Self::Exclude(HashSet::new())
}
#[must_use]
pub fn include_none() -> Self {
Self::Include(HashSet::new())
}
}
pub trait HeaderAttributesFilter {
fn get_header_attributes(&self, headers: &http::HeaderMap) -> Vec<(String, String)>;
}
impl HeaderAttributesFilter for IncludeExclude<String> {
fn get_header_attributes(&self, headers: &http::HeaderMap) -> Vec<(String, String)> {
match self {
Self::Include(set) => {
let mut header_attributes = Vec::new();
for header_name in set {
if let Some(header_value) = headers
.get(header_name)
.and_then(|header_value| header_value.to_str().ok())
{
header_attributes.push((header_name.clone(), header_value.to_string()));
}
}
header_attributes
}
Self::Exclude(set) => {
let mut header_attributes = Vec::new();
for (header_name, header_value) in headers {
if !set.contains(header_name.as_str()) {
if let Ok(header_value) = header_value.to_str() {
header_attributes
.push((header_name.to_string(), header_value.to_string()));
}
}
}
header_attributes
}
}
}
}
pub static QCS_API_TRACING_ENABLED: &str = "QCS_API_TRACING_ENABLED";
pub static QCS_API_PROPAGATE_OTEL_CONTEXT: &str = "QCS_API_PROPAGATE_OTEL_CONTEXT";
pub static QCS_API_TRACING_FILTER: &str = "QCS_API_TRACING_FILTER";
pub static QCS_API_NEGATE_TRACING_FILTER: &str = "QCS_API_NEGATE_TRACING_FILTER";
#[derive(Error, Debug)]
pub enum TracingFilterError {
#[error("invalid url `{url}`: {error}")]
InvalidUrl {
url: String,
error: url::ParseError,
},
#[error("trace filtering only supports https, http, and tcp urls, found: `{0}`")]
UnsupportedUrlScheme(String),
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone)]
pub struct TracingConfigurationBuilder {
filter: Option<TracingFilter>,
propagate_otel_context: bool,
request_headers: IncludeExclude<String>,
response_headers: IncludeExclude<String>,
}
impl Default for TracingConfigurationBuilder {
fn default() -> Self {
Self {
filter: None,
propagate_otel_context: false,
request_headers: IncludeExclude::include_none(),
response_headers: IncludeExclude::include_none(),
}
}
}
impl From<TracingConfiguration> for TracingConfigurationBuilder {
fn from(tracing_configuration: TracingConfiguration) -> Self {
Self {
filter: tracing_configuration.filter,
propagate_otel_context: tracing_configuration.propagate_otel_context,
request_headers: tracing_configuration.request_headers,
response_headers: tracing_configuration.response_headers,
}
}
}
impl TracingConfigurationBuilder {
#![allow(clippy::missing_const_for_fn)]
#[must_use]
pub fn set_filter(mut self, filter: Option<TracingFilter>) -> Self {
self.filter = filter;
self
}
#[must_use]
pub fn set_propagate_otel_context(mut self, propagate_otel_context: bool) -> Self {
self.propagate_otel_context = propagate_otel_context;
self
}
#[must_use]
pub fn set_request_headers(mut self, request_headers: IncludeExclude<String>) -> Self {
self.request_headers = request_headers;
self
}
#[must_use]
pub fn set_response_headers(mut self, response_headers: IncludeExclude<String>) -> Self {
self.response_headers = response_headers;
self
}
#[must_use]
pub fn build(self) -> TracingConfiguration {
TracingConfiguration {
filter: self.filter,
propagate_otel_context: self.propagate_otel_context,
request_headers: self.request_headers,
response_headers: self.response_headers,
}
}
}
#[derive(Debug, Clone)]
pub struct TracingConfiguration {
filter: Option<TracingFilter>,
propagate_otel_context: bool,
request_headers: IncludeExclude<String>,
response_headers: IncludeExclude<String>,
}
impl Default for TracingConfiguration {
fn default() -> Self {
Self {
filter: None,
propagate_otel_context: false,
request_headers: IncludeExclude::Include(
[KEY_X_REQUEST_ID, KEY_X_REQUEST_RETRY_INDEX]
.iter()
.map(ToString::to_string)
.collect(),
),
response_headers: IncludeExclude::Include(
std::iter::once(KEY_X_REQUEST_ID.to_string()).collect(),
),
}
}
}
const KEY_X_REQUEST_ID: &str = "x-request-id";
const KEY_X_REQUEST_RETRY_INDEX: &str = "x-request-retry-index";
impl TracingConfiguration {
#![allow(clippy::missing_const_for_fn)]
#[must_use]
pub fn builder() -> TracingConfigurationBuilder {
TracingConfigurationBuilder::default()
}
pub fn from_env() -> Result<Option<Self>, TracingFilterError> {
if !is_env_var_true(QCS_API_TRACING_ENABLED) {
return Ok(None);
}
let filter = TracingFilter::from_env()?;
let propagate_otel_context = is_env_var_true(QCS_API_PROPAGATE_OTEL_CONTEXT);
Ok(Some(Self {
filter,
propagate_otel_context,
..Self::default()
}))
}
#[must_use]
pub fn filter(&self) -> Option<&TracingFilter> {
self.filter.as_ref()
}
#[must_use]
pub fn propagate_otel_context(&self) -> bool {
self.propagate_otel_context
}
#[must_use]
pub fn request_headers(&self) -> &IncludeExclude<String> {
&self.request_headers
}
#[must_use]
pub fn response_headers(&self) -> &IncludeExclude<String> {
&self.response_headers
}
#[must_use]
pub fn is_enabled(&self, url: &UrlPatternMatchInput) -> bool {
self.filter
.as_ref()
.is_none_or(|filter| filter.is_enabled(url))
}
}
impl From<TracingFilter> for TracingFilterBuilder {
fn from(tracing_filter: TracingFilter) -> Self {
Self {
paths: tracing_filter.paths,
is_negated: tracing_filter.is_negated,
}
}
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Default, Clone)]
pub struct TracingFilterBuilder {
is_negated: bool,
paths: Vec<UrlPatternInit>,
}
impl TracingFilterBuilder {
#![allow(clippy::missing_const_for_fn)]
#[must_use]
pub fn set_is_negated(mut self, is_negated: bool) -> Self {
self.is_negated = is_negated;
self
}
#[must_use]
pub fn set_paths(mut self, paths: Vec<UrlPatternInit>) -> Self {
self.paths = paths;
self
}
pub fn parse_strs_and_set_paths(self, paths: &[&str]) -> Result<Self, TracingFilterError> {
Ok(self.set_paths(
paths
.iter()
.map(|s| parse_constructor_string(s))
.collect::<Result<Vec<UrlPatternInit>, TracingFilterError>>()?,
))
}
#[must_use]
pub fn build(self) -> TracingFilter {
TracingFilter {
is_negated: self.is_negated,
paths: self.paths,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TracingFilter {
is_negated: bool,
paths: Vec<UrlPatternInit>,
}
impl TracingFilter {
#[must_use]
#[allow(clippy::missing_const_for_fn)]
pub fn builder() -> TracingFilterBuilder {
TracingFilterBuilder::default()
}
pub fn from_env() -> Result<Option<Self>, TracingFilterError> {
if let Ok(filter) = env::var(QCS_API_TRACING_FILTER) {
let is_negated = env::var(QCS_API_NEGATE_TRACING_FILTER)
.is_ok_and(|_| is_env_var_true(QCS_API_NEGATE_TRACING_FILTER));
return Ok(Self::builder()
.parse_strs_and_set_paths(&filter.split(',').collect::<Vec<_>>())?
.set_is_negated(is_negated)
.build()
.into());
}
Ok(None)
}
fn first_match(&self, input: &UrlPatternMatchInput) -> Option<UrlPatternResult> {
self.paths.iter().find_map(|init| {
<UrlPattern>::parse(init.clone(), UrlPatternOptions { ignore_case: false })
.and_then(|pattern| pattern.exec(input.clone()))
.map_err(|e| {
#[cfg(feature = "tracing")]
tracing::error!("error matching url pattern: {}", e);
})
.ok()
.flatten()
})
}
#[must_use]
pub fn is_enabled(&self, input: &UrlPatternMatchInput) -> bool {
let first_match = self.first_match(input);
if self.is_negated {
first_match.is_none()
} else {
first_match.is_some()
}
}
}
impl FromStr for TracingFilter {
type Err = TracingFilterError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let paths: Vec<UrlPatternInit> = s
.split(',')
.map(parse_constructor_string)
.collect::<Result<Vec<UrlPatternInit>, TracingFilterError>>()?;
Ok(Self {
is_negated: false,
paths,
})
}
}
fn parse_constructor_string(filter: &str) -> Result<UrlPatternInit, TracingFilterError> {
url::Url::options()
.parse(filter)
.map_err(|error| TracingFilterError::InvalidUrl {
url: filter.to_string(),
error,
})
.and_then(validate_url_scheme)
.and_then(|fully_specified_url| {
url_origin_to_url(&fully_specified_url)
.map(|base_url| url_to_url_pattern_init(&fully_specified_url, Some(base_url)))
})
.or_else(|original_error| {
let baseless_url_pattern_init = url::Url::options()
.base_url(Some(
&url::Url::parse("https://api.qcs.rigetti.com")
.expect("base url bootstrap value should always parse"),
))
.parse(filter)
.map(|url_with_bootstrapped_base_url| {
url_to_url_pattern_init(&url_with_bootstrapped_base_url, None)
});
baseless_url_pattern_init.map_err(|_| original_error)
})
}
fn url_origin_to_url(value: &url::Url) -> Result<url::Url, TracingFilterError> {
value
.origin()
.unicode_serialization()
.parse()
.map_err(|error| TracingFilterError::InvalidUrl {
url: value.to_string(),
error,
})
}
fn url_to_url_pattern_init(value: &url::Url, base_url: Option<url::Url>) -> UrlPatternInit {
UrlPatternInit {
protocol: base_url.as_ref().map(|v| v.scheme().to_string()),
username: base_url
.as_ref()
.map(|v| v.username().to_string())
.filter(|s| !s.is_empty()),
password: base_url
.as_ref()
.and_then(|v| v.password().map(String::from)),
hostname: base_url
.as_ref()
.and_then(|v| v.host_str().map(String::from)),
port: base_url
.as_ref()
.and_then(|v| v.port().map(|p| p.to_string())),
pathname: Some(value.path().to_string()).filter(|s| !s.is_empty()),
search: value.query().map(String::from),
hash: value.fragment().map(String::from),
base_url,
}
}
fn validate_url_scheme(value: url::Url) -> Result<url::Url, TracingFilterError> {
if let "http" | "https" | "tcp" = value.scheme() {
Ok(value)
} else {
Err(TracingFilterError::UnsupportedUrlScheme(
value.scheme().to_string(),
))
}
}
fn is_env_var_true(var: &str) -> bool {
matches!(env::var(var), Ok(e) if matches!(e.to_lowercase().as_str(), "true" | "t" | "1" | "yes" | "y" | "on"))
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
#[case(
"https://api.qcs.rigetti.com/v1/users/:id",
"https://api.qcs.rigetti.com/v1/users/10",
true
)]
#[case(
"https://api.qcs.rigetti.com/v1/users/:id",
"https://api.dev.qcs.rigetti.com/v1/users/10",
false
)]
#[case("/v1/users/:id", "https://api.qcs.rigetti.com/v1/users/10", true)]
#[case("/v1/users/:id", "https://api.qcs.rigetti.com/v1/groups/10", false)]
#[case("tcp://localhost:5555", "tcp://localhost:5555", true)]
#[case("tcp://localhost:5555/my_rpc", "tcp://localhost:5555/my_rpc", true)]
#[case("tcp://localhost:5555/my_rpc", "tcp://localhost:5555/other_rpc", false)]
#[case("/my_rpc", "tcp://localhost:5555/my_rpc", true)]
fn test_tracing_filter(#[case] filter: &str, #[case] url: &str, #[case] matches: bool) {
let input = UrlPatternMatchInput::Url(url::Url::parse(url).unwrap());
let mut tracing_filter = TracingFilter::from_str(filter).unwrap();
assert_eq!(tracing_filter.is_enabled(&input), matches);
tracing_filter = TracingFilterBuilder::from(tracing_filter)
.set_is_negated(true)
.build();
assert_eq!(tracing_filter.is_enabled(&input), !matches);
}
}