Skip to main content

iii_sdk/
context.rs

1use crate::logger::Logger;
2
3#[derive(Clone)]
4pub struct Context {
5    pub logger: Logger,
6}
7
8tokio::task_local! {
9    static CONTEXT: Context;
10}
11
12pub async fn with_context<F, Fut, T>(context: Context, f: F) -> T
13where
14    F: FnOnce() -> Fut,
15    Fut: std::future::Future<Output = T>,
16{
17    CONTEXT.scope(context, f()).await
18}
19
20pub fn get_context() -> Context {
21    CONTEXT
22        .try_with(|ctx| ctx.clone())
23        .unwrap_or_else(|_| Context {
24            logger: Logger::default(),
25        })
26}
27
28#[cfg(test)]
29mod tests {
30    use std::sync::{Arc, Mutex};
31
32    use serde_json::Value;
33
34    use super::*;
35
36    #[tokio::test]
37    async fn get_context_returns_scoped_context() {
38        let calls: Arc<Mutex<Vec<(String, Value)>>> = Arc::new(Mutex::new(Vec::new()));
39        let calls_ref = calls.clone();
40        let invoker = Arc::new(move |path: &str, params: Value| {
41            calls_ref.lock().unwrap().push((path.to_string(), params));
42        });
43
44        let logger = Logger::new(
45            Some(invoker),
46            Some("trace-ctx".to_string()),
47            Some("fn-ctx".to_string()),
48        );
49
50        with_context(Context { logger }, || async {
51            let ctx = get_context();
52            ctx.logger.info("inside", None);
53        })
54        .await;
55
56        let calls = calls.lock().unwrap();
57        assert_eq!(calls.len(), 1);
58        assert_eq!(calls[0].0, "logger.info");
59        assert_eq!(calls[0].1["trace_id"], "trace-ctx");
60        assert_eq!(calls[0].1["function_name"], "fn-ctx");
61        assert_eq!(calls[0].1["message"], "inside");
62    }
63}