logforth_diagnostic_task_local/
lib.rs

1// Copyright 2024 FastLabs Developers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! A diagnostic that stores key-value pairs in a task-local map.
16//!
17//! # Examples
18//!
19//! ```
20//! use logforth_core::Diagnostic;
21//! use logforth_core::kv::Visitor;
22//! use logforth_diagnostic_task_local::FutureExt;
23//!
24//! let fut = async { log::info!("Hello, world!") };
25//! fut.with_task_local_context([("key".into(), "value".into())]);
26//! ```
27
28use std::cell::RefCell;
29use std::pin::Pin;
30use std::task::Context;
31use std::task::Poll;
32
33use logforth_core::Diagnostic;
34use logforth_core::Error;
35use logforth_core::kv::Key;
36use logforth_core::kv::Value;
37use logforth_core::kv::Visitor;
38
39thread_local! {
40    static TASK_LOCAL_MAP: RefCell<Vec<(String, String)>> = const { RefCell::new(Vec::new()) };
41}
42
43/// A diagnostic that stores key-value pairs in a task-local context.
44///
45/// See [module-level documentation](self) for usage examples.
46#[derive(Default, Debug, Clone, Copy)]
47#[non_exhaustive]
48pub struct TaskLocalDiagnostic {}
49
50impl Diagnostic for TaskLocalDiagnostic {
51    fn visit(&self, visitor: &mut dyn Visitor) -> Result<(), Error> {
52        TASK_LOCAL_MAP.with(|map| {
53            let map = map.borrow();
54            for (key, value) in map.iter() {
55                let key = Key::new_ref(key.as_str());
56                let value = Value::from(value.as_str());
57                visitor.visit(key, value)?;
58            }
59            Ok(())
60        })
61    }
62}
63
64/// An extension trait for futures to run them with a task-local context.
65///
66/// See [module-level documentation](self) for usage examples.
67pub trait FutureExt: Future {
68    /// Run a future with a task-local context.
69    fn with_task_local_context(
70        self,
71        kvs: impl IntoIterator<Item = (String, String)>,
72    ) -> impl Future<Output = Self::Output>
73    where
74        Self: Sized,
75    {
76        TaskLocalFuture {
77            future: Some(self),
78            context: kvs.into_iter().collect(),
79        }
80    }
81}
82
83impl<F: Future> FutureExt for F {}
84
85#[pin_project::pin_project]
86struct TaskLocalFuture<F> {
87    #[pin]
88    future: Option<F>,
89    context: Vec<(String, String)>,
90}
91
92impl<F: Future> Future for TaskLocalFuture<F> {
93    type Output = F::Output;
94
95    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
96        let this = self.project();
97
98        let mut fut = this.future;
99        if let Some(future) = fut.as_mut().as_pin_mut() {
100            struct Guard {
101                n: usize,
102            }
103
104            impl Drop for Guard {
105                fn drop(&mut self) {
106                    TASK_LOCAL_MAP.with(|map| {
107                        let mut map = map.borrow_mut();
108                        for _ in 0..self.n {
109                            map.pop();
110                        }
111                    });
112                }
113            }
114
115            TASK_LOCAL_MAP.with(|map| {
116                let mut map = map.borrow_mut();
117                for (key, value) in this.context.iter() {
118                    map.push((key.clone(), value.clone()));
119                }
120            });
121
122            let n = this.context.len();
123            let guard = Guard { n };
124
125            let result = match future.poll(cx) {
126                Poll::Ready(output) => {
127                    fut.set(None);
128                    Poll::Ready(output)
129                }
130                Poll::Pending => Poll::Pending,
131            };
132
133            drop(guard);
134            return result;
135        }
136
137        unreachable!("TaskLocalFuture polled after completion");
138    }
139}