1use super::{
2 component_id::HttpMessageComponentId,
3 component_name::{DerivedComponentName, HttpMessageComponentName},
4 component_param::{handle_params_key_into, handle_params_sf, HttpMessageComponentParam},
5 component_value::HttpMessageComponentValue,
6};
7use crate::{
8 error::{HttpSigError, HttpSigResult},
9 trace::*,
10};
11
12#[derive(Debug, Clone)]
14pub struct HttpMessageComponent {
16 pub id: HttpMessageComponentId,
18 pub value: HttpMessageComponentValue,
20}
21
22impl TryFrom<&str> for HttpMessageComponent {
23 type Error = HttpSigError;
24 fn try_from(val: &str) -> Result<Self, Self::Error> {
27 let Some((id, value)) = val.split_once(':') else {
28 return Err(HttpSigError::InvalidComponent(format!(
29 "Invalid http message component: {val}"
30 )));
31 };
32 let id = id.trim();
33
34 if !(id.starts_with('"') && (id.ends_with('"') || id[1..].contains("\";"))) {
36 return Err(HttpSigError::InvalidComponentId(format!(
37 "Invalid http message component id: {id}"
38 )));
39 }
40
41 Ok(Self {
42 id: HttpMessageComponentId::try_from(id)?,
43 value: HttpMessageComponentValue::from(value.trim()),
44 })
45 }
46}
47
48impl TryFrom<(&HttpMessageComponentId, &[String])> for HttpMessageComponent {
49 type Error = HttpSigError;
50
51 fn try_from((id, field_values): (&HttpMessageComponentId, &[String])) -> Result<Self, Self::Error> {
53 match &id.name {
54 HttpMessageComponentName::HttpField(_) => build_http_field_component(id, field_values),
55 HttpMessageComponentName::Derived(_) => build_derived_component(id, field_values),
56 }
57 }
58}
59
60impl std::fmt::Display for HttpMessageComponent {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}: {}", self.id, self.value)
65 }
66}
67
68pub(super) fn build_derived_component(
71 id: &HttpMessageComponentId,
72 field_values: &[String],
73) -> HttpSigResult<HttpMessageComponent> {
74 let HttpMessageComponentName::Derived(derived_id) = &id.name else {
75 return Err(HttpSigError::InvalidComponent(
76 "invalid http message component name as derived component".to_string(),
77 ));
78 };
79 if field_values.is_empty() {
80 return Err(HttpSigError::InvalidComponent(
81 "derived component requires field values".to_string(),
82 ));
83 }
84 if !id
86 .params
87 .0
88 .iter()
89 .all(|p| matches!(p, HttpMessageComponentParam::Req | HttpMessageComponentParam::Name(_)))
90 {
91 return Err(HttpSigError::InvalidComponent(
92 "invalid parameter for derived component".to_string(),
93 ));
94 }
95
96 let value = match derived_id {
97 DerivedComponentName::Method => HttpMessageComponentValue::from(field_values[0].to_ascii_uppercase().as_ref()),
98 DerivedComponentName::TargetUri => HttpMessageComponentValue::from(field_values[0].to_string().as_ref()),
99 DerivedComponentName::Authority => HttpMessageComponentValue::from(field_values[0].to_ascii_lowercase().as_ref()),
100 DerivedComponentName::Scheme => HttpMessageComponentValue::from(field_values[0].to_ascii_lowercase().as_ref()),
101 DerivedComponentName::RequestTarget => HttpMessageComponentValue::from(field_values[0].to_string().as_ref()),
102 DerivedComponentName::Path => HttpMessageComponentValue::from(field_values[0].to_string().as_ref()),
103 DerivedComponentName::Query => HttpMessageComponentValue::from(field_values[0].to_string().as_ref()),
104 DerivedComponentName::Status => HttpMessageComponentValue::from(field_values[0].to_string().as_ref()),
105 DerivedComponentName::QueryParam => {
106 let name = id.params.0.iter().find_map(|p| match p {
107 HttpMessageComponentParam::Name(name) => Some(name),
108 _ => None,
109 });
110 if name.is_none() {
111 return Err(HttpSigError::InvalidComponent(
112 "query-param derived component requires name parameter".to_string(),
113 ));
114 };
115 let name = name.unwrap();
116 let kvs = field_values
117 .iter()
118 .filter(|v| v.contains('='))
119 .map(|v| v.split_once('=').unwrap())
120 .filter(|(k, _)| *k == name.as_str())
121 .map(|(_, v)| v)
122 .collect::<Vec<_>>();
123 HttpMessageComponentValue::from(kvs.join(", ").as_ref())
124 }
125 DerivedComponentName::SignatureParams => {
126 let value = field_values[0].to_string();
127 let opt_pair = value.trim().split_once('=');
128 if opt_pair.is_none() {
129 return Err(HttpSigError::InvalidComponent(
130 "invalid signature-params derived component".to_string(),
131 ));
132 }
133 let (key, value) = opt_pair.unwrap();
134 HttpMessageComponentValue::from((key, value))
135 }
136 };
137 let component = HttpMessageComponent { id: id.clone(), value };
138 Ok(component)
139}
140
141pub(super) fn build_http_field_component(
145 id: &HttpMessageComponentId,
146 field_values: &[String],
147) -> HttpSigResult<HttpMessageComponent> {
148 let mut field_values = field_values.to_vec();
149 let params = &id.params;
150
151 for p in params.0.iter() {
152 match p {
153 HttpMessageComponentParam::Sf => {
154 handle_params_sf(&mut field_values)?;
155 }
156 HttpMessageComponentParam::Key(key) => {
157 field_values = handle_params_key_into(&field_values, key)?;
158 }
159 HttpMessageComponentParam::Bs => {
160 return Err(HttpSigError::NotYetImplemented("`bs` is not supported yet".to_string()));
161 }
162 HttpMessageComponentParam::Req => {
163 debug!("`req` is given for http field component");
164 }
165 HttpMessageComponentParam::Tr => return Err(HttpSigError::NotYetImplemented("`tr` is not supported yet".to_string())),
166 HttpMessageComponentParam::Name(_) => {
167 return Err(HttpSigError::NotYetImplemented(
168 "`name` is only for derived component query-params".to_string(),
169 ));
170 }
171 }
172 }
173
174 let field_values_str = field_values.join(", ");
177
178 let component = HttpMessageComponent {
179 id: id.clone(),
180 value: HttpMessageComponentValue::from(field_values_str.as_ref()),
181 };
182 Ok(component)
183}
184
185#[cfg(test)]
187mod tests {
188 use super::*;
189 type IndexSet<K> = indexmap::IndexSet<K, rustc_hash::FxBuildHasher>;
190
191 #[test]
192 fn test_from_serialized_string_derived() {
193 let tuples = vec![
194 ("\"@method\"", "POST", DerivedComponentName::Method),
195 ("\"@target-uri\"", "https://example.com/", DerivedComponentName::TargetUri),
196 ("\"@authority\"", "example.com", DerivedComponentName::Authority),
197 ("\"@scheme\"", "https", DerivedComponentName::Scheme),
198 ("\"@request-target\"", "/path?query", DerivedComponentName::RequestTarget),
199 ("\"@path\"", "/path", DerivedComponentName::Path),
200 ("\"@query\"", "query", DerivedComponentName::Query),
201 ("\"@query-param\";name=\"key\"", "\"value\"", DerivedComponentName::QueryParam),
202 ("\"@status\"", "200", DerivedComponentName::Status),
203 ];
204 for (id, value, name) in tuples {
205 let comp = HttpMessageComponent::try_from(format!("{}: {}", id, value).as_ref()).unwrap();
206 assert_eq!(comp.id.name, HttpMessageComponentName::Derived(name));
207 if !id.contains(';') {
208 assert_eq!(comp.id.params.0, IndexSet::default());
209 } else {
210 assert!(!comp.id.params.0.is_empty());
211 }
212 assert_eq!(comp.value.as_field_value(), value);
213 assert_eq!(comp.value.key(), None);
214 assert_eq!(comp.to_string(), format!("{}: {}", id, value));
215 }
216 }
217
218 #[test]
219 fn test_from_serialized_string_derived_query_params() {
220 let (id, value, name) = ("\"@query-param\";name=\"key\"", "\"value\"", DerivedComponentName::QueryParam);
221 let comp = HttpMessageComponent::try_from(format!("{}: {}", id, value).as_ref()).unwrap();
222 assert_eq!(comp.id.name, HttpMessageComponentName::Derived(name));
223 assert_eq!(
224 comp.id.params.0.get(&HttpMessageComponentParam::Name("key".to_string())),
225 Some(&HttpMessageComponentParam::Name("key".to_string()))
226 );
227 assert_eq!(comp.value.as_field_value(), value);
228 assert_eq!(comp.value.key(), None);
229 assert_eq!(comp.to_string(), format!("{}: {}", id, value));
230 }
231
232 #[test]
233 fn test_from_serialized_string_http_field() {
234 let tuples = vec![
235 ("\"example-header\"", "example-value", "example-header"),
236 ("\"example-header\";bs;tr", "example-value", "example-header"),
237 ("\"example-header\";bs", "example-value", "example-header"),
238 ("\"x-empty-header\"", "", "x-empty-header"),
239 ];
240 for (id, value, inner_name) in tuples {
241 let comp = HttpMessageComponent::try_from(format!("{}: {}", id, value).as_ref()).unwrap();
242 assert_eq!(comp.id.name, HttpMessageComponentName::HttpField(inner_name.to_string()));
243 if !id.contains(';') {
244 assert_eq!(comp.id.params.0, IndexSet::default());
245 } else {
246 assert!(!comp.id.params.0.is_empty());
247 }
248 assert_eq!(comp.value.as_field_value(), value);
249 assert_eq!(comp.to_string(), format!("{}: {}", id, value));
250 }
251 }
252
253 #[test]
254 fn test_from_serialized_string_http_field_params() {
255 let comp = HttpMessageComponent::try_from("\"example-header\";bs;tr: example-value").unwrap();
256 assert_eq!(
257 comp.id.name,
258 HttpMessageComponentName::HttpField("example-header".to_string())
259 );
260 assert_eq!(
261 comp.id.params.0,
262 vec![HttpMessageComponentParam::Bs, HttpMessageComponentParam::Tr]
263 .into_iter()
264 .collect::<IndexSet<_>>()
265 );
266 }
267
268 #[test]
269 fn test_from_serialized_string_http_field_params_key() {
270 let comp = HttpMessageComponent::try_from("\"example-header\";key=\"hoge\": example-value").unwrap();
271 assert_eq!(
272 comp.id.name,
273 HttpMessageComponentName::HttpField("example-header".to_string())
274 );
275 assert_eq!(
276 comp.id.params.0,
277 vec![HttpMessageComponentParam::Key("hoge".to_string())]
278 .into_iter()
279 .collect::<IndexSet<_>>()
280 );
281 }
282
283 #[test]
284 fn test_field_params_derived_component() {
285 let comp = HttpMessageComponent::try_from("\"@method\";req: POST");
288 assert!(comp.is_ok());
289 let comp = HttpMessageComponent::try_from("\"@method\";bs: POST");
290 assert!(comp.is_err());
291 let comp = HttpMessageComponent::try_from("\"@method\";key=\"hoge\": POST");
292 assert!(comp.is_err());
293 }
294
295 #[test]
296 fn test_build_http_field_component() {
297 let id = HttpMessageComponentId::try_from("content-type").unwrap();
298 let field_values = vec!["application/json".to_owned()];
299 let component = build_http_field_component(&id, &field_values).unwrap();
300 assert_eq!(component.id, id);
301 assert_eq!(component.value, HttpMessageComponentValue::from("application/json"));
302 assert_eq!(component.to_string(), "\"content-type\": application/json");
303 }
304 #[test]
305 fn test_build_http_field_component_multiple_values() {
306 let id = HttpMessageComponentId::try_from("\"content-type\"").unwrap();
307 let field_values = vec!["application/json".to_owned(), "application/json-patch+json".to_owned()];
308 let component = build_http_field_component(&id, &field_values).unwrap();
309 assert_eq!(component.id, id);
310 assert_eq!(
311 component.value,
312 HttpMessageComponentValue::from("application/json, application/json-patch+json")
313 );
314 assert_eq!(
315 component.to_string(),
316 "\"content-type\": application/json, application/json-patch+json"
317 );
318 }
319 #[test]
320 fn test_build_http_field_component_sf() {
321 let id = HttpMessageComponentId::try_from("\"content-type\";sf").unwrap();
322 let field_values = vec![
323 "application/json; patched=true".to_owned(),
324 "application/json-patch+json;patched".to_owned(),
325 ];
326 let component = build_http_field_component(&id, &field_values).unwrap();
327 assert_eq!(component.id, id);
328 assert_eq!(
329 component.value,
330 HttpMessageComponentValue::from("application/json;patched=true, application/json-patch+json;patched")
331 );
332 assert_eq!(
333 component.to_string(),
334 "\"content-type\";sf: application/json;patched=true, application/json-patch+json;patched"
335 );
336 }
337 #[test]
338 fn test_build_http_field_component_key() {
339 let id = HttpMessageComponentId::try_from("\"example-header\";key=\"patched\"").unwrap();
340 let field_values = vec!["patched=12345678".to_owned()];
341 let component = build_http_field_component(&id, &field_values).unwrap();
342 assert_eq!(component.id, id);
343 assert_eq!(component.value, HttpMessageComponentValue::from("12345678"));
344 assert_eq!(component.to_string(), "\"example-header\";key=\"patched\": 12345678");
345 }
346 #[test]
347 fn test_build_http_field_component_key_multiple_values() {
348 let id = HttpMessageComponentId::try_from("\"example-header\";key=\"patched\"").unwrap();
349 let field_values = vec![
350 "patched=12345678".to_owned(),
351 "patched=87654321".to_owned(),
352 "not-patched=12345678".to_owned(),
353 ];
354 let component = build_http_field_component(&id, &field_values).unwrap();
355 assert_eq!(component.id, id);
356 assert_eq!(component.value, HttpMessageComponentValue::from("12345678, 87654321"));
357 assert_eq!(
358 component.to_string(),
359 "\"example-header\";key=\"patched\": 12345678, 87654321"
360 );
361 }
362
363 #[test]
364 fn test_build_derived_component() {
365 let id = HttpMessageComponentId::try_from("@method").unwrap();
366 let field_values = vec!["GET".to_owned()];
367 let component = build_derived_component(&id, &field_values).unwrap();
368 assert_eq!(component.id, id);
369 assert_eq!(component.value, HttpMessageComponentValue::from("GET"));
370 assert_eq!(component.to_string(), "\"@method\": GET");
371
372 let id = HttpMessageComponentId::try_from("@target-uri").unwrap();
373 let field_values = vec!["https://example.com/foo".to_owned()];
374 let component = build_derived_component(&id, &field_values).unwrap();
375 assert_eq!(component.id, id);
376 assert_eq!(component.value, HttpMessageComponentValue::from("https://example.com/foo"));
377 assert_eq!(component.to_string(), "\"@target-uri\": https://example.com/foo");
378 }
379 #[test]
380 fn test_build_http_field_component_query_param() {
381 let id = HttpMessageComponentId::try_from("\"@query-param\";name=\"var\"").unwrap();
382 let query_param = "var=this%20is%20a%20big%0Amultiline%20value&bar=with+plus+whitespace&fa%C3%A7ade%22%3A%20=something&ok";
383 let field_values = query_param.split('&').map(|v| v.to_owned()).collect::<Vec<_>>();
384 let component = build_derived_component(&id, &field_values).unwrap();
385 assert_eq!(component.id, id);
386 assert_eq!(
387 component.value,
388 HttpMessageComponentValue::from("this%20is%20a%20big%0Amultiline%20value")
389 );
390 assert_eq!(
391 component.to_string(),
392 "\"@query-param\";name=\"var\": this%20is%20a%20big%0Amultiline%20value"
393 );
394 }
395
396 #[test]
397 fn test_disallow_invalid_params() {
398 let id = HttpMessageComponentId::try_from("\"@method\";key=\"patched\"");
399 assert!(id.is_err());
400 }
401}