1use crate::logger::Logger;
2
3#[derive(Clone)]
4pub struct Context {
5 pub logger: Logger,
6}
7
8tokio::task_local! {
9 static CONTEXT: Context;
10}
11
12pub async fn with_context<F, Fut, T>(context: Context, f: F) -> T
13where
14 F: FnOnce() -> Fut,
15 Fut: std::future::Future<Output = T>,
16{
17 CONTEXT.scope(context, f()).await
18}
19
20pub fn get_context() -> Context {
21 CONTEXT
22 .try_with(|ctx| ctx.clone())
23 .unwrap_or_else(|_| Context {
24 logger: Logger::default(),
25 })
26}
27
28#[cfg(test)]
29mod tests {
30 use std::sync::{Arc, Mutex};
31
32 use serde_json::Value;
33
34 use super::*;
35
36 #[tokio::test]
37 async fn get_context_returns_scoped_context() {
38 let calls: Arc<Mutex<Vec<(String, Value)>>> = Arc::new(Mutex::new(Vec::new()));
39 let calls_ref = calls.clone();
40 let invoker = Arc::new(move |path: &str, params: Value| {
41 calls_ref.lock().unwrap().push((path.to_string(), params));
42 });
43
44 let logger = Logger::new(
45 Some(invoker),
46 Some("trace-ctx".to_string()),
47 Some("fn-ctx".to_string()),
48 );
49
50 with_context(Context { logger }, || async {
51 let ctx = get_context();
52 ctx.logger.info("inside", None);
53 })
54 .await;
55
56 let calls = calls.lock().unwrap();
57 assert_eq!(calls.len(), 1);
58 assert_eq!(calls[0].0, "logger.info");
59 assert_eq!(calls[0].1["trace_id"], "trace-ctx");
60 assert_eq!(calls[0].1["function_name"], "fn-ctx");
61 assert_eq!(calls[0].1["message"], "inside");
62 }
63}