use crate::{DefaultKey, Scalar, consideration::*, decision_makers::*, task::*};
use std::{collections::HashMap, hash::Hash};
pub enum ReasonerError<K = DefaultKey> {
StateDoesNotExists(K),
}
impl<K> Clone for ReasonerError<K>
where
K: Clone,
{
fn clone(&self) -> Self {
match self {
Self::StateDoesNotExists(key) => Self::StateDoesNotExists(key.clone()),
}
}
}
impl<K> PartialEq for ReasonerError<K>
where
K: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::StateDoesNotExists(a), Self::StateDoesNotExists(b)) => a == b,
}
}
}
impl<K> Eq for ReasonerError<K> where K: Eq {}
impl<K> std::fmt::Debug for ReasonerError<K>
where
K: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::StateDoesNotExists(key) => {
write!(f, "StateDoesNotExists({:?})", key)
}
}
}
}
pub struct ReasonerState<M = ()> {
consideration: Box<dyn Consideration<M>>,
task: Box<dyn Task<M>>,
}
impl<M> ReasonerState<M> {
pub fn new<C, T>(consideration: C, task: T) -> Self
where
C: Consideration<M> + 'static,
T: Task<M> + 'static,
{
Self {
consideration: Box::new(consideration),
task: Box::new(task),
}
}
pub fn new_raw(consideration: Box<dyn Consideration<M>>, task: Box<dyn Task<M>>) -> Self {
Self {
consideration,
task,
}
}
}
impl<M> std::fmt::Debug for ReasonerState<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReasonerState").finish()
}
}
pub trait ReasonerStateSelector<M, K>: Send + Sync
where
K: Clone + Hash + Eq,
{
fn select_state(&self, memory: &M, scored_states: &[(&K, Scalar)]) -> Option<K>;
}
pub struct MaxReasonerStateSelector;
impl<M, K> ReasonerStateSelector<M, K> for MaxReasonerStateSelector
where
K: Clone + Hash + Eq,
{
fn select_state(&self, _memory: &M, scored_states: &[(&K, Scalar)]) -> Option<K> {
scored_states
.iter()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(k, _)| (*k).clone())
}
}
pub struct MinReasonerStateSelector;
impl<M, K> ReasonerStateSelector<M, K> for MinReasonerStateSelector
where
K: Clone + Hash + Eq,
{
fn select_state(&self, _memory: &M, scored_states: &[(&K, Scalar)]) -> Option<K> {
scored_states
.iter()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(k, _)| (*k).clone())
}
}
pub struct ClosestToReasonerStateSelector(pub Scalar);
impl<M, K> ReasonerStateSelector<M, K> for ClosestToReasonerStateSelector
where
K: Clone + Hash + Eq,
{
fn select_state(&self, _memory: &M, scored_states: &[(&K, Scalar)]) -> Option<K> {
scored_states
.iter()
.min_by(|(_, a), (_, b)| (a - self.0).abs().partial_cmp(&(b - self.0).abs()).unwrap())
.map(|(k, _)| (*k).clone())
}
}
pub struct Reasoner<M = (), K = DefaultKey>
where
K: Clone + Hash + Eq,
{
states: HashMap<K, ReasonerState<M>>,
active_state: Option<K>,
state_selector: Box<dyn ReasonerStateSelector<M, K>>,
}
impl<M, K> Reasoner<M, K>
where
K: Clone + Hash + Eq,
{
pub fn new(states: HashMap<K, ReasonerState<M>>) -> Self {
Self::with_selector(states, MaxReasonerStateSelector)
}
pub fn with_selector<SS>(states: HashMap<K, ReasonerState<M>>, state_selector: SS) -> Self
where
SS: ReasonerStateSelector<M, K> + 'static,
{
Self {
states,
active_state: None,
state_selector: Box::new(state_selector),
}
}
pub fn active_state(&self) -> Option<&K> {
self.active_state.as_ref()
}
pub fn change_active_state(
&mut self,
id: Option<K>,
memory: &mut M,
forced: bool,
) -> Result<bool, ReasonerError<K>> {
if id == self.active_state {
return Ok(false);
}
if let Some(id) = &id
&& !self.states.contains_key(id)
{
return Err(ReasonerError::StateDoesNotExists(id.clone()));
}
if let Some(id) = &self.active_state {
let state = self.states.get_mut(id).unwrap();
if !forced && state.task.is_locked(memory) {
return Ok(false);
}
state.task.on_exit(memory);
}
if let Some(id) = &id {
self.states.get_mut(id).unwrap().task.on_enter(memory);
}
self.active_state = id;
Ok(true)
}
pub fn process(&mut self, memory: &mut M) -> bool {
if self.states.is_empty() {
return false;
}
let scored_ids = self
.states
.iter()
.map(|(id, state)| (id, state.consideration.score(memory)))
.collect::<Vec<_>>();
let Some(new_id) = self.state_selector.select_state(memory, &scored_ids) else {
return false;
};
if let Ok(true) = self.change_active_state(Some(new_id), memory, false) {
return true;
}
if let Some(id) = &self.active_state {
return self.states.get_mut(id).unwrap().task.on_process(memory);
}
false
}
pub fn update(&mut self, memory: &mut M) {
if let Some(id) = &self.active_state {
self.states.get_mut(id).unwrap().task.on_update(memory);
}
}
}
impl<M, K> DecisionMaker<M, K> for Reasoner<M, K>
where
K: Clone + Hash + Eq + Send + Sync,
{
fn decide(&mut self, memory: &mut M) -> Option<K> {
self.process(memory);
self.active_state().cloned()
}
fn change_mind(&mut self, id: Option<K>, memory: &mut M) -> bool {
matches!(self.change_active_state(id, memory, true), Ok(true))
}
}
impl<M, K> Task<M> for Reasoner<M, K>
where
K: Clone + Hash + Eq + Send + Sync,
{
fn is_locked(&self, memory: &M) -> bool {
if let Some(id) = &self.active_state
&& let Some(state) = self.states.get(id)
{
return state.task.is_locked(memory);
}
false
}
fn on_enter(&mut self, memory: &mut M) {
let _ = self.change_active_state(None, memory, true);
self.process(memory);
}
fn on_exit(&mut self, memory: &mut M) {
let _ = self.change_active_state(None, memory, true);
}
fn on_update(&mut self, memory: &mut M) {
self.update(memory);
}
fn on_process(&mut self, memory: &mut M) -> bool {
self.process(memory)
}
}
impl<M, K> std::fmt::Debug for Reasoner<M, K>
where
K: Clone + Hash + Eq + std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Reasoner")
.field("states", &self.states)
.field("active_state", &self.active_state)
.finish()
}
}
pub struct ReasonerBuilder<M = (), K = DefaultKey>(pub HashMap<K, ReasonerState<M>>);
impl<M, K> Default for ReasonerBuilder<M, K> {
fn default() -> Self {
Self(Default::default())
}
}
impl<M, K> ReasonerBuilder<M, K>
where
K: Clone + Hash + Eq,
{
pub fn state(mut self, id: K, state: ReasonerState<M>) -> Self {
self.0.insert(id, state);
self
}
pub fn build(self) -> Reasoner<M, K>
where
K: Clone + Hash + Eq,
{
Reasoner::new(self.0)
}
pub fn build_with_state_selector<SS>(self, state_selector: SS) -> Reasoner<M, K>
where
K: Clone + Hash + Eq,
SS: ReasonerStateSelector<M, K> + 'static,
{
Reasoner::with_selector(self.0, state_selector)
}
}
impl<M, K> std::fmt::Debug for ReasonerBuilder<M, K>
where
K: Clone + Hash + Eq + std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReasonerBuilder")
.field("states", &self.0)
.finish()
}
}