use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use tokio::sync::Notify;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageSource {
User,
Teammate,
System,
Tick,
Task,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub id: String,
pub source: MessageSource,
pub content: String,
pub from: Option<String>,
pub color: Option<String>,
pub timestamp: DateTime<Utc>,
}
impl Message {
pub fn new(source: MessageSource, content: String) -> Self {
Self {
id: Uuid::new_v4().to_string(),
source,
content,
from: None,
color: None,
timestamp: Utc::now(),
}
}
}
pub type MessagePredicate = Box<dyn Fn(&Message) -> bool + Send>;
struct Waiter {
predicate: MessagePredicate,
sender: tokio::sync::oneshot::Sender<Message>,
}
pub struct Mailbox {
queue: VecDeque<Message>,
waiters: Vec<Waiter>,
notify: Notify,
revision: u64,
}
impl Mailbox {
pub fn new() -> Self {
Self {
queue: VecDeque::new(),
waiters: Vec::new(),
notify: Notify::new(),
revision: 0,
}
}
pub fn len(&self) -> usize {
self.queue.len()
}
pub fn is_empty(&self) -> bool {
self.queue.is_empty()
}
pub fn revision(&self) -> u64 {
self.revision
}
pub fn send(&mut self, msg: Message) {
self.revision += 1;
if let Some(idx) = self.waiters.iter().position(|w| (w.predicate)(&msg)) {
let waiter = self.waiters.remove(idx);
let _ = waiter.sender.send(msg);
self.notify.notify_waiters();
return;
}
self.queue.push_back(msg);
self.notify.notify_waiters();
}
pub fn poll<F>(&mut self, predicate: F) -> Option<Message>
where
F: Fn(&Message) -> bool,
{
if let Some(idx) = self.queue.iter().position(|m| predicate(m)) {
let msg = self.queue.remove(idx).unwrap();
self.notify.notify_waiters();
Some(msg)
} else {
None
}
}
pub async fn receive<F>(&mut self, predicate: F) -> Message
where
F: Fn(&Message) -> bool + Send + 'static,
{
if let Some(idx) = self.queue.iter().position(|m| predicate(m)) {
let msg = self.queue.remove(idx).unwrap();
self.notify.notify_waiters();
return msg;
}
let (sender, receiver) = tokio::sync::oneshot::channel();
self.waiters.push(Waiter {
predicate: Box::new(predicate),
sender,
});
self.notify.notified().await;
receiver.await.expect("Mailbox receiver cancelled")
}
pub async fn subscribe(&self) {
self.notify.notified().await;
}
fn notify(&self) {
self.notify.notify_waiters();
}
}
impl Default for Mailbox {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mailbox_send_poll() {
let mut mailbox = Mailbox::new();
let msg = Message::new(MessageSource::User, "hello".to_string());
let id = msg.id.clone();
mailbox.send(msg);
assert_eq!(mailbox.len(), 1);
let received = mailbox.poll(|m| m.id == id);
assert!(received.is_some());
assert_eq!(received.unwrap().content, "hello");
assert_eq!(mailbox.len(), 0);
}
#[tokio::test]
async fn test_mailbox_receive() {
let mut mailbox = Mailbox::new();
let handle = tokio::spawn(async move {
mailbox
.receive(|m| m.source == MessageSource::User)
.await
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let msg = Message::new(MessageSource::User, "test".to_string());
mailbox.send(msg);
let received = handle.await.unwrap();
assert_eq!(received.content, "test");
}
}