Skip to main content

systemprompt_logging/layer/
mod.rs

1mod proxy;
2pub(crate) mod visitor;
3
4use std::io::Write;
5use std::time::Duration;
6
7use tokio::sync::mpsc;
8use tracing::{Event, Subscriber};
9use tracing_subscriber::Layer;
10use tracing_subscriber::layer::Context;
11use tracing_subscriber::registry::LookupSpan;
12
13pub use proxy::ProxyDatabaseLayer;
14use proxy::{build_log_entry, record_span_fields, update_span_fields};
15
16use crate::models::{LogEntry, LogLevel};
17use systemprompt_database::DbPool;
18use systemprompt_identifiers::{ClientId, ContextId, TaskId};
19
20const BUFFER_FLUSH_SIZE: usize = 100;
21const BUFFER_FLUSH_INTERVAL_SECS: u64 = 10;
22
23enum LogCommand {
24    Entry(Box<LogEntry>),
25    FlushNow,
26}
27
28pub struct DatabaseLayer {
29    sender: mpsc::UnboundedSender<LogCommand>,
30}
31
32impl std::fmt::Debug for DatabaseLayer {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("DatabaseLayer").finish_non_exhaustive()
35    }
36}
37
38impl DatabaseLayer {
39    pub fn new(db_pool: DbPool) -> Self {
40        let (sender, receiver) = mpsc::unbounded_channel();
41
42        tokio::spawn(Self::batch_writer(db_pool, receiver));
43
44        Self { sender }
45    }
46
47    async fn batch_writer(db_pool: DbPool, mut receiver: mpsc::UnboundedReceiver<LogCommand>) {
48        let mut buffer = Vec::with_capacity(BUFFER_FLUSH_SIZE);
49        let mut interval = tokio::time::interval(Duration::from_secs(BUFFER_FLUSH_INTERVAL_SECS));
50
51        loop {
52            tokio::select! {
53                Some(command) = receiver.recv() => {
54                    match command {
55                        LogCommand::Entry(entry) => {
56                            buffer.push(*entry);
57                            if buffer.len() >= BUFFER_FLUSH_SIZE {
58                                Self::flush(&db_pool, &mut buffer).await;
59                            }
60                        }
61                        LogCommand::FlushNow => {
62                            if !buffer.is_empty() {
63                                Self::flush(&db_pool, &mut buffer).await;
64                            }
65                        }
66                    }
67                }
68                _ = interval.tick() => {
69                    if !buffer.is_empty() {
70                        Self::flush(&db_pool, &mut buffer).await;
71                    }
72                }
73            }
74        }
75    }
76
77    async fn flush(db_pool: &DbPool, buffer: &mut Vec<LogEntry>) {
78        if let Err(e) = Self::batch_insert(db_pool, buffer).await {
79            writeln!(
80                std::io::stderr(),
81                "DATABASE LOG FLUSH FAILED ({} entries lost): {e}",
82                buffer.len()
83            )
84            .ok();
85        }
86        buffer.clear();
87    }
88
89    async fn batch_insert(
90        db_pool: &DbPool,
91        entries: &[LogEntry],
92    ) -> Result<(), crate::models::LoggingError> {
93        let pool = db_pool.write_pool_arc()?;
94        for entry in entries {
95            let metadata_json: Option<String> = entry
96                .metadata
97                .as_ref()
98                .map(serde_json::to_string)
99                .transpose()?;
100
101            let entry_id = entry.id.as_str();
102            let level_str = entry.level.to_string();
103            let user_id = entry.user_id.as_str();
104            let session_id = entry.session_id.as_str();
105            let task_id = entry.task_id.as_ref().map(TaskId::as_str);
106            let trace_id = entry.trace_id.as_str();
107            let context_id = entry.context_id.as_ref().map(ContextId::as_str);
108            let client_id = entry.client_id.as_ref().map(ClientId::as_str);
109
110            sqlx::query!(
111                r"
112                INSERT INTO logs (id, level, module, message, metadata, user_id, session_id, task_id, trace_id, context_id, client_id)
113                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
114                ",
115                entry_id,
116                level_str,
117                entry.module,
118                entry.message,
119                metadata_json,
120                user_id,
121                session_id,
122                task_id,
123                trace_id,
124                context_id,
125                client_id
126            )
127            .execute(pool.as_ref())
128            .await?;
129        }
130
131        Ok(())
132    }
133}
134
135impl DatabaseLayer {
136    fn send_entry(&self, entry: LogEntry) {
137        let is_error = entry.level == LogLevel::Error;
138        self.sender.send(LogCommand::Entry(Box::new(entry))).ok();
139        if is_error {
140            self.sender.send(LogCommand::FlushNow).ok();
141        }
142    }
143}
144
145impl<S> Layer<S> for DatabaseLayer
146where
147    S: Subscriber + for<'a> LookupSpan<'a>,
148{
149    fn on_new_span(
150        &self,
151        attrs: &tracing::span::Attributes<'_>,
152        id: &tracing::span::Id,
153        ctx: Context<'_, S>,
154    ) {
155        record_span_fields(attrs, id, &ctx);
156    }
157
158    fn on_record(
159        &self,
160        id: &tracing::span::Id,
161        values: &tracing::span::Record<'_>,
162        ctx: Context<'_, S>,
163    ) {
164        update_span_fields(id, values, &ctx);
165    }
166
167    fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
168        self.send_entry(build_log_entry(event, &ctx));
169    }
170}