use super::client::AmatersClient;
use super::records::{MessageBlob, MessageRecord};
use super::UidLockMap;
use crate::traits::MessageStore;
use crate::types::{MailboxCounters, MailboxId, MessageFlags, MessageMetadata, SearchCriteria};
use async_trait::async_trait;
use rusmes_proto::{Mail, MailAddress, MessageId, MimeMessage};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
async fn mime_to_rfc822(mime: &MimeMessage) -> anyhow::Result<Vec<u8>> {
let mut buf: Vec<u8> = Vec::new();
for (name, values) in mime.headers().iter() {
for value in values {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
}
buf.extend_from_slice(b"\r\n");
let body_bytes: Vec<u8> = match mime.body() {
rusmes_proto::MessageBody::Small(bytes) => bytes.to_vec(),
rusmes_proto::MessageBody::Large(large) => large
.read_to_bytes()
.await
.map_err(|e| anyhow::anyhow!("Failed to read large body: {e}"))?
.to_vec(),
};
buf.extend_from_slice(&body_bytes);
Ok(buf)
}
fn counters_key(mailbox_id: &MailboxId) -> String {
format!("counters:{}", mailbox_id)
}
async fn read_counters(
client: &AmatersClient,
keyspace: &str,
mailbox_id: &MailboxId,
) -> anyhow::Result<MailboxCounters> {
let key = counters_key(mailbox_id);
match client.get(keyspace, &key).await? {
Some(data) => Ok(serde_json::from_slice(&data)?),
None => Ok(MailboxCounters::default()),
}
}
async fn write_counters(
client: &AmatersClient,
keyspace: &str,
mailbox_id: &MailboxId,
counters: &MailboxCounters,
) -> anyhow::Result<()> {
let key = counters_key(mailbox_id);
let data = serde_json::to_vec(counters)?;
client.put(keyspace, key, data).await
}
fn nextuid_key(mailbox_id: &MailboxId) -> String {
format!("nextuid:{}", mailbox_id)
}
async fn get_or_create_mailbox_mutex(
uid_locks: &UidLockMap,
mailbox_id: &MailboxId,
) -> Arc<Mutex<u32>> {
let mut map = uid_locks.lock().await;
map.entry(mailbox_id.to_string())
.or_insert_with(|| Arc::new(Mutex::new(1u32)))
.clone()
}
async fn sync_and_advance_uid(
client: &AmatersClient,
keyspace: &str,
mailbox_id: &MailboxId,
nextuid_guard: &mut u32,
) -> anyhow::Result<u32> {
let key = nextuid_key(mailbox_id);
if let Some(bytes) = client.get(keyspace, &key).await? {
if bytes.len() >= 4 {
let arr: [u8; 4] = bytes[..4]
.try_into()
.map_err(|_| anyhow::anyhow!("nextuid key has invalid length"))?;
let stored = u32::from_be_bytes(arr);
if stored > *nextuid_guard {
*nextuid_guard = stored;
}
}
}
let uid = *nextuid_guard;
*nextuid_guard = uid
.checked_add(1)
.ok_or_else(|| anyhow::anyhow!("UID counter overflow for mailbox {}", mailbox_id))?;
let new_nextuid = *nextuid_guard;
client
.put(keyspace, key, new_nextuid.to_be_bytes().to_vec())
.await?;
Ok(uid)
}
pub(super) struct AmatersMessageStore {
pub(super) client: Arc<AmatersClient>,
pub(super) metadata_keyspace: String,
pub(super) blob_keyspace: String,
pub(super) uid_locks: UidLockMap,
}
#[async_trait]
impl MessageStore for AmatersMessageStore {
async fn append_message(
&self,
mailbox_id: &MailboxId,
message: Mail,
) -> anyhow::Result<MessageMetadata> {
let message_id = *message.message_id();
let message_size = message.size();
let rfc822_bytes = mime_to_rfc822(message.message()).await?;
let per_mailbox_mutex = get_or_create_mailbox_mutex(&self.uid_locks, mailbox_id).await;
let mut nextuid_guard = per_mailbox_mutex.lock().await;
let uid = sync_and_advance_uid(
&self.client,
&self.metadata_keyspace,
mailbox_id,
&mut nextuid_guard,
)
.await?;
let blob = MessageBlob {
message_id: message_id.to_string(),
body: rfc822_bytes,
compressed: false,
};
let blob_key = format!("blob:{}", message_id);
let blob_value = serde_json::to_vec(&blob)?;
self.client
.put(&self.blob_keyspace, blob_key.clone(), blob_value)
.await?;
let record = MessageRecord {
id: message_id.to_string(),
mailbox_id: mailbox_id.to_string(),
uid,
sender: message.sender().map(|s| s.to_string()),
recipients: message.recipients().iter().map(|r| r.to_string()).collect(),
headers: HashMap::new(),
size: message_size,
blob_key,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64,
};
let metadata_key = format!("message:{}", message_id);
let metadata_value = serde_json::to_vec(&record)?;
self.client
.put(&self.metadata_keyspace, metadata_key, metadata_value)
.await?;
let mailbox_index_key = format!("mailbox:{}:message:{}", mailbox_id, message_id);
self.client
.put(&self.metadata_keyspace, mailbox_index_key, vec![])
.await?;
let mut counters = read_counters(&self.client, &self.metadata_keyspace, mailbox_id).await?;
counters.exists = counters.exists.saturating_add(1);
counters.recent = counters.recent.saturating_add(1);
counters.unseen = counters.unseen.saturating_add(1);
write_counters(&self.client, &self.metadata_keyspace, mailbox_id, &counters).await?;
drop(nextuid_guard);
let mut flags = MessageFlags::new();
flags.set_recent(true);
let metadata = MessageMetadata::new(message_id, *mailbox_id, uid, flags, message_size);
Ok(metadata)
}
async fn get_message(&self, message_id: &MessageId) -> anyhow::Result<Option<Mail>> {
let metadata_key = format!("message:{}", message_id);
let record_bytes = match self
.client
.get(&self.metadata_keyspace, &metadata_key)
.await?
{
Some(b) => b,
None => return Ok(None),
};
let record: MessageRecord = serde_json::from_slice(&record_bytes)?;
let blob_bytes = match self
.client
.get(&self.blob_keyspace, &record.blob_key)
.await?
{
Some(b) => b,
None => {
tracing::warn!(
"Blob {} for message {} not found",
record.blob_key,
message_id
);
return Ok(None);
}
};
let blob: MessageBlob = serde_json::from_slice(&blob_bytes)?;
let mime = MimeMessage::parse_from_bytes(&blob.body)
.map_err(|e| anyhow::anyhow!("Failed to parse stored RFC 822 blob: {e}"))?;
let sender: Option<MailAddress> = record
.sender
.as_deref()
.map(|s| {
s.parse::<MailAddress>()
.map_err(|e| anyhow::anyhow!("Invalid stored sender '{}': {}", s, e))
})
.transpose()?;
let recipients: Vec<MailAddress> = record
.recipients
.iter()
.map(|r| {
r.parse::<MailAddress>()
.map_err(|e| anyhow::anyhow!("Invalid stored recipient '{}': {}", r, e))
})
.collect::<anyhow::Result<Vec<_>>>()?;
let mail = Mail::with_message_id(sender, recipients, mime, None, None, *message_id);
Ok(Some(mail))
}
async fn delete_messages(&self, message_ids: &[MessageId]) -> anyhow::Result<()> {
let mut mailbox_deletes: HashMap<String, u32> = HashMap::new();
for message_id in message_ids {
let key = format!("message:{}", message_id);
if let Some(data) = self.client.get(&self.metadata_keyspace, &key).await? {
if let Ok(record) = serde_json::from_slice::<MessageRecord>(&data) {
*mailbox_deletes
.entry(record.mailbox_id.clone())
.or_insert(0) += 1;
self.client
.delete(&self.blob_keyspace, &record.blob_key)
.await?;
let index_key = format!("mailbox:{}:message:{}", record.mailbox_id, message_id);
self.client
.delete(&self.metadata_keyspace, &index_key)
.await?;
}
}
self.client.delete(&self.metadata_keyspace, &key).await?;
}
for (mailbox_id_str, count) in mailbox_deletes {
if let Ok(uuid) = uuid::Uuid::parse_str(&mailbox_id_str) {
let mailbox_id = MailboxId::from_uuid(uuid);
let per_mailbox_mutex =
get_or_create_mailbox_mutex(&self.uid_locks, &mailbox_id).await;
let _guard = per_mailbox_mutex.lock().await;
let mut counters =
read_counters(&self.client, &self.metadata_keyspace, &mailbox_id).await?;
counters.exists = counters.exists.saturating_sub(count);
write_counters(
&self.client,
&self.metadata_keyspace,
&mailbox_id,
&counters,
)
.await?;
}
}
Ok(())
}
async fn set_flags(
&self,
message_ids: &[MessageId],
flags: MessageFlags,
) -> anyhow::Result<()> {
for message_id in message_ids {
let key = format!("flags:{}", message_id);
let value = serde_json::to_vec(&flags)?;
self.client.put(&self.metadata_keyspace, key, value).await?;
}
Ok(())
}
async fn search(
&self,
mailbox_id: &MailboxId,
_criteria: SearchCriteria,
) -> anyhow::Result<Vec<MessageId>> {
let prefix = format!("mailbox:{}:message:", mailbox_id);
let keys = self
.client
.list_prefix(&self.metadata_keyspace, &prefix)
.await?;
let message_ids = keys
.into_iter()
.filter_map(|k| {
k.strip_prefix(&prefix)
.and_then(|id_str| uuid::Uuid::parse_str(id_str).ok().map(MessageId::from_uuid))
})
.collect();
Ok(message_ids)
}
async fn copy_messages(
&self,
message_ids: &[MessageId],
dest_mailbox_id: &MailboxId,
) -> anyhow::Result<Vec<MessageMetadata>> {
let mut metadata_list = Vec::new();
for message_id in message_ids {
if let Some(message) = self.get_message(message_id).await? {
let metadata = self.append_message(dest_mailbox_id, message).await?;
metadata_list.push(metadata);
}
}
Ok(metadata_list)
}
async fn get_mailbox_messages(
&self,
mailbox_id: &MailboxId,
) -> anyhow::Result<Vec<MessageMetadata>> {
let prefix = format!("mailbox:{}:message:", mailbox_id);
let keys = self
.client
.list_prefix(&self.metadata_keyspace, &prefix)
.await?;
let mut metadata_list = Vec::new();
for key in keys {
if let Some(id_str) = key.strip_prefix(&prefix) {
if let Ok(uuid) = uuid::Uuid::parse_str(id_str) {
let message_id = MessageId::from_uuid(uuid);
let metadata_key = format!("message:{}", message_id);
if let Some(data) = self
.client
.get(&self.metadata_keyspace, &metadata_key)
.await?
{
if let Ok(record) = serde_json::from_slice::<MessageRecord>(&data) {
let metadata = MessageMetadata::new(
message_id,
*mailbox_id,
record.uid,
MessageFlags::new(),
record.size,
);
metadata_list.push(metadata);
}
}
}
}
}
Ok(metadata_list)
}
}