flare_rpc_core/interceptor/
ctxinterceprot.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use tonic::{Request, Status};
4use flare_core::context::{AppContext, AppContextBuilder};
5use tonic::metadata::{MetadataValue, MetadataMap, MetadataKey};
6use std::str::FromStr;
7
8const REMOTE_ADDR_KEY: &str = "remote-addr";
9const USER_ID_KEY: &str = "user-id";
10const PLATFORM_KEY: &str = "platform";
11const CLIENT_ID_KEY: &str = "client-id";
12const LANGUAGE_KEY: &str = "language";
13const CONN_ID_KEY: &str = "conn-id";
14const CLIENT_MSG_ID_KEY: &str = "client-msg-id";
15const VALUES_PREFIX: &str = "ctx-val-";
16
17#[cfg(feature = "client")]
18use {
19    tower::{Service, Layer},
20    std::future::Future,
21    std::pin::Pin,
22};
23
24#[cfg(feature = "client")]
25#[derive(Clone)]
26pub struct AppContextConfig {
27    context: Arc<Mutex<Option<AppContext>>>,
28}
29
30#[cfg(feature = "client")]
31impl AppContextConfig {
32    pub fn new() -> Self {
33        Self {
34            context: Arc::new(Mutex::new(None)),
35        }
36    }
37
38    pub fn set_context(&self, context: AppContext) {
39        if let Ok(mut ctx) = self.context.lock() {
40            *ctx = Some(context);
41        }
42    }
43}
44
45#[cfg(feature = "client")]
46impl Default for AppContextConfig {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52#[cfg(feature = "client")]
53#[derive(Clone)]
54pub struct AppContextLayer {
55    config: Arc<AppContextConfig>,
56}
57
58#[cfg(feature = "client")]
59impl AppContextLayer {
60    pub fn new(config: AppContextConfig) -> Self {
61        Self {
62            config: Arc::new(config),
63        }
64    }
65}
66
67#[cfg(feature = "client")]
68impl<S> Layer<S> for AppContextLayer {
69    type Service = AppContextInterceptor<S>;
70
71    fn layer(&self, inner: S) -> Self::Service {
72        AppContextInterceptor {
73            inner,
74            config: self.config.clone(),
75        }
76    }
77}
78
79#[cfg(feature = "client")]
80#[derive(Clone)]
81pub struct AppContextInterceptor<S> {
82    inner: S,
83    config: Arc<AppContextConfig>,
84}
85
86#[cfg(feature = "client")]
87impl<S, B> Service<Request<B>> for AppContextInterceptor<S>
88where
89    S: Service<Request<B>, Response = tonic::Response<B>, Error = Status> + Clone + Send + 'static,
90    S::Future: Send + 'static,
91    B: Send + 'static,
92{
93    type Response = tonic::Response<B>;
94    type Error = Status;
95    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
96
97    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
98        self.inner.poll_ready(cx)
99    }
100
101    fn call(&mut self, mut request: Request<B>) -> Self::Future {
102        if let Ok(guard) = self.config.context.lock() {
103            if let Some(ctx) = guard.as_ref() {
104                build_req_metadata_form_ctx(ctx, &mut request);
105            }
106        }
107
108        let mut inner = self.inner.clone();
109        Box::pin(async move {
110            inner.call(request).await
111        })
112    }
113}
114
115#[cfg(feature = "client")]
116pub fn build_req_metadata_form_ctx<B>(ctx: &AppContext, request: &mut Request<B>) {
117    let metadata = request.metadata_mut();
118    
119    if let Ok(val) = MetadataValue::from_str(&ctx.remote_addr()) {
120        metadata.insert(REMOTE_ADDR_KEY, val);
121    }
122
123    if let Some(user_id) = ctx.user_id() {
124        if let Ok(val) = MetadataValue::from_str(&user_id) {
125            metadata.insert(USER_ID_KEY, val);
126        }
127    }
128
129    if let Some(platform) = ctx.platform() {
130        if let Ok(val) = MetadataValue::from_str(&platform.to_string()) {
131            metadata.insert(PLATFORM_KEY, val);
132        }
133    }
134
135    if let Some(client_id) = ctx.client_id() {
136        if let Ok(val) = MetadataValue::from_str(&client_id) {
137            metadata.insert(CLIENT_ID_KEY, val);
138        }
139    }
140
141    if let Some(language) = ctx.language() {
142        if let Ok(val) = MetadataValue::from_str(&language) {
143            metadata.insert(LANGUAGE_KEY, val);
144        }
145    }
146
147    let conn_id = ctx.conn_id();
148    if let Ok(val) = MetadataValue::from_str(&conn_id) {
149        metadata.insert(CONN_ID_KEY, val);
150    }
151
152    let client_msg_id = ctx.client_msg_id();
153    if let Ok(val) = MetadataValue::from_str(&client_msg_id) {
154        metadata.insert(CLIENT_MSG_ID_KEY, val);
155    }
156
157    if let Ok(values) = ctx.values().lock() {
158        for (key, value) in values.iter() {
159            let metadata_key = format!("{}{}", VALUES_PREFIX, key);
160            if let (Ok(key), Ok(val)) = (MetadataKey::from_bytes(metadata_key.as_bytes()), MetadataValue::try_from(value.as_str())) {
161                metadata.insert(key, val);
162            }
163        }
164    }
165}
166
167#[cfg(feature = "server")]
168pub fn build_context_from_metadata(metadata: &MetadataMap) -> Result<AppContext, Status> {
169    let mut builder = AppContextBuilder::new();
170
171    if let Some(addr) = metadata.get(REMOTE_ADDR_KEY) {
172        builder = builder.remote_addr(addr.to_str()
173            .map_err(|_| Status::internal("Invalid remote_addr format"))?
174            .to_string());
175    } else {
176        builder = builder.remote_addr("127.0.0.1".to_string());
177    }
178
179    if let Some(user_id) = metadata.get(USER_ID_KEY) {
180        builder = builder.user_id(user_id.to_str()
181            .map_err(|_| Status::internal("Invalid user_id format"))?
182            .to_string());
183    }
184
185    if let Some(platform) = metadata.get(PLATFORM_KEY) {
186        let platform_str = platform.to_str()
187            .map_err(|_| Status::internal("Invalid platform format"))?;
188        let platform_val = platform_str.parse::<i32>()
189            .map_err(|_| Status::internal("Invalid platform value"))?;
190        builder = builder.platform(platform_val);
191    }
192
193    if let Some(client_id) = metadata.get(CLIENT_ID_KEY) {
194        builder = builder.client_id(client_id.to_str()
195            .map_err(|_| Status::internal("Invalid client_id format"))?
196            .to_string());
197    }
198
199    if let Some(language) = metadata.get(LANGUAGE_KEY) {
200        builder = builder.with_language(Some(language.to_str()
201            .map_err(|_| Status::internal("Invalid language format"))?
202            .to_string()));
203    }
204
205    if let Some(conn_id) = metadata.get(CONN_ID_KEY) {
206        builder = builder.with_conn_id(conn_id.to_str()
207            .map_err(|_| Status::internal("Invalid conn_id format"))?
208            .to_string());
209    }
210
211    if let Some(client_msg_id) = metadata.get(CLIENT_MSG_ID_KEY) {
212        builder = builder.with_client_msg_id(client_msg_id.to_str()
213            .map_err(|_| Status::internal("Invalid client_msg_id format"))?
214            .to_string());
215    }
216
217    let values = Arc::new(Mutex::new(HashMap::new()));
218    {
219        let mut values_map = values.lock()
220            .map_err(|_| Status::internal("Failed to lock values"))?;
221        
222        for item in metadata.iter() {
223            if let tonic::metadata::KeyAndValueRef::Ascii(k, v) = item {
224                if k.as_str().starts_with(VALUES_PREFIX) {
225                    let actual_key = k.as_str().trim_start_matches(VALUES_PREFIX);
226                    let value = v.to_str()
227                        .map_err(|_| Status::internal("Invalid value format"))?;
228                    values_map.insert(actual_key.to_string(), value.to_string());
229                }
230            }
231        }
232    }
233    builder = builder.values(values);
234
235    builder.build()
236        .map_err(|e| Status::internal(format!("Failed to build AppContext: {}", e)))
237}