use crate::SendError;
use crate::mailbox::Mailbox;
use crate::process_handle::{ProcessHandle, ProcessState};
use crate::registry::ProcessRegistry;
use starlang_core::{ExitReason, Pid, Ref, SystemMessage, Term};
use std::sync::{Arc, RwLock};
use std::time::Duration;
pub struct Context {
pid: Pid,
mailbox: Mailbox,
state: Arc<RwLock<ProcessState>>,
registry: ProcessRegistry,
}
impl Context {
pub fn new(
pid: Pid,
mailbox: Mailbox,
state: Arc<RwLock<ProcessState>>,
registry: ProcessRegistry,
) -> Self {
Self {
pid,
mailbox,
state,
registry,
}
}
pub fn pid(&self) -> Pid {
self.pid
}
pub async fn recv(&mut self) -> Option<Vec<u8>> {
self.mailbox.recv().await.map(|e| e.data)
}
pub async fn recv_timeout(&mut self, timeout: Duration) -> Result<Option<Vec<u8>>, ()> {
self.mailbox
.recv_timeout(timeout)
.await
.map(|opt| opt.map(|e| e.data))
}
pub fn try_recv(&mut self) -> Option<Vec<u8>> {
self.mailbox.try_recv().ok().map(|e| e.data)
}
pub fn send_raw(&self, pid: Pid, data: Vec<u8>) -> Result<(), SendError> {
self.registry.send_raw(pid, data)
}
pub fn send<M: Term>(&self, pid: Pid, msg: &M) -> Result<(), SendError> {
self.registry.send(pid, msg)
}
pub fn set_trap_exit(&self, trap: bool) -> bool {
let mut state = self.state.write().unwrap();
let prev = state.trap_exit;
state.trap_exit = trap;
prev
}
pub fn is_trapping_exits(&self) -> bool {
let state = self.state.read().unwrap();
state.trap_exit
}
pub fn link(&self, other: Pid) -> Result<(), SendError> {
{
let mut state = self.state.write().unwrap();
state.links.insert(other);
}
if let Some(other_handle) = self.registry.get(other) {
other_handle.add_link(self.pid);
Ok(())
} else {
let mut state = self.state.write().unwrap();
state.links.remove(&other);
Err(SendError::ProcessNotFound(other))
}
}
pub fn unlink(&self, other: Pid) {
{
let mut state = self.state.write().unwrap();
state.links.remove(&other);
}
if let Some(other_handle) = self.registry.get(other) {
other_handle.remove_link(self.pid);
}
}
pub fn monitor(&self, target: Pid) -> Result<Ref, SendError> {
let reference = Ref::new();
{
let mut state = self.state.write().unwrap();
state.monitors.insert(reference, target);
}
if let Some(target_handle) = self.registry.get(target) {
target_handle.add_monitored_by(reference, self.pid);
Ok(reference)
} else {
let mut state = self.state.write().unwrap();
state.monitors.remove(&reference);
let down = SystemMessage::down(reference, target, ExitReason::error("noproc"));
let _ = self.registry.send(self.pid, &down);
Ok(reference)
}
}
pub fn demonitor(&self, reference: Ref) {
let target = {
let mut state = self.state.write().unwrap();
state.monitors.remove(&reference)
};
if let Some(target_pid) = target
&& let Some(target_handle) = self.registry.get(target_pid)
{
target_handle.remove_monitored_by(reference);
}
}
pub fn exit(&self, target: Pid, reason: ExitReason) -> Result<(), SendError> {
if let Some(handle) = self.registry.get(target) {
if reason.is_killed() {
handle.mark_terminated(reason);
} else if handle.is_trapping_exits() {
let exit_msg = SystemMessage::exit(self.pid, reason);
handle.send(&exit_msg)?;
} else if reason.is_abnormal() {
handle.mark_terminated(reason);
}
Ok(())
} else {
Err(SendError::ProcessNotFound(target))
}
}
pub fn whereis(&self, name: &str) -> Option<Pid> {
self.registry.whereis(name)
}
pub fn register(&self, name: String) -> bool {
self.registry.register_name(name, self.pid)
}
pub fn unregister(&self, name: &str) -> Option<Pid> {
self.registry.unregister_name(name)
}
pub fn is_alive(&self, pid: Pid) -> bool {
self.registry
.get(pid)
.map(|h| h.is_alive())
.unwrap_or(false)
}
pub(crate) fn handle(&self) -> ProcessHandle {
self.registry.get(self.pid).unwrap()
}
}
impl std::fmt::Debug for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Context").field("pid", &self.pid).finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mailbox::MailboxSender;
fn create_test_context(registry: &ProcessRegistry) -> (Context, MailboxSender) {
let pid = Pid::new();
let (mailbox, sender) = Mailbox::new();
let state = Arc::new(RwLock::new(ProcessState::new(pid)));
let handle = ProcessHandle::new(pid, sender.clone(), state.clone(), None);
registry.register(handle);
let ctx = Context::new(pid, mailbox, state, registry.clone());
(ctx, sender)
}
#[test]
fn test_context_pid() {
let registry = ProcessRegistry::new();
let (ctx, _sender) = create_test_context(®istry);
assert!(ctx.pid().is_local());
}
#[test]
fn test_trap_exit() {
let registry = ProcessRegistry::new();
let (ctx, _sender) = create_test_context(®istry);
assert!(!ctx.is_trapping_exits());
let prev = ctx.set_trap_exit(true);
assert!(!prev);
assert!(ctx.is_trapping_exits());
}
#[tokio::test]
async fn test_recv() {
let registry = ProcessRegistry::new();
let (mut ctx, sender) = create_test_context(®istry);
sender
.send(crate::mailbox::Envelope::new(vec![1, 2, 3]))
.unwrap();
let msg = ctx.recv().await.unwrap();
assert_eq!(msg, vec![1, 2, 3]);
}
#[test]
fn test_link() {
let registry = ProcessRegistry::new();
let (ctx1, _sender1) = create_test_context(®istry);
let (ctx2, _sender2) = create_test_context(®istry);
ctx1.link(ctx2.pid()).unwrap();
let state1 = ctx1.state.read().unwrap();
assert!(state1.links.contains(&ctx2.pid()));
let state2 = ctx2.state.read().unwrap();
assert!(state2.links.contains(&ctx1.pid()));
}
#[test]
fn test_unlink() {
let registry = ProcessRegistry::new();
let (ctx1, _sender1) = create_test_context(®istry);
let (ctx2, _sender2) = create_test_context(®istry);
ctx1.link(ctx2.pid()).unwrap();
ctx1.unlink(ctx2.pid());
let state1 = ctx1.state.read().unwrap();
assert!(!state1.links.contains(&ctx2.pid()));
let state2 = ctx2.state.read().unwrap();
assert!(!state2.links.contains(&ctx1.pid()));
}
#[test]
fn test_monitor() {
let registry = ProcessRegistry::new();
let (ctx1, _sender1) = create_test_context(®istry);
let (ctx2, _sender2) = create_test_context(®istry);
let reference = ctx1.monitor(ctx2.pid()).unwrap();
let state1 = ctx1.state.read().unwrap();
assert_eq!(state1.monitors.get(&reference), Some(&ctx2.pid()));
let state2 = ctx2.state.read().unwrap();
assert_eq!(state2.monitored_by.get(&reference), Some(&ctx1.pid()));
}
#[test]
fn test_demonitor() {
let registry = ProcessRegistry::new();
let (ctx1, _sender1) = create_test_context(®istry);
let (ctx2, _sender2) = create_test_context(®istry);
let reference = ctx1.monitor(ctx2.pid()).unwrap();
ctx1.demonitor(reference);
let state1 = ctx1.state.read().unwrap();
assert!(!state1.monitors.contains_key(&reference));
let state2 = ctx2.state.read().unwrap();
assert!(!state2.monitored_by.contains_key(&reference));
}
#[test]
fn test_register_name() {
let registry = ProcessRegistry::new();
let (ctx, _sender) = create_test_context(®istry);
assert!(ctx.register("test_proc".to_string()));
assert_eq!(ctx.whereis("test_proc"), Some(ctx.pid()));
assert!(!ctx.register("test_proc".to_string()));
}
}