use polaris_system::param::SystemContext;
use polaris_system::resource::Output;
use std::any::{TypeId, type_name};
use std::fmt;
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub enum PredicateError {
OutputNotFound {
type_name: &'static str,
},
ContextError(String),
}
impl fmt::Display for PredicateError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PredicateError::OutputNotFound { type_name } => {
write!(f, "output not found: {type_name}")
}
PredicateError::ContextError(msg) => {
write!(f, "context error: {msg}")
}
}
}
}
impl std::error::Error for PredicateError {}
pub trait ErasedPredicate: Send + Sync {
fn evaluate(&self, ctx: &SystemContext<'_>) -> Result<bool, PredicateError>;
fn input_type_id(&self) -> TypeId;
fn input_type_name(&self) -> &'static str;
}
impl fmt::Debug for dyn ErasedPredicate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ErasedPredicate")
.field("input_type", &self.input_type_name())
.finish()
}
}
pub type BoxedPredicate = Box<dyn ErasedPredicate>;
pub trait ErasedDiscriminator: Send + Sync {
fn discriminate(&self, ctx: &SystemContext<'_>) -> Result<&'static str, PredicateError>;
fn input_type_id(&self) -> TypeId;
fn input_type_name(&self) -> &'static str;
}
impl fmt::Debug for dyn ErasedDiscriminator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ErasedDiscriminator")
.field("input_type", &self.input_type_name())
.finish()
}
}
pub type BoxedDiscriminator = Box<dyn ErasedDiscriminator>;
pub struct Predicate<T, F> {
func: F,
_marker: PhantomData<fn() -> T>,
}
impl<T, F> Predicate<T, F>
where
T: Output,
F: Fn(&T) -> bool + Send + Sync + 'static,
{
#[must_use]
pub fn new(func: F) -> Self {
Self {
func,
_marker: PhantomData,
}
}
}
impl<T, F> ErasedPredicate for Predicate<T, F>
where
T: Output,
F: Fn(&T) -> bool + Send + Sync + 'static,
{
fn evaluate(&self, ctx: &SystemContext<'_>) -> Result<bool, PredicateError> {
let output = ctx
.get_output::<T>()
.map_err(|_| PredicateError::OutputNotFound {
type_name: type_name::<T>(),
})?;
Ok((self.func)(&output))
}
fn input_type_id(&self) -> TypeId {
TypeId::of::<T>()
}
fn input_type_name(&self) -> &'static str {
type_name::<T>()
}
}
impl<T, F> fmt::Debug for Predicate<T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Predicate")
.field("input_type", &type_name::<T>())
.finish()
}
}
pub struct Discriminator<T, F> {
func: F,
_marker: PhantomData<fn() -> T>,
}
impl<T, F> Discriminator<T, F>
where
T: Output,
F: Fn(&T) -> &'static str + Send + Sync + 'static,
{
#[must_use]
pub fn new(func: F) -> Self {
Self {
func,
_marker: PhantomData,
}
}
}
impl<T, F> ErasedDiscriminator for Discriminator<T, F>
where
T: Output,
F: Fn(&T) -> &'static str + Send + Sync + 'static,
{
fn discriminate(&self, ctx: &SystemContext<'_>) -> Result<&'static str, PredicateError> {
let output = ctx
.get_output::<T>()
.map_err(|_| PredicateError::OutputNotFound {
type_name: type_name::<T>(),
})?;
Ok((self.func)(&output))
}
fn input_type_id(&self) -> TypeId {
TypeId::of::<T>()
}
fn input_type_name(&self) -> &'static str {
type_name::<T>()
}
}
impl<T, F> fmt::Debug for Discriminator<T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Discriminator")
.field("input_type", &type_name::<T>())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct TestOutput {
value: i32,
done: bool,
}
#[test]
fn predicate_evaluate_true() {
let pred = Predicate::<TestOutput, _>::new(|output| output.value > 5);
let mut ctx = SystemContext::new();
ctx.insert_output(TestOutput {
value: 10,
done: false,
});
let result = pred.evaluate(&ctx).unwrap();
assert!(result);
}
#[test]
fn predicate_evaluate_false() {
let pred = Predicate::<TestOutput, _>::new(|output| output.value > 5);
let mut ctx = SystemContext::new();
ctx.insert_output(TestOutput {
value: 3,
done: false,
});
let result = pred.evaluate(&ctx).unwrap();
assert!(!result);
}
#[test]
fn predicate_missing_output() {
let pred = Predicate::<TestOutput, _>::new(|_| true);
let ctx = SystemContext::new();
let result = pred.evaluate(&ctx);
assert!(matches!(result, Err(PredicateError::OutputNotFound { .. })));
}
#[test]
fn boxed_predicate() {
let pred: BoxedPredicate = Box::new(Predicate::<TestOutput, _>::new(|o| o.done));
let mut ctx = SystemContext::new();
ctx.insert_output(TestOutput {
value: 0,
done: true,
});
assert!(pred.evaluate(&ctx).unwrap());
}
#[derive(Debug, Clone)]
struct RouterOutput {
action: &'static str,
}
#[test]
fn discriminator_returns_key() {
let disc = Discriminator::<RouterOutput, _>::new(|output| output.action);
let mut ctx = SystemContext::new();
ctx.insert_output(RouterOutput { action: "tool" });
let result = disc.discriminate(&ctx).unwrap();
assert_eq!(result, "tool");
}
#[test]
fn discriminator_different_keys() {
let disc = Discriminator::<RouterOutput, _>::new(|output| output.action);
let mut ctx = SystemContext::new();
ctx.insert_output(RouterOutput { action: "respond" });
assert_eq!(disc.discriminate(&ctx).unwrap(), "respond");
ctx.insert_output(RouterOutput { action: "clarify" });
assert_eq!(disc.discriminate(&ctx).unwrap(), "clarify");
}
#[test]
fn discriminator_missing_output() {
let disc = Discriminator::<RouterOutput, _>::new(|_| "test");
let ctx = SystemContext::new();
let result = disc.discriminate(&ctx);
assert!(matches!(result, Err(PredicateError::OutputNotFound { .. })));
}
#[test]
fn boxed_discriminator() {
let disc: BoxedDiscriminator =
Box::new(Discriminator::<RouterOutput, _>::new(|o| o.action));
let mut ctx = SystemContext::new();
ctx.insert_output(RouterOutput { action: "agent" });
assert_eq!(disc.discriminate(&ctx).unwrap(), "agent");
}
}