Skip to main content

apalis_core/task/
metadata.rs

1//! Task metadata extension trait and implementations
2//!
3//! The [`MetadataExt`] trait allows injecting and extracting metadata associated with tasks.
4//! It includes implementations for common metadata types.
5//!
6//! ## Overview
7//! - `MetadataExt<T>`: A trait for extracting and injecting metadata of type `T`.
8//!
9//! # Usage
10//! Implement the `MetadataExt` trait for your metadata types to enable easy extraction and injection
11//! from task contexts. This allows middleware and services to access and modify task metadata in a
12//! type-safe manner.
13use crate::task::Task;
14use crate::task_fn::FromRequest;
15use std::ops::Deref;
16
17/// Metadata wrapper for task contexts.
18#[derive(Debug, Clone)]
19pub struct Meta<T>(pub T);
20
21impl<T> Deref for Meta<T> {
22    type Target = T;
23    fn deref(&self) -> &Self::Target {
24        &self.0
25    }
26}
27
28/// Task metadata extension trait and implementations.
29/// This trait allows for injecting and extracting metadata associated with tasks.
30pub trait MetadataExt<T> {
31    /// The error type that can occur during extraction or injection.
32    type Error;
33    /// Extract metadata of type `T`.
34    fn extract(&self) -> Result<T, Self::Error>;
35    /// Inject metadata of type `T`.
36    fn inject(&mut self, value: T) -> Result<(), Self::Error>;
37}
38
39impl<T, Args: Send + Sync, Ctx: MetadataExt<T> + Send + Sync, IdType: Send + Sync>
40    FromRequest<Task<Args, Ctx, IdType>> for Meta<T>
41{
42    type Error = Ctx::Error;
43
44    async fn from_request(task: &Task<Args, Ctx, IdType>) -> Result<Self, Self::Error> {
45        task.parts.ctx.extract().map(Meta)
46    }
47}
48
49/// Metadata used specifically for storing the tracing context
50#[cfg(feature = "tracing")]
51#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
52#[derive(Debug, Default, Clone)]
53pub struct TracingContext {
54    trace_id: Option<String>,
55    span_id: Option<String>,
56    trace_flags: Option<u8>,
57    trace_state: Option<String>,
58}
59
60#[cfg(feature = "tracing")]
61impl TracingContext {
62    /// Create a new empty `TracingContext`.
63    #[must_use]
64    pub fn new() -> Self {
65        Self::default()
66    }
67
68    /// Set the trace ID.
69    #[must_use]
70    pub fn with_trace_id(mut self, trace_id: impl Into<String>) -> Self {
71        self.trace_id = Some(trace_id.into());
72        self
73    }
74
75    /// Set the span ID.
76    #[must_use]
77    pub fn with_span_id(mut self, span_id: impl Into<String>) -> Self {
78        self.span_id = Some(span_id.into());
79        self
80    }
81
82    /// Set the trace flags.
83    #[must_use]
84    pub fn with_trace_flags(mut self, trace_flags: u8) -> Self {
85        self.trace_flags = Some(trace_flags);
86        self
87    }
88
89    /// Set the trace state.
90    #[must_use]
91    pub fn with_trace_state(mut self, trace_state: impl Into<String>) -> Self {
92        self.trace_state = Some(trace_state.into());
93        self
94    }
95
96    /// Get the trace ID.
97    #[must_use]
98    pub fn trace_id(&self) -> &Option<String> {
99        &self.trace_id
100    }
101
102    /// Get the span ID.
103    #[must_use]
104    pub fn span_id(&self) -> &Option<String> {
105        &self.span_id
106    }
107
108    /// Get the trace flags.
109    #[must_use]
110    pub fn trace_flags(&self) -> &Option<u8> {
111        &self.trace_flags
112    }
113
114    /// Get the trace state.
115    #[must_use]
116    pub fn trace_state(&self) -> &Option<String> {
117        &self.trace_state
118    }
119}
120
121#[cfg(test)]
122#[allow(unused)]
123mod tests {
124    use std::{convert::Infallible, fmt::Debug, task::Poll, time::Duration};
125
126    use crate::{
127        error::BoxDynError,
128        task::{
129            Task,
130            metadata::{Meta, MetadataExt},
131        },
132        task_fn::FromRequest,
133    };
134    use futures_core::future::BoxFuture;
135    use tower::Service;
136
137    #[derive(Debug, Clone)]
138    struct ExampleService<S> {
139        service: S,
140    }
141    #[derive(Debug, Clone, Default)]
142    struct ExampleConfig {
143        timeout: Duration,
144    }
145
146    struct SampleStore;
147
148    impl MetadataExt<ExampleConfig> for SampleStore {
149        type Error = Infallible;
150        fn extract(&self) -> Result<ExampleConfig, Self::Error> {
151            Ok(ExampleConfig {
152                timeout: Duration::from_secs(1),
153            })
154        }
155        fn inject(&mut self, _: ExampleConfig) -> Result<(), Self::Error> {
156            unreachable!()
157        }
158    }
159
160    impl<S, Args: Send + Sync + 'static, Ctx: Send + Sync + 'static, IdType: Send + Sync + 'static>
161        Service<Task<Args, Ctx, IdType>> for ExampleService<S>
162    where
163        S: Service<Task<Args, Ctx, IdType>> + Clone + Send + 'static,
164        Ctx: MetadataExt<ExampleConfig> + Send,
165        Ctx::Error: Debug,
166        S::Future: Send + 'static,
167    {
168        type Response = S::Response;
169        type Error = S::Error;
170        type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
171
172        fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
173            self.service.poll_ready(cx)
174        }
175
176        fn call(&mut self, request: Task<Args, Ctx, IdType>) -> Self::Future {
177            let mut svc = self.service.clone();
178
179            // Do something with config
180            Box::pin(async move {
181                let _config: Meta<ExampleConfig> = request.extract().await.unwrap();
182                svc.call(request).await
183            })
184        }
185    }
186}