1use std::time::Duration;
4
5use anyhow::Context;
6use anyhow::Result;
7use anyhow::anyhow;
8use chrono::DateTime;
9use chrono::Utc;
10use diesel::Connection;
11use diesel_async::AsyncConnection;
12use diesel_async::AsyncPgConnection;
13use diesel_async::pooled_connection::AsyncDieselConnectionManager;
14use diesel_async::pooled_connection::deadpool::Pool;
15use diesel_async::scoped_futures::ScopedFutureExt;
16use diesel_migrations::EmbeddedMigrations;
17use diesel_migrations::HarnessWithOutput;
18use diesel_migrations::MigrationHarness;
19use diesel_migrations::embed_migrations;
20use futures::future::BoxFuture;
21use secrecy::ExposeSecret;
22use secrecy::SecretString;
23use tes::v1::types::requests::DEFAULT_PAGE_SIZE;
24use tes::v1::types::requests::GetTaskParams;
25use tes::v1::types::requests::ListTasksParams;
26use tes::v1::types::requests::Task as TesTask;
27use tes::v1::types::requests::View;
28use tes::v1::types::responses::ExecutorLog;
29use tes::v1::types::responses::OutputFile;
30use tes::v1::types::responses::Task;
31use tes::v1::types::responses::TaskLog;
32use tes::v1::types::responses::TaskResponse;
33use tes::v1::types::task::Input;
34use tes::v1::types::task::Output;
35use tes::v1::types::task::State;
36use tracing::debug;
37use tracing::info;
38
39use super::Database;
40use super::DatabaseResult;
41use super::TaskIo;
42use crate::TerminatedContainer;
43
44pub(crate) mod models;
45#[allow(clippy::missing_docs_in_private_items)]
46pub(crate) mod schema;
47
48const MIGRATIONS: EmbeddedMigrations = embed_migrations!("src/postgres/migrations");
51
52const POOL_RETAIN_INTERVAL: Duration = Duration::from_secs(30);
55
56const MAX_CONNECTION_AGE: Duration = Duration::from_secs(60);
59
60const MAX_POOL_SIZE: usize = 10;
64
65fn zip_longest<A, B>(a: A, b: B) -> impl Iterator<Item = (A::Item, B::Item)>
69where
70 A: IntoIterator,
71 A::Item: Default,
72 B: IntoIterator,
73 B::Item: Default,
74{
75 let mut a = a.into_iter();
76 let mut b = b.into_iter();
77 std::iter::from_fn(move || match (a.next(), b.next()) {
78 (None, None) => None,
79 (a, b) => Some((a.unwrap_or_default(), b.unwrap_or_default())),
80 })
81}
82
83pub fn format_database_url(
85 user: &str,
86 password: &SecretString,
87 host: &str,
88 port: i32,
89 database_name: &str,
90 app_name: &str,
91) -> String {
92 format!(
93 "postgres://{user}:{password}@{host}:{port}/{database_name}?application_name={app_name}",
94 password = password.expose_secret(),
95 )
96}
97
98#[derive(Debug, thiserror::Error)]
100pub enum Error {
101 #[error("task `{0}` was not found")]
103 TaskNotFound(String),
104 #[error(transparent)]
106 Pool(#[from] diesel_async::pooled_connection::deadpool::PoolError),
107 #[error(transparent)]
109 Diesel(#[from] diesel::result::Error),
110}
111
112fn into_task<T, C>(task: T, containers: Vec<C>) -> Task
114where
115 T: Into<(Task, Vec<OutputFile>, Vec<String>)>,
116 C: Into<ExecutorLog>,
117{
118 let (mut task, outputs, system_logs) = task.into();
119 let executor_logs: Vec<_> = containers.into_iter().map(Into::into).collect();
120
121 if !outputs.is_empty() || !executor_logs.is_empty() || !system_logs.is_empty() {
122 let start_time = executor_logs.first().and_then(|e| e.start_time);
123 let end_time = executor_logs.last().and_then(|e| e.end_time);
124
125 task.logs = Some(vec![TaskLog {
126 logs: executor_logs,
127 metadata: None,
128 start_time,
129 end_time,
130 outputs,
131 system_logs: if system_logs.is_empty() {
132 None
133 } else {
134 Some(system_logs)
135 },
136 }]);
137 }
138
139 task
140}
141
142pub struct PostgresDatabase {
144 url: SecretString,
146 pool: Pool<AsyncPgConnection>,
148}
149
150impl PostgresDatabase {
151 pub fn new(url: SecretString) -> Result<Self> {
153 let config = AsyncDieselConnectionManager::new(url.expose_secret());
154 debug!("creating database connection pool with {MAX_POOL_SIZE} slots");
155
156 let pool = Pool::builder(config)
157 .max_size(MAX_POOL_SIZE)
158 .build()
159 .context("failed to initialize PostgreSQL connection pool")?;
160
161 let p = pool.clone();
162
163 tokio::spawn(async move {
166 loop {
167 tokio::time::sleep(POOL_RETAIN_INTERVAL).await;
168
169 let res = p.retain(|_, metrics| metrics.last_used() < MAX_CONNECTION_AGE);
170
171 debug!(
172 "removed {removed} and retained {retained} connections(s) from the database \
173 connection pool",
174 removed = res.removed.len(),
175 retained = res.retained
176 );
177 }
178 });
179
180 Ok(Self { url, pool })
181 }
182
183 pub async fn run_pending_migrations(&self) -> Result<()> {
185 struct Writer;
186 impl std::io::Write for Writer {
187 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
188 let buf = String::from_utf8_lossy(buf);
189 info!("{buf}", buf = buf.trim_end());
190 Ok(buf.len())
191 }
192
193 fn flush(&mut self) -> std::io::Result<()> {
194 Ok(())
195 }
196 }
197
198 let mut conn = diesel::pg::PgConnection::establish(self.url.expose_secret())?;
201 HarnessWithOutput::new(&mut conn, std::io::LineWriter::new(Writer))
202 .run_pending_migrations(MIGRATIONS)
203 .map_err(|e| anyhow!("failed to run pending database migrations: {e}"))?;
204
205 Ok(())
206 }
207}
208
209#[async_trait::async_trait]
210impl Database for PostgresDatabase {
211 async fn insert_task(&self, task: &TesTask) -> DatabaseResult<String> {
212 use diesel_async::RunQueryDsl;
213
214 let task = models::NewTask::new(task);
215
216 let mut conn = self.pool.get().await.map_err(Error::Pool)?;
218 diesel::insert_into(schema::tasks::table)
219 .values(&task)
220 .execute(&mut conn)
221 .await
222 .map_err(Error::Diesel)?;
223
224 Ok(task.tes_id)
225 }
226
227 async fn get_task(&self, tes_id: &str, params: GetTaskParams) -> DatabaseResult<TaskResponse> {
228 use diesel::*;
229 use diesel_async::RunQueryDsl;
230
231 let mut conn = self.pool.get().await.map_err(Error::Pool)?;
232
233 match params.view {
234 View::Minimal => Ok(TaskResponse::Minimal(
235 schema::tasks::table
236 .select(models::MinimalTask::as_select())
237 .filter(schema::tasks::tes_id.eq(tes_id))
238 .first(&mut conn)
239 .await
240 .optional()
241 .map_err(Error::Diesel)?
242 .ok_or_else(|| Error::TaskNotFound(tes_id.to_string()))?
243 .into(),
244 )),
245 View::Basic => {
246 let task = schema::tasks::table
247 .select(models::BasicTask::as_select())
248 .filter(schema::tasks::tes_id.eq(tes_id))
249 .first(&mut conn)
250 .await
251 .optional()
252 .map_err(Error::Diesel)?
253 .ok_or_else(|| Error::TaskNotFound(tes_id.to_string()))?;
254
255 let containers = models::BasicContainer::belonging_to(&task)
256 .select(models::BasicContainer::as_select())
257 .filter(schema::containers::executor_index.is_not_null())
258 .order_by(schema::containers::executor_index)
259 .load(&mut conn)
260 .await
261 .map_err(Error::Diesel)?;
262
263 Ok(TaskResponse::Basic(into_task(task, containers)))
264 }
265 View::Full => {
266 let task = schema::tasks::table
267 .select(models::FullTask::as_select())
268 .filter(schema::tasks::tes_id.eq(tes_id))
269 .first(&mut conn)
270 .await
271 .optional()
272 .map_err(Error::Diesel)?
273 .ok_or_else(|| Error::TaskNotFound(tes_id.to_string()))?;
274
275 let containers = models::FullContainer::belonging_to(&task)
276 .select(models::FullContainer::as_select())
277 .filter(schema::containers::executor_index.is_not_null())
278 .order_by(schema::containers::executor_index)
279 .load(&mut conn)
280 .await
281 .map_err(Error::Diesel)?;
282
283 Ok(TaskResponse::Full(into_task(task, containers)))
284 }
285 }
286 }
287
288 async fn get_tasks(
289 &self,
290 params: ListTasksParams,
291 ) -> DatabaseResult<(Vec<TaskResponse>, Option<String>)> {
292 use diesel::*;
293 use diesel_async::RunQueryDsl;
294
295 let mut query = schema::tasks::table.into_boxed();
296
297 if let Some(prefix) = ¶ms.name_prefix {
299 query = query.filter(schema::tasks::name.like(format!("{prefix}%")));
300 }
301
302 if let Some(state) = params.state {
304 query = query.filter(schema::tasks::state.eq(models::TaskState::from(state)));
305 }
306
307 let offset = if let Some(page_token) = params.page_token {
309 let offset: i64 = page_token
310 .parse()
311 .map_err(|_| super::Error::InvalidPageToken(page_token.clone()))?;
312
313 if offset < 0 {
314 return Err(super::Error::InvalidPageToken(page_token));
315 }
316
317 query = query.offset(offset);
318 offset
319 } else {
320 0
321 };
322
323 for (k, v) in zip_longest(
325 params.tag_keys.unwrap_or_default(),
326 params.tag_values.unwrap_or_default(),
327 ) {
328 if !v.is_empty() {
329 query = query.filter(
330 schema::tasks::tags.contains(models::Json(models::TagFilter::new(k, v))),
331 );
332 } else {
333 query = query.filter(schema::tasks::tags.has_key(k));
334 }
335 }
336
337 let page_size = params.page_size.unwrap_or(DEFAULT_PAGE_SIZE);
339 query = query.limit(page_size as i64).order_by(schema::tasks::id);
340
341 let mut conn = self.pool.get().await.map_err(Error::Pool)?;
342
343 match params.view.unwrap_or_default() {
344 View::Minimal => {
345 let tasks = query
346 .select(models::MinimalTask::as_select())
347 .load(&mut conn)
348 .await
349 .map_err(Error::Diesel)?;
350
351 let token = if tasks.len() < page_size as usize {
352 None
353 } else {
354 Some((offset as usize + tasks.len()).to_string())
355 };
356
357 Ok((
358 tasks
359 .into_iter()
360 .map(|t| TaskResponse::Minimal(t.into()))
361 .collect(),
362 token,
363 ))
364 }
365 View::Basic => {
366 let tasks: Vec<_> = query
367 .select(models::BasicTask::as_select())
368 .load(&mut conn)
369 .await
370 .map_err(Error::Diesel)?
371 .into_iter()
372 .collect();
373
374 let token = if tasks.len() < page_size as usize {
375 None
376 } else {
377 Some((offset as usize + tasks.len()).to_string())
378 };
379
380 Ok((
381 models::BasicContainer::belonging_to(&tasks)
382 .select(models::BasicContainer::as_select())
383 .filter(schema::containers::executor_index.is_not_null())
384 .order_by(schema::containers::executor_index)
385 .load(&mut conn)
386 .await
387 .map_err(Error::Diesel)?
388 .grouped_by(&tasks)
389 .into_iter()
390 .zip(tasks)
391 .map(|(containers, task)| TaskResponse::Basic(into_task(task, containers)))
392 .collect(),
393 token,
394 ))
395 }
396 View::Full => {
397 let tasks: Vec<_> = query
398 .select(models::FullTask::as_select())
399 .load(&mut conn)
400 .await
401 .map_err(Error::Diesel)?
402 .into_iter()
403 .collect();
404
405 let token = if tasks.len() < page_size as usize {
406 None
407 } else {
408 Some((offset as usize + tasks.len()).to_string())
409 };
410
411 Ok((
412 models::FullContainer::belonging_to(&tasks)
413 .select(models::FullContainer::as_select())
414 .filter(schema::containers::executor_index.is_not_null())
415 .order_by(schema::containers::executor_index)
416 .load(&mut conn)
417 .await
418 .map_err(Error::Diesel)?
419 .grouped_by(&tasks)
420 .into_iter()
421 .zip(tasks)
422 .map(|(containers, task)| TaskResponse::Full(into_task(task, containers)))
423 .collect(),
424 token,
425 ))
426 }
427 }
428 }
429
430 async fn get_task_io(&self, tes_id: &str) -> DatabaseResult<TaskIo> {
431 use diesel::*;
432 use diesel_async::RunQueryDsl;
433
434 let mut conn = self.pool.get().await.map_err(Error::Pool)?;
435
436 let (inputs, outputs): (
437 Option<models::Json<Vec<Input>>>,
438 Option<models::Json<Vec<Output>>>,
439 ) = schema::tasks::table
440 .select((schema::tasks::inputs, schema::tasks::outputs))
441 .filter(schema::tasks::tes_id.eq(tes_id))
442 .first(&mut conn)
443 .await
444 .optional()
445 .map_err(Error::Diesel)?
446 .ok_or_else(|| Error::TaskNotFound(tes_id.to_string()))?;
447
448 Ok(TaskIo {
449 inputs: inputs.map(models::Json::into_inner).unwrap_or_default(),
450 outputs: outputs.map(models::Json::into_inner).unwrap_or_default(),
451 })
452 }
453
454 async fn get_in_progress_tasks(&self, before: DateTime<Utc>) -> DatabaseResult<Vec<String>> {
455 use diesel::pg::sql_types::Timestamptz;
456 use diesel::*;
457 use diesel_async::RunQueryDsl;
458 use models::TaskState;
459
460 let mut conn = self.pool.get().await.map_err(Error::Pool)?;
461
462 Ok(schema::tasks::table
463 .select(schema::tasks::tes_id)
464 .filter(
465 schema::tasks::state
466 .eq_any(&[
467 TaskState::Unknown,
468 TaskState::Queued,
469 TaskState::Initializing,
470 TaskState::Running,
471 ])
472 .and(schema::tasks::creation_time.le(before.into_sql::<Timestamptz>())),
473 )
474 .get_results(&mut conn)
475 .await
476 .map_err(Error::Diesel)?)
477 }
478
479 async fn update_task_state<'a>(
480 &self,
481 tes_id: &str,
482 state: State,
483 messages: &[&str],
484 containers: Option<BoxFuture<'a, Result<Vec<TerminatedContainer<'a>>>>>,
485 ) -> DatabaseResult<bool> {
486 use diesel::pg::sql_types::Array;
487 use diesel::sql_types::Text;
488 use diesel::*;
489 use diesel_async::RunQueryDsl;
490 use models::TaskState;
491
492 #[derive(QueryableByName)]
496 #[diesel(table_name = schema::tasks)]
497 #[diesel(check_for_backend(diesel::pg::Pg))]
498 struct UpdatedTask {
499 id: i32,
501 }
502
503 let previous: &[TaskState] = match state {
505 State::Unknown | State::Paused => {
507 return Ok(false);
508 }
509 State::Queued => &[TaskState::Unknown],
511 State::Initializing => &[TaskState::Unknown, TaskState::Queued],
513 State::Running => &[
515 TaskState::Unknown,
516 TaskState::Queued,
517 TaskState::Initializing,
518 ],
519 State::Complete | State::ExecutorError => &[
521 TaskState::Unknown,
522 TaskState::Queued,
523 TaskState::Initializing,
524 TaskState::Running,
525 ],
526 State::SystemError | State::Canceling => &[
528 TaskState::Unknown,
529 TaskState::Queued,
530 TaskState::Initializing,
531 TaskState::Running,
532 ],
533 State::Canceled => &[TaskState::Canceling],
535 State::Preempted => &[
537 TaskState::Unknown,
538 TaskState::Queued,
539 TaskState::Initializing,
540 TaskState::Running,
541 ],
542 };
543
544 let mut conn = self.pool.get().await.map_err(Error::Pool)?;
545
546 let updated = conn
547 .transaction(|conn| {
548 async move {
549 let updated: Option<UpdatedTask> = sql_query(
552 "UPDATE tasks SET state = $1, system_logs = array_cat(system_logs, $2) \
553 WHERE tes_id = $3 AND state = ANY ($4) RETURNING id",
554 )
555 .bind::<schema::sql_types::TaskState, _>(TaskState::from(state))
556 .bind::<Array<Text>, _>(messages)
557 .bind::<Text, _>(tes_id)
558 .bind::<Array<schema::sql_types::TaskState>, _>(previous)
559 .get_result(conn)
560 .await
561 .optional()
562 .map_err(Error::Diesel)?;
563
564 match updated {
565 Some(UpdatedTask { id }) => {
566 if let Some(containers) = containers {
567 let containers = containers.await?;
569 diesel::insert_into(schema::containers::table)
570 .values(
571 containers
572 .into_iter()
573 .map(|c| models::NewContainer::new(id, c))
574 .collect::<Vec<_>>(),
575 )
576 .on_conflict_do_nothing()
577 .execute(conn)
578 .await
579 .map_err(Error::Diesel)?;
580 }
581
582 anyhow::Ok(true)
583 }
584 None => Ok(false),
585 }
586 }
587 .scope_boxed()
588 })
589 .await?;
590
591 Ok(updated)
592 }
593
594 async fn append_system_log(&self, tes_id: &str, messages: &[&str]) -> DatabaseResult<()> {
595 use diesel::pg::sql_types::Array;
596 use diesel::sql_types::Text;
597 use diesel::*;
598 use diesel_async::RunQueryDsl;
599
600 let mut conn = self.pool.get().await.map_err(Error::Pool)?;
601
602 sql_query("UPDATE tasks SET system_logs = array_cat(system_logs, $1) WHERE tes_id = $2")
606 .bind::<Array<Text>, _>(messages)
607 .bind::<Text, _>(tes_id)
608 .execute(&mut conn)
609 .await
610 .map_err(Error::Diesel)?;
611
612 Ok(())
613 }
614
615 async fn update_task_output_files(
616 &self,
617 tes_id: &str,
618 files: &[OutputFile],
619 ) -> DatabaseResult<()> {
620 use diesel::*;
621 use diesel_async::RunQueryDsl;
622
623 let mut conn = self.pool.get().await.map_err(Error::Pool)?;
624
625 diesel::update(schema::tasks::table)
626 .filter(
627 schema::tasks::tes_id
628 .eq(tes_id)
629 .and(schema::tasks::output_files.is_null()),
630 )
631 .set(schema::tasks::output_files.eq(models::Json(files)))
632 .execute(&mut conn)
633 .await
634 .map_err(Error::Diesel)?;
635
636 Ok(())
637 }
638
639 async fn insert_error(
640 &self,
641 source: &str,
642 tes_id: Option<&str>,
643 message: &str,
644 ) -> DatabaseResult<()> {
645 use diesel::*;
646 use diesel_async::RunQueryDsl;
647
648 let mut conn = self.pool.get().await.map_err(Error::Pool)?;
649
650 let transaction = conn.transaction(|conn| {
651 async move {
652 let task_id = if let Some(tes_id) = tes_id {
654 Some(
655 schema::tasks::table
656 .select(schema::tasks::id)
657 .filter(schema::tasks::tes_id.eq(tes_id))
658 .for_update()
659 .first(conn)
660 .await
661 .optional()
662 .map_err(Error::Diesel)?
663 .ok_or_else(|| Error::TaskNotFound(tes_id.to_string()))?,
664 )
665 } else {
666 None
667 };
668
669 diesel::insert_into(schema::errors::table)
671 .values(models::NewError {
672 source,
673 task_id,
674 message,
675 })
676 .execute(conn)
677 .await
678 .map_err(Error::Diesel)
679 }
680 .scope_boxed()
681 });
682
683 transaction.await?;
684 Ok(())
685 }
686}