planetary_db/
postgres.rs

1//! Implementation of a TES database using PostgreSQL.
2
3use std::time::Duration;
4
5use anyhow::Context;
6use anyhow::Result;
7use anyhow::anyhow;
8use chrono::DateTime;
9use chrono::Utc;
10use diesel::Connection;
11use diesel_async::AsyncConnection;
12use diesel_async::AsyncPgConnection;
13use diesel_async::pooled_connection::AsyncDieselConnectionManager;
14use diesel_async::pooled_connection::deadpool::Pool;
15use diesel_async::scoped_futures::ScopedFutureExt;
16use diesel_migrations::EmbeddedMigrations;
17use diesel_migrations::HarnessWithOutput;
18use diesel_migrations::MigrationHarness;
19use diesel_migrations::embed_migrations;
20use futures::future::BoxFuture;
21use secrecy::ExposeSecret;
22use secrecy::SecretString;
23use tes::v1::types::requests::DEFAULT_PAGE_SIZE;
24use tes::v1::types::requests::GetTaskParams;
25use tes::v1::types::requests::ListTasksParams;
26use tes::v1::types::requests::Task as TesTask;
27use tes::v1::types::requests::View;
28use tes::v1::types::responses::ExecutorLog;
29use tes::v1::types::responses::OutputFile;
30use tes::v1::types::responses::Task;
31use tes::v1::types::responses::TaskLog;
32use tes::v1::types::responses::TaskResponse;
33use tes::v1::types::task::Input;
34use tes::v1::types::task::Output;
35use tes::v1::types::task::State;
36use tracing::debug;
37use tracing::info;
38
39use super::Database;
40use super::DatabaseResult;
41use super::TaskIo;
42use crate::TerminatedContainer;
43
44pub(crate) mod models;
45#[allow(clippy::missing_docs_in_private_items)]
46pub(crate) mod schema;
47
48/// Used to embed the migrations into the binary so they can be applied at
49/// runtime.
50const MIGRATIONS: EmbeddedMigrations = embed_migrations!("src/postgres/migrations");
51
52/// The interval between attempts to retain unexpired connections from the
53/// connection pool.
54const POOL_RETAIN_INTERVAL: Duration = Duration::from_secs(30);
55
56/// The maximum age a database connection will remain in the pool since it was
57/// last used.
58const MAX_CONNECTION_AGE: Duration = Duration::from_secs(60);
59
60/// The maximum number of connections in the connection pool.
61///
62/// This is currently a fixed-size limit as we keep a connection pool per-pod.
63const MAX_POOL_SIZE: usize = 10;
64
65/// Helper for zipping two uneven iterators.
66///
67/// The shorter iterator will yield default values after it terminates.
68fn zip_longest<A, B>(a: A, b: B) -> impl Iterator<Item = (A::Item, B::Item)>
69where
70    A: IntoIterator,
71    A::Item: Default,
72    B: IntoIterator,
73    B::Item: Default,
74{
75    let mut a = a.into_iter();
76    let mut b = b.into_iter();
77    std::iter::from_fn(move || match (a.next(), b.next()) {
78        (None, None) => None,
79        (a, b) => Some((a.unwrap_or_default(), b.unwrap_or_default())),
80    })
81}
82
83/// Formats the Postgres database URL.
84pub fn format_database_url(
85    user: &str,
86    password: &SecretString,
87    host: &str,
88    port: i32,
89    database_name: &str,
90    app_name: &str,
91) -> String {
92    format!(
93        "postgres://{user}:{password}@{host}:{port}/{database_name}?application_name={app_name}",
94        password = password.expose_secret(),
95    )
96}
97
98/// Represents a PostgreSQL database error.
99#[derive(Debug, thiserror::Error)]
100pub enum Error {
101    /// The provided TES task identifier was not found.
102    #[error("task `{0}` was not found")]
103    TaskNotFound(String),
104    /// A diesel connection pool error occurred.
105    #[error(transparent)]
106    Pool(#[from] diesel_async::pooled_connection::deadpool::PoolError),
107    /// A diesel error occurred.
108    #[error(transparent)]
109    Diesel(#[from] diesel::result::Error),
110}
111
112/// Converts a task model into a TES task.
113fn into_task<T, C>(task: T, containers: Vec<C>) -> Task
114where
115    T: Into<(Task, Vec<OutputFile>, Vec<String>)>,
116    C: Into<ExecutorLog>,
117{
118    let (mut task, outputs, system_logs) = task.into();
119    let executor_logs: Vec<_> = containers.into_iter().map(Into::into).collect();
120
121    if !outputs.is_empty() || !executor_logs.is_empty() || !system_logs.is_empty() {
122        let start_time = executor_logs.first().and_then(|e| e.start_time);
123        let end_time = executor_logs.last().and_then(|e| e.end_time);
124
125        task.logs = Some(vec![TaskLog {
126            logs: executor_logs,
127            metadata: None,
128            start_time,
129            end_time,
130            outputs,
131            system_logs: if system_logs.is_empty() {
132                None
133            } else {
134                Some(system_logs)
135            },
136        }]);
137    }
138
139    task
140}
141
142/// Implements a planetary database using a PostgreSQL server.
143pub struct PostgresDatabase {
144    /// The database URL.
145    url: SecretString,
146    /// The database connection pool.
147    pool: Pool<AsyncPgConnection>,
148}
149
150impl PostgresDatabase {
151    /// Constructs a new PostgreSQL database with the given database URL.
152    pub fn new(url: SecretString) -> Result<Self> {
153        let config = AsyncDieselConnectionManager::new(url.expose_secret());
154        debug!("creating database connection pool with {MAX_POOL_SIZE} slots");
155
156        let pool = Pool::builder(config)
157            .max_size(MAX_POOL_SIZE)
158            .build()
159            .context("failed to initialize PostgreSQL connection pool")?;
160
161        let p = pool.clone();
162
163        // Span a task that is responsible for removing connections from the pool that
164        // exceed
165        tokio::spawn(async move {
166            loop {
167                tokio::time::sleep(POOL_RETAIN_INTERVAL).await;
168
169                let res = p.retain(|_, metrics| metrics.last_used() < MAX_CONNECTION_AGE);
170
171                debug!(
172                    "removed {removed} and retained {retained} connections(s) from the database \
173                     connection pool",
174                    removed = res.removed.len(),
175                    retained = res.retained
176                );
177            }
178        });
179
180        Ok(Self { url, pool })
181    }
182
183    /// Runs any pending migrations for the database.
184    pub async fn run_pending_migrations(&self) -> Result<()> {
185        struct Writer;
186        impl std::io::Write for Writer {
187            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
188                let buf = String::from_utf8_lossy(buf);
189                info!("{buf}", buf = buf.trim_end());
190                Ok(buf.len())
191            }
192
193            fn flush(&mut self) -> std::io::Result<()> {
194                Ok(())
195            }
196        }
197
198        // Required to use a direct connection here as `diesel-migration` doesn't
199        // support async
200        let mut conn = diesel::pg::PgConnection::establish(self.url.expose_secret())?;
201        HarnessWithOutput::new(&mut conn, std::io::LineWriter::new(Writer))
202            .run_pending_migrations(MIGRATIONS)
203            .map_err(|e| anyhow!("failed to run pending database migrations: {e}"))?;
204
205        Ok(())
206    }
207}
208
209#[async_trait::async_trait]
210impl Database for PostgresDatabase {
211    async fn insert_task(&self, task: &TesTask) -> DatabaseResult<String> {
212        use diesel_async::RunQueryDsl;
213
214        let task = models::NewTask::new(task);
215
216        // Insert the task
217        let mut conn = self.pool.get().await.map_err(Error::Pool)?;
218        diesel::insert_into(schema::tasks::table)
219            .values(&task)
220            .execute(&mut conn)
221            .await
222            .map_err(Error::Diesel)?;
223
224        Ok(task.tes_id)
225    }
226
227    async fn get_task(&self, tes_id: &str, params: GetTaskParams) -> DatabaseResult<TaskResponse> {
228        use diesel::*;
229        use diesel_async::RunQueryDsl;
230
231        let mut conn = self.pool.get().await.map_err(Error::Pool)?;
232
233        match params.view {
234            View::Minimal => Ok(TaskResponse::Minimal(
235                schema::tasks::table
236                    .select(models::MinimalTask::as_select())
237                    .filter(schema::tasks::tes_id.eq(tes_id))
238                    .first(&mut conn)
239                    .await
240                    .optional()
241                    .map_err(Error::Diesel)?
242                    .ok_or_else(|| Error::TaskNotFound(tes_id.to_string()))?
243                    .into(),
244            )),
245            View::Basic => {
246                let task = schema::tasks::table
247                    .select(models::BasicTask::as_select())
248                    .filter(schema::tasks::tes_id.eq(tes_id))
249                    .first(&mut conn)
250                    .await
251                    .optional()
252                    .map_err(Error::Diesel)?
253                    .ok_or_else(|| Error::TaskNotFound(tes_id.to_string()))?;
254
255                let containers = models::BasicContainer::belonging_to(&task)
256                    .select(models::BasicContainer::as_select())
257                    .filter(schema::containers::executor_index.is_not_null())
258                    .order_by(schema::containers::executor_index)
259                    .load(&mut conn)
260                    .await
261                    .map_err(Error::Diesel)?;
262
263                Ok(TaskResponse::Basic(into_task(task, containers)))
264            }
265            View::Full => {
266                let task = schema::tasks::table
267                    .select(models::FullTask::as_select())
268                    .filter(schema::tasks::tes_id.eq(tes_id))
269                    .first(&mut conn)
270                    .await
271                    .optional()
272                    .map_err(Error::Diesel)?
273                    .ok_or_else(|| Error::TaskNotFound(tes_id.to_string()))?;
274
275                let containers = models::FullContainer::belonging_to(&task)
276                    .select(models::FullContainer::as_select())
277                    .filter(schema::containers::executor_index.is_not_null())
278                    .order_by(schema::containers::executor_index)
279                    .load(&mut conn)
280                    .await
281                    .map_err(Error::Diesel)?;
282
283                Ok(TaskResponse::Full(into_task(task, containers)))
284            }
285        }
286    }
287
288    async fn get_tasks(
289        &self,
290        params: ListTasksParams,
291    ) -> DatabaseResult<(Vec<TaskResponse>, Option<String>)> {
292        use diesel::*;
293        use diesel_async::RunQueryDsl;
294
295        let mut query = schema::tasks::table.into_boxed();
296
297        // Add the name prefix to the query
298        if let Some(prefix) = &params.name_prefix {
299            query = query.filter(schema::tasks::name.like(format!("{prefix}%")));
300        }
301
302        // Add the state to the query
303        if let Some(state) = params.state {
304            query = query.filter(schema::tasks::state.eq(models::TaskState::from(state)));
305        }
306
307        // Add the page token to the query
308        let offset = if let Some(page_token) = params.page_token {
309            let offset: i64 = page_token
310                .parse()
311                .map_err(|_| super::Error::InvalidPageToken(page_token.clone()))?;
312
313            if offset < 0 {
314                return Err(super::Error::InvalidPageToken(page_token));
315            }
316
317            query = query.offset(offset);
318            offset
319        } else {
320            0
321        };
322
323        // Add the tags to the query
324        for (k, v) in zip_longest(
325            params.tag_keys.unwrap_or_default(),
326            params.tag_values.unwrap_or_default(),
327        ) {
328            if !v.is_empty() {
329                query = query.filter(
330                    schema::tasks::tags.contains(models::Json(models::TagFilter::new(k, v))),
331                );
332            } else {
333                query = query.filter(schema::tasks::tags.has_key(k));
334            }
335        }
336
337        // Add the page size to the query and order by the id
338        let page_size = params.page_size.unwrap_or(DEFAULT_PAGE_SIZE);
339        query = query.limit(page_size as i64).order_by(schema::tasks::id);
340
341        let mut conn = self.pool.get().await.map_err(Error::Pool)?;
342
343        match params.view.unwrap_or_default() {
344            View::Minimal => {
345                let tasks = query
346                    .select(models::MinimalTask::as_select())
347                    .load(&mut conn)
348                    .await
349                    .map_err(Error::Diesel)?;
350
351                let token = if tasks.len() < page_size as usize {
352                    None
353                } else {
354                    Some((offset as usize + tasks.len()).to_string())
355                };
356
357                Ok((
358                    tasks
359                        .into_iter()
360                        .map(|t| TaskResponse::Minimal(t.into()))
361                        .collect(),
362                    token,
363                ))
364            }
365            View::Basic => {
366                let tasks: Vec<_> = query
367                    .select(models::BasicTask::as_select())
368                    .load(&mut conn)
369                    .await
370                    .map_err(Error::Diesel)?
371                    .into_iter()
372                    .collect();
373
374                let token = if tasks.len() < page_size as usize {
375                    None
376                } else {
377                    Some((offset as usize + tasks.len()).to_string())
378                };
379
380                Ok((
381                    models::BasicContainer::belonging_to(&tasks)
382                        .select(models::BasicContainer::as_select())
383                        .filter(schema::containers::executor_index.is_not_null())
384                        .order_by(schema::containers::executor_index)
385                        .load(&mut conn)
386                        .await
387                        .map_err(Error::Diesel)?
388                        .grouped_by(&tasks)
389                        .into_iter()
390                        .zip(tasks)
391                        .map(|(containers, task)| TaskResponse::Basic(into_task(task, containers)))
392                        .collect(),
393                    token,
394                ))
395            }
396            View::Full => {
397                let tasks: Vec<_> = query
398                    .select(models::FullTask::as_select())
399                    .load(&mut conn)
400                    .await
401                    .map_err(Error::Diesel)?
402                    .into_iter()
403                    .collect();
404
405                let token = if tasks.len() < page_size as usize {
406                    None
407                } else {
408                    Some((offset as usize + tasks.len()).to_string())
409                };
410
411                Ok((
412                    models::FullContainer::belonging_to(&tasks)
413                        .select(models::FullContainer::as_select())
414                        .filter(schema::containers::executor_index.is_not_null())
415                        .order_by(schema::containers::executor_index)
416                        .load(&mut conn)
417                        .await
418                        .map_err(Error::Diesel)?
419                        .grouped_by(&tasks)
420                        .into_iter()
421                        .zip(tasks)
422                        .map(|(containers, task)| TaskResponse::Full(into_task(task, containers)))
423                        .collect(),
424                    token,
425                ))
426            }
427        }
428    }
429
430    async fn get_task_io(&self, tes_id: &str) -> DatabaseResult<TaskIo> {
431        use diesel::*;
432        use diesel_async::RunQueryDsl;
433
434        let mut conn = self.pool.get().await.map_err(Error::Pool)?;
435
436        let (inputs, outputs): (
437            Option<models::Json<Vec<Input>>>,
438            Option<models::Json<Vec<Output>>>,
439        ) = schema::tasks::table
440            .select((schema::tasks::inputs, schema::tasks::outputs))
441            .filter(schema::tasks::tes_id.eq(tes_id))
442            .first(&mut conn)
443            .await
444            .optional()
445            .map_err(Error::Diesel)?
446            .ok_or_else(|| Error::TaskNotFound(tes_id.to_string()))?;
447
448        Ok(TaskIo {
449            inputs: inputs.map(models::Json::into_inner).unwrap_or_default(),
450            outputs: outputs.map(models::Json::into_inner).unwrap_or_default(),
451        })
452    }
453
454    async fn get_in_progress_tasks(&self, before: DateTime<Utc>) -> DatabaseResult<Vec<String>> {
455        use diesel::pg::sql_types::Timestamptz;
456        use diesel::*;
457        use diesel_async::RunQueryDsl;
458        use models::TaskState;
459
460        let mut conn = self.pool.get().await.map_err(Error::Pool)?;
461
462        Ok(schema::tasks::table
463            .select(schema::tasks::tes_id)
464            .filter(
465                schema::tasks::state
466                    .eq_any(&[
467                        TaskState::Unknown,
468                        TaskState::Queued,
469                        TaskState::Initializing,
470                        TaskState::Running,
471                    ])
472                    .and(schema::tasks::creation_time.le(before.into_sql::<Timestamptz>())),
473            )
474            .get_results(&mut conn)
475            .await
476            .map_err(Error::Diesel)?)
477    }
478
479    async fn update_task_state<'a>(
480        &self,
481        tes_id: &str,
482        state: State,
483        messages: &[&str],
484        containers: Option<BoxFuture<'a, Result<Vec<TerminatedContainer<'a>>>>>,
485    ) -> DatabaseResult<bool> {
486        use diesel::pg::sql_types::Array;
487        use diesel::sql_types::Text;
488        use diesel::*;
489        use diesel_async::RunQueryDsl;
490        use models::TaskState;
491
492        /// Helper for getting the id for an updated task.
493        /// This is required because `sql_query` returns data by name, not
494        /// index.
495        #[derive(QueryableByName)]
496        #[diesel(table_name = schema::tasks)]
497        #[diesel(check_for_backend(diesel::pg::Pg))]
498        struct UpdatedTask {
499            /// The id of the updated task.
500            id: i32,
501        }
502
503        // Determine the allowed previous state for the task.
504        let previous: &[TaskState] = match state {
505            // Unknown has no previous state and paused isn't supported
506            State::Unknown | State::Paused => {
507                return Ok(false);
508            }
509            // Unknown -> Queued
510            State::Queued => &[TaskState::Unknown],
511            // [Unknown | Queued] -> Initializing
512            State::Initializing => &[TaskState::Unknown, TaskState::Queued],
513            // [Unknown | Queued | Initializing] -> Running
514            State::Running => &[
515                TaskState::Unknown,
516                TaskState::Queued,
517                TaskState::Initializing,
518            ],
519            // [Unknown | Queued | Initializing | Running] -> [Complete | ExecutorError]
520            State::Complete | State::ExecutorError => &[
521                TaskState::Unknown,
522                TaskState::Queued,
523                TaskState::Initializing,
524                TaskState::Running,
525            ],
526            // [Unknown | Queued | Initializing | Running] -> [SystemError | Canceling]
527            State::SystemError | State::Canceling => &[
528                TaskState::Unknown,
529                TaskState::Queued,
530                TaskState::Initializing,
531                TaskState::Running,
532            ],
533            // Canceling -> Canceled
534            State::Canceled => &[TaskState::Canceling],
535            // [Unknown | Queued | Initializing | Running] -> Preempted
536            State::Preempted => &[
537                TaskState::Unknown,
538                TaskState::Queued,
539                TaskState::Initializing,
540                TaskState::Running,
541            ],
542        };
543
544        let mut conn = self.pool.get().await.map_err(Error::Pool)?;
545
546        let updated = conn
547            .transaction(|conn| {
548                async move {
549                    // TODO: currently diesel hasn't released support for the PostgreSQL
550                    // `array_cat` function; remove the raw query when diesel supports it
551                    let updated: Option<UpdatedTask> = sql_query(
552                        "UPDATE tasks SET state = $1, system_logs = array_cat(system_logs, $2) \
553                         WHERE tes_id = $3 AND state = ANY ($4) RETURNING id",
554                    )
555                    .bind::<schema::sql_types::TaskState, _>(TaskState::from(state))
556                    .bind::<Array<Text>, _>(messages)
557                    .bind::<Text, _>(tes_id)
558                    .bind::<Array<schema::sql_types::TaskState>, _>(previous)
559                    .get_result(conn)
560                    .await
561                    .optional()
562                    .map_err(Error::Diesel)?;
563
564                    match updated {
565                        Some(UpdatedTask { id }) => {
566                            if let Some(containers) = containers {
567                                // Insert the containers
568                                let containers = containers.await?;
569                                diesel::insert_into(schema::containers::table)
570                                    .values(
571                                        containers
572                                            .into_iter()
573                                            .map(|c| models::NewContainer::new(id, c))
574                                            .collect::<Vec<_>>(),
575                                    )
576                                    .on_conflict_do_nothing()
577                                    .execute(conn)
578                                    .await
579                                    .map_err(Error::Diesel)?;
580                            }
581
582                            anyhow::Ok(true)
583                        }
584                        None => Ok(false),
585                    }
586                }
587                .scope_boxed()
588            })
589            .await?;
590
591        Ok(updated)
592    }
593
594    async fn append_system_log(&self, tes_id: &str, messages: &[&str]) -> DatabaseResult<()> {
595        use diesel::pg::sql_types::Array;
596        use diesel::sql_types::Text;
597        use diesel::*;
598        use diesel_async::RunQueryDsl;
599
600        let mut conn = self.pool.get().await.map_err(Error::Pool)?;
601
602        // Append the log entry
603        // TODO: currently diesel hasn't released support for the PostgreSQL
604        // `array_cat` function; remove the raw query when diesel supports it
605        sql_query("UPDATE tasks SET system_logs = array_cat(system_logs, $1) WHERE tes_id = $2")
606            .bind::<Array<Text>, _>(messages)
607            .bind::<Text, _>(tes_id)
608            .execute(&mut conn)
609            .await
610            .map_err(Error::Diesel)?;
611
612        Ok(())
613    }
614
615    async fn update_task_output_files(
616        &self,
617        tes_id: &str,
618        files: &[OutputFile],
619    ) -> DatabaseResult<()> {
620        use diesel::*;
621        use diesel_async::RunQueryDsl;
622
623        let mut conn = self.pool.get().await.map_err(Error::Pool)?;
624
625        diesel::update(schema::tasks::table)
626            .filter(
627                schema::tasks::tes_id
628                    .eq(tes_id)
629                    .and(schema::tasks::output_files.is_null()),
630            )
631            .set(schema::tasks::output_files.eq(models::Json(files)))
632            .execute(&mut conn)
633            .await
634            .map_err(Error::Diesel)?;
635
636        Ok(())
637    }
638
639    async fn insert_error(
640        &self,
641        source: &str,
642        tes_id: Option<&str>,
643        message: &str,
644    ) -> DatabaseResult<()> {
645        use diesel::*;
646        use diesel_async::RunQueryDsl;
647
648        let mut conn = self.pool.get().await.map_err(Error::Pool)?;
649
650        let transaction = conn.transaction(|conn| {
651            async move {
652                // Lookup the associated task id, if there is one
653                let task_id = if let Some(tes_id) = tes_id {
654                    Some(
655                        schema::tasks::table
656                            .select(schema::tasks::id)
657                            .filter(schema::tasks::tes_id.eq(tes_id))
658                            .for_update()
659                            .first(conn)
660                            .await
661                            .optional()
662                            .map_err(Error::Diesel)?
663                            .ok_or_else(|| Error::TaskNotFound(tes_id.to_string()))?,
664                    )
665                } else {
666                    None
667                };
668
669                // Insert the new error
670                diesel::insert_into(schema::errors::table)
671                    .values(models::NewError {
672                        source,
673                        task_id,
674                        message,
675                    })
676                    .execute(conn)
677                    .await
678                    .map_err(Error::Diesel)
679            }
680            .scope_boxed()
681        });
682
683        transaction.await?;
684        Ok(())
685    }
686}