Skip to main content

systemprompt_logging/layer/
mod.rs

1//! `tracing` subscriber layer that persists events to the database.
2//!
3//! [`DatabaseLayer`] buffers log events off the hot path and batch-inserts them
4//! from a background task, flushing on a size threshold, a timer, or
5//! immediately on an error. [`ProxyDatabaseLayer`] is the proxy-side variant.
6
7mod proxy;
8mod visitor;
9
10use std::io::Write;
11use std::time::Duration;
12
13use tokio::sync::mpsc;
14use tracing::{Event, Subscriber};
15use tracing_subscriber::Layer;
16use tracing_subscriber::layer::Context;
17use tracing_subscriber::registry::LookupSpan;
18
19pub use proxy::ProxyDatabaseLayer;
20use proxy::{build_log_entry, record_span_fields, update_span_fields};
21
22use crate::models::{LogEntry, LogLevel};
23use systemprompt_database::DbPool;
24use systemprompt_identifiers::{ClientId, ContextId, TaskId};
25
26const BUFFER_FLUSH_SIZE: usize = 100;
27const BUFFER_FLUSH_INTERVAL_SECS: u64 = 10;
28
29enum LogCommand {
30    Entry(Box<LogEntry>),
31    FlushNow,
32}
33
34pub struct DatabaseLayer {
35    sender: mpsc::UnboundedSender<LogCommand>,
36}
37
38impl std::fmt::Debug for DatabaseLayer {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("DatabaseLayer").finish_non_exhaustive()
41    }
42}
43
44impl DatabaseLayer {
45    pub fn new(db_pool: DbPool) -> Self {
46        let (sender, receiver) = mpsc::unbounded_channel();
47
48        tokio::spawn(Self::batch_writer(db_pool, receiver));
49
50        Self { sender }
51    }
52
53    async fn batch_writer(db_pool: DbPool, mut receiver: mpsc::UnboundedReceiver<LogCommand>) {
54        let mut buffer = Vec::with_capacity(BUFFER_FLUSH_SIZE);
55        let mut interval = tokio::time::interval(Duration::from_secs(BUFFER_FLUSH_INTERVAL_SECS));
56
57        loop {
58            tokio::select! {
59                Some(command) = receiver.recv() => {
60                    match command {
61                        LogCommand::Entry(entry) => {
62                            buffer.push(*entry);
63                            if buffer.len() >= BUFFER_FLUSH_SIZE {
64                                Self::flush(&db_pool, &mut buffer).await;
65                            }
66                        }
67                        LogCommand::FlushNow => {
68                            if !buffer.is_empty() {
69                                Self::flush(&db_pool, &mut buffer).await;
70                            }
71                        }
72                    }
73                }
74                _ = interval.tick() => {
75                    if !buffer.is_empty() {
76                        Self::flush(&db_pool, &mut buffer).await;
77                    }
78                }
79            }
80        }
81    }
82
83    async fn flush(db_pool: &DbPool, buffer: &mut Vec<LogEntry>) {
84        if let Err(e) = Self::batch_insert(db_pool, buffer).await {
85            writeln!(
86                std::io::stderr(),
87                "DATABASE LOG FLUSH FAILED ({} entries lost): {e}",
88                buffer.len()
89            )
90            .ok();
91        }
92        buffer.clear();
93    }
94
95    async fn batch_insert(
96        db_pool: &DbPool,
97        entries: &[LogEntry],
98    ) -> Result<(), crate::models::LoggingError> {
99        let pool = db_pool.write_pool_arc()?;
100        for entry in entries {
101            let metadata_json: Option<String> = entry
102                .metadata
103                .as_ref()
104                .map(serde_json::to_string)
105                .transpose()?;
106
107            let entry_id = entry.id.as_str();
108            let level_str = entry.level.to_string();
109            let user_id = entry.user_id.as_str();
110            let session_id = entry.session_id.as_str();
111            let task_id = entry.task_id.as_ref().map(TaskId::as_str);
112            let trace_id = entry.trace_id.as_str();
113            let context_id = entry.context_id.as_ref().map(ContextId::as_str);
114            let client_id = entry.client_id.as_ref().map(ClientId::as_str);
115
116            sqlx::query!(
117                r"
118                INSERT INTO logs (id, level, module, message, metadata, user_id, session_id, task_id, trace_id, context_id, client_id)
119                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
120                ",
121                entry_id,
122                level_str,
123                entry.module,
124                entry.message,
125                metadata_json,
126                user_id,
127                session_id,
128                task_id,
129                trace_id,
130                context_id,
131                client_id
132            )
133            .execute(pool.as_ref())
134            .await?;
135        }
136
137        Ok(())
138    }
139}
140
141impl DatabaseLayer {
142    fn send_entry(&self, entry: LogEntry) {
143        let is_error = entry.level == LogLevel::Error;
144        self.sender.send(LogCommand::Entry(Box::new(entry))).ok();
145        if is_error {
146            self.sender.send(LogCommand::FlushNow).ok();
147        }
148    }
149}
150
151impl<S> Layer<S> for DatabaseLayer
152where
153    S: Subscriber + for<'a> LookupSpan<'a>,
154{
155    fn on_new_span(
156        &self,
157        attrs: &tracing::span::Attributes<'_>,
158        id: &tracing::span::Id,
159        ctx: Context<'_, S>,
160    ) {
161        record_span_fields(attrs, id, &ctx);
162    }
163
164    fn on_record(
165        &self,
166        id: &tracing::span::Id,
167        values: &tracing::span::Record<'_>,
168        ctx: Context<'_, S>,
169    ) {
170        update_span_fields(id, values, &ctx);
171    }
172
173    fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
174        if let Some(entry) = build_log_entry(event, &ctx) {
175            self.send_entry(entry);
176        }
177    }
178}