use uuid::Uuid;
use vti_common::audit::AuditKey;
use vti_common::error::AppError;
use vti_common::pagination::{Cursor, Paginated, paginate};
use vti_common::store::KeyspaceHandle;
use super::JoinRequest;
pub const JOIN_REQUEST_VP_MAX_BYTES: usize = 256 * 1024;
pub const JOIN_REQUEST_EXTENSIONS_MAX_BYTES: usize = 16 * 1024;
const PREFIX: &[u8] = b"join_requests:";
fn key(id: Uuid) -> Vec<u8> {
let mut k = PREFIX.to_vec();
k.extend_from_slice(id.as_hyphenated().to_string().as_bytes());
k
}
fn decode(bytes: &[u8]) -> Result<JoinRequest, AppError> {
serde_json::from_slice(bytes)
.map_err(|e| AppError::Internal(format!("JoinRequest decode: {e}")))
}
pub async fn get_join_request(
ks: &KeyspaceHandle,
id: Uuid,
) -> Result<Option<JoinRequest>, AppError> {
let raw = ks.get_raw(key(id)).await?;
match raw {
Some(bytes) => Ok(Some(decode(&bytes)?)),
None => Ok(None),
}
}
pub async fn store_join_request(
ks: &KeyspaceHandle,
request: &JoinRequest,
) -> Result<(), AppError> {
let vp_bytes = serde_json::to_vec(&request.vp)
.map_err(|e| AppError::Internal(format!("JoinRequest vp serialize: {e}")))?;
if vp_bytes.len() > JOIN_REQUEST_VP_MAX_BYTES {
return Err(AppError::Validation(format!(
"join request VP exceeds {} bytes (got {})",
JOIN_REQUEST_VP_MAX_BYTES,
vp_bytes.len(),
)));
}
let extensions_bytes = serde_json::to_vec(&request.extensions)
.map_err(|e| AppError::Internal(format!("JoinRequest extensions serialize: {e}")))?;
if extensions_bytes.len() > JOIN_REQUEST_EXTENSIONS_MAX_BYTES {
return Err(AppError::Validation(format!(
"join request extensions exceeds {} bytes (got {})",
JOIN_REQUEST_EXTENSIONS_MAX_BYTES,
extensions_bytes.len(),
)));
}
ks.insert(
String::from_utf8(key(request.id)).expect("key is ASCII"),
request,
)
.await
}
pub async fn delete_join_request(ks: &KeyspaceHandle, id: Uuid) -> Result<(), AppError> {
ks.remove(key(id)).await
}
pub async fn list_join_requests(ks: &KeyspaceHandle) -> Result<Vec<JoinRequest>, AppError> {
let raw = ks.prefix_iter_raw(PREFIX.to_vec()).await?;
let mut out = Vec::with_capacity(raw.len());
for (_k, v) in raw {
match decode(&v) {
Ok(r) => out.push(r),
Err(err) => tracing::warn!(error = %err, "skipping unparseable join_request row"),
}
}
Ok(out)
}
pub async fn list_join_requests_paginated(
ks: &KeyspaceHandle,
audit_key: &AuditKey,
cursor: Option<&Cursor>,
limit: usize,
) -> Result<Paginated<JoinRequest>, AppError> {
let mut pairs = ks.prefix_iter_raw(PREFIX.to_vec()).await?;
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
let snapshot_id: u64 = pairs.len() as u64;
paginate(pairs, cursor, limit, &audit_key.key, snapshot_id, decode)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::join::JoinStatus;
use vti_common::config::StoreConfig;
use vti_common::store::Store;
async fn temp_ks() -> (KeyspaceHandle, tempfile::TempDir) {
let dir = tempfile::tempdir().unwrap();
let store = Store::open(&StoreConfig {
data_dir: dir.path().to_path_buf(),
})
.unwrap();
let ks = store.keyspace("join_requests").unwrap();
(ks, dir)
}
fn fresh(applicant: &str) -> JoinRequest {
JoinRequest::new(applicant, serde_json::json!({"vp":"placeholder"}))
}
#[tokio::test]
async fn round_trip() {
let (ks, _dir) = temp_ks().await;
let r = fresh("did:key:zApplicant");
store_join_request(&ks, &r).await.unwrap();
let got = get_join_request(&ks, r.id).await.unwrap().unwrap();
assert_eq!(got, r);
}
#[tokio::test]
async fn list_returns_every_request() {
let (ks, _dir) = temp_ks().await;
for did in ["did:key:zA", "did:key:zB", "did:key:zC"] {
store_join_request(&ks, &fresh(did)).await.unwrap();
}
let list = list_join_requests(&ks).await.unwrap();
assert_eq!(list.len(), 3);
}
#[tokio::test]
async fn delete_is_idempotent() {
let (ks, _dir) = temp_ks().await;
let r = fresh("did:key:z");
store_join_request(&ks, &r).await.unwrap();
delete_join_request(&ks, r.id).await.unwrap();
delete_join_request(&ks, r.id).await.unwrap();
assert!(get_join_request(&ks, r.id).await.unwrap().is_none());
}
#[tokio::test]
async fn vp_size_limit_enforced() {
let (ks, _dir) = temp_ks().await;
let big = "a".repeat(JOIN_REQUEST_VP_MAX_BYTES + 1);
let mut r = fresh("did:key:zBig");
r.vp = serde_json::json!(big);
let err = store_join_request(&ks, &r).await.expect_err("size hit");
assert!(matches!(err, AppError::Validation(_)));
}
#[test]
fn join_status_lowercase_wire() {
let r = JoinRequest {
status: JoinStatus::Pending,
..fresh("did:key:z")
};
let json = serde_json::to_value(&r).unwrap();
assert_eq!(json["status"], "pending");
}
}