Skip to main content

logforth_diagnostic_task_local/
lib.rs

1// Copyright 2024 FastLabs Developers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! A diagnostic that stores key-value pairs in a task-local map.
16//!
17//! # Examples
18//!
19//! ```
20//! use logforth_core::Diagnostic;
21//! use logforth_core::kv::Visitor;
22//! use logforth_diagnostic_task_local::FutureExt;
23//!
24//! let fut = async { log::info!("Hello, world!") };
25//! fut.with_task_local_context([("key".into(), "value".into())]);
26//! ```
27
28#![cfg_attr(docsrs, feature(doc_cfg))]
29#![deny(missing_docs)]
30
31use std::cell::RefCell;
32use std::pin::Pin;
33use std::task::Context;
34use std::task::Poll;
35
36use logforth_core::Diagnostic;
37use logforth_core::Error;
38use logforth_core::kv::Key;
39use logforth_core::kv::Value;
40use logforth_core::kv::Visitor;
41
42thread_local! {
43    static TASK_LOCAL_MAP: RefCell<Vec<(String, String)>> = const { RefCell::new(Vec::new()) };
44}
45
46/// A diagnostic that stores key-value pairs in a task-local context.
47///
48/// See [module-level documentation](self) for usage examples.
49#[derive(Default, Debug, Clone, Copy)]
50#[non_exhaustive]
51pub struct TaskLocalDiagnostic {}
52
53impl Diagnostic for TaskLocalDiagnostic {
54    fn visit(&self, visitor: &mut dyn Visitor) -> Result<(), Error> {
55        TASK_LOCAL_MAP.with(|map| {
56            let map = map.borrow();
57            for (key, value) in map.iter() {
58                let key = Key::borrowed(key.as_str());
59                let value = Value::str(value.as_str());
60                visitor.visit(key.view(), value.view())?;
61            }
62            Ok(())
63        })
64    }
65}
66
67/// An extension trait for futures to run them with a task-local context.
68///
69/// See [module-level documentation](self) for usage examples.
70pub trait FutureExt: Future {
71    /// Run a future with a task-local context.
72    fn with_task_local_context(
73        self,
74        kvs: impl IntoIterator<Item = (String, String)>,
75    ) -> impl Future<Output = Self::Output>
76    where
77        Self: Sized,
78    {
79        TaskLocalFuture {
80            future: Some(self),
81            context: kvs.into_iter().collect(),
82        }
83    }
84}
85
86impl<F: Future> FutureExt for F {}
87
88#[pin_project::pin_project]
89struct TaskLocalFuture<F> {
90    #[pin]
91    future: Option<F>,
92    context: Vec<(String, String)>,
93}
94
95impl<F: Future> Future for TaskLocalFuture<F> {
96    type Output = F::Output;
97
98    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
99        let this = self.project();
100
101        let mut fut = this.future;
102        if let Some(future) = fut.as_mut().as_pin_mut() {
103            struct Guard {
104                n: usize,
105            }
106
107            impl Drop for Guard {
108                fn drop(&mut self) {
109                    TASK_LOCAL_MAP.with(|map| {
110                        let mut map = map.borrow_mut();
111                        for _ in 0..self.n {
112                            map.pop();
113                        }
114                    });
115                }
116            }
117
118            TASK_LOCAL_MAP.with(|map| {
119                let mut map = map.borrow_mut();
120                for (key, value) in this.context.iter() {
121                    map.push((key.clone(), value.clone()));
122                }
123            });
124
125            let n = this.context.len();
126            let guard = Guard { n };
127
128            let result = match future.poll(cx) {
129                Poll::Ready(output) => {
130                    fut.set(None);
131                    Poll::Ready(output)
132                }
133                Poll::Pending => Poll::Pending,
134            };
135
136            drop(guard);
137            return result;
138        }
139
140        unreachable!("TaskLocalFuture polled after completion");
141    }
142}