use std::collections::{HashMap, HashSet, VecDeque};
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use futures::future::{self, FutureExt};
use payjoin::directory::ShortId;
use rand::rngs::OsRng;
use rand::RngCore;
use tokio::fs::{self, File};
use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
use tokio::sync::{oneshot, Mutex};
use tracing::trace;
use crate::db::{Db as DbTrait, Error as DbError};
const DEFAULT_CAPACITY: usize = 1 << (1 + 12 + 8);
const DEFAULT_UNREAD_TTL_AT_CAPACITY: Duration = Duration::from_secs(60 * 60 * 24); const DEFAULT_UNREAD_TTL_BELOW_CAPACITY: Duration = Duration::from_secs(60 * 60 * 24 * 7);
const DEFAULT_READ_TTL: Duration = Duration::from_secs(60 * 10);
#[derive(Debug)]
struct V2WaitMapEntry {
receiver: future::Shared<oneshot::Receiver<Arc<Vec<u8>>>>,
sender: oneshot::Sender<Arc<Vec<u8>>>, }
#[derive(Debug)]
struct V1WaitMapEntry {
payload: Option<Arc<Vec<u8>>>,
sender: oneshot::Sender<Vec<u8>>,
}
#[derive(Debug)]
pub(crate) struct Mailboxes {
capacity: usize,
persistent_storage: DiskStorage,
pending_v1: HashMap<ShortId, V1WaitMapEntry>,
pending_v2: HashMap<ShortId, V2WaitMapEntry>,
insert_order: VecDeque<(SystemTime, ShortId)>,
read_order: VecDeque<(SystemTime, ShortId)>,
read_mailbox_ids: HashSet<ShortId>,
unread_ttl_below_capacity: Duration,
unread_ttl_at_capacity: Duration,
read_ttl: Duration,
early_removal_count: usize,
}
#[derive(Debug)]
struct DiskStorage {
dir: PathBuf,
xor: Vec<u8>,
}
impl DiskStorage {
async fn init(dir: PathBuf) -> io::Result<Self> {
let tmp_dir = &dir.join("tmp");
if fs::try_exists(tmp_dir).await? {
fs::remove_dir_all(tmp_dir).await?;
}
fs::create_dir_all(tmp_dir).await?;
let xor: Vec<u8>;
let xor_file = dir.join("xor.dat");
if fs::try_exists(&xor_file).await? {
xor = fs::read(xor_file).await?;
} else {
xor = OsRng.next_u64().to_ne_bytes().to_vec();
let mut file = fs::File::create_new(xor_file).await?;
file.write_all(&xor).await?;
file.sync_all().await?;
}
Ok(Self { dir, xor })
}
fn mailbox_path(&self, id: &ShortId) -> PathBuf { self.dir.join(id.to_string()) }
fn insert_mailbox_path(&self, id: &ShortId) -> PathBuf {
self.dir.join("tmp").join(id.to_string())
}
async fn contains_key(&self, id: &ShortId) -> io::Result<bool> {
fs::try_exists(self.mailbox_path(id)).await
}
async fn get(&self, id: &ShortId) -> io::Result<Option<(SystemTime, Vec<u8>)>> {
let mut file = match File::open(self.mailbox_path(id)).await {
Ok(file) => file,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(err) => return Err(err),
};
let created = file.metadata().await?.created()?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer).await?;
self.xor_buffer(&mut buffer);
Ok(Some((created, buffer)))
}
async fn try_insert(
&self,
id: &ShortId,
contents: impl AsRef<[u8]>,
) -> io::Result<Option<SystemTime>> {
let mailbox_path = self.mailbox_path(id);
if self.contains_key(id).await? {
if let Ok(Some((created, existing_contents))) = self.get(id).await {
if &existing_contents[..] == contents.as_ref() {
return Ok(Some(created));
}
}
return Ok(None);
}
let mut buffer = contents.as_ref().to_vec();
self.xor_buffer(&mut buffer);
let tmp_path = self.insert_mailbox_path(id);
let mut file = fs::File::create_new(&tmp_path).await?;
file.write_all(&buffer).await?;
file.sync_data().await?;
let link_ret = fs::hard_link(&tmp_path, &mailbox_path).await;
fs::remove_file(tmp_path).await?;
match link_ret {
Ok(()) => Ok(Some(file.metadata().await?.created()?)),
Err(e) if e.kind() == io::ErrorKind::AlreadyExists => Ok(None),
Err(e) => Err(e),
}
}
fn xor_buffer(&self, buffer: &mut [u8]) {
for (byte, &pattern) in buffer.iter_mut().zip(self.xor.iter().cycle()) {
*byte ^= pattern;
}
}
async fn remove(&self, id: &ShortId) -> io::Result<Option<()>> {
match fs::remove_file(self.mailbox_path(id)).await {
Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(None),
Ok(()) => Ok(Some(())),
Err(e) => Err(e),
}
}
async fn insert_order(&self) -> io::Result<Vec<(SystemTime, ShortId)>> {
let mut ids: Vec<(SystemTime, ShortId)> = Vec::default();
let mut dir_entries = fs::read_dir(&self.dir).await?;
while let Some(entry) = dir_entries.next_entry().await? {
if let Some(file_name) = entry.file_name().to_str() {
if let Ok(id) = ShortId::from_str(file_name) {
let ctime = entry.metadata().await?.created()?;
ids.push((ctime, id));
}
}
}
ids.sort_by_key(|&(ctime, _id)| ctime);
Ok(ids)
}
}
impl Mailboxes {
async fn init(dir: PathBuf) -> io::Result<Self> {
let storage = DiskStorage::init(dir).await?;
let insert_order = storage.insert_order().await?.into();
Ok(Self {
persistent_storage: storage,
insert_order,
capacity: DEFAULT_CAPACITY,
pending_v1: HashMap::default(),
pending_v2: HashMap::default(),
read_order: VecDeque::default(),
read_mailbox_ids: HashSet::default(),
unread_ttl_below_capacity: DEFAULT_UNREAD_TTL_BELOW_CAPACITY,
unread_ttl_at_capacity: DEFAULT_UNREAD_TTL_AT_CAPACITY,
read_ttl: DEFAULT_READ_TTL,
early_removal_count: 0,
})
}
}
#[derive(Clone, Debug)]
pub struct FilesDb {
timeout: Duration,
pub(crate) mailboxes: Arc<Mutex<Mailboxes>>,
}
impl FilesDb {
pub async fn init(timeout: Duration, path: PathBuf) -> io::Result<Self> {
Ok(Self { timeout, mailboxes: Arc::new(Mutex::new(Mailboxes::init(path).await?)) })
}
pub async fn prune(&self) -> io::Result<Duration> { self.mailboxes.lock().await.prune().await }
pub async fn spawn_background_prune(&self) {
let this = self.clone();
tokio::spawn(async move {
loop {
let sleep_for =
{ this.mailboxes.lock().await.prune().await.expect("disk storage failed") };
tokio::time::sleep(sleep_for).await;
}
});
}
}
impl DbTrait for FilesDb {
type OperationalError = io::Error;
async fn post_v2_payload(
&self,
id: &ShortId,
payload: Vec<u8>,
) -> Result<Option<()>, DbError<Self::OperationalError>> {
let mut guard = self.mailboxes.lock().await;
Ok(guard.post_v2(id, payload).await?)
}
async fn wait_for_v2_payload(
&self,
id: &ShortId,
) -> Result<Arc<Vec<u8>>, DbError<Self::OperationalError>> {
let receiver = {
let mut guard = self.mailboxes.lock().await;
if let Some(payload) = guard.read(id).await? {
return Ok(payload);
} else {
guard.wait_v2(id).await?
}
};
let ret = match tokio::time::timeout(self.timeout, receiver).await {
Ok(payload) => Ok((payload.expect("receiver must not fail")).clone()),
Err(elapsed) => Err(DbError::Timeout(elapsed)),
};
self.mailboxes.lock().await.maybe_cleanup_v2_waitmap(id);
ret
}
async fn post_v1_request_and_wait_for_response(
&self,
id: &ShortId,
payload: Vec<u8>,
) -> Result<Arc<Vec<u8>>, DbError<Self::OperationalError>> {
let receiver = {
self.mailboxes
.lock()
.await
.post_v1_req_and_wait(id, payload)
.await?
.ok_or(DbError::OverCapacity)?
};
trace!("v1 sender waiting for v2 receiver's response");
let ret = match tokio::time::timeout(self.timeout, receiver).await {
Ok(payload) => Ok(Arc::new(payload.expect("receiver must not fail"))),
Err(elapsed) => Err(DbError::Timeout(elapsed)),
};
self.mailboxes.lock().await.pending_v1.remove(id);
ret
}
async fn post_v1_response(
&self,
id: &ShortId,
payload: Vec<u8>,
) -> Result<(), DbError<Self::OperationalError>> {
let mut guard = self.mailboxes.lock().await;
Ok(guard.post_v1_res(id, payload).await?)
}
}
impl Mailboxes {
async fn read(&mut self, id: &ShortId) -> io::Result<Option<Arc<Vec<u8>>>> {
if let Some(entry) = self.pending_v1.get_mut(id) {
if let Some(payload) = entry.payload.take() {
return Ok(Some(payload));
}
}
if let Some((_created, payload)) = self.persistent_storage.get(id).await? {
self.mark_read(id);
return Ok(Some(Arc::new(payload)));
}
Ok(None)
}
fn mark_read(&mut self, id: &ShortId) {
if self.read_mailbox_ids.insert(*id) {
self.read_order.push_back((SystemTime::now(), *id));
}
}
async fn has_capacity(&mut self) -> io::Result<bool> {
self.maybe_prune().await?;
Ok(self.len() < self.capacity)
}
async fn wait_v2(
&mut self,
id: &ShortId,
) -> Result<future::Shared<oneshot::Receiver<Arc<Vec<u8>>>>, Error> {
if !self.has_capacity().await? {
return Err(Error::OverCapacity);
}
if let Some(entry) = self.pending_v1.get(id) {
if entry.payload.is_some() {
return Err(Error::OverCapacity);
}
return Err(Error::AlreadyRead);
}
let receiver = self
.pending_v2
.entry(*id)
.or_insert_with(|| {
let (sender, receiver) = oneshot::channel::<Arc<Vec<u8>>>();
let shared_receiver = receiver.shared();
V2WaitMapEntry { sender, receiver: shared_receiver.clone() }
})
.receiver
.clone();
Ok(receiver)
}
fn maybe_cleanup_v2_waitmap(&mut self, id: &ShortId) {
if let Some(entry) = self.pending_v2.get(id) {
if entry.receiver.strong_count().unwrap_or(0) <= 1 {
self.pending_v2.remove(id);
}
}
}
async fn post_v2(&mut self, id: &ShortId, payload: Vec<u8>) -> Result<Option<()>, Error> {
let Some(created) = self.persistent_storage.try_insert(id, &payload).await? else {
return Ok(None);
};
self.insert_order.push_back((created, *id));
if let Some(pending) = self.pending_v2.remove(id) {
trace!("notifying pending readers for {}", id);
self.mark_read(id);
pending
.sender
.send(Arc::new(payload))
.expect("sending on oneshot channel must succeed");
}
Ok(Some(()))
}
async fn post_v1_req_and_wait(
&mut self,
id: &ShortId,
payload: Vec<u8>,
) -> Result<Option<oneshot::Receiver<Vec<u8>>>, Error> {
let mut ret = None;
let payload = Arc::new(payload);
self.pending_v1.entry(*id).or_insert_with(|| {
let payload = payload.clone();
let (sender, receiver) = oneshot::channel::<Vec<u8>>();
ret = Some(receiver);
V1WaitMapEntry { payload: Some(payload), sender }
});
if let Some(pending) = self.pending_v2.remove(id) {
trace!("notifying pending readers for {} (v1 fallback)", id);
pending.sender.send(payload.clone()).expect("sending on oneshot channel must succeed");
if let Some(entry) = self.pending_v1.get_mut(id) {
entry.payload.take();
}
}
Ok(ret)
}
async fn remove(&mut self, id: &ShortId) -> io::Result<Option<()>> {
self.read_mailbox_ids.remove(id);
self.persistent_storage.remove(id).await
}
async fn post_v1_res(&mut self, id: &ShortId, payload: Vec<u8>) -> Result<(), Error> {
match self.pending_v1.remove(id) {
None => Err(Error::V1SenderUnavailable),
Some(V1WaitMapEntry { sender, .. }) =>
sender.send(payload).map_err(|_| Error::V1SenderUnavailable),
}
}
fn len(&self) -> usize {
(self.insert_order.len() - self.early_removal_count)
+ self.pending_v1.len()
+ self.pending_v2.len()
}
async fn maybe_prune(&mut self) -> io::Result<Duration> {
self.prune().await
}
async fn prune(&mut self) -> io::Result<Duration> {
trace!("pruning");
let now = SystemTime::now();
debug_assert!(self.read_ttl < self.unread_ttl_at_capacity);
debug_assert!(self.unread_ttl_at_capacity < self.unread_ttl_below_capacity);
debug_assert!(self.pending_v1.iter().all(|(_, v)| !v.sender.is_closed()));
self.pending_v2.retain(|_, v| v.receiver.strong_count().unwrap_or(0) > 1);
while let Some((created, id)) = self.insert_order.front().cloned() {
if created + self.unread_ttl_below_capacity < now {
debug_assert!(self.insert_order.len() >= self.early_removal_count);
_ = self.insert_order.pop_front();
if self.remove(&id).await?.is_none() {
self.early_removal_count = self
.early_removal_count
.checked_sub(1)
.expect("early removal adjustment should never underflow");
}
debug_assert!(self.insert_order.len() >= self.early_removal_count);
trace!("Pruned old mailbox {id}");
} else {
break;
}
}
while let Some((read, id)) = self.read_order.front().cloned() {
if read + self.read_ttl < now {
_ = self.read_order.pop_front();
if self.remove(&id).await?.is_some() {
self.early_removal_count += 1;
debug_assert!(self.insert_order.len() >= self.early_removal_count);
}
trace!("Pruned read mailbox {id}");
} else {
break;
}
}
debug_assert!(self.len() <= self.capacity);
if self.len() == self.capacity {
if let Some((created, id)) = self.insert_order.front().cloned() {
if created + self.unread_ttl_at_capacity < now {
_ = self.insert_order.pop_front();
self.remove(&id).await?;
trace!("Pruned unread mailbox {id} to make room");
} else {
trace!("Nothing to prune, {} entries remain", self.len());
}
}
}
Ok(self.next_prune())
}
fn next_prune(&mut self) -> Duration {
let earliest_read_prune_opportunity = self
.read_order
.front()
.map(|(read, _id)| {
self.read_ttl
.checked_sub(read.elapsed().expect("system clock moved back"))
.unwrap_or(self.read_ttl)
})
.unwrap_or_else(|| self.read_ttl);
let earliest_unread_prune_opportunity = self
.insert_order
.front()
.map(|(created, _id)| {
self.unread_ttl_at_capacity
.checked_sub(created.elapsed().expect("system clock moved back"))
.unwrap_or(self.unread_ttl_at_capacity)
})
.unwrap_or_else(|| self.unread_ttl_at_capacity);
std::cmp::min(earliest_read_prune_opportunity, earliest_unread_prune_opportunity)
}
}
#[derive(Debug)]
pub enum Error {
OverCapacity,
AlreadyRead,
V1SenderUnavailable,
IO(io::Error),
}
impl From<io::Error> for Error {
fn from(e: io::Error) -> Self { Self::IO(e) }
}
impl From<Error> for DbError<std::io::Error> {
fn from(val: Error) -> DbError<io::Error> {
match val {
Error::V1SenderUnavailable => DbError::V1SenderUnavailable,
Error::OverCapacity => DbError::OverCapacity,
Error::AlreadyRead => DbError::AlreadyRead,
Error::IO(e) => DbError::Operational(e),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::IO(e) => Some(e),
_ => None,
}
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use Error::*;
match self {
OverCapacity => "Database over capacity".fmt(f),
AlreadyRead => "Mailbox payload already read".fmt(f),
V1SenderUnavailable => "Sender no longer connected".fmt(f),
IO(e) => write!(f, "Internal Error: {e}"),
}
}
}
impl crate::db::SendableError for Error {}
#[tokio::test]
async fn test_disk_storage_initialization() -> std::io::Result<()> {
let dir = tempfile::tempdir()?;
assert!(!dir.path().join("tmp").exists(), "tmp subdirectory should not have been created yet");
let xor_pattern = {
let storage = DiskStorage::init(dir.path().to_owned())
.await
.expect("initializing storage directory should succeed");
assert!(dir.path().join("tmp").exists(), "tmp subdirectory should have been created");
assert!(
dir.path().join("xor.dat").exists(),
"random obfuscation pattern should have been generated"
);
fs::write(dir.path().join("tmp").join("blah"), "junk").await?;
storage.xor
};
assert!(
dir.path().join("tmp").join("blah").exists(),
"temp file should not have been cleared yet"
);
let storage = DiskStorage::init(dir.path().to_owned())
.await
.expect("initializing storage directory should succeed");
assert!(!dir.path().join("tmp").join("blah").exists(), "temp file should have been cleared");
assert_eq!(storage.xor, xor_pattern, "xor pattern loaded from file");
Ok(())
}
#[tokio::test]
async fn test_disk_storage_mailboxes() -> std::io::Result<()> {
let dir = tempfile::tempdir()?;
let storage = DiskStorage::init(dir.path().to_owned())
.await
.expect("initializing storage directory should succeed");
let id1 = ShortId::try_from(&(b"12345678")[..]).unwrap();
let id2 = ShortId::try_from(&(b"87654321")[..]).unwrap();
assert!(!storage
.contains_key(&id1)
.await
.expect("checking mailbox existence should not error"));
assert!(!storage
.contains_key(&id2)
.await
.expect("checking mailbox existence should not error"));
assert!(matches!(storage.get(&id1).await, Ok(None)));
assert!(matches!(storage.get(&id2).await, Ok(None)));
let contents1 = b"OH HAI";
let contents2 = b"HI FREN";
let created1 = storage
.try_insert(&id1, contents1)
.await
.expect("writing should succeed")
.expect("writing should return a creation time");
match storage.get(&id1).await {
Ok(Some((got_created, got_contents))) => {
assert_eq!(got_created, created1.to_owned());
assert_eq!(got_contents, contents1.to_owned());
}
e => {
e.expect("retrieval should work");
}
};
assert!(matches!(storage.get(&id2).await, Ok(None)));
assert!(
storage
.try_insert(&id1, contents2)
.await
.expect("writing a second time should not fail with IO error")
.is_none(),
"writing a second time should be rejected",
);
assert_eq!(
storage.try_insert(&id1, contents1).await.expect("idempotent write should not fail"),
Some(created1),
"idempotent write should have the same creation time",
);
tokio::time::sleep(Duration::from_millis(1)).await;
let created2 = storage
.try_insert(&id2, contents2)
.await
.expect("writing should succeed")
.expect("writing should return a creation time");
assert!(created1 < created2, "creation times should be ordered as expected");
assert_eq!(
storage.insert_order().await.expect("enumeration should succeed"),
vec![(created1, id1), (created2, id2)],
"enumeration should return expected keys and creation times",
);
let mut file_contents =
fs::read(storage.mailbox_path(&id1)).await.expect("mailbox file should be readable");
assert_eq!(file_contents.len(), contents1.len(), "file data should have the right length");
assert_ne!(file_contents, contents1, "file data should be obfuscated");
storage.xor_buffer(&mut file_contents[..]);
assert_eq!(file_contents, contents1, "deobfuscation should recover contents");
storage.remove(&id1).await.expect("removing an existing mailbox should succeed");
assert!(
!storage.contains_key(&id1).await.expect("checking existence should not error"),
"mailbox file should no longer exist"
);
storage.remove(&id1).await.expect("removing a non-existing mailbox should still not error");
assert_eq!(
storage.insert_order().await.expect("enumeration should succeed"),
vec![(created2, id2)],
"enumeration should return expected keys and creation times",
);
Ok(())
}
#[tokio::test]
async fn test_mailbox_storage() -> std::io::Result<()> {
let dir = tempfile::tempdir()?;
let db = FilesDb::init(Duration::from_millis(10), dir.path().to_owned())
.await
.expect("initializing mailbox database should succeed");
let id = ShortId([0u8; 8]);
let contents = b"foo bar";
db.post_v2_payload(&id, contents.to_vec())
.await
.expect("posting payload should succeed")
.expect("contents should be accepted");
let res = db.wait_for_v2_payload(&id).await.expect("waiting for payload should succeed");
assert_eq!(&res[..], contents, "posted payload should be retrievable");
Ok(())
}
#[tokio::test]
async fn test_v2_wait() -> std::io::Result<()> {
let dir = tempfile::tempdir()?;
let db = FilesDb::init(Duration::from_millis(1), dir.path().to_owned())
.await
.expect("initializing mailbox database should succeed");
let id = ShortId([0u8; 8]);
let contents = b"foo bar";
match db.wait_for_v2_payload(&id).await {
Err(DbError::Timeout(_)) => {}
res => panic!("expected timeout, got {res:?}"),
}
let read_task1 = tokio::spawn({
let db = db.clone();
async move { db.wait_for_v2_payload(&id).await }
});
let read_task2 = tokio::spawn({
let db = db.clone();
async move { db.wait_for_v2_payload(&id).await }
});
db.post_v2_payload(&id, contents.to_vec())
.await
.expect("posting payload should succeed")
.expect("contents should be accepted");
let res = read_task1
.await
.expect("joining task should succeed")
.expect("waiting for payload should succeed");
assert_eq!(&res[..], contents, "posted payload should be retrievable");
let res = read_task2
.await
.expect("joining task should succeed")
.expect("waiting for payload should succeed");
assert_eq!(&res[..], contents, "posted payload should be retrievable");
assert!(
db.post_v2_payload(&id, b"something else".to_vec())
.await
.expect("posting payload should succeed")
.is_none(),
"duplicate POST should be rejected"
);
let res = db.wait_for_v2_payload(&id).await.expect("reading payload should succeed");
assert_eq!(&res[..], contents, "posted payload should be retrievable");
Ok(())
}
#[tokio::test]
async fn test_v1_wait() -> std::io::Result<()> {
let dir = tempfile::tempdir()?;
let db = Arc::new(
FilesDb::init(Duration::from_millis(1), dir.path().to_owned())
.await
.expect("initializing mailbox database should succeed"),
);
let id = ShortId([0u8; 8]);
let v1_sender_task = tokio::spawn({
let db = db.clone();
async move { db.post_v1_request_and_wait_for_response(&id, b"request".to_vec()).await }
});
let res = db.wait_for_v2_payload(&id).await.expect("reading payload should succeed");
assert_eq!(&res[..], b"request", "in flight v1 request should be retrievable");
assert!(
matches!(
db.post_v1_request_and_wait_for_response(&id, b"different request".to_vec()).await,
Err(DbError::OverCapacity),
),
"second v1 sender with the same shortid should be rejected while request is in flight",
);
db.post_v1_response(&id, b"response".to_vec()).await.expect("posting payload should succeed");
let res = v1_sender_task
.await
.expect("joining task should succeed")
.expect("waiting for payload should succeed");
assert_eq!(&res[..], b"response", "should be response from v2 receiver");
assert!(
matches!(
db.post_v1_response(&id, b"response".to_vec()).await,
Err(DbError::V1SenderUnavailable)
),
"posting without a v1 sender waiting should fail"
);
Ok(())
}
#[tokio::test]
async fn test_v1_data_minimization() -> std::io::Result<()> {
let dir = tempfile::tempdir()?;
let db = Arc::new(
FilesDb::init(Duration::from_millis(500), dir.path().to_owned())
.await
.expect("initializing mailbox database should succeed"),
);
let id = ShortId([0u8; 8]);
let v1_sender_task = tokio::spawn({
let db = db.clone();
async move { db.post_v1_request_and_wait_for_response(&id, b"request".to_vec()).await }
});
tokio::time::sleep(Duration::from_millis(10)).await;
let res = db.wait_for_v2_payload(&id).await.expect("first read should succeed");
assert_eq!(&res[..], b"request", "first read should return the payload");
assert!(
matches!(db.wait_for_v2_payload(&id).await, Err(DbError::AlreadyRead)),
"subsequent reads should indicate the payload was already consumed"
);
{
let guard = db.mailboxes.lock().await;
let entry = guard.pending_v1.get(&id);
assert!(
entry.is_none_or(|e| e.payload.is_none()),
"v1 payload should have been cleared after first read"
);
}
db.post_v1_response(&id, b"response".to_vec()).await.expect("posting response should succeed");
let res = v1_sender_task
.await
.expect("joining task should succeed")
.expect("v1 sender should get response");
assert_eq!(&res[..], b"response", "v1 sender should receive the response");
Ok(())
}
#[tokio::test]
async fn test_prune() -> std::io::Result<()> {
let dir = tempfile::tempdir()?;
let db = FilesDb::init(Duration::from_millis(2), dir.path().to_owned())
.await
.expect("initializing mailbox database should succeed");
let read_ttl = Duration::from_secs(60);
let unread_ttl_at_capacity = Duration::from_secs(600);
let unread_ttl_below_capacity = Duration::from_secs(3600);
{
let mut guard = db.mailboxes.lock().await;
guard.capacity = 2;
guard.read_ttl = read_ttl;
guard.unread_ttl_at_capacity = unread_ttl_at_capacity;
guard.unread_ttl_below_capacity = unread_ttl_below_capacity;
}
assert_eq!(db.mailboxes.lock().await.len(), 0);
db.prune().await.expect("pruning should not fail");
assert_eq!(db.mailboxes.lock().await.len(), 0);
let id = ShortId([0u8; 8]);
let contents = b"fooo";
let read_task1 = tokio::spawn({
let db = db.clone();
async move { db.wait_for_v2_payload(&id).await }
});
tokio::time::sleep(Duration::from_millis(1)).await;
assert_eq!(db.mailboxes.lock().await.len(), 1);
match read_task1.await.expect("joining should succeed") {
Err(DbError::Timeout(_)) => {}
res => panic!("expected timeout, got {res:?}"),
}
db.prune().await.expect("pruning should not fail");
assert_eq!(db.mailboxes.lock().await.len(), 0);
db.post_v2_payload(&id, contents.to_vec())
.await
.expect("posting payload should succeed")
.expect("contents should be accepted");
assert_eq!(db.mailboxes.lock().await.len(), 1);
db.prune().await.expect("pruning should not fail");
assert_eq!(db.mailboxes.lock().await.len(), 1);
{
let mut guard = db.mailboxes.lock().await;
for (ts, _) in guard.insert_order.iter_mut() {
*ts -= unread_ttl_below_capacity + Duration::from_secs(1);
}
}
assert_eq!(db.mailboxes.lock().await.len(), 1);
db.prune().await.expect("pruning should not fail");
assert_eq!(db.mailboxes.lock().await.len(), 0);
db.post_v2_payload(&id, contents.to_vec())
.await
.expect("posting payload should succeed")
.expect("contents should be accepted");
assert_eq!(db.mailboxes.lock().await.len(), 1);
_ = db.wait_for_v2_payload(&id).await.expect("waiting for payload should succeed");
assert_eq!(db.mailboxes.lock().await.len(), 1);
db.prune().await.expect("pruning should not fail");
assert_eq!(db.mailboxes.lock().await.len(), 1);
{
let mut guard = db.mailboxes.lock().await;
for (ts, _) in guard.read_order.iter_mut() {
*ts -= read_ttl + Duration::from_secs(1);
}
}
assert_eq!(db.mailboxes.lock().await.len(), 1);
db.prune().await.expect("pruning should not fail");
assert_eq!(db.mailboxes.lock().await.len(), 0);
db.prune().await.expect("pruning should not fail");
assert_eq!(db.mailboxes.lock().await.len(), 0);
Ok(())
}