Skip to main content

aiway_protocol/context/
http_context.rs

1use crate::context::request_context::RequestContext;
2use crate::context::response_context::ResponseContext;
3use dashmap::DashMap;
4use serde::Serialize;
5use serde::de::DeserializeOwned;
6use serde_json::Value;
7
8/// HTTP上下文
9#[derive(Debug, Default)]
10pub struct HttpContext {
11    /// 请求上下文,在请求阶段构建
12    pub request: RequestContext,
13    /// 响应上下文,在构建请求上下文时同步构建,在响应阶段更新
14    pub response: ResponseContext,
15    /// 内部扩展数据
16    pub inner_state: InnerState,
17    /// 自定义的扩展数据
18    pub state: DashMap<String, Value>,
19}
20
21impl HttpContext {
22    pub fn insert_state<T: Serialize>(&self, key: &str, value: T) {
23        self.state.insert(
24            key.to_string(),
25            serde_json::to_value(value).expect("Failed to serialize state value"),
26        );
27    }
28
29    pub fn get_state<T: DeserializeOwned>(
30        &self,
31        key: &str,
32    ) -> Result<Option<T>, serde_json::Error> {
33        self.state
34            .get(key)
35            .map(|v| serde_json::from_value(v.clone()))
36            .transpose()
37    }
38
39    pub fn remove_state(&self, key: &str) {
40        self.state.remove(key);
41    }
42}
43
44#[derive(Debug, Default)]
45pub struct InnerState(DashMap<String, Value>);
46
47impl InnerState {
48    #[cfg(feature = "model")]
49    const MODEL_PROVIDER: &'static str = "model_proxy:provider";
50    #[cfg(feature = "model")]
51    pub fn get_model_provider(&self) -> Option<crate::model::Provider> {
52        self.0.get(Self::MODEL_PROVIDER).and_then(|v| {
53            serde_json::from_value(v.value().clone())
54                .expect("Failed to deserialize model provider value")
55        })
56    }
57    #[cfg(feature = "model")]
58    pub fn set_model_provider(&self, provider: crate::model::Provider) {
59        self.0.insert(
60            Self::MODEL_PROVIDER.to_string(),
61            serde_json::to_value(provider).expect("Failed to serialize model provider value"),
62        );
63    }
64}