Documentation
use anyhow::Result;
use async_trait::async_trait;
use petgraph::{graph::NodeIndex, visit::EdgeRef, Graph};
use std::{collections::HashMap, fmt::Debug, hash::Hash, sync::Arc};

type NodeId = usize;
type Condition<Ctx> = Box<dyn Fn(&Ctx) -> bool + Send + Sync>;
type NodeRegistry<Ctx> = HashMap<NodeId, Arc<dyn NodeTrait<Ctx>>>;
pub trait NodeKind: Clone + Copy + PartialEq + Eq + Hash + Debug {}

// Blanket implementation for any type that satisfies the requirements
impl<T: Clone + Copy + PartialEq + Eq + Hash + Debug> NodeKind for T {}


#[async_trait]
pub trait NodeTrait<Ctx>: Send + Sync {
    async fn enter(&self, ctx: &mut Ctx) -> Result<()>;
}

pub struct AgentEngine<Ctx, N>
where
    N: NodeKind,
{
    graph: Graph<NodeId, Condition<Ctx>>,
    nodes: NodeRegistry<Ctx>,
    current: Option<NodeIndex>,
    context: Ctx,
    next_node_id: NodeId,
    node_map: HashMap<N, NodeId>,
}

impl<Ctx, N> AgentEngine<Ctx, N>
where
    N: NodeKind,
{
    pub fn builder(context: Ctx) -> AgentEngineBuilder<Ctx, N> {
        AgentEngineBuilder::new(context)
    }

    fn new(context: Ctx) -> Self {
        Self {
            graph: Graph::new(),
            nodes: HashMap::new(),
            current: None,
            context,
            next_node_id: 0,
            node_map: HashMap::new(),
        }
    }

    pub fn add_node<T: NodeTrait<Ctx> + 'static>(&mut self, kind: N, node: T) -> NodeId {
        let node_id = self.next_node_id;
        self.next_node_id += 1;

        let _idx = self.graph.add_node(node_id);
        self.nodes.insert(node_id, Arc::new(node));
        self.node_map.insert(kind, node_id);

        node_id
    }

    pub fn add_transition<F>(&mut self, from: N, to: N, condition: F) -> Result<()>
    where
        F: Fn(&Ctx) -> bool + Send + Sync + 'static,
    {
        use anyhow::anyhow;

        let from_id = self.node_map.get(&from).copied().ok_or_else(|| {
            anyhow!("Source node {:?} not found", from)
        })?;

        let to_id = self.node_map.get(&to).copied().ok_or_else(|| {
            anyhow!("Target node {:?} not found", to)
        })?;

        let from_idx = self
            .graph
            .node_indices()
            .find(|idx| self.graph[*idx] == from_id)
            .ok_or_else(|| anyhow!("Source node not found: {:?}", from))?;

        let to_idx = self
            .graph
            .node_indices()
            .find(|idx| self.graph[*idx] == to_id)
            .ok_or_else(|| anyhow!("Target node not found: {:?}", to))?;

        self.graph.add_edge(from_idx, to_idx, Box::new(condition));
        Ok(())
    }

    pub fn set_start_node(&mut self, kind: N) -> Result<()> {
        use anyhow::anyhow;

        let id = self.node_map.get(&kind).copied().ok_or_else(|| {
            anyhow!("Start node {:?} not found", kind)
        })?;

        let start_idx = self
            .graph
            .node_indices()
            .find(|idx| self.graph[*idx] == id)
            .ok_or_else(|| anyhow!("Start node not found: {:?}", kind))?;

        self.current = Some(start_idx);
        Ok(())
    }

    pub async fn run(&mut self) -> Result<()> {
        use anyhow::anyhow;

        let mut current_idx = self
            .current
            .ok_or_else(|| anyhow!("Start node not set, call set_start_node() first"))?;

        loop {
            let node_id = match self.graph.node_weight(current_idx) {
                Some(id) => *id,
                None => break,
            };

            let node = self
                .nodes
                .get(&node_id)
                .ok_or_else(|| anyhow!("Node does not exist: {}", node_id))?;

            node.enter(&mut self.context).await?;

            let next_edge = self
                .graph
                .edges(current_idx)
                .find(|edge| (edge.weight())(&self.context));

            match next_edge {
                Some(edge) => {
                    current_idx = edge.target();
                }
                None => {
                    break;
                }
            }
        }

        self.current = Some(current_idx);
        Ok(())
    }
}

pub struct AgentEngineBuilder<Ctx, N>
where
    N: NodeKind,
{
    engine: AgentEngine<Ctx, N>,
    start_node: Option<N>,
}

impl<Ctx, N> AgentEngineBuilder<Ctx, N>
where
    N: NodeKind,
{
    fn new(context: Ctx) -> Self {
        Self {
            engine: AgentEngine::new(context),
            start_node: None,
        }
    }

    pub fn with_node(mut self, kind: N, node: impl NodeTrait<Ctx> + 'static) -> Self {
        self.engine.add_node(kind, node);
        self
    }

    pub fn with_transition(
        mut self,
        from: N,
        to: N,
        condition: impl Fn(&Ctx) -> bool + Send + Sync + 'static,
    ) -> Result<Self> {
        self.engine.add_transition(from, to, condition)?;
        Ok(self)
    }

    pub fn with_start_node(mut self, kind: N) -> Self {
        self.start_node = Some(kind);
        self
    }

    pub fn build(mut self) -> Result<AgentEngine<Ctx, N>> {
        if let Some(start_kind) = self.start_node {
            self.engine.set_start_node(start_kind)?;
        }
        Ok(self.engine)
    }
}