use async_trait::async_trait;
use chrono::{DateTime, Utc};
use sqlx::Row;
use super::{SqliteStorage, map_sqlx_err};
use crate::storage::ProcessRegistry;
use crate::storage::error::{Result, StorageError};
use crate::storage::types::{
JobId, PodRecord, ProcessRecord, SlotAssignment, decode_queues, encode_queues,
};
#[async_trait]
impl ProcessRegistry for SqliteStorage {
async fn register(&self, process_id: &str, queue: &str, host: &str) -> Result<()> {
let now = iso(Utc::now());
sqlx::query(
r"INSERT OR REPLACE INTO queue_process
(process_id, queue_name, host_id, started_at, heartbeat_at, current_job)
VALUES (?1, ?2, ?3, ?4, ?4, NULL)",
)
.bind(process_id)
.bind(queue)
.bind(host)
.bind(&now)
.execute(&self.write_pool)
.await
.map_err(map_sqlx_err)?;
Ok(())
}
async fn heartbeat(&self, process_id: &str, current_job: Option<JobId>) -> Result<()> {
let now = iso(Utc::now());
let current_job_str = current_job.as_ref().map(JobId::as_str);
let res = sqlx::query(
r"UPDATE queue_process
SET heartbeat_at = ?1, current_job = ?2
WHERE process_id = ?3",
)
.bind(&now)
.bind(current_job_str)
.bind(process_id)
.execute(&self.write_pool)
.await
.map_err(map_sqlx_err)?;
if res.rows_affected() > 0 {
return Ok(());
}
sqlx::query(
r"INSERT OR REPLACE INTO queue_process
(process_id, queue_name, host_id, started_at, heartbeat_at, current_job)
VALUES (?1, '', '', ?2, ?2, ?3)",
)
.bind(process_id)
.bind(&now)
.bind(current_job_str)
.execute(&self.write_pool)
.await
.map_err(map_sqlx_err)?;
Ok(())
}
async fn deregister(&self, process_id: &str) -> Result<()> {
sqlx::query("DELETE FROM queue_process WHERE process_id = ?1")
.bind(process_id)
.execute(&self.write_pool)
.await
.map_err(map_sqlx_err)?;
Ok(())
}
async fn reap_stale(&self, stale_before: DateTime<Utc>) -> Result<u64> {
let cutoff = iso(stale_before);
let res = sqlx::query("DELETE FROM queue_process WHERE heartbeat_at < ?1")
.bind(&cutoff)
.execute(&self.write_pool)
.await
.map_err(map_sqlx_err)?;
sqlx::query("DELETE FROM pod WHERE heartbeat_at < ?1")
.bind(&cutoff)
.execute(&self.write_pool)
.await
.map_err(map_sqlx_err)?;
sqlx::query(
"DELETE FROM pod_slot_assignment
WHERE host_id NOT IN (SELECT host_id FROM pod)",
)
.execute(&self.write_pool)
.await
.map_err(map_sqlx_err)?;
Ok(res.rows_affected())
}
async fn list(&self, queue: Option<&str>) -> Result<Vec<ProcessRecord>> {
let rows = if let Some(q) = queue {
sqlx::query("SELECT * FROM queue_process WHERE queue_name = ?1 ORDER BY process_id ASC")
.bind(q)
.fetch_all(&self.read_pool)
.await
} else {
sqlx::query("SELECT * FROM queue_process ORDER BY queue_name ASC, process_id ASC")
.fetch_all(&self.read_pool)
.await
}
.map_err(map_sqlx_err)?;
rows.iter().map(row_to_proc).collect()
}
async fn delete_for_host(&self, host: &str) -> Result<u64> {
let res = sqlx::query("DELETE FROM queue_process WHERE host_id = ?1")
.bind(host)
.execute(&self.write_pool)
.await
.map_err(map_sqlx_err)?;
for sql in [
"DELETE FROM pod WHERE host_id = ?1",
"DELETE FROM pod_slot_assignment WHERE host_id = ?1",
] {
sqlx::query(sql)
.bind(host)
.execute(&self.write_pool)
.await
.map_err(map_sqlx_err)?;
}
Ok(res.rows_affected())
}
async fn pod_heartbeat(
&self,
host: &str,
worker_name: Option<&str>,
queues: &[String],
) -> Result<()> {
let now = iso(Utc::now());
let queues_csv = encode_queues(queues);
sqlx::query(
r"INSERT INTO pod (host_id, heartbeat_at, worker_name, queues)
VALUES (?1, ?2, ?3, ?4)
ON CONFLICT(host_id) DO UPDATE
SET heartbeat_at = excluded.heartbeat_at,
worker_name = excluded.worker_name,
queues = excluded.queues",
)
.bind(host)
.bind(&now)
.bind(worker_name)
.bind(queues_csv)
.execute(&self.write_pool)
.await
.map_err(map_sqlx_err)?;
Ok(())
}
async fn list_live_pods(&self, stale_before: DateTime<Utc>) -> Result<Vec<PodRecord>> {
let rows = sqlx::query(
"SELECT host_id, worker_name, queues, heartbeat_at FROM pod
WHERE heartbeat_at >= ?1 ORDER BY host_id ASC",
)
.bind(iso(stale_before))
.fetch_all(&self.read_pool)
.await
.map_err(map_sqlx_err)?;
rows.iter().map(row_to_pod).collect()
}
async fn list_slot_assignments(&self) -> Result<Vec<SlotAssignment>> {
let rows = sqlx::query(
"SELECT queue_name, host_id, slots FROM pod_slot_assignment
ORDER BY host_id ASC, queue_name ASC",
)
.fetch_all(&self.read_pool)
.await
.map_err(map_sqlx_err)?;
rows.iter()
.map(|r| {
Ok(SlotAssignment {
queue_name: r.try_get("queue_name").map_err(map_sqlx_err)?,
host_id: r.try_get("host_id").map_err(map_sqlx_err)?,
slots: i32::try_from(r.try_get::<i64, _>("slots").map_err(map_sqlx_err)?)
.unwrap_or(0),
})
})
.collect()
}
async fn set_slots(&self, queue: &str, host: &str, slots: i32) -> Result<()> {
let now = iso(Utc::now());
sqlx::query(
r"INSERT INTO pod_slot_assignment (queue_name, host_id, slots, updated_at)
VALUES (?1, ?2, ?3, ?4)
ON CONFLICT(queue_name, host_id) DO UPDATE
SET slots = excluded.slots, updated_at = excluded.updated_at",
)
.bind(queue)
.bind(host)
.bind(i64::from(slots.max(0)))
.bind(&now)
.execute(&self.write_pool)
.await
.map_err(map_sqlx_err)?;
Ok(())
}
async fn get_slots(&self, queue: &str, host: &str) -> Result<Option<i32>> {
let row = sqlx::query(
"SELECT slots FROM pod_slot_assignment WHERE queue_name = ?1 AND host_id = ?2",
)
.bind(queue)
.bind(host)
.fetch_optional(&self.read_pool)
.await
.map_err(map_sqlx_err)?;
row.map(|r| {
r.try_get::<i64, _>("slots")
.map_err(map_sqlx_err)
.map(|n| i32::try_from(n).unwrap_or(0))
})
.transpose()
}
}
fn iso(dt: DateTime<Utc>) -> String {
dt.to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
}
fn row_to_pod(r: &sqlx::sqlite::SqliteRow) -> Result<PodRecord> {
let heartbeat_at: String = r.try_get("heartbeat_at").map_err(map_sqlx_err)?;
Ok(PodRecord {
host_id: r.try_get("host_id").map_err(map_sqlx_err)?,
worker_name: r.try_get("worker_name").map_err(map_sqlx_err)?,
queues: decode_queues(r.try_get("queues").map_err(map_sqlx_err)?),
heartbeat_at: DateTime::parse_from_rfc3339(&heartbeat_at)
.map(|d| d.with_timezone(&Utc))
.map_err(|e| StorageError::Backend(format!("bad datetime {heartbeat_at:?}: {e}")))?,
})
}
fn row_to_proc(r: &sqlx::sqlite::SqliteRow) -> Result<ProcessRecord> {
let parse_dt = |s: String| -> Result<DateTime<Utc>> {
DateTime::parse_from_rfc3339(&s)
.map(|d| d.with_timezone(&Utc))
.map_err(|e| StorageError::Backend(format!("bad datetime {s:?}: {e}")))
};
Ok(ProcessRecord {
process_id: r.try_get("process_id").map_err(map_sqlx_err)?,
queue_name: r.try_get("queue_name").map_err(map_sqlx_err)?,
host_id: r.try_get("host_id").map_err(map_sqlx_err)?,
started_at: parse_dt(r.try_get("started_at").map_err(map_sqlx_err)?)?,
heartbeat_at: parse_dt(r.try_get("heartbeat_at").map_err(map_sqlx_err)?)?,
current_job: r
.try_get::<Option<String>, _>("current_job")
.map_err(map_sqlx_err)?
.map(JobId::new),
})
}