pyo3_python_tracing_subscriber/
lib.rs

1use pyo3::prelude::*;
2use serde_json::json;
3use tracing_core::{span, Event, Subscriber};
4use tracing_serde::AsSerde;
5use tracing_subscriber::{
6    layer::{Context, Layer},
7    registry::LookupSpan,
8};
9
10/// `PythonCallbackLayerBridge` is an adapter allowing the
11/// [`tracing_subscriber::layer::Layer`] trait to be implemented by a Python
12/// object. Each trait method's arguments are serialized as JSON strings and
13/// passed to the corresponding method on the Python object if it exists.
14///
15/// The interface `PythonCallbackLayerBridge` expects Python objects to
16/// implement differs slightly from the `Layer` trait in Rust:
17/// - The Python implementation of `on_new_span` may return some state that will
18///   be stored in the new span's [`tracing_subscriber::registry::Extensions`].
19/// - When calling other trait methods, `PythonCallbackLayerBridge` will get
20///   that state from the current span and pass it back to Python as an
21///   additional positional argument.
22///
23/// The state is opaque to `PythonCallbackLayerBridge` but, for example, a layer
24/// for a Python tracing system could create a Python span for each Rust span
25/// and use a reference to the Python span as the state.
26///
27/// Currently only a subset of `Layer` methods are bridged to Python:
28/// - [`tracing_subscriber::layer::Layer::on_event`], with corresponding Python
29///   signature ```python def on_event(self, event: str, state: Any): ... ```
30/// - [`tracing_subscriber::layer::Layer::on_new_span`] ```python def
31///   on_new_span(self, span_attrs: str, span_id: str): ... ```
32/// - [`tracing_subscriber::layer::Layer::on_close`] ```python def
33///   on_close(self, span_id: str, state: Any): ... ```
34/// - [`tracing_subscriber::layer::Layer::on_record`] ```python def
35///   on_record(self, span_id: str, values: str, state: Any): ... ```
36pub struct PythonCallbackLayerBridge {
37    on_event: Option<Py<PyAny>>,
38    on_new_span: Option<Py<PyAny>>,
39    on_close: Option<Py<PyAny>>,
40    on_record: Option<Py<PyAny>>,
41}
42
43impl PythonCallbackLayerBridge {
44    pub fn new(py_impl: Bound<'_, PyAny>) -> PythonCallbackLayerBridge {
45        let on_event = py_impl.getattr("on_event").ok().map(Bound::unbind);
46        let on_close = py_impl.getattr("on_close").ok().map(Bound::unbind);
47        let on_new_span = py_impl.getattr("on_new_span").ok().map(Bound::unbind);
48        let on_record = py_impl.getattr("on_record").ok().map(Bound::unbind);
49
50        PythonCallbackLayerBridge {
51            on_event,
52            on_close,
53            on_new_span,
54            on_record,
55        }
56    }
57}
58
59impl<S> Layer<S> for PythonCallbackLayerBridge
60where
61    S: Subscriber + for<'a> LookupSpan<'a>,
62{
63    fn on_event(&self, event: &Event, ctx: Context<'_, S>) {
64        let Some(py_on_event) = &self.on_event else {
65            return;
66        };
67
68        let current_span = event
69            .parent()
70            .and_then(|id| ctx.span(id))
71            .or_else(|| ctx.lookup_current());
72        let extensions = current_span.as_ref().map(|span| span.extensions());
73        let json_event = json!(event.as_serde()).to_string();
74
75        Python::with_gil(|py| {
76            let py_state =
77                extensions.map(|ext| ext.get::<Py<PyAny>>().map(|state| state.clone_ref(py)));
78            let _ = py_on_event.bind(py).call((json_event, py_state), None);
79        })
80    }
81
82    fn on_new_span(&self, attrs: &span::Attributes<'_>, span_id: &span::Id, ctx: Context<'_, S>) {
83        let (Some(py_on_new_span), Some(current_span)) = (&self.on_new_span, ctx.span(span_id))
84        else {
85            return;
86        };
87
88        let json_attrs = json!(attrs.as_serde()).to_string();
89        let json_id = json!(span_id.as_serde()).to_string();
90        let mut extensions = current_span.extensions_mut();
91
92        Python::with_gil(|py| {
93            let Ok(py_state) = py_on_new_span.bind(py).call((json_attrs, json_id), None) else {
94                return;
95            };
96
97            extensions.insert::<Py<PyAny>>(py_state.unbind());
98        })
99    }
100
101    fn on_close(&self, span_id: span::Id, ctx: Context<'_, S>) {
102        let (Some(py_on_close), Some(current_span)) = (&self.on_close, ctx.span(&span_id)) else {
103            return;
104        };
105
106        let json_id = json!(span_id.as_serde()).to_string();
107        let py_state = current_span.extensions_mut().remove::<Py<PyAny>>();
108
109        Python::with_gil(|py| {
110            let _ = py_on_close.bind(py).call((json_id, py_state), None);
111        })
112    }
113
114    fn on_record(&self, span_id: &span::Id, values: &span::Record<'_>, ctx: Context<'_, S>) {
115        let (Some(py_on_record), Some(current_span)) = (&self.on_record, ctx.span(span_id)) else {
116            return;
117        };
118
119        let json_id = json!(span_id.as_serde()).to_string();
120        let json_values = json!(values.as_serde()).to_string();
121        let extensions = current_span.extensions();
122
123        Python::with_gil(|py| {
124            let py_state = extensions
125                .get::<Py<PyAny>>()
126                .map(|state| state.clone_ref(py));
127
128            let _ = py_on_record
129                .bind(py)
130                .call((json_id, json_values, py_state), None);
131        })
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use std::{ops::RangeFrom, sync::Once};
138
139    use serde_json::{Map, Value};
140    use tracing::{info, instrument, warn_span};
141    use tracing_subscriber::prelude::*;
142
143    use super::*;
144
145    static INIT: Once = Once::new();
146
147    #[pyclass]
148    struct PythonLayer {
149        span_ids: RangeFrom<u16>,
150        pub events: Vec<(String, String, u16)>,
151        pub new_spans: Vec<Value>,
152        pub closed_spans: Vec<u16>,
153        pub span_records: Vec<(Value, u16)>,
154    }
155
156    #[pymethods]
157    impl PythonLayer {
158        #[new]
159        pub fn new() -> PythonLayer {
160            PythonLayer {
161                span_ids: 0..,
162                events: Vec::new(),
163                new_spans: Vec::new(),
164                closed_spans: Vec::new(),
165                span_records: Vec::new(),
166            }
167        }
168
169        pub fn on_event(&mut self, event: String, state: u16) {
170            let event = serde_json::from_str::<Map<String, Value>>(&event).unwrap();
171            let message = event.get("message").unwrap().as_str().unwrap();
172            let level = event
173                .get("metadata")
174                .unwrap()
175                .get("level")
176                .unwrap()
177                .as_str()
178                .unwrap();
179
180            self.events
181                .push((message.to_owned(), level.to_owned(), state));
182        }
183
184        pub fn on_new_span(&mut self, span_attrs: String, _span_id: String) -> u16 {
185            let span_attrs = serde_json::from_str::<Map<String, Value>>(&span_attrs).unwrap();
186            let metadata = span_attrs.get("metadata").unwrap().as_object().unwrap();
187
188            let mut stripped_attrs = Map::new();
189
190            stripped_attrs.insert("level".to_string(), metadata.get("level").unwrap().clone());
191            stripped_attrs.insert("name".to_string(), metadata.get("name").unwrap().clone());
192
193            let fields = metadata.get("fields").unwrap().as_array().unwrap();
194            for field in fields {
195                let field = field.as_str().unwrap();
196                if let Some(value) = span_attrs.get(field) {
197                    stripped_attrs.insert(field.to_owned(), value.clone());
198                }
199            }
200
201            self.new_spans.push(stripped_attrs.into());
202            self.span_ids.next().unwrap()
203        }
204
205        pub fn on_close(&mut self, _span_id: String, state: u16) {
206            self.closed_spans.push(state);
207        }
208
209        pub fn on_record(&mut self, _span_id: String, values: String, state: u16) {
210            let values = serde_json::from_str(&values).unwrap();
211            self.span_records.push((values, state));
212        }
213    }
214
215    fn initialize_tracing() -> (Py<PythonLayer>, tracing::dispatcher::DefaultGuard) {
216        INIT.call_once(|| {
217            pyo3::prepare_freethreaded_python();
218        });
219        let (py_layer, rs_layer) = Python::with_gil(|py| {
220            let py_layer = Bound::new(py, PythonLayer::new()).unwrap();
221            let (py_layer, py_layer_unbound) = (py_layer.clone().into_any(), py_layer.unbind());
222            (py_layer_unbound, PythonCallbackLayerBridge::new(py_layer))
223        });
224        (
225            py_layer,
226            tracing_subscriber::registry().with(rs_layer).set_default(),
227        )
228    }
229
230    #[instrument(fields(data))]
231    fn func(arg1: u16, arg2: String) {
232        info!("About to record something");
233        tracing::Span::current().record("data", "some data");
234    }
235
236    #[test]
237    fn test_simple_span() {
238        let (py_layer, _dispatcher) = initialize_tracing();
239
240        func(1337, "foo".to_string());
241
242        let expected_events = vec![("About to record something".to_owned(), "INFO".to_owned(), 0)];
243        let expected_new_spans =
244            vec![json!({"arg1": 1337, "arg2": "\"foo\"", "level": "INFO", "name": "func"})];
245        let expected_closed_spans = vec![0];
246        let expected_records = vec![(json!({"data": "some data"}), 0)];
247
248        Python::with_gil(|py| {
249            let borrowed = py_layer.borrow(py);
250            assert_eq!(&expected_events, &borrowed.events);
251            assert_eq!(&expected_new_spans, &borrowed.new_spans);
252            assert_eq!(&expected_closed_spans, &borrowed.closed_spans);
253            assert_eq!(&expected_records, &borrowed.span_records);
254        });
255    }
256
257    #[test]
258    fn test_nested_span() {
259        let (py_layer, _dispatcher) = initialize_tracing();
260
261        {
262            let span = warn_span!("outer");
263            span.in_scope(|| {
264                func(1337, "bar".to_string());
265            });
266        }
267
268        let expected_events = vec![("About to record something".to_owned(), "INFO".to_owned(), 1)];
269        let expected_new_spans = vec![
270            json!({"level": "WARN", "name": "outer"}),
271            json!({"arg1": 1337, "arg2": "\"bar\"", "level": "INFO", "name": "func"}),
272        ];
273        let expected_closed_spans = vec![1, 0];
274        let expected_records = vec![(json!({"data": "some data"}), 1)];
275
276        Python::with_gil(|py| {
277            let borrowed = py_layer.borrow(py);
278            assert_eq!(&expected_events, &borrowed.events);
279            assert_eq!(&expected_new_spans, &borrowed.new_spans);
280            assert_eq!(&expected_closed_spans, &borrowed.closed_spans);
281            assert_eq!(&expected_records, &borrowed.span_records);
282        });
283    }
284}