Skip to main content

aiway_protocol/context/
http_context.rs

1use crate::context::Route;
2use crate::context::parts::SerdeParts;
3#[cfg(feature = "model")]
4use crate::model::Provider;
5use bincode;
6use dashmap::DashMap;
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9use std::any::Any;
10use std::fmt::Debug;
11use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13
14/// 网关HTTP上下文
15#[derive(Debug)]
16pub struct HttpContext {
17    /// 路由信息
18    routing: Routing,
19    /// 运行时数据,必须是可序列化的。
20    state: State,
21    /// 运行时数据,可以是任何数据,在请求生命周期内有效。
22    any_state: AnyState,
23}
24
25#[derive(Debug, Default)]
26pub struct Routing {
27    /// 匹配到的路由配置信息
28    route: Option<Arc<Route>>,
29    /// 路由目标地址,可以是域名或IP(包含协议头),由负载均衡器设置。
30    target: Option<String>,
31}
32
33#[derive(Debug)]
34pub struct State(DashMap<String, Vec<u8>>);
35
36impl Default for State {
37    fn default() -> Self {
38        let data = DashMap::new();
39        
40        // 插入请求 ID
41        let request_id = uuid::Uuid::new_v4().to_string();
42        let encoded_id = bincode::serialize(&request_id).unwrap();
43        data.insert(HttpContext::REQUEST_ID.to_string(), encoded_id);
44        
45        // 插入请求时间戳
46        let request_ts = SystemTime::now()
47            .duration_since(UNIX_EPOCH)
48            .unwrap()
49            .as_millis() as i64;
50        let encoded_ts = bincode::serialize(&request_ts).unwrap();
51        data.insert(HttpContext::REQUEST_TS.to_string(), encoded_ts);
52        
53        State(data)
54    }
55}
56
57impl State {
58    pub fn insert<T: Serialize>(&self, key: &str, value: T) -> Option<Vec<u8>> {
59        let encoded = bincode::serialize(&value).unwrap();
60        self.0.insert(key.to_string(), encoded)
61    }
62
63    pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
64        self.0
65            .get(key)
66            .and_then(|v| bincode::deserialize::<T>(v.value()).ok())
67    }
68
69    pub fn remove(&self, key: &str) -> Option<Vec<u8>> {
70        self.0.remove(key).map(|(_, v)| v)
71    }
72
73    pub fn exists(&self, key: &str) -> bool {
74        self.0.contains_key(key)
75    }
76}
77
78#[derive(Debug, Default)]
79pub struct AnyState {
80    data: DashMap<String, Arc<dyn Any + Send + Sync>>,
81}
82
83impl AnyState {
84    pub fn insert<T: Any + Send + Sync>(
85        &self,
86        key: &str,
87        value: T,
88    ) -> Option<Arc<dyn Any + Send + Sync>> {
89        self.data.insert(key.to_string(), Arc::new(value))
90    }
91
92    pub fn get<T: Any + Send + Sync>(&self, key: &str) -> Option<Arc<T>> {
93        self.data
94            .get(key)
95            .and_then(|v| v.clone().downcast::<T>().ok())
96    }
97
98    pub fn remove<T: Any + Send + Sync>(&self, key: &str) -> Option<Arc<T>> {
99        self.data
100            .remove(key)
101            .and_then(|(_, v)| v.downcast::<T>().ok())
102    }
103
104    pub fn exists<T: Any + Send + Sync>(&self, key: &str) -> bool {
105        self.data.contains_key(key)
106    }
107}
108
109impl Default for HttpContext {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115impl HttpContext {
116    /// 请求ID
117    pub const REQUEST_ID: &'static str = ":req:id";
118    /// 请求时间戳,毫秒
119    pub const REQUEST_TS: &'static str = ":req:ts";
120    /// 请求的Parts(原始值,未经过网关处理),格式为[http::request::Parts]。
121    /// 在请求生命周期内始终有效,不可变。
122    pub const REQUEST_RAW_PARTS: &'static str = ":req:raw:parts";
123    /// 最终响应给客户端的Parts,格式为[http::response::Parts]
124    pub const RESPONSE_SERDE_PARTS: &'static str = ":resp:parts";
125    /// 响应的Body大小
126    pub const RESPONSE_BODY_SIZE: &'static str = ":resp:parts:body_size";
127    /// 是否sse
128    pub const IS_SSE: &'static str = ":resp:is_sse";
129    /// 是否websocket
130    pub const IS_WEBSOCKET: &'static str = ":resp:is_ws";
131    /// 模型名称,仅适用于模型插件
132    pub const MODEL_PROXY_MODEL: &'static str = ":model_proxy:model";
133    /// 模型提供商,仅适用于模型插件
134    pub const MODEL_PROXY_PROVIDER: &'static str = ":model_proxy:provider";
135    pub fn new() -> Self {
136        Self {
137            routing: Default::default(),
138            state: Default::default(),
139            any_state: Default::default(),
140        }
141    }
142
143    #[inline]
144    pub fn set_route(&mut self, route: Arc<Route>) {
145        self.routing.route = Some(route);
146    }
147
148    #[inline]
149    pub fn get_route(&self) -> Option<Arc<Route>> {
150        self.routing.route.clone()
151    }
152
153    #[inline]
154    pub fn set_routing_url(&mut self, url: String) {
155        self.routing.target = Some(url);
156    }
157
158    #[inline]
159    pub fn get_routing_url(&self) -> Option<&String> {
160        self.routing.target.as_ref()
161    }
162
163    #[inline]
164    pub fn insert_state<T: Serialize>(&self, key: &str, value: T) {
165        self.state.insert(key, value);
166    }
167
168    #[inline]
169    pub fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
170        self.state.get(key)
171    }
172
173    #[inline]
174    pub fn remove_state(&self, key: &str) {
175        self.state.remove(key);
176    }
177
178    #[inline]
179    pub fn insert_any_state<T: Any + Send + Sync>(&self, key: &str, value: T) {
180        self.any_state.insert(key, value);
181    }
182
183    #[inline]
184    pub fn get_any_state<T: Any + Send + Sync>(&self, key: &str) -> Option<Arc<T>> {
185        self.any_state.get(key)
186    }
187
188    #[inline]
189    pub fn remove_any_state<T: Any + Send + Sync>(&self, key: &str) {
190        self.any_state.remove::<T>(key);
191    }
192
193    #[inline]
194    pub fn exists_any_state<T: Any + Send + Sync>(&self, key: &str) -> bool {
195        self.any_state.exists::<T>(key)
196    }
197
198    /// 获取请求ID
199    pub fn request_id(&self) -> String {
200        //SAFE
201        self.state.get(Self::REQUEST_ID).unwrap()
202    }
203
204    /// 获取请求时间戳(网关收到请求的时间)
205    pub fn request_ts(&self) -> i64 {
206        //SAFE
207        self.state.get(Self::REQUEST_TS).unwrap()
208    }
209
210    pub fn request_raw_parts(&self) -> Option<SerdeParts> {
211        self.state.get(Self::REQUEST_RAW_PARTS)
212    }
213
214    /// 获取请求的模型名称
215    ///
216    /// 使用限制:仅在模型插件时可用
217    #[cfg(feature = "model")]
218    #[inline]
219    pub fn get_proxy_model_name(&self) -> Option<String> {
220        self.state.get(Self::MODEL_PROXY_MODEL)
221    }
222
223    /// 获取命中的模型代理商
224    ///
225    /// 使用限制:仅在模型插件时可用
226    #[cfg(feature = "model")]
227    #[inline]
228    pub fn get_proxy_model_provider(&self) -> Option<Provider> {
229        self.state.get(Self::MODEL_PROXY_PROVIDER)
230    }
231
232    #[inline]
233    pub fn is_sse(&self) -> bool {
234        self.state.exists(Self::IS_SSE)
235    }
236
237    #[inline]
238    pub fn is_websocket(&self) -> bool {
239        self.state.exists(Self::IS_WEBSOCKET)
240    }
241}