pyo3_tracing_subscriber/
contextmanager.rs

1// Copyright 2023 Rigetti Computing
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use pyo3::prelude::*;
15
16use crate::{common::ToPythonError, py_wrap_error, wrap_error};
17
18use super::export_process::{
19    ExportProcess, ExportProcessConfig, RustTracingShutdownError, RustTracingStartError,
20};
21
22#[pyclass]
23#[derive(Clone, Debug, Default)]
24pub(crate) struct GlobalTracingConfig {
25    pub(crate) export_process: ExportProcessConfig,
26}
27
28#[pymethods]
29impl GlobalTracingConfig {
30    #[new]
31    #[pyo3(signature = (/, export_process = None))]
32    #[allow(clippy::pedantic)]
33    fn new(export_process: Option<ExportProcessConfig>) -> PyResult<Self> {
34        let export_process = export_process.unwrap_or_default();
35        Ok(Self { export_process })
36    }
37}
38
39#[pyclass]
40#[derive(Clone, Debug)]
41pub(crate) struct CurrentThreadTracingConfig {
42    pub(crate) export_process: ExportProcessConfig,
43}
44
45#[pymethods]
46impl CurrentThreadTracingConfig {
47    #[new]
48    #[pyo3(signature = (/, export_process = None))]
49    #[allow(clippy::pedantic)]
50    fn new(export_process: Option<ExportProcessConfig>) -> PyResult<Self> {
51        let export_process = export_process.unwrap_or_default();
52        Ok(Self { export_process })
53    }
54}
55
56#[derive(FromPyObject, Debug)]
57pub(crate) enum TracingConfig {
58    Global(GlobalTracingConfig),
59    CurrentThread(CurrentThreadTracingConfig),
60}
61
62impl Default for TracingConfig {
63    fn default() -> Self {
64        Self::Global(GlobalTracingConfig::default())
65    }
66}
67
68/// Represents the current state of the context manager. This state is used to ensure the
69/// context manager methods are invoked in the correct order and multiplicity.
70#[derive(Debug)]
71enum ContextManagerState {
72    Initialized(TracingConfig),
73    Entered(ExportProcess),
74    Starting,
75    Exited,
76}
77
78/// A Python class that implements the context manager interface. It is initialized with a
79/// configuration. Upon entry it builds and installs the configured tracing subscriber. Upon exit
80/// it shuts down the tracing subscriber.
81#[pyclass]
82#[derive(Debug)]
83pub struct Tracing {
84    state: ContextManagerState,
85}
86
87#[derive(thiserror::Error, Debug)]
88enum ContextManagerError {
89    #[error("entered tracing context manager with no configuration defined; ensure contextmanager only enters once")]
90    EnterWithoutConfiguration,
91    #[error("exited tracing context manager with no export process defined; ensure contextmanager only exits once after being entered")]
92    ExitWithoutExportProcess,
93}
94
95wrap_error!(RustContextManagerError(ContextManagerError));
96py_wrap_error!(
97    contextmanager,
98    RustContextManagerError,
99    TracingContextManagerError,
100    pyo3::exceptions::PyRuntimeError
101);
102
103#[pymethods]
104impl Tracing {
105    #[new]
106    #[pyo3(signature = (/, config = None))]
107    #[allow(clippy::pedantic)]
108    fn new(config: Option<TracingConfig>) -> PyResult<Self> {
109        let config = config.unwrap_or_default();
110        Ok(Self {
111            state: ContextManagerState::Initialized(config),
112        })
113    }
114
115    fn __enter__(&mut self) -> PyResult<()> {
116        let state = std::mem::replace(&mut self.state, ContextManagerState::Starting);
117        if let ContextManagerState::Initialized(config) = state {
118            self.state = ContextManagerState::Entered(
119                ExportProcess::start(config)
120                    .map_err(RustTracingStartError::from)
121                    .map_err(ToPythonError::to_py_err)?,
122            );
123        } else {
124            return Err(RustContextManagerError::from(
125                ContextManagerError::EnterWithoutConfiguration,
126            )
127            .to_py_err())?;
128        }
129        Ok(())
130    }
131
132    fn __aenter__<'a>(&'a mut self, py: Python<'a>) -> PyResult<&'a PyAny> {
133        self.__enter__()?;
134        pyo3_asyncio::tokio::future_into_py(py, async { Ok(()) })
135    }
136
137    fn __exit__(
138        &mut self,
139        _exc_type: Option<&PyAny>,
140        _exc_value: Option<&PyAny>,
141        _traceback: Option<&PyAny>,
142    ) -> PyResult<()> {
143        let state = std::mem::replace(&mut self.state, ContextManagerState::Exited);
144        if let ContextManagerState::Entered(export_process) = state {
145            let py_rt = pyo3_asyncio::tokio::get_runtime();
146            // Why block and not run this in a future within aexit? The `shutdown`
147            // method returns a Tokio runtime, which cannot be dropped within another
148            // runtime. Additionally, `pyo3_asyncio::tokio::future_into_py` futures
149            // must resolve to something that implements `IntoPy`.
150            let export_runtime = py_rt.block_on(async move {
151                export_process
152                    .shutdown()
153                    .await
154                    .map_err(RustTracingShutdownError::from)
155                    .map_err(ToPythonError::to_py_err)
156            })?;
157            if let Some(export_runtime) = export_runtime {
158                // This immediately shuts the runtime down. The expectation here is that the
159                // process shutdown is responsible for cleaning up all background tasks and
160                // shutting down gracefully.
161                export_runtime.shutdown_background();
162            }
163        } else {
164            return Err(RustContextManagerError::from(
165                ContextManagerError::ExitWithoutExportProcess,
166            )
167            .to_py_err())?;
168        }
169
170        Ok(())
171    }
172
173    fn __aexit__<'a>(
174        &'a mut self,
175        py: Python<'a>,
176        exc_type: Option<&PyAny>,
177        exc_value: Option<&PyAny>,
178        traceback: Option<&PyAny>,
179    ) -> PyResult<&'a PyAny> {
180        self.__exit__(exc_type, exc_value, traceback)?;
181        pyo3_asyncio::tokio::future_into_py(py, async { Ok(()) })
182    }
183}
184
185#[cfg(feature = "layer-otel-otlp-file")]
186#[cfg(test)]
187mod test {
188    use std::{
189        env::temp_dir,
190        io::BufRead,
191        path::PathBuf,
192        thread::sleep,
193        time::{Duration, SystemTime, UNIX_EPOCH},
194    };
195
196    use tokio::runtime::Builder;
197
198    use crate::{
199        contextmanager::{CurrentThreadTracingConfig, GlobalTracingConfig, TracingConfig},
200        export_process::{ExportProcess, ExportProcessConfig, SimpleConfig},
201        subscriber::TracingSubscriberRegistryConfig,
202    };
203    use opentelemetry_proto::tonic::trace::v1 as otlp;
204
205    #[tracing::instrument]
206    fn example() {
207        sleep(SPAN_DURATION);
208    }
209
210    const N_SPANS: usize = 5;
211    const SPAN_DURATION: Duration = Duration::from_millis(100);
212
213    #[test]
214    /// Test that a global simple export process can be started and stopped and that it
215    /// exports accurate spans as configured.
216    fn test_global_simple() {
217        let temporary_file_path = get_tempfile("test_global_simple");
218        let layer_config = Box::new(crate::layers::otel_otlp_file::Config {
219            file_path: Some(temporary_file_path.as_os_str().to_str().unwrap().to_owned()),
220            filter: Some("error,pyo3_tracing_subscriber=info".to_string()),
221            instrumentation_library: None,
222        });
223        let subscriber = Box::new(TracingSubscriberRegistryConfig { layer_config });
224        let config = TracingConfig::Global(GlobalTracingConfig {
225            export_process: ExportProcessConfig::Simple(SimpleConfig {
226                subscriber: crate::subscriber::PyConfig {
227                    subscriber_config: subscriber,
228                },
229            }),
230        });
231        let export_process = ExportProcess::start(config).unwrap();
232        let rt2 = Builder::new_current_thread().enable_time().build().unwrap();
233        let _guard = rt2.enter();
234        let runtime = rt2
235            .block_on(tokio::time::timeout(Duration::from_secs(1), async move {
236                for _ in 0..N_SPANS {
237                    example();
238                }
239                export_process.shutdown().await
240            }))
241            .unwrap()
242            .unwrap();
243        assert!(runtime.is_none());
244
245        let reader = std::io::BufReader::new(std::fs::File::open(temporary_file_path).unwrap());
246        let lines = reader.lines();
247        let spans = lines
248            .flat_map(|line| {
249                let line = line.unwrap();
250                let span_data: otlp::TracesData =
251                    serde_json::from_str(line.as_str().trim()).unwrap();
252                span_data
253                    .resource_spans
254                    .iter()
255                    .flat_map(|resource_span| {
256                        resource_span
257                            .scope_spans
258                            .iter()
259                            .flat_map(|scope_span| scope_span.spans.clone())
260                    })
261                    .collect::<Vec<otlp::Span>>()
262            })
263            .collect::<Vec<otlp::Span>>();
264        assert_eq!(spans.len(), N_SPANS);
265
266        let span_grace = Duration::from_millis(50);
267        for span in spans {
268            assert_eq!(span.name, "example");
269            assert!(
270                u128::from(span.end_time_unix_nano - span.start_time_unix_nano)
271                    >= SPAN_DURATION.as_nanos()
272            );
273            assert!(
274                u128::from(span.end_time_unix_nano - span.start_time_unix_nano)
275                    <= (SPAN_DURATION.as_nanos() + span_grace.as_nanos())
276            );
277        }
278    }
279
280    fn get_tempfile(prefix: &str) -> PathBuf {
281        let timestamp = SystemTime::now()
282            .duration_since(UNIX_EPOCH)
283            .expect("should be able to get current time")
284            .as_nanos();
285        let dir = temp_dir();
286        dir.join(std::path::Path::new(
287            format!("{prefix}-{timestamp}.txt").as_str(),
288        ))
289    }
290
291    #[test]
292    /// Test that a current thread simple export process can be started and stopped and that it
293    /// exports accurate spans as configured.
294    fn test_current_thread_simple() {
295        let temporary_file_path = get_tempfile("test_current_thread_simple");
296        let layer_config = Box::new(crate::layers::otel_otlp_file::Config {
297            file_path: Some(temporary_file_path.as_os_str().to_str().unwrap().to_owned()),
298            filter: Some("error,pyo3_tracing_subscriber=info".to_string()),
299            instrumentation_library: None,
300        });
301        let subscriber = Box::new(TracingSubscriberRegistryConfig { layer_config });
302        let config = TracingConfig::CurrentThread(CurrentThreadTracingConfig {
303            export_process: crate::export_process::ExportProcessConfig::Simple(SimpleConfig {
304                subscriber: crate::subscriber::PyConfig {
305                    subscriber_config: subscriber,
306                },
307            }),
308        });
309        let export_process = ExportProcess::start(config).unwrap();
310
311        for _ in 0..N_SPANS {
312            example();
313        }
314
315        let rt2 = Builder::new_current_thread().enable_time().build().unwrap();
316        let _guard = rt2.enter();
317        let runtime = rt2
318            .block_on(tokio::time::timeout(Duration::from_secs(1), async move {
319                export_process.shutdown().await
320            }))
321            .unwrap()
322            .unwrap();
323        assert!(runtime.is_none());
324
325        let reader = std::io::BufReader::new(std::fs::File::open(temporary_file_path).unwrap());
326        let lines = reader.lines();
327        let spans = lines
328            .flat_map(|line| {
329                let line = line.unwrap();
330                let span_data: otlp::TracesData = serde_json::from_str(line.as_str()).unwrap();
331                span_data
332                    .resource_spans
333                    .iter()
334                    .flat_map(|resource_span| {
335                        resource_span
336                            .scope_spans
337                            .iter()
338                            .flat_map(|scope_span| scope_span.spans.clone())
339                    })
340                    .collect::<Vec<otlp::Span>>()
341            })
342            .collect::<Vec<otlp::Span>>();
343        assert_eq!(spans.len(), N_SPANS);
344
345        let span_grace = Duration::from_millis(50);
346        for span in spans {
347            assert_eq!(span.name, "example");
348            assert!(
349                u128::from(span.end_time_unix_nano - span.start_time_unix_nano)
350                    >= SPAN_DURATION.as_nanos()
351            );
352            assert!(
353                u128::from(span.end_time_unix_nano - span.start_time_unix_nano)
354                    <= (SPAN_DURATION.as_nanos() + span_grace.as_nanos())
355            );
356        }
357    }
358}