// SPDX-License-Identifier: Apache-2.0
//! Postgres-backed reference storage for the stateless server.
#![cfg(feature = "postgres")]
use std::sync::Arc;
use objects::{
error::{HeddleError, Result},
object::ChangeId,
};
use sqlx::{PgPool, Row};
use uuid::Uuid;
use super::{CoreRefBackend, Head, RefBackend, RefExpectation, RefUpdate, resolve_refspec};
fn sqlx_err(e: sqlx::Error) -> HeddleError {
HeddleError::Io(std::io::Error::other(e.to_string()))
}
#[derive(Clone)]
pub struct PgRefBackend {
pool: Arc<PgPool>,
repo_id: Uuid,
}
impl PgRefBackend {
pub fn new(pool: Arc<PgPool>, repo_id: Uuid) -> Self {
Self { pool, repo_id }
}
fn block<F, T>(&self, f: F) -> Result<T>
where
F: std::future::Future<Output = Result<T>> + Send,
{
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(f))
}
fn id_to_bytes(id: &ChangeId) -> Vec<u8> {
id.as_bytes().to_vec()
}
fn bytes_to_id(bytes: Vec<u8>) -> Result<ChangeId> {
let arr: [u8; 16] = bytes
.try_into()
.map_err(|_| HeddleError::InvalidObject("invalid ChangeId bytes in database".into()))?;
Ok(ChangeId::from_bytes(arr))
}
async fn get_ref_async(
pool: &PgPool,
repo_id: Uuid,
name: &str,
is_thread: bool,
) -> Result<Option<ChangeId>> {
let row = sqlx::query(
"SELECT change_id FROM refs WHERE repo_id = $1 AND name = $2 AND is_thread = $3",
)
.bind(repo_id)
.bind(name)
.bind(is_thread)
.fetch_optional(pool)
.await
.map_err(sqlx_err)?;
row.map(|r| {
let bytes: Vec<u8> = r.try_get("change_id").map_err(sqlx_err)?;
Self::bytes_to_id(bytes)
})
.transpose()
}
}
impl CoreRefBackend for PgRefBackend {
type Error = HeddleError;
fn read_head(&self) -> Result<Head> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
self.block(async move {
let maybe_row = sqlx::query("SELECT thread, change_id FROM heads WHERE repo_id = $1")
.bind(repo_id)
.fetch_optional(pool.as_ref())
.await
.map_err(sqlx_err)?;
match maybe_row {
None => Ok(Head::Attached {
thread: "main".to_string(),
}),
Some(r) => {
let thread: Option<String> = r.try_get("thread").map_err(sqlx_err)?;
let change_id: Option<Vec<u8>> = r.try_get("change_id").map_err(sqlx_err)?;
if let Some(t) = thread {
Ok(Head::Attached { thread: t })
} else if let Some(b) = change_id {
Ok(Head::Detached {
state: Self::bytes_to_id(b)?,
})
} else {
Ok(Head::Attached {
thread: "main".to_string(),
})
}
}
}
})
}
fn write_head(&self, head: &Head) -> Result<()> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
let (thread, change_id): (Option<String>, Option<Vec<u8>>) = match head {
Head::Attached { thread } => (Some(thread.clone()), None),
Head::Detached { state } => (None, Some(Self::id_to_bytes(state))),
};
self.block(async move {
sqlx::query(
"INSERT INTO heads (repo_id, thread, change_id)
VALUES ($1, $2, $3)
ON CONFLICT (repo_id)
DO UPDATE SET thread = EXCLUDED.thread, change_id = EXCLUDED.change_id",
)
.bind(repo_id)
.bind(thread)
.bind(change_id)
.execute(pool.as_ref())
.await
.map_err(sqlx_err)?;
Ok(())
})
}
fn write_head_cas(&self, expected: RefExpectation<Head>, head: &Head) -> Result<()> {
let current = self.read_head()?;
match &expected {
RefExpectation::Any => {}
RefExpectation::Missing => {
return Err(HeddleError::Conflict(
"HEAD cannot be missing on a Postgres backend".into(),
));
}
RefExpectation::Value(expected_head) => {
if ¤t != expected_head {
return Err(HeddleError::Conflict(format!(
"HEAD CAS conflict: expected {:?}, found {:?}",
expected_head, current
)));
}
}
}
self.write_head(head)
}
fn get_thread(&self, name: &str) -> Result<Option<ChangeId>> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
let name = name.to_string();
self.block(async move { Self::get_ref_async(&pool, repo_id, &name, true).await })
}
fn set_thread(&self, name: &str, state: &ChangeId) -> Result<()> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
let name = name.to_string();
let bytes = Self::id_to_bytes(state);
self.block(async move { sqlx::query("INSERT INTO refs (repo_id, name, is_thread, change_id, updated_at) VALUES ($1, $2, true, $3, NOW()) ON CONFLICT (repo_id, name) DO UPDATE SET change_id = EXCLUDED.change_id, updated_at = NOW()").bind(repo_id).bind(&name).bind(bytes).execute(pool.as_ref()).await.map_err(sqlx_err)?; Ok(()) })
}
fn set_thread_cas(
&self,
name: &str,
expected: RefExpectation<ChangeId>,
state: &ChangeId,
) -> Result<()> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
let name = name.to_string();
let new_bytes = Self::id_to_bytes(state);
self.block(async move { match expected { RefExpectation::Any => { sqlx::query("INSERT INTO refs (repo_id, name, is_thread, change_id, updated_at) VALUES ($1, $2, true, $3, NOW()) ON CONFLICT (repo_id, name) DO UPDATE SET change_id = EXCLUDED.change_id, updated_at = NOW()").bind(repo_id).bind(&name).bind(&new_bytes).execute(pool.as_ref()).await.map_err(sqlx_err)?; } RefExpectation::Missing => { let n = sqlx::query("INSERT INTO refs (repo_id, name, is_thread, change_id, updated_at) VALUES ($1, $2, true, $3, NOW()) ON CONFLICT DO NOTHING").bind(repo_id).bind(&name).bind(&new_bytes).execute(pool.as_ref()).await.map_err(sqlx_err)?.rows_affected(); if n == 0 { return Err(HeddleError::Conflict(format!("thread '{}' already exists", name))); } } RefExpectation::Value(old) => { let old_bytes = Self::id_to_bytes(&old); let n = sqlx::query("UPDATE refs SET change_id = $4, updated_at = NOW() WHERE repo_id = $1 AND name = $2 AND is_thread = true AND change_id = $3").bind(repo_id).bind(&name).bind(old_bytes).bind(&new_bytes).execute(pool.as_ref()).await.map_err(sqlx_err)?.rows_affected(); if n == 0 { return Err(HeddleError::Conflict(format!("thread '{}' CAS conflict", name))); } } } Ok(()) })
}
fn delete_thread(&self, name: &str) -> Result<Option<ChangeId>> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
let name = name.to_string();
self.block(async move { let row = sqlx::query("DELETE FROM refs WHERE repo_id = $1 AND name = $2 AND is_thread = true RETURNING change_id").bind(repo_id).bind(&name).fetch_optional(pool.as_ref()).await.map_err(sqlx_err)?; row.map(|r| { let bytes: Vec<u8> = r.try_get("change_id").map_err(sqlx_err)?; Self::bytes_to_id(bytes) }).transpose() })
}
fn delete_thread_cas(&self, name: &str, expected: RefExpectation<ChangeId>) -> Result<()> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
let name = name.to_string();
self.block(async move { let n = match expected { RefExpectation::Any | RefExpectation::Missing => sqlx::query("DELETE FROM refs WHERE repo_id = $1 AND name = $2 AND is_thread = true").bind(repo_id).bind(&name).execute(pool.as_ref()).await.map_err(sqlx_err)?.rows_affected(), RefExpectation::Value(old) => { let old_bytes = Self::id_to_bytes(&old); sqlx::query("DELETE FROM refs WHERE repo_id = $1 AND name = $2 AND is_thread = true AND change_id = $3").bind(repo_id).bind(&name).bind(old_bytes).execute(pool.as_ref()).await.map_err(sqlx_err)?.rows_affected() } }; if n == 0 { Err(HeddleError::Conflict(format!("thread '{}' delete CAS conflict", name))) } else { Ok(()) } })
}
fn list_threads(&self) -> Result<Vec<String>> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
self.block(async move {
sqlx::query_scalar::<_, String>(
"SELECT name FROM refs WHERE repo_id = $1 AND is_thread = true ORDER BY name",
)
.bind(repo_id)
.fetch_all(pool.as_ref())
.await
.map_err(sqlx_err)
})
}
fn get_marker(&self, name: &str) -> Result<Option<ChangeId>> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
let name = name.to_string();
self.block(async move { Self::get_ref_async(&pool, repo_id, &name, false).await })
}
fn create_marker(&self, name: &str, state: &ChangeId) -> Result<()> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
let name = name.to_string();
let bytes = Self::id_to_bytes(state);
self.block(async move { let n = sqlx::query("INSERT INTO refs (repo_id, name, is_thread, change_id, updated_at) VALUES ($1, $2, false, $3, NOW()) ON CONFLICT DO NOTHING").bind(repo_id).bind(&name).bind(bytes).execute(pool.as_ref()).await.map_err(sqlx_err)?.rows_affected(); if n == 0 { Err(HeddleError::Conflict(format!("marker '{}' already exists", name))) } else { Ok(()) } })
}
fn set_marker_cas(
&self,
name: &str,
expected: RefExpectation<ChangeId>,
state: &ChangeId,
) -> Result<()> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
let name = name.to_string();
let new_bytes = Self::id_to_bytes(state);
self.block(async move { match expected { RefExpectation::Any => { sqlx::query("INSERT INTO refs (repo_id, name, is_thread, change_id, updated_at) VALUES ($1, $2, false, $3, NOW()) ON CONFLICT (repo_id, name) DO UPDATE SET change_id = EXCLUDED.change_id, updated_at = NOW()").bind(repo_id).bind(&name).bind(&new_bytes).execute(pool.as_ref()).await.map_err(sqlx_err)?; } RefExpectation::Missing => { let n = sqlx::query("INSERT INTO refs (repo_id, name, is_thread, change_id, updated_at) VALUES ($1, $2, false, $3, NOW()) ON CONFLICT DO NOTHING").bind(repo_id).bind(&name).bind(&new_bytes).execute(pool.as_ref()).await.map_err(sqlx_err)?.rows_affected(); if n == 0 { return Err(HeddleError::Conflict(format!("marker '{}' already exists", name))); } } RefExpectation::Value(old) => { let old_bytes = Self::id_to_bytes(&old); let n = sqlx::query("UPDATE refs SET change_id = $4, updated_at = NOW() WHERE repo_id = $1 AND name = $2 AND is_thread = false AND change_id = $3").bind(repo_id).bind(&name).bind(old_bytes).bind(&new_bytes).execute(pool.as_ref()).await.map_err(sqlx_err)?.rows_affected(); if n == 0 { return Err(HeddleError::Conflict(format!("marker '{}' CAS conflict", name))); } } } Ok(()) })
}
fn delete_marker(&self, name: &str) -> Result<Option<ChangeId>> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
let name = name.to_string();
self.block(async move { let row = sqlx::query("DELETE FROM refs WHERE repo_id = $1 AND name = $2 AND is_thread = false RETURNING change_id").bind(repo_id).bind(&name).fetch_optional(pool.as_ref()).await.map_err(sqlx_err)?; row.map(|r| { let bytes: Vec<u8> = r.try_get("change_id").map_err(sqlx_err)?; Self::bytes_to_id(bytes) }).transpose() })
}
fn delete_marker_cas(&self, name: &str, expected: RefExpectation<ChangeId>) -> Result<()> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
let name = name.to_string();
self.block(async move { let n = match expected { RefExpectation::Any | RefExpectation::Missing => sqlx::query("DELETE FROM refs WHERE repo_id = $1 AND name = $2 AND is_thread = false").bind(repo_id).bind(&name).execute(pool.as_ref()).await.map_err(sqlx_err)?.rows_affected(), RefExpectation::Value(old) => { let old_bytes = Self::id_to_bytes(&old); sqlx::query("DELETE FROM refs WHERE repo_id = $1 AND name = $2 AND is_thread = false AND change_id = $3").bind(repo_id).bind(&name).bind(old_bytes).execute(pool.as_ref()).await.map_err(sqlx_err)?.rows_affected() } }; if n == 0 { Err(HeddleError::Conflict(format!("marker '{}' delete CAS conflict", name))) } else { Ok(()) } })
}
fn list_markers(&self) -> Result<Vec<String>> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
self.block(async move {
sqlx::query_scalar::<_, String>(
"SELECT name FROM refs WHERE repo_id = $1 AND is_thread = false ORDER BY name",
)
.bind(repo_id)
.fetch_all(pool.as_ref())
.await
.map_err(sqlx_err)
})
}
fn update_refs(&self, updates: &[RefUpdate]) -> Result<()> {
let pool = Arc::clone(&self.pool);
let repo_id = self.repo_id;
let updates = updates.to_vec();
self.block(async move { let mut tx = pool.begin().await.map_err(sqlx_err)?; for update in &updates { match update { RefUpdate::Thread { name, expected, new } => match (expected, new) { (_, None) => { sqlx::query("DELETE FROM refs WHERE repo_id = $1 AND name = $2 AND is_thread = true").bind(repo_id).bind(name).execute(&mut *tx).await.map_err(sqlx_err)?; } (RefExpectation::Missing, Some(state)) => { sqlx::query("INSERT INTO refs (repo_id, name, is_thread, change_id, updated_at) VALUES ($1, $2, true, $3, NOW()) ON CONFLICT DO NOTHING").bind(repo_id).bind(name).bind(Self::id_to_bytes(state)).execute(&mut *tx).await.map_err(sqlx_err)?; } (RefExpectation::Value(old), Some(new_state)) => { sqlx::query("UPDATE refs SET change_id = $4, updated_at = NOW() WHERE repo_id = $1 AND name = $2 AND is_thread = true AND change_id = $3").bind(repo_id).bind(name).bind(Self::id_to_bytes(old)).bind(Self::id_to_bytes(new_state)).execute(&mut *tx).await.map_err(sqlx_err)?; } (_, Some(state)) => { sqlx::query("INSERT INTO refs (repo_id, name, is_thread, change_id, updated_at) VALUES ($1, $2, true, $3, NOW()) ON CONFLICT (repo_id, name) DO UPDATE SET change_id = EXCLUDED.change_id, updated_at = NOW()").bind(repo_id).bind(name).bind(Self::id_to_bytes(state)).execute(&mut *tx).await.map_err(sqlx_err)?; } }, RefUpdate::Marker { name, expected: _, new } => match new { None => { sqlx::query("DELETE FROM refs WHERE repo_id = $1 AND name = $2 AND is_thread = false").bind(repo_id).bind(name).execute(&mut *tx).await.map_err(sqlx_err)?; } Some(state) => { sqlx::query("INSERT INTO refs (repo_id, name, is_thread, change_id, updated_at) VALUES ($1, $2, false, $3, NOW()) ON CONFLICT (repo_id, name) DO UPDATE SET change_id = EXCLUDED.change_id, updated_at = NOW()").bind(repo_id).bind(name).bind(Self::id_to_bytes(state)).execute(&mut *tx).await.map_err(sqlx_err)?; } }, RefUpdate::Head { new, .. } => { let (thread, change_id): (Option<String>, Option<Vec<u8>>) = match new { Head::Attached { thread } => (Some(thread.clone()), None), Head::Detached { state } => (None, Some(Self::id_to_bytes(state))), }; sqlx::query("INSERT INTO heads (repo_id, thread, change_id) VALUES ($1, $2, $3) ON CONFLICT (repo_id) DO UPDATE SET thread = EXCLUDED.thread, change_id = EXCLUDED.change_id").bind(repo_id).bind(thread).bind(change_id).execute(&mut *tx).await.map_err(sqlx_err)?; } } } tx.commit().await.map_err(sqlx_err)?; Ok(()) })
}
fn resolve(&self, refspec: &str) -> Result<Option<ChangeId>> {
resolve_refspec(
refspec,
|| self.read_head(),
|name| self.get_thread(name),
|name| self.get_marker(name),
)
}
}
impl RefBackend for PgRefBackend {
fn get_remote_thread(&self, _remote: &str, _track: &str) -> Result<Option<ChangeId>> {
Err(HeddleError::Conflict(
"remote threading refs are not supported on the server backend".into(),
))
}
fn set_remote_thread(&self, _remote: &str, _track: &str, _state: &ChangeId) -> Result<()> {
Err(HeddleError::Conflict(
"remote threading refs are not supported on the server backend".into(),
))
}
fn delete_remote_thread(&self, _remote: &str, _track: &str) -> Result<Option<ChangeId>> {
Err(HeddleError::Conflict(
"remote threading refs are not supported on the server backend".into(),
))
}
fn list_remotes(&self) -> Result<Vec<String>> {
Ok(Vec::new())
}
fn list_remote_threads(&self, _remote: &str) -> Result<Vec<String>> {
Ok(Vec::new())
}
}