Skip to main content

systemprompt_logging/layer/
mod.rs

1mod visitor;
2
3use std::io::Write;
4use std::time::Duration;
5
6use chrono::Utc;
7use tokio::sync::mpsc;
8use tracing::{Event, Subscriber};
9use tracing_subscriber::Layer;
10use tracing_subscriber::layer::Context;
11use tracing_subscriber::registry::LookupSpan;
12
13use crate::models::{LogEntry, LogLevel};
14use systemprompt_database::DbPool;
15use systemprompt_identifiers::{ClientId, ContextId, LogId, SessionId, TaskId, TraceId, UserId};
16use visitor::{FieldVisitor, SpanContext, SpanFields, SpanVisitor, extract_span_context};
17
18const BUFFER_FLUSH_SIZE: usize = 100;
19const BUFFER_FLUSH_INTERVAL_SECS: u64 = 10;
20
21enum LogCommand {
22    Entry(Box<LogEntry>),
23    FlushNow,
24}
25
26pub struct DatabaseLayer {
27    sender: mpsc::UnboundedSender<LogCommand>,
28}
29
30impl std::fmt::Debug for DatabaseLayer {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("DatabaseLayer").finish_non_exhaustive()
33    }
34}
35
36impl DatabaseLayer {
37    pub fn new(db_pool: DbPool) -> Self {
38        let (sender, receiver) = mpsc::unbounded_channel();
39
40        tokio::spawn(Self::batch_writer(db_pool, receiver));
41
42        Self { sender }
43    }
44
45    async fn batch_writer(db_pool: DbPool, mut receiver: mpsc::UnboundedReceiver<LogCommand>) {
46        let mut buffer = Vec::with_capacity(BUFFER_FLUSH_SIZE);
47        let mut interval = tokio::time::interval(Duration::from_secs(BUFFER_FLUSH_INTERVAL_SECS));
48
49        loop {
50            tokio::select! {
51                Some(command) = receiver.recv() => {
52                    match command {
53                        LogCommand::Entry(entry) => {
54                            buffer.push(*entry);
55                            if buffer.len() >= BUFFER_FLUSH_SIZE {
56                                Self::flush(&db_pool, &mut buffer).await;
57                            }
58                        }
59                        LogCommand::FlushNow => {
60                            if !buffer.is_empty() {
61                                Self::flush(&db_pool, &mut buffer).await;
62                            }
63                        }
64                    }
65                }
66                _ = interval.tick() => {
67                    if !buffer.is_empty() {
68                        Self::flush(&db_pool, &mut buffer).await;
69                    }
70                }
71            }
72        }
73    }
74
75    async fn flush(db_pool: &DbPool, buffer: &mut Vec<LogEntry>) {
76        if let Err(e) = Self::batch_insert(db_pool, buffer).await {
77            let msg = e.to_string();
78            if !msg.contains("does not exist") {
79                let _ = writeln!(std::io::stderr(), "Failed to flush logs: {e}");
80            }
81        }
82        buffer.clear();
83    }
84
85    async fn batch_insert(db_pool: &DbPool, entries: &[LogEntry]) -> anyhow::Result<()> {
86        let pool = db_pool.write_pool_arc()?;
87        for entry in entries {
88            let metadata_json: Option<String> = entry
89                .metadata
90                .as_ref()
91                .map(serde_json::to_string)
92                .transpose()?;
93
94            let entry_id = entry.id.as_str();
95            let level_str = entry.level.to_string();
96            let user_id = entry.user_id.as_str();
97            let session_id = entry.session_id.as_str();
98            let task_id = entry.task_id.as_ref().map(TaskId::as_str);
99            let trace_id = entry.trace_id.as_str();
100            let context_id = entry.context_id.as_ref().map(ContextId::as_str);
101            let client_id = entry.client_id.as_ref().map(ClientId::as_str);
102
103            sqlx::query!(
104                r"
105                INSERT INTO logs (id, level, module, message, metadata, user_id, session_id, task_id, trace_id, context_id, client_id)
106                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
107                ",
108                entry_id,
109                level_str,
110                entry.module,
111                entry.message,
112                metadata_json,
113                user_id,
114                session_id,
115                task_id,
116                trace_id,
117                context_id,
118                client_id
119            )
120            .execute(pool.as_ref())
121            .await?;
122        }
123
124        Ok(())
125    }
126}
127
128impl<S> Layer<S> for DatabaseLayer
129where
130    S: Subscriber + for<'a> LookupSpan<'a>,
131{
132    fn on_new_span(
133        &self,
134        attrs: &tracing::span::Attributes<'_>,
135        id: &tracing::span::Id,
136        ctx: Context<'_, S>,
137    ) {
138        let Some(span) = ctx.span(id) else {
139            return;
140        };
141        let mut fields = SpanFields::default();
142        let mut context = SpanContext::default();
143        let mut visitor = SpanVisitor {
144            context: &mut context,
145        };
146        attrs.record(&mut visitor);
147
148        fields.user = context.user;
149        fields.session = context.session;
150        fields.task = context.task;
151        fields.trace = context.trace;
152        fields.context = context.context;
153        fields.client = context.client;
154
155        let mut extensions = span.extensions_mut();
156        extensions.insert(fields);
157    }
158
159    fn on_record(
160        &self,
161        id: &tracing::span::Id,
162        values: &tracing::span::Record<'_>,
163        ctx: Context<'_, S>,
164    ) {
165        if let Some(span) = ctx.span(id) {
166            let mut extensions = span.extensions_mut();
167            if let Some(fields) = extensions.get_mut::<SpanFields>() {
168                let mut context = SpanContext {
169                    user: fields.user.clone(),
170                    session: fields.session.clone(),
171                    task: fields.task.clone(),
172                    trace: fields.trace.clone(),
173                    context: fields.context.clone(),
174                    client: fields.client.clone(),
175                };
176                let mut visitor = SpanVisitor {
177                    context: &mut context,
178                };
179                values.record(&mut visitor);
180
181                fields.user = context.user;
182                fields.session = context.session;
183                fields.task = context.task;
184                fields.trace = context.trace;
185                fields.context = context.context;
186                fields.client = context.client;
187            }
188        }
189    }
190
191    fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
192        let level = *event.metadata().level();
193        let module = event.metadata().target().to_string();
194
195        let mut visitor = FieldVisitor::default();
196        event.record(&mut visitor);
197
198        let span_context = ctx
199            .current_span()
200            .id()
201            .and_then(|id| ctx.span(id))
202            .map(extract_span_context);
203
204        let log_level = match level {
205            tracing::Level::ERROR => LogLevel::Error,
206            tracing::Level::WARN => LogLevel::Warn,
207            tracing::Level::INFO => LogLevel::Info,
208            tracing::Level::DEBUG => LogLevel::Debug,
209            tracing::Level::TRACE => LogLevel::Trace,
210        };
211
212        let is_error = log_level == LogLevel::Error;
213
214        let entry = LogEntry {
215            id: LogId::generate(),
216            timestamp: Utc::now(),
217            level: log_level,
218            module,
219            message: visitor.message,
220            metadata: visitor.fields,
221            user_id: span_context
222                .as_ref()
223                .and_then(|c| c.user.as_ref())
224                .map_or_else(UserId::system, |s| UserId::new(s.clone())),
225            session_id: span_context
226                .as_ref()
227                .and_then(|c| c.session.as_ref())
228                .map_or_else(SessionId::system, |s| SessionId::new(s.clone())),
229            task_id: span_context
230                .as_ref()
231                .and_then(|c| c.task.as_ref())
232                .map(|s| TaskId::new(s.clone())),
233            trace_id: span_context
234                .as_ref()
235                .and_then(|c| c.trace.as_ref())
236                .map_or_else(TraceId::system, |s| TraceId::new(s.clone())),
237            context_id: span_context
238                .as_ref()
239                .and_then(|c| c.context.as_ref())
240                .map(|s| ContextId::new(s.clone())),
241            client_id: span_context
242                .as_ref()
243                .and_then(|c| c.client.as_ref())
244                .map(|s| ClientId::new(s.clone())),
245        };
246
247        let _ = self.sender.send(LogCommand::Entry(Box::new(entry)));
248
249        if is_error {
250            let _ = self.sender.send(LogCommand::FlushNow);
251        }
252    }
253}