Skip to main content

aiway_protocol/context/
request_ext.rs

1use crate::common::constants::MODEL_API_PREFIX;
2use crate::common::header::Headers;
3use http::request::Parts;
4use http::{HeaderName, HeaderValue, Method};
5use std::collections::HashMap;
6use std::str::FromStr;
7
8pub trait RequestExt {
9    fn get_request_id(&self) -> String;
10    fn get_request_header(&self, key: &str) -> Option<String>;
11    fn set_request_header(&mut self, key: &str, value: &str);
12
13    fn all_request_headers(&self) -> HashMap<String, String>;
14    fn get_path(&self) -> String;
15    fn set_path(&mut self, path: &str);
16    fn get_method(&self) -> &Method;
17    fn get_host(&self) -> String;
18    fn route_match_key(&self) -> String;
19
20    fn query(&self) -> Option<String>;
21
22    fn is_model_request(&self) -> bool;
23}
24
25impl RequestExt for Parts {
26    fn get_request_id(&self) -> String {
27        self.get_request_header(Headers::REQUEST_ID)
28            .expect("request_id not set")
29    }
30
31    fn get_request_header(&self, key: &str) -> Option<String> {
32        self.headers
33            .get(key)
34            .map(|s| s.to_str().unwrap().to_string())
35    }
36
37    fn set_request_header(&mut self, key: &str, value: &str) {
38        if let (Ok(name), Ok(value)) = (HeaderName::from_str(key), HeaderValue::from_str(value)) {
39            self.headers.insert(name, value);
40        }
41    }
42
43    fn all_request_headers(&self) -> HashMap<String, String> {
44        self.headers
45            .iter()
46            .map(|(k, v)| (k.to_string(), v.to_str().unwrap().to_string()))
47            .collect()
48    }
49
50    fn get_path(&self) -> String {
51        self.uri.path().to_string()
52    }
53
54    fn set_path(&mut self, path: &str) {
55        let old_uri = self.uri.clone();
56        let mut parts = old_uri.into_parts();
57
58        // 构建新的 path_and_query
59        let new_path = if let Some(pq) = parts.path_and_query {
60            match pq.query() {
61                Some(query) => format!("{}?{}", path, query),
62                None => path.to_string(),
63            }
64        } else {
65            path.to_string()
66        };
67
68        parts.path_and_query = Some(new_path.parse().unwrap());
69
70        if let Ok(new_uri) = http::Uri::from_parts(parts) {
71            self.uri = new_uri;
72        }
73    }
74
75    fn get_method(&self) -> &Method {
76        &self.method
77    }
78
79    fn get_host(&self) -> String {
80        if self.version == http::Version::HTTP_2 {
81            self.get_request_header(":authority").unwrap()
82        } else {
83            self.get_request_header("host").unwrap()
84        }
85    }
86
87    fn route_match_key(&self) -> String {
88        format!("{}{}", self.get_host(), self.get_path())
89    }
90
91    fn query(&self) -> Option<String> {
92        self.uri.query().map(|s| s.to_string())
93    }
94
95    fn is_model_request(&self) -> bool {
96        self.get_path().starts_with(MODEL_API_PREFIX)
97    }
98}