aiway_protocol/context/
request_ext.rs1use crate::common::constants::MODEL_API_PREFIX;
2use crate::common::header::Headers;
3use http::request::Parts;
4use http::{HeaderName, HeaderValue, Method};
5use std::collections::HashMap;
6use std::str::FromStr;
7
8pub trait RequestExt {
9 fn get_request_id(&self) -> String;
10 fn get_request_header(&self, key: &str) -> Option<String>;
11 fn set_request_header(&mut self, key: &str, value: &str);
12
13 fn all_request_headers(&self) -> HashMap<String, String>;
14 fn get_path(&self) -> String;
15 fn set_path(&mut self, path: &str);
16 fn get_method(&self) -> &Method;
17 fn get_host(&self) -> String;
18 fn route_match_key(&self) -> String;
19
20 fn query(&self) -> Option<String>;
21
22 fn is_model_request(&self) -> bool;
23}
24
25impl RequestExt for Parts {
26 fn get_request_id(&self) -> String {
27 self.get_request_header(Headers::REQUEST_ID)
28 .expect("request_id not set")
29 }
30
31 fn get_request_header(&self, key: &str) -> Option<String> {
32 self.headers
33 .get(key)
34 .map(|s| s.to_str().unwrap().to_string())
35 }
36
37 fn set_request_header(&mut self, key: &str, value: &str) {
38 if let (Ok(name), Ok(value)) = (HeaderName::from_str(key), HeaderValue::from_str(value)) {
39 self.headers.insert(name, value);
40 }
41 }
42
43 fn all_request_headers(&self) -> HashMap<String, String> {
44 self.headers
45 .iter()
46 .map(|(k, v)| (k.to_string(), v.to_str().unwrap().to_string()))
47 .collect()
48 }
49
50 fn get_path(&self) -> String {
51 self.uri.path().to_string()
52 }
53
54 fn set_path(&mut self, path: &str) {
55 let old_uri = self.uri.clone();
56 let mut parts = old_uri.into_parts();
57
58 let new_path = if let Some(pq) = parts.path_and_query {
60 match pq.query() {
61 Some(query) => format!("{}?{}", path, query),
62 None => path.to_string(),
63 }
64 } else {
65 path.to_string()
66 };
67
68 parts.path_and_query = Some(new_path.parse().unwrap());
69
70 if let Ok(new_uri) = http::Uri::from_parts(parts) {
71 self.uri = new_uri;
72 }
73 }
74
75 fn get_method(&self) -> &Method {
76 &self.method
77 }
78
79 fn get_host(&self) -> String {
80 if self.version == http::Version::HTTP_2 {
81 self.get_request_header(":authority").unwrap()
82 } else {
83 self.get_request_header("host").unwrap()
84 }
85 }
86
87 fn route_match_key(&self) -> String {
88 format!("{}{}", self.get_host(), self.get_path())
89 }
90
91 fn query(&self) -> Option<String> {
92 self.uri.query().map(|s| s.to_string())
93 }
94
95 fn is_model_request(&self) -> bool {
96 self.get_path().starts_with(MODEL_API_PREFIX)
97 }
98}