1use axum::extract::FromRequestParts;
7use axum::http::request::Parts;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11use apcore::context::{Context, Identity};
12use apcore::trace_context::{TraceContext, TraceParent};
13
14use crate::errors::AxumApcoreError;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct RequestIdentity {
32 pub id: String,
33 #[serde(default = "default_identity_type")]
34 pub identity_type: String,
35 #[serde(default)]
36 pub roles: Vec<String>,
37 #[serde(default)]
38 pub attrs: HashMap<String, serde_json::Value>,
39}
40
41fn default_identity_type() -> String {
42 "user".to_string()
43}
44
45impl From<RequestIdentity> for Identity {
46 fn from(ri: RequestIdentity) -> Self {
47 Identity {
48 id: ri.id,
49 identity_type: ri.identity_type,
50 roles: ri.roles,
51 attrs: ri.attrs,
52 }
53 }
54}
55
56pub struct ApContext(pub Context<serde_json::Value>);
73
74impl<S> FromRequestParts<S> for ApContext
75where
76 S: Send + Sync,
77{
78 type Rejection = AxumApcoreError;
79
80 fn from_request_parts(
81 parts: &mut Parts,
82 _state: &S,
83 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
84 let factory = AxumContextFactory;
85 let result = factory.create_from_parts(parts);
86 async move { result.map(ApContext) }
87 }
88}
89
90pub struct AxumContextFactory;
92
93impl AxumContextFactory {
94 pub fn create_from_parts(
96 &self,
97 parts: &Parts,
98 ) -> Result<Context<serde_json::Value>, AxumApcoreError> {
99 let identity = self.extract_identity(parts);
100 let trace_context = self.extract_trace_context(parts);
101
102 let mut ctx = Context::new(identity);
103 ctx.trace_context = trace_context;
104
105 Ok(ctx)
106 }
107
108 fn extract_identity(&self, parts: &Parts) -> Identity {
110 if let Some(ri) = parts.extensions.get::<RequestIdentity>() {
111 ri.clone().into()
112 } else {
113 Identity {
114 id: "anonymous".into(),
115 identity_type: "anonymous".into(),
116 roles: vec![],
117 attrs: HashMap::new(),
118 }
119 }
120 }
121
122 fn extract_trace_context(&self, parts: &Parts) -> Option<TraceContext> {
124 let header = parts.headers.get("traceparent")?.to_str().ok()?;
125 let traceparent = TraceParent::parse(header).ok()?;
126 Some(TraceContext::new(traceparent))
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use axum::http::Request;
134
135 #[test]
136 fn test_request_identity_into_identity() {
137 let ri = RequestIdentity {
138 id: "user-1".into(),
139 identity_type: "user".into(),
140 roles: vec!["admin".into()],
141 attrs: HashMap::new(),
142 };
143 let identity: Identity = ri.into();
144 assert_eq!(identity.id, "user-1");
145 assert_eq!(identity.identity_type, "user");
146 assert_eq!(identity.roles, vec!["admin"]);
147 }
148
149 #[test]
150 fn test_extract_identity_anonymous_fallback() {
151 let req = Request::builder().body(()).unwrap();
152 let (parts, _) = req.into_parts();
153 let factory = AxumContextFactory;
154 let identity = factory.extract_identity(&parts);
155 assert_eq!(identity.id, "anonymous");
156 assert_eq!(identity.identity_type, "anonymous");
157 }
158
159 #[test]
160 fn test_extract_identity_from_extensions() {
161 let mut req = Request::builder().body(()).unwrap();
162 req.extensions_mut().insert(RequestIdentity {
163 id: "user-42".into(),
164 identity_type: "service".into(),
165 roles: vec!["reader".into()],
166 attrs: HashMap::new(),
167 });
168 let (parts, _) = req.into_parts();
169 let factory = AxumContextFactory;
170 let identity = factory.extract_identity(&parts);
171 assert_eq!(identity.id, "user-42");
172 assert_eq!(identity.identity_type, "service");
173 }
174
175 #[test]
176 fn test_create_from_parts() {
177 let req = Request::builder().body(()).unwrap();
178 let (parts, _) = req.into_parts();
179 let factory = AxumContextFactory;
180 let ctx = factory.create_from_parts(&parts).unwrap();
181 assert_eq!(ctx.identity.id, "anonymous");
182 assert!(!ctx.trace_id.is_empty());
183 }
184
185 #[test]
186 fn test_extract_trace_context_none() {
187 let req = Request::builder().body(()).unwrap();
188 let (parts, _) = req.into_parts();
189 let factory = AxumContextFactory;
190 assert!(factory.extract_trace_context(&parts).is_none());
191 }
192}