use std::borrow::Cow;
use crate::pipeline::persisted_documents::extract::HttpRequestContext;
use super::super::super::types::PersistedDocumentId;
use super::super::core::{DocumentIdSourceExtractor, ExtractionContext};
pub(crate) struct UrlQueryParamExtractor {
pub(crate) name: String,
}
impl DocumentIdSourceExtractor for UrlQueryParamExtractor {
fn extract(&self, ctx: &ExtractionContext<'_>) -> Option<PersistedDocumentId> {
ctx.query_param(&self.name)
.and_then(|value| PersistedDocumentId::try_from(value.as_ref()).ok())
}
}
impl<'a> HttpRequestContext<'a> {
pub fn query_param(&self, name: &str) -> Option<Cow<'a, str>> {
self.query.as_ref()?.get(name)
}
}
pub(crate) struct QueryParams<'a> {
raw: &'a str,
}
impl<'a> QueryParams<'a> {
pub(crate) fn new(raw: &'a str) -> Self {
Self { raw }
}
pub(crate) fn get(&self, name: &str) -> Option<Cow<'a, str>> {
let value = Self::find_first_value(self.raw, name)?;
Self::decode_if_needed(value)
}
#[inline]
fn find_first_value<'b>(query: &'b str, name: &str) -> Option<&'b str> {
let bytes = query.as_bytes();
let name_bytes = name.as_bytes();
if name_bytes.is_empty() {
return None;
}
for idx in memchr::memchr_iter(name_bytes[0], bytes) {
if !Self::is_pair_boundary(bytes, idx) {
continue;
}
let Some(key_end) = Self::match_key_at(bytes, idx, name_bytes) else {
continue;
};
if key_end == bytes.len() {
return None;
}
let separator = bytes[key_end];
if separator == b'&' {
return None;
}
if separator != b'=' {
continue;
}
let value_start = key_end + 1;
if value_start >= bytes.len() || bytes[value_start] == b'&' {
return None;
}
let suffix = &bytes[value_start..];
let value_end = if let Some(offset) = memchr::memchr(b'&', suffix) {
value_start + offset
} else {
query.len()
};
if value_start < value_end {
return Some(&query[value_start..value_end]);
}
}
None
}
#[inline]
fn is_pair_boundary(bytes: &[u8], idx: usize) -> bool {
idx == 0 || bytes[idx - 1] == b'&'
}
#[inline]
fn match_key_at(bytes: &[u8], idx: usize, name_bytes: &[u8]) -> Option<usize> {
let key_end = idx + name_bytes.len();
if key_end > bytes.len() {
return None;
}
if &bytes[idx..key_end] != name_bytes {
return None;
}
Some(key_end)
}
fn decode_if_needed<'b>(value: &'b str) -> Option<Cow<'b, str>> {
let value_bytes = value.as_bytes();
let percent_at = memchr::memchr(b'%', value_bytes);
let plus_at = memchr::memchr(b'+', value_bytes);
if percent_at.is_none() && plus_at.is_none() {
return Some(Cow::Borrowed(value));
}
let Some(plus_at) = plus_at else {
return percent_encoding::percent_decode(value_bytes)
.decode_utf8()
.ok();
};
let replaced = Self::replace_plus(value_bytes, plus_at);
let decoded = percent_encoding::percent_decode(&replaced)
.decode_utf8()
.ok()?;
Some(Cow::Owned(decoded.into_owned()))
}
fn replace_plus(input: &[u8], first_position: usize) -> Cow<'_, [u8]> {
let mut replaced = input.to_owned();
replaced[first_position] = b' ';
for byte in &mut replaced[first_position + 1..] {
if *byte == b'+' {
*byte = b' ';
}
}
Cow::Owned(replaced)
}
}
#[cfg(test)]
mod tests {
use super::QueryParams;
fn query_param(raw_query: &str, name: &str) -> Option<String> {
QueryParams::new(raw_query)
.get(name)
.map(|value| value.into_owned())
}
#[test]
fn query_params_lookup_rules() {
let cases = [
("key=first&key=second", "key", Some("first")),
("key=&key=second", "key", None),
("key&key=second", "key", None),
("keys=1&key=value", "key", Some("value")),
("xkey=1&key=value", "key", Some("value")),
("foo=bar", "key", None),
("", "key", None),
("key=value", "", None),
];
for (query, name, expected) in cases {
let actual = query_param(query, name);
assert_eq!(
actual.as_deref(),
expected,
"query='{query}', name='{name}'"
);
}
}
#[test]
fn query_params_decoding_rules() {
let cases = [
("key=a+b", Some("a b")),
("key=a%2Bb", Some("a+b")),
("key=sha256%3Aabc", Some("sha256:abc")),
("key=abc%ZZ", Some("abc%ZZ")),
];
for (query, expected) in cases {
let actual = query_param(query, "key");
assert_eq!(actual.as_deref(), expected, "query='{query}'");
}
}
}