flare_rpc_core/interceptor/
ctxinterceprot.rs1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use tonic::{Request, Status};
4use flare_core::context::{AppContext, AppContextBuilder};
5use tonic::metadata::{MetadataValue, MetadataMap, MetadataKey};
6use std::str::FromStr;
7
8const REMOTE_ADDR_KEY: &str = "remote-addr";
9const USER_ID_KEY: &str = "user-id";
10const PLATFORM_KEY: &str = "platform";
11const CLIENT_ID_KEY: &str = "client-id";
12const LANGUAGE_KEY: &str = "language";
13const CONN_ID_KEY: &str = "conn-id";
14const CLIENT_MSG_ID_KEY: &str = "client-msg-id";
15const VALUES_PREFIX: &str = "ctx-val-";
16
17#[cfg(feature = "client")]
18use {
19 tower::{Service, Layer},
20 std::future::Future,
21 std::pin::Pin,
22};
23
24#[cfg(feature = "client")]
25#[derive(Clone)]
26pub struct AppContextConfig {
27 context: Arc<Mutex<Option<AppContext>>>,
28}
29
30#[cfg(feature = "client")]
31impl AppContextConfig {
32 pub fn new() -> Self {
33 Self {
34 context: Arc::new(Mutex::new(None)),
35 }
36 }
37
38 pub fn set_context(&self, context: AppContext) {
39 if let Ok(mut ctx) = self.context.lock() {
40 *ctx = Some(context);
41 }
42 }
43}
44
45#[cfg(feature = "client")]
46impl Default for AppContextConfig {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52#[cfg(feature = "client")]
53#[derive(Clone)]
54pub struct AppContextLayer {
55 config: Arc<AppContextConfig>,
56}
57
58#[cfg(feature = "client")]
59impl AppContextLayer {
60 pub fn new(config: AppContextConfig) -> Self {
61 Self {
62 config: Arc::new(config),
63 }
64 }
65}
66
67#[cfg(feature = "client")]
68impl<S> Layer<S> for AppContextLayer {
69 type Service = AppContextInterceptor<S>;
70
71 fn layer(&self, inner: S) -> Self::Service {
72 AppContextInterceptor {
73 inner,
74 config: self.config.clone(),
75 }
76 }
77}
78
79#[cfg(feature = "client")]
80#[derive(Clone)]
81pub struct AppContextInterceptor<S> {
82 inner: S,
83 config: Arc<AppContextConfig>,
84}
85
86#[cfg(feature = "client")]
87impl<S, B> Service<Request<B>> for AppContextInterceptor<S>
88where
89 S: Service<Request<B>, Response = tonic::Response<B>, Error = Status> + Clone + Send + 'static,
90 S::Future: Send + 'static,
91 B: Send + 'static,
92{
93 type Response = tonic::Response<B>;
94 type Error = Status;
95 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
96
97 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
98 self.inner.poll_ready(cx)
99 }
100
101 fn call(&mut self, mut request: Request<B>) -> Self::Future {
102 if let Ok(guard) = self.config.context.lock() {
103 if let Some(ctx) = guard.as_ref() {
104 build_req_metadata_form_ctx(ctx, &mut request);
105 }
106 }
107
108 let mut inner = self.inner.clone();
109 Box::pin(async move {
110 inner.call(request).await
111 })
112 }
113}
114
115#[cfg(feature = "client")]
116pub fn build_req_metadata_form_ctx<B>(ctx: &AppContext, request: &mut Request<B>) {
117 let metadata = request.metadata_mut();
118
119 if let Ok(val) = MetadataValue::from_str(&ctx.remote_addr()) {
120 metadata.insert(REMOTE_ADDR_KEY, val);
121 }
122
123 if let Some(user_id) = ctx.user_id() {
124 if let Ok(val) = MetadataValue::from_str(&user_id) {
125 metadata.insert(USER_ID_KEY, val);
126 }
127 }
128
129 if let Some(platform) = ctx.platform() {
130 if let Ok(val) = MetadataValue::from_str(&platform.to_string()) {
131 metadata.insert(PLATFORM_KEY, val);
132 }
133 }
134
135 if let Some(client_id) = ctx.client_id() {
136 if let Ok(val) = MetadataValue::from_str(&client_id) {
137 metadata.insert(CLIENT_ID_KEY, val);
138 }
139 }
140
141 if let Some(language) = ctx.language() {
142 if let Ok(val) = MetadataValue::from_str(&language) {
143 metadata.insert(LANGUAGE_KEY, val);
144 }
145 }
146
147 let conn_id = ctx.conn_id();
148 if let Ok(val) = MetadataValue::from_str(&conn_id) {
149 metadata.insert(CONN_ID_KEY, val);
150 }
151
152 let client_msg_id = ctx.client_msg_id();
153 if let Ok(val) = MetadataValue::from_str(&client_msg_id) {
154 metadata.insert(CLIENT_MSG_ID_KEY, val);
155 }
156
157 if let Ok(values) = ctx.values().lock() {
158 for (key, value) in values.iter() {
159 let metadata_key = format!("{}{}", VALUES_PREFIX, key);
160 if let (Ok(key), Ok(val)) = (MetadataKey::from_bytes(metadata_key.as_bytes()), MetadataValue::try_from(value.as_str())) {
161 metadata.insert(key, val);
162 }
163 }
164 }
165}
166
167#[cfg(feature = "server")]
168pub fn build_context_from_metadata(metadata: &MetadataMap) -> Result<AppContext, Status> {
169 let mut builder = AppContextBuilder::new();
170
171 if let Some(addr) = metadata.get(REMOTE_ADDR_KEY) {
172 builder = builder.remote_addr(addr.to_str()
173 .map_err(|_| Status::internal("Invalid remote_addr format"))?
174 .to_string());
175 } else {
176 builder = builder.remote_addr("127.0.0.1".to_string());
177 }
178
179 if let Some(user_id) = metadata.get(USER_ID_KEY) {
180 builder = builder.user_id(user_id.to_str()
181 .map_err(|_| Status::internal("Invalid user_id format"))?
182 .to_string());
183 }
184
185 if let Some(platform) = metadata.get(PLATFORM_KEY) {
186 let platform_str = platform.to_str()
187 .map_err(|_| Status::internal("Invalid platform format"))?;
188 let platform_val = platform_str.parse::<i32>()
189 .map_err(|_| Status::internal("Invalid platform value"))?;
190 builder = builder.platform(platform_val);
191 }
192
193 if let Some(client_id) = metadata.get(CLIENT_ID_KEY) {
194 builder = builder.client_id(client_id.to_str()
195 .map_err(|_| Status::internal("Invalid client_id format"))?
196 .to_string());
197 }
198
199 if let Some(language) = metadata.get(LANGUAGE_KEY) {
200 builder = builder.with_language(Some(language.to_str()
201 .map_err(|_| Status::internal("Invalid language format"))?
202 .to_string()));
203 }
204
205 if let Some(conn_id) = metadata.get(CONN_ID_KEY) {
206 builder = builder.with_conn_id(conn_id.to_str()
207 .map_err(|_| Status::internal("Invalid conn_id format"))?
208 .to_string());
209 }
210
211 if let Some(client_msg_id) = metadata.get(CLIENT_MSG_ID_KEY) {
212 builder = builder.with_client_msg_id(client_msg_id.to_str()
213 .map_err(|_| Status::internal("Invalid client_msg_id format"))?
214 .to_string());
215 }
216
217 let values = Arc::new(Mutex::new(HashMap::new()));
218 {
219 let mut values_map = values.lock()
220 .map_err(|_| Status::internal("Failed to lock values"))?;
221
222 for item in metadata.iter() {
223 if let tonic::metadata::KeyAndValueRef::Ascii(k, v) = item {
224 if k.as_str().starts_with(VALUES_PREFIX) {
225 let actual_key = k.as_str().trim_start_matches(VALUES_PREFIX);
226 let value = v.to_str()
227 .map_err(|_| Status::internal("Invalid value format"))?;
228 values_map.insert(actual_key.to_string(), value.to_string());
229 }
230 }
231 }
232 }
233 builder = builder.values(values);
234
235 builder.build()
236 .map_err(|e| Status::internal(format!("Failed to build AppContext: {}", e)))
237}