agent_chain_core/tracers/
context.rs

1//! Context management for tracers.
2//!
3//! This module provides context management for tracers using thread-local storage.
4//! Mirrors `langchain_core.tracers.context`.
5
6use std::cell::RefCell;
7use std::sync::Arc;
8use uuid::Uuid;
9
10use crate::tracers::run_collector::RunCollectorCallbackHandler;
11use crate::tracers::schemas::Run;
12
13// Thread-local storage for the tracing callback handler.
14thread_local! {
15    static TRACING_V2_CALLBACK: RefCell<Option<Arc<dyn TracingCallback>>> = const { RefCell::new(None) };
16    static RUN_COLLECTOR: RefCell<Option<Arc<std::sync::Mutex<RunCollectorCallbackHandler>>>> = const { RefCell::new(None) };
17}
18
19/// Trait for tracing callbacks that can be stored in context.
20pub trait TracingCallback: Send + Sync {
21    /// Get the project name.
22    fn project_name(&self) -> Option<&str>;
23
24    /// Get the example ID.
25    fn example_id(&self) -> Option<Uuid>;
26
27    /// Get the latest run.
28    fn latest_run(&self) -> Option<&Run>;
29
30    /// Get the run URL.
31    fn get_run_url(&self) -> Option<String>;
32}
33
34/// Guard that resets the tracing callback when dropped.
35pub struct TracingV2Guard {
36    previous: Option<Arc<dyn TracingCallback>>,
37}
38
39impl Drop for TracingV2Guard {
40    fn drop(&mut self) {
41        TRACING_V2_CALLBACK.with(|cell| {
42            *cell.borrow_mut() = self.previous.take();
43        });
44    }
45}
46
47/// Guard that resets the run collector when dropped.
48pub struct RunCollectorGuard {
49    previous: Option<Arc<std::sync::Mutex<RunCollectorCallbackHandler>>>,
50}
51
52impl Drop for RunCollectorGuard {
53    fn drop(&mut self) {
54        RUN_COLLECTOR.with(|cell| {
55            *cell.borrow_mut() = self.previous.take();
56        });
57    }
58}
59
60/// Enable tracing v2 in the current context.
61///
62/// # Arguments
63///
64/// * `callback` - The tracing callback to use.
65///
66/// # Returns
67///
68/// A guard that will reset the callback when dropped.
69pub fn tracing_v2_enabled(callback: Arc<dyn TracingCallback>) -> TracingV2Guard {
70    let previous = TRACING_V2_CALLBACK.with(|cell| {
71        let mut borrow = cell.borrow_mut();
72        let prev = borrow.take();
73        *borrow = Some(callback);
74        prev
75    });
76
77    TracingV2Guard { previous }
78}
79
80/// Check if tracing v2 is enabled.
81pub fn tracing_v2_is_enabled() -> bool {
82    TRACING_V2_CALLBACK.with(|cell| cell.borrow().is_some())
83}
84
85/// Get the current tracing callback.
86pub fn get_tracing_callback() -> Option<Arc<dyn TracingCallback>> {
87    TRACING_V2_CALLBACK.with(|cell| cell.borrow().clone())
88}
89
90/// Collect runs in the current context.
91///
92/// # Arguments
93///
94/// * `collector` - The run collector to use.
95///
96/// # Returns
97///
98/// A guard that will reset the collector when dropped.
99pub fn collect_runs(
100    collector: RunCollectorCallbackHandler,
101) -> (
102    RunCollectorGuard,
103    Arc<std::sync::Mutex<RunCollectorCallbackHandler>>,
104) {
105    let collector = Arc::new(std::sync::Mutex::new(collector));
106    let collector_clone = collector.clone();
107
108    let previous = RUN_COLLECTOR.with(|cell| {
109        let mut borrow = cell.borrow_mut();
110        let prev = borrow.take();
111        *borrow = Some(collector);
112        prev
113    });
114
115    (RunCollectorGuard { previous }, collector_clone)
116}
117
118/// Get the current run collector.
119pub fn get_run_collector() -> Option<Arc<std::sync::Mutex<RunCollectorCallbackHandler>>> {
120    RUN_COLLECTOR.with(|cell| cell.borrow().clone())
121}
122
123/// Configuration hook for registering callback handlers.
124#[derive(Debug, Clone)]
125pub struct ConfigureHook {
126    /// Whether the hook is inheritable.
127    pub inheritable: bool,
128    /// The environment variable to check.
129    pub env_var: Option<String>,
130}
131
132impl ConfigureHook {
133    /// Create a new configure hook.
134    pub fn new(inheritable: bool, env_var: Option<String>) -> Self {
135        Self {
136            inheritable,
137            env_var,
138        }
139    }
140}
141
142/// Registry for configure hooks.
143#[derive(Debug, Default)]
144pub struct ConfigureHookRegistry {
145    hooks: Vec<ConfigureHook>,
146}
147
148impl ConfigureHookRegistry {
149    /// Create a new configure hook registry.
150    pub fn new() -> Self {
151        Self::default()
152    }
153
154    /// Register a configure hook.
155    pub fn register(&mut self, hook: ConfigureHook) {
156        self.hooks.push(hook);
157    }
158
159    /// Get all registered hooks.
160    pub fn hooks(&self) -> &[ConfigureHook] {
161        &self.hooks
162    }
163}
164
165/// Global configure hook registry.
166static CONFIGURE_HOOKS: std::sync::LazyLock<std::sync::Mutex<ConfigureHookRegistry>> =
167    std::sync::LazyLock::new(|| std::sync::Mutex::new(ConfigureHookRegistry::new()));
168
169/// Register a configure hook.
170///
171/// # Arguments
172///
173/// * `inheritable` - Whether the hook is inheritable.
174/// * `env_var` - The environment variable to check.
175pub fn register_configure_hook(inheritable: bool, env_var: Option<String>) {
176    if let Ok(mut registry) = CONFIGURE_HOOKS.lock() {
177        registry.register(ConfigureHook::new(inheritable, env_var));
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    struct TestCallback {
186        project: String,
187    }
188
189    impl TracingCallback for TestCallback {
190        fn project_name(&self) -> Option<&str> {
191            Some(&self.project)
192        }
193
194        fn example_id(&self) -> Option<Uuid> {
195            None
196        }
197
198        fn latest_run(&self) -> Option<&Run> {
199            None
200        }
201
202        fn get_run_url(&self) -> Option<String> {
203            None
204        }
205    }
206
207    #[test]
208    fn test_tracing_v2_enabled() {
209        assert!(!tracing_v2_is_enabled());
210
211        let callback = Arc::new(TestCallback {
212            project: "test".to_string(),
213        });
214
215        {
216            let _guard = tracing_v2_enabled(callback.clone());
217            assert!(tracing_v2_is_enabled());
218
219            let cb = get_tracing_callback().unwrap();
220            assert_eq!(cb.project_name(), Some("test"));
221        }
222
223        assert!(!tracing_v2_is_enabled());
224    }
225
226    #[test]
227    fn test_collect_runs() {
228        let collector = RunCollectorCallbackHandler::new(None);
229
230        {
231            let (_guard, collector_arc) = collect_runs(collector);
232
233            let current = get_run_collector();
234            assert!(current.is_some());
235
236            // Verify it's the same collector
237            let collector_locked = collector_arc.lock().unwrap();
238            assert!(collector_locked.is_empty());
239        }
240
241        assert!(get_run_collector().is_none());
242    }
243
244    #[test]
245    fn test_register_configure_hook() {
246        register_configure_hook(false, None);
247        register_configure_hook(true, Some("LANGCHAIN_TRACING_V2".to_string()));
248
249        let registry = CONFIGURE_HOOKS.lock().unwrap();
250        assert!(registry.hooks().len() >= 2);
251    }
252}