Skip to main content

systemprompt_logging/layer/
proxy.rs

1use std::io::Write;
2use std::sync::{Arc, OnceLock};
3
4use chrono::Utc;
5use tracing::{Event, Subscriber};
6use tracing_subscriber::Layer;
7use tracing_subscriber::layer::Context;
8use tracing_subscriber::registry::LookupSpan;
9
10use super::DatabaseLayer;
11use super::visitor::{FieldVisitor, SpanContext, SpanFields, SpanVisitor, extract_span_context};
12use crate::models::{LogEntry, LogLevel};
13use systemprompt_database::DbPool;
14use systemprompt_identifiers::{ClientId, ContextId, LogId, SessionId, TaskId, TraceId, UserId};
15
16#[derive(Clone)]
17pub struct ProxyDatabaseLayer {
18    inner: Arc<OnceLock<DatabaseLayer>>,
19}
20
21impl std::fmt::Debug for ProxyDatabaseLayer {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("ProxyDatabaseLayer")
24            .field("attached", &self.inner.get().is_some())
25            .finish()
26    }
27}
28
29impl Default for ProxyDatabaseLayer {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl ProxyDatabaseLayer {
36    pub fn new() -> Self {
37        Self {
38            inner: Arc::new(OnceLock::new()),
39        }
40    }
41
42    pub fn attach(&self, db_pool: DbPool) {
43        if self.inner.set(DatabaseLayer::new(db_pool)).is_err() {
44            writeln!(
45                std::io::stderr(),
46                "ProxyDatabaseLayer: database layer already attached, ignoring duplicate"
47            )
48            .ok();
49        }
50    }
51}
52
53impl<S> Layer<S> for ProxyDatabaseLayer
54where
55    S: Subscriber + for<'a> LookupSpan<'a>,
56{
57    fn on_new_span(
58        &self,
59        attrs: &tracing::span::Attributes<'_>,
60        id: &tracing::span::Id,
61        ctx: Context<'_, S>,
62    ) {
63        if let Some(db) = self.inner.get() {
64            db.on_new_span(attrs, id, ctx);
65        } else {
66            record_span_fields(attrs, id, &ctx);
67        }
68    }
69
70    fn on_record(
71        &self,
72        id: &tracing::span::Id,
73        values: &tracing::span::Record<'_>,
74        ctx: Context<'_, S>,
75    ) {
76        if let Some(db) = self.inner.get() {
77            db.on_record(id, values, ctx);
78        } else {
79            update_span_fields(id, values, &ctx);
80        }
81    }
82
83    fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
84        if let Some(db) = self.inner.get() {
85            db.on_event(event, ctx);
86        }
87    }
88}
89
90pub(super) fn record_span_fields<S>(
91    attrs: &tracing::span::Attributes<'_>,
92    id: &tracing::span::Id,
93    ctx: &Context<'_, S>,
94) where
95    S: Subscriber + for<'a> LookupSpan<'a>,
96{
97    let Some(span) = ctx.span(id) else {
98        return;
99    };
100    let mut fields = SpanFields::default();
101    let mut context = SpanContext::default();
102    let mut visitor = SpanVisitor {
103        context: &mut context,
104    };
105    attrs.record(&mut visitor);
106
107    fields.user = context.user;
108    fields.session = context.session;
109    fields.task = context.task;
110    fields.trace = context.trace;
111    fields.context = context.context;
112    fields.client = context.client;
113
114    let mut extensions = span.extensions_mut();
115    extensions.insert(fields);
116}
117
118pub(super) fn update_span_fields<S>(
119    id: &tracing::span::Id,
120    values: &tracing::span::Record<'_>,
121    ctx: &Context<'_, S>,
122) where
123    S: Subscriber + for<'a> LookupSpan<'a>,
124{
125    if let Some(span) = ctx.span(id) {
126        let mut extensions = span.extensions_mut();
127        if let Some(fields) = extensions.get_mut::<SpanFields>() {
128            let mut context = SpanContext {
129                user: fields.user.clone(),
130                session: fields.session.clone(),
131                task: fields.task.clone(),
132                trace: fields.trace.clone(),
133                context: fields.context.clone(),
134                client: fields.client.clone(),
135            };
136            let mut visitor = SpanVisitor {
137                context: &mut context,
138            };
139            values.record(&mut visitor);
140
141            fields.user = context.user;
142            fields.session = context.session;
143            fields.task = context.task;
144            fields.trace = context.trace;
145            fields.context = context.context;
146            fields.client = context.client;
147        }
148    }
149}
150
151pub(super) fn build_log_entry<S>(event: &Event<'_>, ctx: &Context<'_, S>) -> Option<LogEntry>
152where
153    S: Subscriber + for<'a> LookupSpan<'a>,
154{
155    let level = *event.metadata().level();
156    let module = event.metadata().target().to_owned();
157
158    let mut visitor = FieldVisitor::default();
159    event.record(&mut visitor);
160
161    let span_context = ctx
162        .current_span()
163        .id()
164        .and_then(|id| ctx.span(id))
165        .map(extract_span_context)?;
166
167    let log_level = match level {
168        tracing::Level::ERROR => LogLevel::Error,
169        tracing::Level::WARN => LogLevel::Warn,
170        tracing::Level::INFO => LogLevel::Info,
171        tracing::Level::DEBUG => LogLevel::Debug,
172        tracing::Level::TRACE => LogLevel::Trace,
173    };
174
175    let user_id = UserId::new(span_context.user.as_ref()?.clone());
176    let session_id = SessionId::new(span_context.session.as_ref()?.clone());
177    let trace_id = TraceId::new(span_context.trace.as_ref()?.clone());
178
179    Some(LogEntry {
180        id: LogId::generate(),
181        timestamp: Utc::now(),
182        level: log_level,
183        module,
184        message: visitor.message,
185        metadata: visitor.fields,
186        user_id,
187        session_id,
188        task_id: span_context.task.as_ref().map(|s| TaskId::new(s.clone())),
189        trace_id,
190        context_id: span_context
191            .context
192            .as_ref()
193            .and_then(|s| ContextId::try_new(s.clone()).ok()),
194        client_id: span_context
195            .client
196            .as_ref()
197            .map(|s| ClientId::new(s.clone())),
198    })
199}