Skip to main content

axum_apcore/
context.rs

1// Axum context factory — extract apcore Context from Axum requests.
2//
3// Provides an Axum extractor (`ApContext`) and a factory (`AxumContextFactory`)
4// that maps Axum request state to apcore Identity and Context.
5
6use axum::extract::FromRequestParts;
7use axum::http::request::Parts;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11use apcore::context::{Context, Identity};
12use apcore::trace_context::{TraceContext, TraceParent};
13
14use crate::errors::AxumApcoreError;
15
16/// Identity information stored in Axum request extensions.
17///
18/// Middleware (e.g., JWT auth) should insert this into request extensions
19/// before handlers run:
20///
21/// ```ignore
22/// use axum_apcore::RequestIdentity;
23/// req.extensions_mut().insert(RequestIdentity {
24///     id: "user-123".into(),
25///     identity_type: "user".into(),
26///     roles: vec!["admin".into()],
27///     attrs: Default::default(),
28/// });
29/// ```
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct RequestIdentity {
32    pub id: String,
33    #[serde(default = "default_identity_type")]
34    pub identity_type: String,
35    #[serde(default)]
36    pub roles: Vec<String>,
37    #[serde(default)]
38    pub attrs: HashMap<String, serde_json::Value>,
39}
40
41fn default_identity_type() -> String {
42    "user".to_string()
43}
44
45impl From<RequestIdentity> for Identity {
46    fn from(ri: RequestIdentity) -> Self {
47        Identity {
48            id: ri.id,
49            identity_type: ri.identity_type,
50            roles: ri.roles,
51            attrs: ri.attrs,
52        }
53    }
54}
55
56/// Axum extractor that produces an apcore `Context<serde_json::Value>`.
57///
58/// Extracts identity from request extensions (`RequestIdentity`) and
59/// W3C TraceContext from the `traceparent` header.
60///
61/// # Usage
62///
63/// ```ignore
64/// async fn handler(
65///     ApContext(ctx): ApContext,
66///     Json(input): Json<Value>,
67/// ) -> Result<Json<Value>, AxumApcoreError> {
68///     // ctx is a fully populated apcore Context
69///     Ok(Json(serde_json::json!({"trace_id": ctx.trace_id})))
70/// }
71/// ```
72pub struct ApContext(pub Context<serde_json::Value>);
73
74impl<S> FromRequestParts<S> for ApContext
75where
76    S: Send + Sync,
77{
78    type Rejection = AxumApcoreError;
79
80    fn from_request_parts(
81        parts: &mut Parts,
82        _state: &S,
83    ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
84        let factory = AxumContextFactory;
85        let result = factory.create_from_parts(parts);
86        async move { result.map(ApContext) }
87    }
88}
89
90/// Factory for creating apcore contexts from Axum request parts.
91pub struct AxumContextFactory;
92
93impl AxumContextFactory {
94    /// Create an apcore Context from Axum request parts.
95    pub fn create_from_parts(
96        &self,
97        parts: &Parts,
98    ) -> Result<Context<serde_json::Value>, AxumApcoreError> {
99        let identity = self.extract_identity(parts);
100        let trace_context = self.extract_trace_context(parts);
101
102        let mut ctx = Context::new(identity);
103        ctx.trace_context = trace_context;
104
105        Ok(ctx)
106    }
107
108    /// Extract identity from request extensions, with fallback to anonymous.
109    fn extract_identity(&self, parts: &Parts) -> Identity {
110        if let Some(ri) = parts.extensions.get::<RequestIdentity>() {
111            ri.clone().into()
112        } else {
113            Identity {
114                id: "anonymous".into(),
115                identity_type: "anonymous".into(),
116                roles: vec![],
117                attrs: HashMap::new(),
118            }
119        }
120    }
121
122    /// Extract W3C TraceContext from the `traceparent` header.
123    fn extract_trace_context(&self, parts: &Parts) -> Option<TraceContext> {
124        let header = parts.headers.get("traceparent")?.to_str().ok()?;
125        let traceparent = TraceParent::parse(header).ok()?;
126        Some(TraceContext::new(traceparent))
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use axum::http::Request;
134
135    #[test]
136    fn test_request_identity_into_identity() {
137        let ri = RequestIdentity {
138            id: "user-1".into(),
139            identity_type: "user".into(),
140            roles: vec!["admin".into()],
141            attrs: HashMap::new(),
142        };
143        let identity: Identity = ri.into();
144        assert_eq!(identity.id, "user-1");
145        assert_eq!(identity.identity_type, "user");
146        assert_eq!(identity.roles, vec!["admin"]);
147    }
148
149    #[test]
150    fn test_extract_identity_anonymous_fallback() {
151        let req = Request::builder().body(()).unwrap();
152        let (parts, _) = req.into_parts();
153        let factory = AxumContextFactory;
154        let identity = factory.extract_identity(&parts);
155        assert_eq!(identity.id, "anonymous");
156        assert_eq!(identity.identity_type, "anonymous");
157    }
158
159    #[test]
160    fn test_extract_identity_from_extensions() {
161        let mut req = Request::builder().body(()).unwrap();
162        req.extensions_mut().insert(RequestIdentity {
163            id: "user-42".into(),
164            identity_type: "service".into(),
165            roles: vec!["reader".into()],
166            attrs: HashMap::new(),
167        });
168        let (parts, _) = req.into_parts();
169        let factory = AxumContextFactory;
170        let identity = factory.extract_identity(&parts);
171        assert_eq!(identity.id, "user-42");
172        assert_eq!(identity.identity_type, "service");
173    }
174
175    #[test]
176    fn test_create_from_parts() {
177        let req = Request::builder().body(()).unwrap();
178        let (parts, _) = req.into_parts();
179        let factory = AxumContextFactory;
180        let ctx = factory.create_from_parts(&parts).unwrap();
181        assert_eq!(ctx.identity.id, "anonymous");
182        assert!(!ctx.trace_id.is_empty());
183    }
184
185    #[test]
186    fn test_extract_trace_context_none() {
187        let req = Request::builder().body(()).unwrap();
188        let (parts, _) = req.into_parts();
189        let factory = AxumContextFactory;
190        assert!(factory.extract_trace_context(&parts).is_none());
191    }
192}