chronicle_proxy/database/
sqlite.rs

1//! SQLite database logging support
2use std::sync::Arc;
3
4use chrono::{DateTime, Utc};
5use error_stack::{Report, ResultExt};
6use itertools::Itertools;
7use sqlx::{QueryBuilder, SqliteExecutor, SqlitePool};
8use uuid::Uuid;
9
10use super::{
11    logging::{ProxyLogEntry, ProxyLogEvent},
12    DbProvider, ProxyDatabase,
13};
14use crate::{
15    config::{AliasConfig, AliasConfigProvider, ApiKeyConfig},
16    workflow_events::{
17        RunStartEvent, RunUpdateEvent, StepEventData, StepStartData, StepStateData, WorkflowEvent,
18    },
19    Error,
20};
21
22const SQLITE_MIGRATIONS: &[&'static str] = &[
23    include_str!("../../migrations/20240419_chronicle_proxy_init_sqlite.sql"),
24    include_str!("../../migrations/20240424_chronicle_proxy_data_tables_sqlite.sql"),
25    include_str!("../../migrations/20240625_chronicle_proxy_steps_sqlite.sql"),
26];
27
28/// Log events to an SQLite database
29#[derive(Debug)]
30pub struct SqliteDatabase {
31    pool: SqlitePool,
32}
33
34impl SqliteDatabase {
35    /// Create a new [SqliteDatabase]
36    pub fn new(pool: SqlitePool) -> Arc<dyn ProxyDatabase> {
37        Arc::new(Self { pool })
38    }
39
40    async fn write_step_start(
41        &self,
42        tx: impl SqliteExecutor<'_>,
43        event: StepEventData<StepStartData>,
44    ) -> Result<(), sqlx::Error> {
45        let tags = if event.data.tags.is_empty() {
46            None
47        } else {
48            Some(event.data.tags.join("|"))
49        };
50
51        sqlx::query(
52            r##"
53            INSERT INTO chronicle_steps (
54                id, run_id, type, parent_step, name, input, status, tags, info, span_id, start_time
55            )
56            VALUES (
57                $1, $2, $3, $4, $5, $6, 'started', $7, $8, $9, $10
58            )
59            ON CONFLICT DO NOTHING;
60            "##,
61        )
62        .bind(event.step_id.to_string())
63        .bind(event.run_id.to_string())
64        .bind(event.data.typ)
65        .bind(event.data.parent_step.map(|s| s.to_string()))
66        .bind(event.data.name)
67        .bind(event.data.input)
68        .bind(tags)
69        .bind(event.data.info)
70        .bind(event.data.span_id)
71        .bind(event.time.unwrap_or_else(|| Utc::now()).timestamp())
72        .execute(tx)
73        .await?;
74        Ok(())
75    }
76
77    async fn write_step_end(
78        &self,
79        tx: impl SqliteExecutor<'_>,
80        step_id: Uuid,
81        run_id: Uuid,
82        status: &str,
83        output: serde_json::Value,
84        info: Option<serde_json::Value>,
85        timestamp: Option<DateTime<Utc>>,
86    ) -> Result<(), sqlx::Error> {
87        sqlx::query(
88            r##"
89            UPDATE chronicle_steps
90            SET status = $1,
91                output = $2,
92                info = CASE
93                    WHEN NULLIF(info, 'null') IS NULL THEN $3
94                    WHEN NULLIF($3, 'null') IS NULL THEN info
95                    ELSE json_patch(info, $3)
96                    END,
97                end_time = $4
98            WHERE run_id = $5 AND id = $6
99        "##,
100        )
101        .bind(status)
102        .bind(output)
103        .bind(info)
104        .bind(timestamp.unwrap_or_else(|| chrono::Utc::now()).timestamp())
105        .bind(run_id.to_string())
106        .bind(step_id.to_string())
107        .execute(tx)
108        .await?;
109        Ok(())
110    }
111
112    async fn write_step_status(
113        &self,
114        tx: impl SqliteExecutor<'_>,
115        event: StepEventData<StepStateData>,
116    ) -> Result<(), sqlx::Error> {
117        sqlx::query(
118            "UPDATE chronicle_steps
119                    SET status=$1
120                    WHERE run_id=$2 AND id=$3",
121        )
122        .bind(event.data.state)
123        .bind(event.run_id.to_string())
124        .bind(event.step_id.to_string())
125        .execute(tx)
126        .await?;
127
128        Ok(())
129    }
130
131    async fn write_run_start(
132        &self,
133        tx: impl SqliteExecutor<'_>,
134        event: RunStartEvent,
135    ) -> Result<(), sqlx::Error> {
136        let tags = if event.tags.is_empty() {
137            None
138        } else {
139            Some(event.tags.join("|"))
140        };
141
142        sqlx::query(
143            r##"
144            INSERT INTO chronicle_runs (
145                id, name, description, application, environment, input, status,
146                    trace_id, span_id, tags, info, updated_at, created_at
147            )
148            VALUES (
149                $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $12
150            )
151            ON CONFLICT(id) DO UPDATE SET
152                status = EXCLUDED.status,
153                updated_at = EXCLUDED.updated_at;
154            "##,
155        )
156        .bind(event.id.to_string())
157        .bind(event.name)
158        .bind(event.description)
159        .bind(event.application)
160        .bind(event.environment)
161        .bind(event.input)
162        .bind(event.status.as_deref().unwrap_or("started"))
163        .bind(event.trace_id)
164        .bind(event.span_id)
165        .bind(tags)
166        .bind(event.info)
167        .bind(event.time.unwrap_or_else(|| Utc::now()).timestamp())
168        .execute(tx)
169        .await?;
170        Ok(())
171    }
172
173    async fn write_run_update(
174        &self,
175        tx: impl SqliteExecutor<'_>,
176        event: RunUpdateEvent,
177    ) -> Result<(), sqlx::Error> {
178        sqlx::query(
179            "UPDATE chronicle_runs
180            SET status = $1,
181                output = $2,
182                info = CASE
183                    WHEN NULLIF(info, 'null') IS NULL THEN $3
184                    WHEN NULLIF($3, 'null') IS NULL THEN info
185                    ELSE json_patch(info, $3)
186                    END,
187                updated_at = $4
188            WHERE id = $5",
189        )
190        .bind(event.status.as_deref().unwrap_or("finished"))
191        .bind(event.output)
192        .bind(event.info)
193        .bind(event.time.unwrap_or_else(|| Utc::now()).timestamp())
194        .bind(event.id.to_string())
195        .execute(tx)
196        .await?;
197        Ok(())
198    }
199
200    fn add_event_values(builder: &mut QueryBuilder<'_, sqlx::Sqlite>, item: ProxyLogEvent) {
201        let (rmodel, rprovider, rbody, rmeta) = match item
202            .response
203            .map(|r| (r.body.model.clone(), r.provider, r.body, r.info.meta))
204        {
205            Some((rmodel, rprovider, rbody, rmeta)) => {
206                (rmodel, Some(rprovider), Some(rbody), rmeta)
207            }
208            None => (None, None, None, None),
209        };
210
211        let model = rmodel
212            .or_else(|| item.request.as_ref().and_then(|r| r.model.clone()))
213            .unwrap_or_default();
214
215        let extra = item.options.metadata.extra.filter(|m| !m.is_empty());
216
217        let mut tuple = builder.separated(",");
218        tuple
219            .push_unseparated("(")
220            // sqlx encodes UUIDs as binary blobs by default with Sqlite, which is often nice
221            // but not what we want here.
222            .push_bind(item.id.to_string())
223            .push_bind(item.event_type)
224            .push_bind(item.options.internal_metadata.organization_id)
225            .push_bind(item.options.internal_metadata.project_id)
226            .push_bind(item.options.internal_metadata.user_id)
227            .push_bind(sqlx::types::Json(item.request))
228            .push_bind(sqlx::types::Json(rbody))
229            .push_bind(sqlx::types::Json(item.error))
230            .push_bind(rprovider)
231            .push_bind(model)
232            .push_bind(item.options.metadata.application)
233            .push_bind(item.options.metadata.environment)
234            .push_bind(item.options.metadata.organization_id)
235            .push_bind(item.options.metadata.project_id)
236            .push_bind(item.options.metadata.user_id)
237            .push_bind(item.options.metadata.workflow_id)
238            .push_bind(item.options.metadata.workflow_name)
239            .push_bind(item.options.metadata.run_id.map(|u| u.to_string()))
240            .push_bind(item.options.metadata.step_id.map(|u| u.to_string()))
241            .push_bind(item.options.metadata.step_index.map(|i| i as i32))
242            .push_bind(item.options.metadata.prompt_id)
243            .push_bind(item.options.metadata.prompt_version.map(|i| i as i32))
244            .push_bind(sqlx::types::Json(extra))
245            .push_bind(rmeta)
246            .push_bind(item.num_retries.map(|n| n as i32))
247            .push_bind(item.was_rate_limited)
248            .push_bind(item.latency.map(|d| d.as_millis() as i64))
249            .push_bind(item.total_latency.map(|d| d.as_millis() as i64))
250            .push_bind(item.timestamp.timestamp())
251            .push_unseparated(")");
252    }
253}
254
255#[async_trait::async_trait]
256impl ProxyDatabase for SqliteDatabase {
257    async fn load_providers_from_database(
258        &self,
259        providers_table: &str,
260    ) -> Result<Vec<DbProvider>, Report<crate::Error>> {
261        let rows: Vec<DbProvider> = sqlx::query_as(&format!(
262            "SELECT name, label, url, api_key, format, headers, prefix, api_key_source
263        FROM {providers_table}"
264        ))
265        .fetch_all(&self.pool)
266        .await
267        .change_context(Error::LoadingDatabase)
268        .attach_printable("Failed to load providers from database")?;
269
270        Ok(rows)
271    }
272
273    // SQLite's JSON support sucks in 3.44 which sqlx currently includes, so it's not
274    // possible to just do this via json_group_array(json_object(...)) right now since
275    // the values in the array are strings of JSON instead of normal JSON. Next version
276    // of sqlx will include 3.45 which works better and we can make this look more like
277    // the Postgres version.
278    async fn load_aliases_from_database(
279        &self,
280        alias_table: &str,
281        providers_table: &str,
282    ) -> Result<Vec<AliasConfig>, Report<Error>> {
283        #[derive(sqlx::FromRow)]
284        struct AliasRow {
285            id: i64,
286            name: String,
287            random_order: bool,
288        }
289
290        let aliases: Vec<AliasRow> = sqlx::query_as(&format!(
291            "SELECT id, name, random_order FROM {alias_table} ORDER BY id"
292        ))
293        .fetch_all(&self.pool)
294        .await
295        .change_context(Error::LoadingDatabase)?;
296
297        #[derive(sqlx::FromRow, Debug)]
298        struct DbAliasConfigProvider {
299            alias_id: i64,
300            provider: String,
301            model: String,
302            api_key_name: Option<String>,
303        }
304        let models: Vec<DbAliasConfigProvider> = sqlx::query_as(&format!(
305            "SELECT alias_id, provider, model, api_key_name
306        FROM {providers_table}
307        JOIN {alias_table} ON {alias_table}.id = {providers_table}.alias_id
308        ORDER BY alias_id, sort"
309        ))
310        .fetch_all(&self.pool)
311        .await
312        .change_context(Error::LoadingDatabase)?;
313
314        let mut output = Vec::with_capacity(aliases.len());
315        let mut aliases = aliases.into_iter();
316        let mut models = models.into_iter().peekable();
317
318        while let Some(alias) = aliases.next() {
319            let models = models
320                .by_ref()
321                .peeking_take_while(|model| model.alias_id == alias.id)
322                .map(|model| AliasConfigProvider {
323                    provider: model.provider,
324                    model: model.model,
325                    api_key_name: model.api_key_name,
326                })
327                .collect();
328            output.push(AliasConfig {
329                name: alias.name,
330                random_order: alias.random_order,
331                models,
332            });
333        }
334
335        Ok(output)
336    }
337
338    async fn load_api_key_configs_from_database(
339        &self,
340        table: &str,
341    ) -> Result<Vec<ApiKeyConfig>, Report<Error>> {
342        let rows: Vec<ApiKeyConfig> =
343            sqlx::query_as(&format!("SELECT name, source, value FROM {table}"))
344                .fetch_all(&self.pool)
345                .await
346                .change_context(Error::LoadingDatabase)
347                .attach_printable("Failed to load API keys from database")?;
348
349        Ok(rows)
350    }
351
352    async fn write_log_batch(&self, entries: Vec<ProxyLogEntry>) -> Result<(), sqlx::Error> {
353        let mut event_builder = sqlx::QueryBuilder::new(super::logging::EVENT_INSERT_PREFIX);
354        let mut tx = self.pool.begin().await?;
355        let mut first_event = true;
356
357        for entry in entries.into_iter() {
358            tracing::debug!(?entry, "Processing event");
359            match entry {
360                ProxyLogEntry::Proxied(item) => {
361                    if first_event {
362                        first_event = false;
363                    } else {
364                        event_builder.push(",");
365                    }
366
367                    Self::add_event_values(&mut event_builder, *item);
368                }
369                ProxyLogEntry::Workflow(WorkflowEvent::Event(event)) => {
370                    if first_event {
371                        first_event = false;
372                    } else {
373                        event_builder.push(",");
374                    }
375
376                    let item = ProxyLogEvent::from_payload(Uuid::now_v7(), event);
377                    Self::add_event_values(&mut event_builder, item);
378                }
379                ProxyLogEntry::Workflow(WorkflowEvent::StepStart(event)) => {
380                    self.write_step_start(&mut *tx, event).await?;
381                }
382
383                ProxyLogEntry::Workflow(WorkflowEvent::StepEnd(event)) => {
384                    self.write_step_end(
385                        &mut *tx,
386                        event.step_id,
387                        event.run_id,
388                        "finished",
389                        event.data.output,
390                        event.data.info,
391                        event.time,
392                    )
393                    .await?;
394                }
395                ProxyLogEntry::Workflow(WorkflowEvent::StepState(event)) => {
396                    self.write_step_status(&mut *tx, event).await?;
397                }
398                ProxyLogEntry::Workflow(WorkflowEvent::StepError(event)) => {
399                    self.write_step_end(
400                        &mut *tx,
401                        event.step_id,
402                        event.run_id,
403                        "error",
404                        event.data.error,
405                        None,
406                        event.time,
407                    )
408                    .await?;
409                }
410                ProxyLogEntry::Workflow(WorkflowEvent::RunStart(event)) => {
411                    self.write_run_start(&mut *tx, event).await?;
412                }
413                ProxyLogEntry::Workflow(WorkflowEvent::RunUpdate(event)) => {
414                    self.write_run_update(&mut *tx, event).await?;
415                }
416            }
417        }
418
419        if !first_event {
420            let query = event_builder.build();
421            query.execute(&mut *tx).await?;
422        }
423
424        tx.commit().await?;
425
426        Ok(())
427    }
428}
429
430/// Run database migrations specific to the proxy. These migrations are designed for a simple setup with
431/// single-tenant use. You may want to add multi-tenant features or partitioning, and can integrate
432/// the files from the `migrations` directory into your project to accomplish that.
433pub async fn run_default_migrations(pool: &SqlitePool) -> Result<(), sqlx::Error> {
434    let mut tx = pool.begin().await?;
435    sqlx::raw_sql(
436        "CREATE TABLE IF NOT EXISTS chronicle_meta (
437          key text PRIMARY KEY,
438          value text
439        );",
440    )
441    .execute(&mut *tx)
442    .await?;
443
444    let migration_version = sqlx::query_scalar::<_, i32>(
445        "SELECT cast(value as int) FROM chronicle_meta WHERE key='migration_version'",
446    )
447    .fetch_optional(&mut *tx)
448    .await?
449    .unwrap_or(0) as usize;
450
451    tracing::info!("Migration version is {}", migration_version);
452
453    let start_migration = migration_version.min(SQLITE_MIGRATIONS.len());
454    for (i, migration) in SQLITE_MIGRATIONS[start_migration..].iter().enumerate() {
455        tracing::info!("Running migration {}", start_migration + i);
456        sqlx::raw_sql(migration).execute(&mut *tx).await?;
457    }
458
459    let new_version = SQLITE_MIGRATIONS.len();
460
461    sqlx::query("UPDATE chronicle_meta SET value=$1 WHERE key='migration_version'")
462        .bind(new_version.to_string())
463        .execute(&mut *tx)
464        .await?;
465
466    tx.commit().await?;
467    Ok(())
468}
469
470#[cfg(test)]
471mod test {
472    use serde_json::json;
473    use sqlx::Row;
474
475    use crate::database::{
476        sqlite::run_default_migrations,
477        testing::{test_events, TEST_EVENT1_ID, TEST_RUN_ID, TEST_STEP1_ID, TEST_STEP2_ID},
478    };
479
480    #[sqlx::test(migrations = false)]
481    async fn test_database_writes(pool: sqlx::SqlitePool) {
482        filigree::tracing_config::test::init();
483        run_default_migrations(&pool).await.unwrap();
484
485        let db = super::SqliteDatabase::new(pool.clone());
486
487        db.write_log_batch(test_events())
488            .await
489            .expect("Writing events");
490
491        let runs = sqlx::query(
492            "SELECT id, name, description, application, environment,
493                input, output, status, trace_id, span_id,
494                tags, info, updated_at, created_at
495                FROM chronicle_runs",
496        )
497        .fetch_all(&pool)
498        .await
499        .expect("Fetching runs");
500        assert_eq!(runs.len(), 1);
501        let run = &runs[0];
502
503        assert_eq!(run.get::<String, _>(0), TEST_RUN_ID.to_string(), "run id");
504        assert_eq!(run.get::<String, _>(1), "test run", "name");
505        assert_eq!(
506            run.get::<Option<String>, _>(2),
507            Some("test description".to_string()),
508            "description"
509        );
510        assert_eq!(
511            run.get::<Option<String>, _>(3),
512            Some("test application".to_string()),
513            "application"
514        );
515        assert_eq!(
516            run.get::<Option<String>, _>(4),
517            Some("test environment".to_string()),
518            "environment"
519        );
520        assert_eq!(
521            run.get::<Option<serde_json::Value>, _>(5),
522            Some(json!({"query":"abc"})),
523            "input"
524        );
525        assert_eq!(
526            run.get::<Option<serde_json::Value>, _>(6),
527            Some(json!({"result":"success"})),
528            "output"
529        );
530        assert_eq!(run.get::<String, _>(7), "finished", "status");
531        assert_eq!(
532            run.get::<Option<String>, _>(8),
533            Some("0123456789abcdef".to_string()),
534            "trace_id"
535        );
536        assert_eq!(
537            run.get::<Option<String>, _>(9),
538            Some("12345678".to_string()),
539            "span_id"
540        );
541        assert_eq!(
542            run.get::<Option<String>, _>(10),
543            Some("tag1|tag2".to_string()),
544            "tags"
545        );
546        assert_eq!(
547            run.get::<Option<serde_json::Value>, _>(11),
548            Some(json!({"info1":"value1","info2":"new_value", "info3":"value3"})),
549            "info"
550        );
551        assert_eq!(run.get::<i64, _>(12), 5, "updated_at");
552        assert_eq!(run.get::<i64, _>(13), 1, "created_at");
553
554        let steps = sqlx::query(
555            "SELECT id, run_id, type, parent_step, name,
556                input, output, status, span_id, tags, info, start_time, end_time
557                FROM chronicle_steps",
558        )
559        .fetch_all(&pool)
560        .await
561        .expect("Fetching steps");
562        assert_eq!(steps.len(), 2);
563
564        let step1 = &steps[0];
565        assert_eq!(step1.get::<String, _>(0), TEST_STEP1_ID.to_string(), "id");
566        assert_eq!(step1.get::<String, _>(1), TEST_RUN_ID.to_string(), "run_id");
567        assert_eq!(step1.get::<String, _>(2), "step_type", "type");
568        assert_eq!(step1.get::<Option<String>, _>(3), None, "parent_step");
569        assert_eq!(step1.get::<String, _>(4), "source_node1", "name");
570        assert_eq!(
571            step1.get::<Option<serde_json::Value>, _>(5),
572            Some(json!({ "task_param": "value"})),
573            "input"
574        );
575        assert_eq!(
576            step1.get::<Option<serde_json::Value>, _>(6),
577            Some(json!({ "result": "success" })),
578            "output"
579        );
580        assert_eq!(step1.get::<String, _>(7), "finished", "status");
581        assert_eq!(
582            step1.get::<Option<String>, _>(8),
583            Some("11111111".to_string()),
584            "span_id"
585        );
586        assert_eq!(
587            step1.get::<Option<String>, _>(9),
588            Some("dag|node".to_string()),
589            "tags"
590        );
591        assert_eq!(
592            step1.get::<Option<serde_json::Value>, _>(10),
593            Some(json!({"model": "a_model", "info3": "value3"})),
594            "info"
595        );
596        assert_eq!(step1.get::<i64, _>(11), 2, "start_time");
597        assert_eq!(step1.get::<i64, _>(12), 5, "end_time");
598
599        let step2 = &steps[1];
600        assert_eq!(step2.get::<String, _>(0), TEST_STEP2_ID.to_string(), "id");
601        assert_eq!(step2.get::<String, _>(1), TEST_RUN_ID.to_string(), "run_id");
602        assert_eq!(step2.get::<String, _>(2), "llm", "type");
603        assert_eq!(
604            step2.get::<Option<String>, _>(3),
605            Some(TEST_STEP1_ID.to_string()),
606            "parent_step"
607        );
608        assert_eq!(step2.get::<String, _>(4), "source_node2", "name");
609        assert_eq!(
610            step2.get::<Option<serde_json::Value>, _>(5),
611            Some(json!({ "task_param2": "value"})),
612            "input"
613        );
614        assert_eq!(
615            step2.get::<Option<serde_json::Value>, _>(6),
616            Some(json!({ "message": "an error" })),
617            "output"
618        );
619        assert_eq!(step2.get::<String, _>(7), "error", "status");
620        assert_eq!(
621            step2.get::<Option<String>, _>(8),
622            Some("22222222".to_string()),
623            "span_id"
624        );
625        assert_eq!(step2.get::<Option<String>, _>(9), None, "tags");
626        assert_eq!(
627            step2.get::<Option<serde_json::Value>, _>(10),
628            Some(json!({"model": "a_model"})),
629            "info"
630        );
631        assert_eq!(step2.get::<i64, _>(11), 3, "start_time");
632        assert_eq!(step2.get::<i64, _>(12), 5, "end_time");
633
634        let events = sqlx::query(
635            "SELECT id, event_type, step_id, run_id, meta, error, created_at
636                FROM chronicle_events
637                ORDER BY created_at ASC",
638        )
639        .fetch_all(&pool)
640        .await
641        .expect("Fetching steps");
642        assert_eq!(events.len(), 2);
643
644        let event = &events[0];
645        assert_eq!(event.get::<String, _>(0), TEST_EVENT1_ID.to_string(), "id");
646        assert_eq!(event.get::<String, _>(1), "query", "event_type");
647        assert_eq!(
648            event.get::<String, _>(2),
649            TEST_STEP2_ID.to_string(),
650            "step_id"
651        );
652        assert_eq!(event.get::<String, _>(3), TEST_RUN_ID.to_string(), "run_id");
653        assert_eq!(
654            event.get::<Option<serde_json::Value>, _>(4),
655            Some(json!({"some_key": "some_value"})),
656            "meta"
657        );
658        assert_eq!(
659            event.get::<Option<serde_json::Value>, _>(5),
660            Some(json!(null)),
661            "error"
662        );
663        assert_eq!(event.get::<i64, _>(6), 4, "created_at");
664
665        let event2 = &events[1];
666        assert_eq!(event2.get::<String, _>(1), "an_event", "event_type");
667        assert_eq!(
668            event2.get::<String, _>(2),
669            TEST_STEP2_ID.to_string(),
670            "step_id"
671        );
672        assert_eq!(
673            event2.get::<String, _>(3),
674            TEST_RUN_ID.to_string(),
675            "run_id"
676        );
677        assert_eq!(
678            event2.get::<Option<serde_json::Value>, _>(4),
679            Some(json!({"key": "value"})),
680            "meta"
681        );
682        assert_eq!(
683            event2.get::<Option<serde_json::Value>, _>(5),
684            Some(json!({ "message": "something went wrong"})),
685            "error"
686        );
687        assert_eq!(event2.get::<i64, _>(6), 5, "created_at");
688    }
689}