1use std::sync::Arc;
3
4use chrono::{DateTime, Utc};
5use error_stack::{Report, ResultExt};
6use sqlx::{PgExecutor, PgPool, QueryBuilder};
7use uuid::Uuid;
8
9use super::{
10 logging::{ProxyLogEntry, ProxyLogEvent},
11 DbProvider, ProxyDatabase,
12};
13use crate::{
14 config::{AliasConfig, ApiKeyConfig},
15 workflow_events::{
16 RunStartEvent, RunUpdateEvent, StepEventData, StepStartData, StepStateData, WorkflowEvent,
17 },
18 Error,
19};
20
21const POSTGRESQL_MIGRATIONS: &[&'static str] = &[
22 include_str!("../../migrations/20240419_chronicle_proxy_init_postgresql.sql"),
23 include_str!("../../migrations/20240424_chronicle_proxy_data_tables_postgresql.sql"),
24 include_str!("../../migrations/20240625_chronicle_proxy_steps_postgresql.sql"),
25];
26
27#[derive(Debug)]
29pub struct PostgresDatabase {
30 pool: PgPool,
31}
32
33impl PostgresDatabase {
34 pub fn new(pool: PgPool) -> Arc<dyn ProxyDatabase> {
36 Arc::new(Self { pool })
37 }
38
39 async fn write_step_start(
40 &self,
41 tx: impl PgExecutor<'_>,
42 event: StepEventData<StepStartData>,
43 ) -> Result<(), sqlx::Error> {
44 let tags = if event.data.tags.is_empty() {
45 None
46 } else {
47 Some(event.data.tags)
48 };
49
50 sqlx::query(
51 r##"
52 INSERT INTO chronicle_steps (
53 id, run_id, type, parent_step, name, input, status, tags, info, span_id, start_time
54 )
55 VALUES (
56 $1, $2, $3, $4, $5, $6, 'started', $7, $8, $9, $10
57 )
58 ON CONFLICT DO NOTHING;
59 "##,
60 )
61 .bind(event.step_id)
62 .bind(event.run_id)
63 .bind(event.data.typ)
64 .bind(event.data.parent_step)
65 .bind(event.data.name)
66 .bind(event.data.input)
67 .bind(tags)
68 .bind(event.data.info)
69 .bind(event.data.span_id)
70 .bind(event.time.unwrap_or_else(|| Utc::now()))
71 .execute(tx)
72 .await?;
73 Ok(())
74 }
75
76 async fn write_step_end(
77 &self,
78 tx: impl PgExecutor<'_>,
79 step_id: Uuid,
80 run_id: Uuid,
81 status: &str,
82 output: serde_json::Value,
83 info: Option<serde_json::Value>,
84 timestamp: Option<DateTime<Utc>>,
85 ) -> Result<(), sqlx::Error> {
86 sqlx::query(
87 r##"
88 UPDATE chronicle_steps
89 SET status = $1,
90 output = $2,
91 info = CASE
92 WHEN NULLIF(info, 'null'::jsonb) IS NULL THEN $3
93 WHEN NULLIF($3, 'null'::jsonb) IS NULL THEN info
94 ELSE info || $3
95 END,
96 end_time = $4
97 WHERE run_id = $5 AND id = $6
98 "##,
99 )
100 .bind(status)
101 .bind(output)
102 .bind(info)
103 .bind(timestamp.unwrap_or_else(|| chrono::Utc::now()))
104 .bind(run_id)
105 .bind(step_id)
106 .execute(tx)
107 .await?;
108 Ok(())
109 }
110
111 async fn write_step_status(
112 &self,
113 tx: impl PgExecutor<'_>,
114 event: StepEventData<StepStateData>,
115 ) -> Result<(), sqlx::Error> {
116 sqlx::query(
117 "UPDATE chronicle_steps
118 SET status=$1
119 WHERE run_id=$2 AND id=$3",
120 )
121 .bind(event.data.state)
122 .bind(event.run_id)
123 .bind(event.step_id)
124 .execute(tx)
125 .await?;
126
127 Ok(())
128 }
129
130 async fn write_run_start(
131 &self,
132 tx: impl PgExecutor<'_>,
133 event: RunStartEvent,
134 ) -> Result<(), sqlx::Error> {
135 let tags = if event.tags.is_empty() {
136 None
137 } else {
138 Some(event.tags)
139 };
140
141 sqlx::query(
142 r##"
143 INSERT INTO chronicle_runs (
144 id, name, description, application, environment, input, status,
145 trace_id, span_id, tags, info, updated_at, created_at
146 )
147 VALUES (
148 $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $12
149 )
150 ON CONFLICT(id) DO UPDATE SET
151 status = EXCLUDED.status,
152 updated_at = EXCLUDED.updated_at;
153 "##,
154 )
155 .bind(event.id)
156 .bind(event.name)
157 .bind(event.description)
158 .bind(event.application)
159 .bind(event.environment)
160 .bind(event.input)
161 .bind(event.status.as_deref().unwrap_or("started"))
162 .bind(event.trace_id)
163 .bind(event.span_id)
164 .bind(tags)
165 .bind(event.info)
166 .bind(event.time.unwrap_or_else(|| Utc::now()))
167 .execute(tx)
168 .await?;
169 Ok(())
170 }
171
172 async fn write_run_update(
173 &self,
174 tx: impl PgExecutor<'_>,
175 event: RunUpdateEvent,
176 ) -> Result<(), sqlx::Error> {
177 sqlx::query(
178 "UPDATE chronicle_runs
179 SET status = $1,
180 output = $2,
181 info = CASE
182 WHEN NULLIF(info, 'null'::jsonb) IS NULL THEN $3
183 WHEN NULLIF($3, 'null'::jsonb) IS NULL THEN info
184 ELSE info || $3
185 END,
186 updated_at = $4
187 WHERE id = $5",
188 )
189 .bind(event.status.as_deref().unwrap_or("finished"))
190 .bind(event.output)
191 .bind(event.info)
192 .bind(event.time.unwrap_or_else(|| Utc::now()))
193 .bind(event.id)
194 .execute(tx)
195 .await?;
196 Ok(())
197 }
198
199 fn add_event_values(builder: &mut QueryBuilder<'_, sqlx::Postgres>, item: ProxyLogEvent) {
200 let (rmodel, rprovider, rbody, rmeta) = match item
201 .response
202 .map(|r| (r.body.model.clone(), r.provider, r.body, r.info.meta))
203 {
204 Some((rmodel, rprovider, rbody, rmeta)) => {
205 (rmodel, Some(rprovider), Some(rbody), rmeta)
206 }
207 None => (None, None, None, None),
208 };
209
210 let model = rmodel
211 .or_else(|| item.request.as_ref().and_then(|r| r.model.clone()))
212 .unwrap_or_default();
213
214 let extra = item.options.metadata.extra.filter(|m| !m.is_empty());
215
216 let mut tuple = builder.separated(",");
217 tuple
218 .push_unseparated("(")
219 .push_bind(item.id)
220 .push_bind(item.event_type)
221 .push_bind(item.options.internal_metadata.organization_id)
222 .push_bind(item.options.internal_metadata.project_id)
223 .push_bind(item.options.internal_metadata.user_id)
224 .push_bind(sqlx::types::Json(item.request))
225 .push_bind(sqlx::types::Json(rbody))
226 .push_bind(sqlx::types::Json(item.error))
227 .push_bind(rprovider)
228 .push_bind(model)
229 .push_bind(item.options.metadata.application)
230 .push_bind(item.options.metadata.environment)
231 .push_bind(item.options.metadata.organization_id)
232 .push_bind(item.options.metadata.project_id)
233 .push_bind(item.options.metadata.user_id)
234 .push_bind(item.options.metadata.workflow_id)
235 .push_bind(item.options.metadata.workflow_name)
236 .push_bind(item.options.metadata.run_id)
237 .push_bind(item.options.metadata.step_id)
238 .push_bind(item.options.metadata.step_index.map(|i| i as i32))
239 .push_bind(item.options.metadata.prompt_id)
240 .push_bind(item.options.metadata.prompt_version.map(|i| i as i32))
241 .push_bind(sqlx::types::Json(extra))
242 .push_bind(rmeta)
243 .push_bind(item.num_retries.map(|n| n as i32))
244 .push_bind(item.was_rate_limited)
245 .push_bind(item.latency.map(|d| d.as_millis() as i64))
246 .push_bind(item.total_latency.map(|d| d.as_millis() as i64))
247 .push_bind(item.timestamp)
248 .push_unseparated(")");
249 }
250}
251
252#[async_trait::async_trait]
253impl ProxyDatabase for PostgresDatabase {
254 async fn load_providers_from_database(
255 &self,
256 providers_table: &str,
257 ) -> Result<Vec<DbProvider>, Report<crate::Error>> {
258 let rows: Vec<DbProvider> = sqlx::query_as(&format!(
259 "SELECT name, label, url, api_key, format, headers, prefix, api_key_source
260 FROM {providers_table}"
261 ))
262 .fetch_all(&self.pool)
263 .await
264 .change_context(Error::LoadingDatabase)
265 .attach_printable("Failed to load providers from database")?;
266
267 Ok(rows)
268 }
269
270 async fn load_aliases_from_database(
271 &self,
272 alias_table: &str,
273 providers_table: &str,
274 ) -> Result<Vec<AliasConfig>, Report<Error>> {
275 let results = sqlx::query_as::<_, AliasConfig>(&format!(
276 "SELECT name, random_order,
277 array_agg(jsonb_build_object(
278 'provider', ap.provider,
279 'model', ap.model,
280 'api_key_name', ap.api_key_name
281 ) order by ap.sort) as models
282 FROM {alias_table} al
283 JOIN {providers_table} ap ON ap.alias_id = al.id
284 GROUP BY al.id",
285 ))
286 .fetch_all(&self.pool)
287 .await
288 .change_context(Error::LoadingDatabase)
289 .attach_printable("Failed to load aliases from database")?;
290
291 Ok(results)
292 }
293
294 async fn load_api_key_configs_from_database(
295 &self,
296 table: &str,
297 ) -> Result<Vec<ApiKeyConfig>, Report<Error>> {
298 let rows: Vec<ApiKeyConfig> =
299 sqlx::query_as(&format!("SELECT name, source, value FROM {table}"))
300 .fetch_all(&self.pool)
301 .await
302 .change_context(Error::LoadingDatabase)
303 .attach_printable("Failed to load API keys from database")?;
304
305 Ok(rows)
306 }
307
308 async fn write_log_batch(&self, entries: Vec<ProxyLogEntry>) -> Result<(), sqlx::Error> {
309 let mut event_builder = sqlx::QueryBuilder::new(super::logging::EVENT_INSERT_PREFIX);
310 let mut tx = self.pool.begin().await?;
311 let mut first_event = true;
312
313 let mut run_ids = ahash::AHashSet::new();
314
315 for entry in entries.into_iter() {
316 match entry {
317 ProxyLogEntry::Proxied(item) => {
318 if first_event {
319 first_event = false;
320 } else {
321 event_builder.push(",");
322 }
323
324 if let Some(run_id) = item.options.metadata.run_id {
325 run_ids.insert(run_id);
326 }
327
328 Self::add_event_values(&mut event_builder, *item);
329 }
330 ProxyLogEntry::Workflow(WorkflowEvent::Event(event)) => {
331 if first_event {
332 first_event = false;
333 } else {
334 event_builder.push(",");
335 }
336
337 run_ids.insert(event.run_id);
338
339 let item = ProxyLogEvent::from_payload(Uuid::now_v7(), event);
340 Self::add_event_values(&mut event_builder, item);
341 }
342 ProxyLogEntry::Workflow(WorkflowEvent::StepStart(event)) => {
343 run_ids.insert(event.run_id);
344 self.write_step_start(&mut *tx, event).await?;
345 }
346
347 ProxyLogEntry::Workflow(WorkflowEvent::StepEnd(event)) => {
348 run_ids.insert(event.run_id);
349 self.write_step_end(
350 &mut *tx,
351 event.step_id,
352 event.run_id,
353 "finished",
354 event.data.output,
355 event.data.info,
356 event.time,
357 )
358 .await?;
359 }
360 ProxyLogEntry::Workflow(WorkflowEvent::StepState(event)) => {
361 run_ids.insert(event.run_id);
362 self.write_step_status(&mut *tx, event).await?;
363 }
364 ProxyLogEntry::Workflow(WorkflowEvent::StepError(event)) => {
365 run_ids.insert(event.run_id);
366 self.write_step_end(
367 &mut *tx,
368 event.step_id,
369 event.run_id,
370 "error",
371 event.data.error,
372 None,
373 event.time,
374 )
375 .await?;
376 }
377 ProxyLogEntry::Workflow(WorkflowEvent::RunStart(event)) => {
378 run_ids.insert(event.id);
379 self.write_run_start(&mut *tx, event).await?;
380 }
381 ProxyLogEntry::Workflow(WorkflowEvent::RunUpdate(event)) => {
382 run_ids.insert(event.id);
383 self.write_run_update(&mut *tx, event).await?;
384 }
385 }
386 }
387
388 if !first_event {
389 let query = event_builder.build();
390 query.execute(&mut *tx).await?;
391 }
392
393 if !run_ids.is_empty() {
394 let mut notify_builder =
395 QueryBuilder::new("SELECT pg_notify(ch, '') FROM UNNEST(ARRAY[");
396 let mut sep = notify_builder.separated(',');
397 for run_id in run_ids {
398 sep.push_bind(format!("chronicle_run:{run_id}"));
399 }
400 sep.push_unseparated("]) AS ch");
401
402 notify_builder
403 .build()
404 .persistent(false)
405 .execute(&mut *tx)
406 .await?;
407 }
408
409 tx.commit().await?;
410 Ok(())
411 }
412}
413
414pub async fn run_default_migrations(pool: &PgPool) -> Result<(), sqlx::Error> {
418 let mut tx = pool.begin().await?;
419
420 sqlx::raw_sql(
421 "CREATE TABLE IF NOT EXISTS chronicle_meta (
422 key text PRIMARY KEY,
423 value jsonb
424 );",
425 )
426 .execute(&mut *tx)
427 .await?;
428
429 let migration_version = sqlx::query_scalar::<_, i32>(
430 "SELECT cast(value as int) FROM chronicle_meta WHERE key='migration_version'",
431 )
432 .fetch_optional(&mut *tx)
433 .await?
434 .unwrap_or(0) as usize;
435
436 tracing::info!("Migration version is {}", migration_version);
437
438 let start_migration = migration_version.min(POSTGRESQL_MIGRATIONS.len());
439 for (i, migration) in POSTGRESQL_MIGRATIONS[start_migration..].iter().enumerate() {
440 tracing::info!("Running migration {}", start_migration + i);
441 sqlx::raw_sql(migration).execute(&mut *tx).await?;
442 }
443
444 let new_version = POSTGRESQL_MIGRATIONS.len();
445
446 sqlx::query("UPDATE chronicle_meta SET value=$1::jsonb WHERE key='migration_version'")
447 .bind(new_version.to_string())
448 .execute(&mut *tx)
449 .await?;
450
451 tx.commit().await?;
452 Ok(())
453}
454
455#[cfg(test)]
456mod test {
457 use std::time::Duration;
458
459 use chrono::{DateTime, TimeZone, Utc};
460 use serde_json::json;
461 use sqlx::{postgres::PgListener, PgPool, Row};
462 use uuid::Uuid;
463
464 use crate::database::{
465 postgres::run_default_migrations,
466 testing::{test_events, TEST_EVENT1_ID, TEST_RUN_ID, TEST_STEP1_ID, TEST_STEP2_ID},
467 };
468
469 fn dt(secs: i64) -> DateTime<Utc> {
470 Utc.timestamp_opt(secs, 0).unwrap()
471 }
472
473 #[sqlx::test(migrations = false)]
474 async fn test_database_writes(pool: PgPool) {
475 filigree::tracing_config::test::init();
476 run_default_migrations(&pool).await.unwrap();
477
478 let mut listener = PgListener::connect_with(&pool).await.unwrap();
479 let listen_task = tokio::task::spawn(async move {
480 listener
481 .listen(&format!("chronicle_run:{TEST_RUN_ID}"))
482 .await
483 .unwrap();
484 listener.recv().await
485 });
486
487 let db = super::PostgresDatabase::new(pool.clone());
488
489 db.write_log_batch(test_events())
490 .await
491 .expect("Writing events");
492
493 let runs = sqlx::query(
494 "SELECT id, name, description, application, environment,
495 input, output, status, trace_id, span_id,
496 tags, info, updated_at, created_at
497 FROM chronicle_runs",
498 )
499 .fetch_all(&pool)
500 .await
501 .expect("Fetching runs");
502 assert_eq!(runs.len(), 1);
503 let run = &runs[0];
504
505 assert_eq!(run.get::<Uuid, _>(0), TEST_RUN_ID, "run id");
506 assert_eq!(run.get::<String, _>(1), "test run", "name");
507 assert_eq!(
508 run.get::<Option<String>, _>(2),
509 Some("test description".to_string()),
510 "description"
511 );
512 assert_eq!(
513 run.get::<Option<String>, _>(3),
514 Some("test application".to_string()),
515 "application"
516 );
517 assert_eq!(
518 run.get::<Option<String>, _>(4),
519 Some("test environment".to_string()),
520 "environment"
521 );
522 assert_eq!(
523 run.get::<Option<serde_json::Value>, _>(5),
524 Some(json!({"query":"abc"})),
525 "input"
526 );
527 assert_eq!(
528 run.get::<Option<serde_json::Value>, _>(6),
529 Some(json!({"result":"success"})),
530 "output"
531 );
532 assert_eq!(run.get::<String, _>(7), "finished", "status");
533 assert_eq!(
534 run.get::<Option<String>, _>(8),
535 Some("0123456789abcdef".to_string()),
536 "trace_id"
537 );
538 assert_eq!(
539 run.get::<Option<String>, _>(9),
540 Some("12345678".to_string()),
541 "span_id"
542 );
543 assert_eq!(
544 run.get::<Vec<String>, _>(10),
545 vec!["tag1".to_string(), "tag2".to_string()],
546 "tags"
547 );
548 assert_eq!(
549 run.get::<Option<serde_json::Value>, _>(11),
550 Some(json!({"info1":"value1","info2":"new_value", "info3":"value3"})),
551 "info"
552 );
553 assert_eq!(run.get::<DateTime<Utc>, _>(12), dt(5), "updated_at");
554 assert_eq!(run.get::<DateTime<Utc>, _>(13), dt(1), "created_at");
555
556 let steps = sqlx::query(
557 "SELECT id, run_id, type, parent_step, name,
558 input, output, status, span_id, tags, info, start_time, end_time
559 FROM chronicle_steps
560 ORDER BY start_time ASC",
561 )
562 .fetch_all(&pool)
563 .await
564 .expect("Fetching steps");
565 assert_eq!(steps.len(), 2);
566
567 let step1 = &steps[0];
568 assert_eq!(step1.get::<Uuid, _>(0), TEST_STEP1_ID, "id");
569 assert_eq!(step1.get::<Uuid, _>(1), TEST_RUN_ID, "run_id");
570 assert_eq!(step1.get::<String, _>(2), "step_type", "type");
571 assert_eq!(step1.get::<Option<String>, _>(3), None, "parent_step");
572 assert_eq!(step1.get::<String, _>(4), "source_node1", "name");
573 assert_eq!(
574 step1.get::<Option<serde_json::Value>, _>(5),
575 Some(json!({ "task_param": "value"})),
576 "input"
577 );
578 assert_eq!(
579 step1.get::<Option<serde_json::Value>, _>(6),
580 Some(json!({ "result": "success" })),
581 "output"
582 );
583 assert_eq!(step1.get::<String, _>(7), "finished", "status");
584 assert_eq!(
585 step1.get::<Option<String>, _>(8),
586 Some("11111111".to_string()),
587 "span_id"
588 );
589 assert_eq!(
590 step1.get::<Option<Vec<String>>, _>(9),
591 Some(vec!["dag".to_string(), "node".to_string()]),
592 "tags"
593 );
594 assert_eq!(
595 step1.get::<Option<serde_json::Value>, _>(10),
596 Some(json!({"model": "a_model", "info3": "value3"})),
597 "info"
598 );
599 assert_eq!(
600 step1.get::<DateTime<Utc>, _>(11),
601 Utc.timestamp_opt(2, 0).unwrap(),
602 "start_time"
603 );
604 assert_eq!(
605 step1.get::<DateTime<Utc>, _>(12),
606 Utc.timestamp_opt(5, 0).unwrap(),
607 "end_time"
608 );
609
610 let step2 = &steps[1];
611 assert_eq!(step2.get::<Uuid, _>(0), TEST_STEP2_ID, "id");
612 assert_eq!(step2.get::<Uuid, _>(1), TEST_RUN_ID, "run_id");
613 assert_eq!(step2.get::<String, _>(2), "llm", "type");
614 assert_eq!(
615 step2.get::<Option<Uuid>, _>(3),
616 Some(TEST_STEP1_ID),
617 "parent_step"
618 );
619 assert_eq!(step2.get::<String, _>(4), "source_node2", "name");
620 assert_eq!(
621 step2.get::<Option<serde_json::Value>, _>(5),
622 Some(json!({ "task_param2": "value"})),
623 "input"
624 );
625 assert_eq!(
626 step2.get::<Option<serde_json::Value>, _>(6),
627 Some(json!({ "message": "an error" })),
628 "output"
629 );
630 assert_eq!(step2.get::<String, _>(7), "error", "status");
631 assert_eq!(
632 step2.get::<Option<String>, _>(8),
633 Some("22222222".to_string()),
634 "span_id"
635 );
636 assert_eq!(step2.get::<Option<Vec<String>>, _>(9), None, "tags");
637 assert_eq!(
638 step2.get::<Option<serde_json::Value>, _>(10),
639 Some(json!({"model": "a_model"})),
640 "info"
641 );
642 assert_eq!(
643 step2.get::<DateTime<Utc>, _>(11),
644 Utc.timestamp_opt(3, 0).unwrap(),
645 "start_time"
646 );
647 assert_eq!(
648 step2.get::<DateTime<Utc>, _>(12),
649 Utc.timestamp_opt(5, 0).unwrap(),
650 "end_time"
651 );
652
653 let events = sqlx::query(
654 "SELECT id, event_type, step_id, run_id, meta, error, created_at
655 FROM chronicle_events
656 ORDER BY created_at ASC",
657 )
658 .fetch_all(&pool)
659 .await
660 .expect("Fetching steps");
661 assert_eq!(events.len(), 2);
662
663 let event = &events[0];
664 assert_eq!(event.get::<Uuid, _>(0), TEST_EVENT1_ID, "id");
665 assert_eq!(event.get::<String, _>(1), "query", "event_type");
666 assert_eq!(event.get::<Uuid, _>(2), TEST_STEP2_ID, "step_id");
667 assert_eq!(event.get::<Uuid, _>(3), TEST_RUN_ID, "run_id");
668 assert_eq!(
669 event.get::<Option<serde_json::Value>, _>(4),
670 Some(json!({"some_key": "some_value"})),
671 "meta"
672 );
673 assert_eq!(
674 event.get::<Option<serde_json::Value>, _>(5),
675 Some(json!(null)),
676 "error"
677 );
678 assert_eq!(event.get::<DateTime<Utc>, _>(6), dt(4), "created_at");
679
680 let event2 = &events[1];
681 assert_eq!(event2.get::<String, _>(1), "an_event", "event_type");
682 assert_eq!(event2.get::<Uuid, _>(2), TEST_STEP2_ID, "step_id");
683 assert_eq!(event2.get::<Uuid, _>(3), TEST_RUN_ID, "run_id");
684 assert_eq!(
685 event2.get::<Option<serde_json::Value>, _>(4),
686 Some(json!({"key": "value"})),
687 "meta"
688 );
689 assert_eq!(
690 event2.get::<Option<serde_json::Value>, _>(5),
691 Some(json!({ "message": "something went wrong"})),
692 "error"
693 );
694 assert_eq!(event2.get::<DateTime<Utc>, _>(6), dt(5), "created_at");
695
696 let listened = tokio::time::timeout(Duration::from_secs(1), listen_task)
697 .await
698 .expect("Listener not notified")
699 .expect("Listener panicked")
700 .expect("Listener encountered an error");
701 assert_eq!(listened.channel(), format!("chronicle_run:{TEST_RUN_ID}"));
702 assert_eq!(listened.payload(), "", "payload");
703 }
704}