use crate::SendError;
use crate::mailbox::{Envelope, MailboxSender};
use starlang_core::{ExitReason, Pid, Ref, Term};
use std::collections::HashSet;
use std::sync::{Arc, RwLock};
use tokio::sync::oneshot;
#[derive(Debug)]
pub struct ProcessState {
pub pid: Pid,
pub trap_exit: bool,
pub links: HashSet<Pid>,
pub monitors: std::collections::HashMap<Ref, Pid>,
pub monitored_by: std::collections::HashMap<Ref, Pid>,
pub terminated: bool,
pub exit_reason: Option<ExitReason>,
}
impl ProcessState {
pub fn new(pid: Pid) -> Self {
Self {
pid,
trap_exit: false,
links: HashSet::new(),
monitors: std::collections::HashMap::new(),
monitored_by: std::collections::HashMap::new(),
terminated: false,
exit_reason: None,
}
}
}
#[derive(Clone)]
pub struct ProcessHandle {
pid: Pid,
sender: MailboxSender,
state: Arc<RwLock<ProcessState>>,
#[allow(dead_code)]
termination_tx: Arc<Option<oneshot::Sender<ExitReason>>>,
}
impl ProcessHandle {
pub fn new(
pid: Pid,
sender: MailboxSender,
state: Arc<RwLock<ProcessState>>,
termination_tx: Option<oneshot::Sender<ExitReason>>,
) -> Self {
Self {
pid,
sender,
state,
termination_tx: Arc::new(termination_tx),
}
}
pub fn pid(&self) -> Pid {
self.pid
}
pub fn send_raw(&self, data: Vec<u8>) -> Result<(), SendError> {
if self.sender.is_closed() {
return Err(SendError::ProcessTerminated);
}
self.sender
.send(Envelope::new(data))
.map_err(|_| SendError::ProcessTerminated)
}
pub fn send<M: Term>(&self, msg: &M) -> Result<(), SendError> {
self.send_raw(msg.encode())
}
pub fn is_alive(&self) -> bool {
let state = self.state.read().unwrap();
!state.terminated
}
pub fn is_trapping_exits(&self) -> bool {
let state = self.state.read().unwrap();
state.trap_exit
}
pub fn set_trap_exit(&self, trap: bool) {
let mut state = self.state.write().unwrap();
state.trap_exit = trap;
}
pub fn add_link(&self, other: Pid) {
let mut state = self.state.write().unwrap();
state.links.insert(other);
}
pub fn remove_link(&self, other: Pid) {
let mut state = self.state.write().unwrap();
state.links.remove(&other);
}
pub fn links(&self) -> Vec<Pid> {
let state = self.state.read().unwrap();
state.links.iter().copied().collect()
}
pub fn add_monitor(&self, reference: Ref, target: Pid) {
let mut state = self.state.write().unwrap();
state.monitors.insert(reference, target);
}
pub fn remove_monitor(&self, reference: Ref) -> Option<Pid> {
let mut state = self.state.write().unwrap();
state.monitors.remove(&reference)
}
pub fn add_monitored_by(&self, reference: Ref, monitoring_pid: Pid) {
let mut state = self.state.write().unwrap();
state.monitored_by.insert(reference, monitoring_pid);
}
pub fn remove_monitored_by(&self, reference: Ref) -> Option<Pid> {
let mut state = self.state.write().unwrap();
state.monitored_by.remove(&reference)
}
pub fn monitored_by(&self) -> Vec<(Ref, Pid)> {
let state = self.state.read().unwrap();
state.monitored_by.iter().map(|(r, p)| (*r, *p)).collect()
}
pub fn mark_terminated(&self, reason: ExitReason) {
let mut state = self.state.write().unwrap();
state.terminated = true;
state.exit_reason = Some(reason);
}
pub fn exit_reason(&self) -> Option<ExitReason> {
let state = self.state.read().unwrap();
state.exit_reason.clone()
}
}
impl std::fmt::Debug for ProcessHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProcessHandle")
.field("pid", &self.pid)
.field("alive", &self.is_alive())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mailbox::Mailbox;
fn create_test_handle() -> (ProcessHandle, crate::mailbox::Mailbox) {
let pid = Pid::new();
let (mailbox, sender) = Mailbox::new();
let state = Arc::new(RwLock::new(ProcessState::new(pid)));
let handle = ProcessHandle::new(pid, sender, state, None);
(handle, mailbox)
}
#[test]
fn test_process_handle_pid() {
let (handle, _mailbox) = create_test_handle();
let pid = handle.pid();
assert!(pid.is_local());
}
#[tokio::test]
async fn test_send_message() {
let (handle, mut mailbox) = create_test_handle();
handle.send_raw(vec![1, 2, 3]).unwrap();
let envelope = mailbox.recv().await.unwrap();
assert_eq!(envelope.data, vec![1, 2, 3]);
}
#[test]
fn test_trap_exit() {
let (handle, _mailbox) = create_test_handle();
assert!(!handle.is_trapping_exits());
handle.set_trap_exit(true);
assert!(handle.is_trapping_exits());
handle.set_trap_exit(false);
assert!(!handle.is_trapping_exits());
}
#[test]
fn test_links() {
let (handle, _mailbox) = create_test_handle();
let other_pid = Pid::new();
assert!(handle.links().is_empty());
handle.add_link(other_pid);
assert_eq!(handle.links(), vec![other_pid]);
handle.remove_link(other_pid);
assert!(handle.links().is_empty());
}
#[test]
fn test_monitors() {
let (handle, _mailbox) = create_test_handle();
let target_pid = Pid::new();
let reference = Ref::new();
handle.add_monitor(reference, target_pid);
let removed = handle.remove_monitor(reference);
assert_eq!(removed, Some(target_pid));
let removed_again = handle.remove_monitor(reference);
assert_eq!(removed_again, None);
}
#[test]
fn test_terminated() {
let (handle, _mailbox) = create_test_handle();
assert!(handle.is_alive());
assert!(handle.exit_reason().is_none());
handle.mark_terminated(ExitReason::Normal);
assert!(!handle.is_alive());
assert_eq!(handle.exit_reason(), Some(ExitReason::Normal));
}
}