1use std::sync::Arc;
3
4use chrono::{DateTime, Utc};
5use error_stack::{Report, ResultExt};
6use itertools::Itertools;
7use sqlx::{QueryBuilder, SqliteExecutor, SqlitePool};
8use uuid::Uuid;
9
10use super::{
11 logging::{ProxyLogEntry, ProxyLogEvent},
12 DbProvider, ProxyDatabase,
13};
14use crate::{
15 config::{AliasConfig, AliasConfigProvider, ApiKeyConfig},
16 workflow_events::{
17 RunStartEvent, RunUpdateEvent, StepEventData, StepStartData, StepStateData, WorkflowEvent,
18 },
19 Error,
20};
21
22const SQLITE_MIGRATIONS: &[&'static str] = &[
23 include_str!("../../migrations/20240419_chronicle_proxy_init_sqlite.sql"),
24 include_str!("../../migrations/20240424_chronicle_proxy_data_tables_sqlite.sql"),
25 include_str!("../../migrations/20240625_chronicle_proxy_steps_sqlite.sql"),
26];
27
28#[derive(Debug)]
30pub struct SqliteDatabase {
31 pool: SqlitePool,
32}
33
34impl SqliteDatabase {
35 pub fn new(pool: SqlitePool) -> Arc<dyn ProxyDatabase> {
37 Arc::new(Self { pool })
38 }
39
40 async fn write_step_start(
41 &self,
42 tx: impl SqliteExecutor<'_>,
43 event: StepEventData<StepStartData>,
44 ) -> Result<(), sqlx::Error> {
45 let tags = if event.data.tags.is_empty() {
46 None
47 } else {
48 Some(event.data.tags.join("|"))
49 };
50
51 sqlx::query(
52 r##"
53 INSERT INTO chronicle_steps (
54 id, run_id, type, parent_step, name, input, status, tags, info, span_id, start_time
55 )
56 VALUES (
57 $1, $2, $3, $4, $5, $6, 'started', $7, $8, $9, $10
58 )
59 ON CONFLICT DO NOTHING;
60 "##,
61 )
62 .bind(event.step_id.to_string())
63 .bind(event.run_id.to_string())
64 .bind(event.data.typ)
65 .bind(event.data.parent_step.map(|s| s.to_string()))
66 .bind(event.data.name)
67 .bind(event.data.input)
68 .bind(tags)
69 .bind(event.data.info)
70 .bind(event.data.span_id)
71 .bind(event.time.unwrap_or_else(|| Utc::now()).timestamp())
72 .execute(tx)
73 .await?;
74 Ok(())
75 }
76
77 async fn write_step_end(
78 &self,
79 tx: impl SqliteExecutor<'_>,
80 step_id: Uuid,
81 run_id: Uuid,
82 status: &str,
83 output: serde_json::Value,
84 info: Option<serde_json::Value>,
85 timestamp: Option<DateTime<Utc>>,
86 ) -> Result<(), sqlx::Error> {
87 sqlx::query(
88 r##"
89 UPDATE chronicle_steps
90 SET status = $1,
91 output = $2,
92 info = CASE
93 WHEN NULLIF(info, 'null') IS NULL THEN $3
94 WHEN NULLIF($3, 'null') IS NULL THEN info
95 ELSE json_patch(info, $3)
96 END,
97 end_time = $4
98 WHERE run_id = $5 AND id = $6
99 "##,
100 )
101 .bind(status)
102 .bind(output)
103 .bind(info)
104 .bind(timestamp.unwrap_or_else(|| chrono::Utc::now()).timestamp())
105 .bind(run_id.to_string())
106 .bind(step_id.to_string())
107 .execute(tx)
108 .await?;
109 Ok(())
110 }
111
112 async fn write_step_status(
113 &self,
114 tx: impl SqliteExecutor<'_>,
115 event: StepEventData<StepStateData>,
116 ) -> Result<(), sqlx::Error> {
117 sqlx::query(
118 "UPDATE chronicle_steps
119 SET status=$1
120 WHERE run_id=$2 AND id=$3",
121 )
122 .bind(event.data.state)
123 .bind(event.run_id.to_string())
124 .bind(event.step_id.to_string())
125 .execute(tx)
126 .await?;
127
128 Ok(())
129 }
130
131 async fn write_run_start(
132 &self,
133 tx: impl SqliteExecutor<'_>,
134 event: RunStartEvent,
135 ) -> Result<(), sqlx::Error> {
136 let tags = if event.tags.is_empty() {
137 None
138 } else {
139 Some(event.tags.join("|"))
140 };
141
142 sqlx::query(
143 r##"
144 INSERT INTO chronicle_runs (
145 id, name, description, application, environment, input, status,
146 trace_id, span_id, tags, info, updated_at, created_at
147 )
148 VALUES (
149 $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $12
150 )
151 ON CONFLICT(id) DO UPDATE SET
152 status = EXCLUDED.status,
153 updated_at = EXCLUDED.updated_at;
154 "##,
155 )
156 .bind(event.id.to_string())
157 .bind(event.name)
158 .bind(event.description)
159 .bind(event.application)
160 .bind(event.environment)
161 .bind(event.input)
162 .bind(event.status.as_deref().unwrap_or("started"))
163 .bind(event.trace_id)
164 .bind(event.span_id)
165 .bind(tags)
166 .bind(event.info)
167 .bind(event.time.unwrap_or_else(|| Utc::now()).timestamp())
168 .execute(tx)
169 .await?;
170 Ok(())
171 }
172
173 async fn write_run_update(
174 &self,
175 tx: impl SqliteExecutor<'_>,
176 event: RunUpdateEvent,
177 ) -> Result<(), sqlx::Error> {
178 sqlx::query(
179 "UPDATE chronicle_runs
180 SET status = $1,
181 output = $2,
182 info = CASE
183 WHEN NULLIF(info, 'null') IS NULL THEN $3
184 WHEN NULLIF($3, 'null') IS NULL THEN info
185 ELSE json_patch(info, $3)
186 END,
187 updated_at = $4
188 WHERE id = $5",
189 )
190 .bind(event.status.as_deref().unwrap_or("finished"))
191 .bind(event.output)
192 .bind(event.info)
193 .bind(event.time.unwrap_or_else(|| Utc::now()).timestamp())
194 .bind(event.id.to_string())
195 .execute(tx)
196 .await?;
197 Ok(())
198 }
199
200 fn add_event_values(builder: &mut QueryBuilder<'_, sqlx::Sqlite>, item: ProxyLogEvent) {
201 let (rmodel, rprovider, rbody, rmeta) = match item
202 .response
203 .map(|r| (r.body.model.clone(), r.provider, r.body, r.info.meta))
204 {
205 Some((rmodel, rprovider, rbody, rmeta)) => {
206 (rmodel, Some(rprovider), Some(rbody), rmeta)
207 }
208 None => (None, None, None, None),
209 };
210
211 let model = rmodel
212 .or_else(|| item.request.as_ref().and_then(|r| r.model.clone()))
213 .unwrap_or_default();
214
215 let extra = item.options.metadata.extra.filter(|m| !m.is_empty());
216
217 let mut tuple = builder.separated(",");
218 tuple
219 .push_unseparated("(")
220 .push_bind(item.id.to_string())
223 .push_bind(item.event_type)
224 .push_bind(item.options.internal_metadata.organization_id)
225 .push_bind(item.options.internal_metadata.project_id)
226 .push_bind(item.options.internal_metadata.user_id)
227 .push_bind(sqlx::types::Json(item.request))
228 .push_bind(sqlx::types::Json(rbody))
229 .push_bind(sqlx::types::Json(item.error))
230 .push_bind(rprovider)
231 .push_bind(model)
232 .push_bind(item.options.metadata.application)
233 .push_bind(item.options.metadata.environment)
234 .push_bind(item.options.metadata.organization_id)
235 .push_bind(item.options.metadata.project_id)
236 .push_bind(item.options.metadata.user_id)
237 .push_bind(item.options.metadata.workflow_id)
238 .push_bind(item.options.metadata.workflow_name)
239 .push_bind(item.options.metadata.run_id.map(|u| u.to_string()))
240 .push_bind(item.options.metadata.step_id.map(|u| u.to_string()))
241 .push_bind(item.options.metadata.step_index.map(|i| i as i32))
242 .push_bind(item.options.metadata.prompt_id)
243 .push_bind(item.options.metadata.prompt_version.map(|i| i as i32))
244 .push_bind(sqlx::types::Json(extra))
245 .push_bind(rmeta)
246 .push_bind(item.num_retries.map(|n| n as i32))
247 .push_bind(item.was_rate_limited)
248 .push_bind(item.latency.map(|d| d.as_millis() as i64))
249 .push_bind(item.total_latency.map(|d| d.as_millis() as i64))
250 .push_bind(item.timestamp.timestamp())
251 .push_unseparated(")");
252 }
253}
254
255#[async_trait::async_trait]
256impl ProxyDatabase for SqliteDatabase {
257 async fn load_providers_from_database(
258 &self,
259 providers_table: &str,
260 ) -> Result<Vec<DbProvider>, Report<crate::Error>> {
261 let rows: Vec<DbProvider> = sqlx::query_as(&format!(
262 "SELECT name, label, url, api_key, format, headers, prefix, api_key_source
263 FROM {providers_table}"
264 ))
265 .fetch_all(&self.pool)
266 .await
267 .change_context(Error::LoadingDatabase)
268 .attach_printable("Failed to load providers from database")?;
269
270 Ok(rows)
271 }
272
273 async fn load_aliases_from_database(
279 &self,
280 alias_table: &str,
281 providers_table: &str,
282 ) -> Result<Vec<AliasConfig>, Report<Error>> {
283 #[derive(sqlx::FromRow)]
284 struct AliasRow {
285 id: i64,
286 name: String,
287 random_order: bool,
288 }
289
290 let aliases: Vec<AliasRow> = sqlx::query_as(&format!(
291 "SELECT id, name, random_order FROM {alias_table} ORDER BY id"
292 ))
293 .fetch_all(&self.pool)
294 .await
295 .change_context(Error::LoadingDatabase)?;
296
297 #[derive(sqlx::FromRow, Debug)]
298 struct DbAliasConfigProvider {
299 alias_id: i64,
300 provider: String,
301 model: String,
302 api_key_name: Option<String>,
303 }
304 let models: Vec<DbAliasConfigProvider> = sqlx::query_as(&format!(
305 "SELECT alias_id, provider, model, api_key_name
306 FROM {providers_table}
307 JOIN {alias_table} ON {alias_table}.id = {providers_table}.alias_id
308 ORDER BY alias_id, sort"
309 ))
310 .fetch_all(&self.pool)
311 .await
312 .change_context(Error::LoadingDatabase)?;
313
314 let mut output = Vec::with_capacity(aliases.len());
315 let mut aliases = aliases.into_iter();
316 let mut models = models.into_iter().peekable();
317
318 while let Some(alias) = aliases.next() {
319 let models = models
320 .by_ref()
321 .peeking_take_while(|model| model.alias_id == alias.id)
322 .map(|model| AliasConfigProvider {
323 provider: model.provider,
324 model: model.model,
325 api_key_name: model.api_key_name,
326 })
327 .collect();
328 output.push(AliasConfig {
329 name: alias.name,
330 random_order: alias.random_order,
331 models,
332 });
333 }
334
335 Ok(output)
336 }
337
338 async fn load_api_key_configs_from_database(
339 &self,
340 table: &str,
341 ) -> Result<Vec<ApiKeyConfig>, Report<Error>> {
342 let rows: Vec<ApiKeyConfig> =
343 sqlx::query_as(&format!("SELECT name, source, value FROM {table}"))
344 .fetch_all(&self.pool)
345 .await
346 .change_context(Error::LoadingDatabase)
347 .attach_printable("Failed to load API keys from database")?;
348
349 Ok(rows)
350 }
351
352 async fn write_log_batch(&self, entries: Vec<ProxyLogEntry>) -> Result<(), sqlx::Error> {
353 let mut event_builder = sqlx::QueryBuilder::new(super::logging::EVENT_INSERT_PREFIX);
354 let mut tx = self.pool.begin().await?;
355 let mut first_event = true;
356
357 for entry in entries.into_iter() {
358 tracing::debug!(?entry, "Processing event");
359 match entry {
360 ProxyLogEntry::Proxied(item) => {
361 if first_event {
362 first_event = false;
363 } else {
364 event_builder.push(",");
365 }
366
367 Self::add_event_values(&mut event_builder, *item);
368 }
369 ProxyLogEntry::Workflow(WorkflowEvent::Event(event)) => {
370 if first_event {
371 first_event = false;
372 } else {
373 event_builder.push(",");
374 }
375
376 let item = ProxyLogEvent::from_payload(Uuid::now_v7(), event);
377 Self::add_event_values(&mut event_builder, item);
378 }
379 ProxyLogEntry::Workflow(WorkflowEvent::StepStart(event)) => {
380 self.write_step_start(&mut *tx, event).await?;
381 }
382
383 ProxyLogEntry::Workflow(WorkflowEvent::StepEnd(event)) => {
384 self.write_step_end(
385 &mut *tx,
386 event.step_id,
387 event.run_id,
388 "finished",
389 event.data.output,
390 event.data.info,
391 event.time,
392 )
393 .await?;
394 }
395 ProxyLogEntry::Workflow(WorkflowEvent::StepState(event)) => {
396 self.write_step_status(&mut *tx, event).await?;
397 }
398 ProxyLogEntry::Workflow(WorkflowEvent::StepError(event)) => {
399 self.write_step_end(
400 &mut *tx,
401 event.step_id,
402 event.run_id,
403 "error",
404 event.data.error,
405 None,
406 event.time,
407 )
408 .await?;
409 }
410 ProxyLogEntry::Workflow(WorkflowEvent::RunStart(event)) => {
411 self.write_run_start(&mut *tx, event).await?;
412 }
413 ProxyLogEntry::Workflow(WorkflowEvent::RunUpdate(event)) => {
414 self.write_run_update(&mut *tx, event).await?;
415 }
416 }
417 }
418
419 if !first_event {
420 let query = event_builder.build();
421 query.execute(&mut *tx).await?;
422 }
423
424 tx.commit().await?;
425
426 Ok(())
427 }
428}
429
430pub async fn run_default_migrations(pool: &SqlitePool) -> Result<(), sqlx::Error> {
434 let mut tx = pool.begin().await?;
435 sqlx::raw_sql(
436 "CREATE TABLE IF NOT EXISTS chronicle_meta (
437 key text PRIMARY KEY,
438 value text
439 );",
440 )
441 .execute(&mut *tx)
442 .await?;
443
444 let migration_version = sqlx::query_scalar::<_, i32>(
445 "SELECT cast(value as int) FROM chronicle_meta WHERE key='migration_version'",
446 )
447 .fetch_optional(&mut *tx)
448 .await?
449 .unwrap_or(0) as usize;
450
451 tracing::info!("Migration version is {}", migration_version);
452
453 let start_migration = migration_version.min(SQLITE_MIGRATIONS.len());
454 for (i, migration) in SQLITE_MIGRATIONS[start_migration..].iter().enumerate() {
455 tracing::info!("Running migration {}", start_migration + i);
456 sqlx::raw_sql(migration).execute(&mut *tx).await?;
457 }
458
459 let new_version = SQLITE_MIGRATIONS.len();
460
461 sqlx::query("UPDATE chronicle_meta SET value=$1 WHERE key='migration_version'")
462 .bind(new_version.to_string())
463 .execute(&mut *tx)
464 .await?;
465
466 tx.commit().await?;
467 Ok(())
468}
469
470#[cfg(test)]
471mod test {
472 use serde_json::json;
473 use sqlx::Row;
474
475 use crate::database::{
476 sqlite::run_default_migrations,
477 testing::{test_events, TEST_EVENT1_ID, TEST_RUN_ID, TEST_STEP1_ID, TEST_STEP2_ID},
478 };
479
480 #[sqlx::test(migrations = false)]
481 async fn test_database_writes(pool: sqlx::SqlitePool) {
482 filigree::tracing_config::test::init();
483 run_default_migrations(&pool).await.unwrap();
484
485 let db = super::SqliteDatabase::new(pool.clone());
486
487 db.write_log_batch(test_events())
488 .await
489 .expect("Writing events");
490
491 let runs = sqlx::query(
492 "SELECT id, name, description, application, environment,
493 input, output, status, trace_id, span_id,
494 tags, info, updated_at, created_at
495 FROM chronicle_runs",
496 )
497 .fetch_all(&pool)
498 .await
499 .expect("Fetching runs");
500 assert_eq!(runs.len(), 1);
501 let run = &runs[0];
502
503 assert_eq!(run.get::<String, _>(0), TEST_RUN_ID.to_string(), "run id");
504 assert_eq!(run.get::<String, _>(1), "test run", "name");
505 assert_eq!(
506 run.get::<Option<String>, _>(2),
507 Some("test description".to_string()),
508 "description"
509 );
510 assert_eq!(
511 run.get::<Option<String>, _>(3),
512 Some("test application".to_string()),
513 "application"
514 );
515 assert_eq!(
516 run.get::<Option<String>, _>(4),
517 Some("test environment".to_string()),
518 "environment"
519 );
520 assert_eq!(
521 run.get::<Option<serde_json::Value>, _>(5),
522 Some(json!({"query":"abc"})),
523 "input"
524 );
525 assert_eq!(
526 run.get::<Option<serde_json::Value>, _>(6),
527 Some(json!({"result":"success"})),
528 "output"
529 );
530 assert_eq!(run.get::<String, _>(7), "finished", "status");
531 assert_eq!(
532 run.get::<Option<String>, _>(8),
533 Some("0123456789abcdef".to_string()),
534 "trace_id"
535 );
536 assert_eq!(
537 run.get::<Option<String>, _>(9),
538 Some("12345678".to_string()),
539 "span_id"
540 );
541 assert_eq!(
542 run.get::<Option<String>, _>(10),
543 Some("tag1|tag2".to_string()),
544 "tags"
545 );
546 assert_eq!(
547 run.get::<Option<serde_json::Value>, _>(11),
548 Some(json!({"info1":"value1","info2":"new_value", "info3":"value3"})),
549 "info"
550 );
551 assert_eq!(run.get::<i64, _>(12), 5, "updated_at");
552 assert_eq!(run.get::<i64, _>(13), 1, "created_at");
553
554 let steps = sqlx::query(
555 "SELECT id, run_id, type, parent_step, name,
556 input, output, status, span_id, tags, info, start_time, end_time
557 FROM chronicle_steps",
558 )
559 .fetch_all(&pool)
560 .await
561 .expect("Fetching steps");
562 assert_eq!(steps.len(), 2);
563
564 let step1 = &steps[0];
565 assert_eq!(step1.get::<String, _>(0), TEST_STEP1_ID.to_string(), "id");
566 assert_eq!(step1.get::<String, _>(1), TEST_RUN_ID.to_string(), "run_id");
567 assert_eq!(step1.get::<String, _>(2), "step_type", "type");
568 assert_eq!(step1.get::<Option<String>, _>(3), None, "parent_step");
569 assert_eq!(step1.get::<String, _>(4), "source_node1", "name");
570 assert_eq!(
571 step1.get::<Option<serde_json::Value>, _>(5),
572 Some(json!({ "task_param": "value"})),
573 "input"
574 );
575 assert_eq!(
576 step1.get::<Option<serde_json::Value>, _>(6),
577 Some(json!({ "result": "success" })),
578 "output"
579 );
580 assert_eq!(step1.get::<String, _>(7), "finished", "status");
581 assert_eq!(
582 step1.get::<Option<String>, _>(8),
583 Some("11111111".to_string()),
584 "span_id"
585 );
586 assert_eq!(
587 step1.get::<Option<String>, _>(9),
588 Some("dag|node".to_string()),
589 "tags"
590 );
591 assert_eq!(
592 step1.get::<Option<serde_json::Value>, _>(10),
593 Some(json!({"model": "a_model", "info3": "value3"})),
594 "info"
595 );
596 assert_eq!(step1.get::<i64, _>(11), 2, "start_time");
597 assert_eq!(step1.get::<i64, _>(12), 5, "end_time");
598
599 let step2 = &steps[1];
600 assert_eq!(step2.get::<String, _>(0), TEST_STEP2_ID.to_string(), "id");
601 assert_eq!(step2.get::<String, _>(1), TEST_RUN_ID.to_string(), "run_id");
602 assert_eq!(step2.get::<String, _>(2), "llm", "type");
603 assert_eq!(
604 step2.get::<Option<String>, _>(3),
605 Some(TEST_STEP1_ID.to_string()),
606 "parent_step"
607 );
608 assert_eq!(step2.get::<String, _>(4), "source_node2", "name");
609 assert_eq!(
610 step2.get::<Option<serde_json::Value>, _>(5),
611 Some(json!({ "task_param2": "value"})),
612 "input"
613 );
614 assert_eq!(
615 step2.get::<Option<serde_json::Value>, _>(6),
616 Some(json!({ "message": "an error" })),
617 "output"
618 );
619 assert_eq!(step2.get::<String, _>(7), "error", "status");
620 assert_eq!(
621 step2.get::<Option<String>, _>(8),
622 Some("22222222".to_string()),
623 "span_id"
624 );
625 assert_eq!(step2.get::<Option<String>, _>(9), None, "tags");
626 assert_eq!(
627 step2.get::<Option<serde_json::Value>, _>(10),
628 Some(json!({"model": "a_model"})),
629 "info"
630 );
631 assert_eq!(step2.get::<i64, _>(11), 3, "start_time");
632 assert_eq!(step2.get::<i64, _>(12), 5, "end_time");
633
634 let events = sqlx::query(
635 "SELECT id, event_type, step_id, run_id, meta, error, created_at
636 FROM chronicle_events
637 ORDER BY created_at ASC",
638 )
639 .fetch_all(&pool)
640 .await
641 .expect("Fetching steps");
642 assert_eq!(events.len(), 2);
643
644 let event = &events[0];
645 assert_eq!(event.get::<String, _>(0), TEST_EVENT1_ID.to_string(), "id");
646 assert_eq!(event.get::<String, _>(1), "query", "event_type");
647 assert_eq!(
648 event.get::<String, _>(2),
649 TEST_STEP2_ID.to_string(),
650 "step_id"
651 );
652 assert_eq!(event.get::<String, _>(3), TEST_RUN_ID.to_string(), "run_id");
653 assert_eq!(
654 event.get::<Option<serde_json::Value>, _>(4),
655 Some(json!({"some_key": "some_value"})),
656 "meta"
657 );
658 assert_eq!(
659 event.get::<Option<serde_json::Value>, _>(5),
660 Some(json!(null)),
661 "error"
662 );
663 assert_eq!(event.get::<i64, _>(6), 4, "created_at");
664
665 let event2 = &events[1];
666 assert_eq!(event2.get::<String, _>(1), "an_event", "event_type");
667 assert_eq!(
668 event2.get::<String, _>(2),
669 TEST_STEP2_ID.to_string(),
670 "step_id"
671 );
672 assert_eq!(
673 event2.get::<String, _>(3),
674 TEST_RUN_ID.to_string(),
675 "run_id"
676 );
677 assert_eq!(
678 event2.get::<Option<serde_json::Value>, _>(4),
679 Some(json!({"key": "value"})),
680 "meta"
681 );
682 assert_eq!(
683 event2.get::<Option<serde_json::Value>, _>(5),
684 Some(json!({ "message": "something went wrong"})),
685 "error"
686 );
687 assert_eq!(event2.get::<i64, _>(6), 5, "created_at");
688 }
689}