Skip to main content

aiway_protocol/context/
http_context.rs

1use crate::context::Route;
2use crate::context::parts::SerdeParts;
3#[cfg(feature = "model")]
4use crate::model::Provider;
5use dashmap::DashMap;
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8use serde_json::Value;
9use std::any::Any;
10use std::fmt::Debug;
11use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13
14/// 网关HTTP上下文
15#[derive(Debug)]
16pub struct HttpContext {
17    /// 路由信息
18    routing: Routing,
19    /// 运行时数据,必须是可序列化的。
20    state: State,
21    /// 运行时数据,可以是任何数据,在请求生命周期内有效。
22    any_state: AnyState,
23}
24
25#[derive(Debug, Default)]
26pub struct Routing {
27    /// 匹配到的路由配置信息
28    route: Option<Arc<Route>>,
29    /// 路由目标地址,可以是域名或IP(包含协议头),由负载均衡器设置。
30    target: Option<String>,
31}
32
33#[derive(Debug)]
34pub struct State(DashMap<String, Value>);
35
36impl Default for State {
37    fn default() -> Self {
38        let data = DashMap::from_iter([
39            (
40                HttpContext::REQUEST_ID.to_string(),
41                uuid::Uuid::new_v4().to_string().into(),
42            ),
43            (
44                HttpContext::REQUEST_TS.to_string(),
45                (SystemTime::now()
46                    .duration_since(UNIX_EPOCH)
47                    .unwrap()
48                    .as_millis() as i64)
49                    .into(),
50            ),
51        ]);
52        State(data)
53    }
54}
55
56impl State {
57    pub fn insert<T: Serialize>(&self, key: &str, value: T) -> Option<Value> {
58        self.0
59            .insert(key.to_string(), serde_json::to_value(value).unwrap())
60    }
61
62    pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
63        self.0
64            .get(key)
65            .and_then(|v| serde_json::from_value::<T>(v.clone()).ok())
66    }
67
68    pub fn remove(&self, key: &str) -> Option<Value> {
69        self.0.remove(key).map(|(_, v)| v)
70    }
71
72    pub fn exists(&self, key: &str) -> bool {
73        self.0.contains_key(key)
74    }
75}
76
77#[derive(Debug, Default)]
78pub struct AnyState {
79    data: DashMap<String, Arc<dyn Any + Send + Sync>>,
80}
81
82impl AnyState {
83    pub fn insert<T: Any + Send + Sync>(
84        &self,
85        key: &str,
86        value: T,
87    ) -> Option<Arc<dyn Any + Send + Sync>> {
88        self.data.insert(key.to_string(), Arc::new(value))
89    }
90
91    pub fn get<T: Any + Send + Sync>(&self, key: &str) -> Option<Arc<T>> {
92        self.data
93            .get(key)
94            .and_then(|v| v.clone().downcast::<T>().ok())
95    }
96
97    pub fn remove<T: Any + Send + Sync>(&self, key: &str) -> Option<Arc<T>> {
98        self.data
99            .remove(key)
100            .and_then(|(_, v)| v.downcast::<T>().ok())
101    }
102
103    pub fn exists<T: Any + Send + Sync>(&self, key: &str) -> bool {
104        self.data.contains_key(key)
105    }
106}
107
108impl Default for HttpContext {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl HttpContext {
115    /// 请求ID
116    pub const REQUEST_ID: &'static str = ":req:id";
117    /// 请求时间戳,毫秒
118    pub const REQUEST_TS: &'static str = ":req:ts";
119    /// 请求的Parts(原始值,未经过网关处理),格式为[http::request::Parts]。
120    /// 在请求生命周期内始终有效,不可变。
121    pub const REQUEST_RAW_PARTS: &'static str = ":req:raw:parts";
122    /// 最终响应给客户端的Parts,格式为[http::response::Parts]
123    pub const RESPONSE_SERDE_PARTS: &'static str = ":resp:parts";
124    /// 响应的Body大小
125    pub const RESPONSE_BODY_SIZE: &'static str = ":resp:parts:body_size";
126    /// 是否sse
127    pub const IS_SSE: &'static str = ":resp:is_sse";
128    /// 是否websocket
129    pub const IS_WEBSOCKET: &'static str = ":resp:is_ws";
130    /// 模型名称,仅适用于模型插件
131    pub const MODEL_PROXY_MODEL: &'static str = ":model_proxy:model";
132    /// 模型提供商,仅适用于模型插件
133    pub const MODEL_PROXY_PROVIDER: &'static str = ":model_proxy:provider";
134    pub fn new() -> Self {
135        Self {
136            routing: Default::default(),
137            state: Default::default(),
138            any_state: Default::default(),
139        }
140    }
141
142    #[inline]
143    pub fn set_route(&mut self, route: Arc<Route>) {
144        self.routing.route = Some(route);
145    }
146
147    #[inline]
148    pub fn get_route(&self) -> Option<Arc<Route>> {
149        self.routing.route.clone()
150    }
151
152    #[inline]
153    pub fn set_routing_url(&mut self, url: String) {
154        self.routing.target = Some(url);
155    }
156
157    #[inline]
158    pub fn get_routing_url(&self) -> Option<&String> {
159        self.routing.target.as_ref()
160    }
161
162    #[inline]
163    pub fn insert_state<T: Serialize>(&self, key: &str, value: T) {
164        self.state.insert(key, value);
165    }
166
167    #[inline]
168    pub fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
169        self.state.get(key)
170    }
171
172    #[inline]
173    pub fn remove_state(&self, key: &str) {
174        self.state.remove(key);
175    }
176
177    #[inline]
178    pub fn insert_any_state<T: Any + Send + Sync>(&self, key: &str, value: T) {
179        self.any_state.insert(key, value);
180    }
181
182    #[inline]
183    pub fn get_any_state<T: Any + Send + Sync>(&self, key: &str) -> Option<Arc<T>> {
184        self.any_state.get(key)
185    }
186
187    #[inline]
188    pub fn remove_any_state<T: Any + Send + Sync>(&self, key: &str) {
189        self.any_state.remove::<T>(key);
190    }
191
192    #[inline]
193    pub fn exists_any_state<T: Any + Send + Sync>(&self, key: &str) -> bool {
194        self.any_state.exists::<T>(key)
195    }
196
197    /// 获取请求ID
198    pub fn request_id(&self) -> String {
199        //SAFE
200        self.state.get(Self::REQUEST_ID).unwrap()
201    }
202
203    /// 获取请求时间戳(网关收到请求的时间)
204    pub fn request_ts(&self) -> i64 {
205        //SAFE
206        self.state.get(Self::REQUEST_TS).unwrap()
207    }
208
209    pub fn request_raw_parts(&self) -> Option<SerdeParts> {
210        self.state.get(Self::REQUEST_RAW_PARTS)
211    }
212
213    /// 获取请求的模型名称
214    ///
215    /// 使用限制:仅在模型插件时可用
216    #[cfg(feature = "model")]
217    #[inline]
218    pub fn get_proxy_model_name(&self) -> Option<String> {
219        self.state.get(Self::MODEL_PROXY_MODEL)
220    }
221
222    /// 获取命中的模型代理商
223    ///
224    /// 使用限制:仅在模型插件时可用
225    #[cfg(feature = "model")]
226    #[inline]
227    pub fn get_proxy_model_provider(&self) -> Option<Provider> {
228        self.state.get(Self::MODEL_PROXY_PROVIDER)
229    }
230
231    #[inline]
232    pub fn is_sse(&self) -> bool {
233        self.state.exists(Self::IS_SSE)
234    }
235
236    #[inline]
237    pub fn is_websocket(&self) -> bool {
238        self.state.exists(Self::IS_WEBSOCKET)
239    }
240}