Skip to main content

systemprompt_logging/layer/
mod.rs

1//! `tracing` subscriber layer that persists events to the database.
2//!
3//! [`DatabaseLayer`] buffers log events off the hot path and batch-inserts them
4//! from a background task, flushing on a size threshold, a timer, or
5//! immediately on an error. [`ProxyDatabaseLayer`] is the proxy-side variant.
6
7mod proxy;
8mod visitor;
9
10use std::io::Write;
11use std::sync::Arc;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::time::Duration;
14
15use tokio::sync::mpsc;
16use tracing::{Event, Subscriber};
17use tracing_subscriber::Layer;
18use tracing_subscriber::layer::Context;
19use tracing_subscriber::registry::LookupSpan;
20
21pub use proxy::ProxyDatabaseLayer;
22use proxy::{build_log_entry, record_span_fields, update_span_fields};
23
24use crate::models::{LogEntry, LogLevel};
25use systemprompt_database::DbPool;
26use systemprompt_identifiers::{ClientId, ContextId, TaskId};
27
28const BUFFER_FLUSH_SIZE: usize = 100;
29const BUFFER_FLUSH_INTERVAL_SECS: u64 = 10;
30
31/// Bounded capacity of the log channel. Beyond this depth (a sustained burst
32/// the database writer cannot drain) entries are dropped rather than queued, so
33/// a logging backlog cannot grow the heap without bound.
34const CHANNEL_CAPACITY: usize = 8192;
35
36enum LogCommand {
37    Entry(Box<LogEntry>),
38    FlushNow,
39}
40
41/// Bounded sender to the database writer task. On a full channel the entry is
42/// dropped and [`LogChannel::dropped`] is incremented; the send never blocks,
43/// so logging stays off the hot path even under burst.
44struct LogChannel {
45    sender: mpsc::Sender<LogCommand>,
46    dropped: Arc<AtomicU64>,
47}
48
49impl LogChannel {
50    fn new(capacity: usize) -> (Self, mpsc::Receiver<LogCommand>) {
51        let (sender, receiver) = mpsc::channel(capacity);
52        let channel = Self {
53            sender,
54            dropped: Arc::new(AtomicU64::new(0)),
55        };
56        (channel, receiver)
57    }
58
59    fn send(&self, command: LogCommand) {
60        if let Err(mpsc::error::TrySendError::Full(_)) = self.sender.try_send(command) {
61            self.dropped.fetch_add(1, Ordering::Relaxed);
62        }
63    }
64
65    fn dropped(&self) -> u64 {
66        self.dropped.load(Ordering::Relaxed)
67    }
68}
69
70pub struct DatabaseLayer {
71    channel: LogChannel,
72}
73
74impl std::fmt::Debug for DatabaseLayer {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("DatabaseLayer")
77            .field("dropped", &self.channel.dropped())
78            .finish_non_exhaustive()
79    }
80}
81
82impl DatabaseLayer {
83    pub fn new(db_pool: DbPool) -> Self {
84        let (channel, receiver) = LogChannel::new(CHANNEL_CAPACITY);
85
86        tokio::spawn(Self::batch_writer(db_pool, receiver));
87
88        Self { channel }
89    }
90
91    async fn batch_writer(db_pool: DbPool, mut receiver: mpsc::Receiver<LogCommand>) {
92        let mut buffer = Vec::with_capacity(BUFFER_FLUSH_SIZE);
93        let mut interval = tokio::time::interval(Duration::from_secs(BUFFER_FLUSH_INTERVAL_SECS));
94        let mut failed_total: u64 = 0;
95
96        loop {
97            tokio::select! {
98                Some(command) = receiver.recv() => {
99                    match command {
100                        LogCommand::Entry(entry) => {
101                            buffer.push(*entry);
102                            if buffer.len() >= BUFFER_FLUSH_SIZE {
103                                Self::flush(&db_pool, &mut buffer, &mut failed_total).await;
104                            }
105                        }
106                        LogCommand::FlushNow => {
107                            if !buffer.is_empty() {
108                                Self::flush(&db_pool, &mut buffer, &mut failed_total).await;
109                            }
110                        }
111                    }
112                }
113                _ = interval.tick() => {
114                    if !buffer.is_empty() {
115                        Self::flush(&db_pool, &mut buffer, &mut failed_total).await;
116                    }
117                }
118            }
119        }
120    }
121
122    async fn flush(db_pool: &DbPool, buffer: &mut Vec<LogEntry>, failed_total: &mut u64) {
123        if let Err(e) = Self::batch_insert(db_pool, buffer).await {
124            let lost = u64::try_from(buffer.len()).unwrap_or(u64::MAX);
125            *failed_total = failed_total.saturating_add(lost);
126            writeln!(
127                std::io::stderr(),
128                "DATABASE LOG FLUSH FAILED ({lost} entries lost this flush, {failed_total} total lost since start): {e}"
129            )
130            .ok();
131        }
132        buffer.clear();
133    }
134
135    async fn batch_insert(
136        db_pool: &DbPool,
137        entries: &[LogEntry],
138    ) -> Result<(), crate::models::LoggingError> {
139        let pool = db_pool.write_pool_arc()?;
140        for entry in entries {
141            let metadata_json: Option<String> = entry
142                .metadata
143                .as_ref()
144                .map(serde_json::to_string)
145                .transpose()?;
146
147            let entry_id = entry.id.as_str();
148            let level_str = entry.level.to_string();
149            let user_id = entry.user_id.as_str();
150            let session_id = entry.session_id.as_str();
151            let task_id = entry.task_id.as_ref().map(TaskId::as_str);
152            let trace_id = entry.trace_id.as_str();
153            let context_id = entry.context_id.as_ref().map(ContextId::as_str);
154            let client_id = entry.client_id.as_ref().map(ClientId::as_str);
155
156            sqlx::query!(
157                r"
158                INSERT INTO logs (id, level, module, message, metadata, user_id, session_id, task_id, trace_id, context_id, client_id)
159                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
160                ",
161                entry_id,
162                level_str,
163                entry.module,
164                entry.message,
165                metadata_json,
166                user_id,
167                session_id,
168                task_id,
169                trace_id,
170                context_id,
171                client_id
172            )
173            .execute(pool.as_ref())
174            .await?;
175        }
176
177        Ok(())
178    }
179}
180
181impl DatabaseLayer {
182    fn send_entry(&self, entry: LogEntry) {
183        let is_error = entry.level == LogLevel::Error;
184        self.channel.send(LogCommand::Entry(Box::new(entry)));
185        if is_error {
186            self.channel.send(LogCommand::FlushNow);
187        }
188    }
189}
190
191impl<S> Layer<S> for DatabaseLayer
192where
193    S: Subscriber + for<'a> LookupSpan<'a>,
194{
195    fn on_new_span(
196        &self,
197        attrs: &tracing::span::Attributes<'_>,
198        id: &tracing::span::Id,
199        ctx: Context<'_, S>,
200    ) {
201        record_span_fields(attrs, id, &ctx);
202    }
203
204    fn on_record(
205        &self,
206        id: &tracing::span::Id,
207        values: &tracing::span::Record<'_>,
208        ctx: Context<'_, S>,
209    ) {
210        update_span_fields(id, values, &ctx);
211    }
212
213    fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
214        if let Some(entry) = build_log_entry(event, &ctx) {
215            self.send_entry(entry);
216        }
217    }
218}