wesichain-agent 0.2.1

Rust-native LLM agents & chains with resumable ReAct workflows
Documentation
use crate::phase::{Acting, Completed, Idle, Interrupted, Observing, Thinking};
use crate::{
    validation, AgentError, AgentEvent, ModelAction, PolicyDecision, PolicyEngine, RepromptStrategy,
};
use tokio_util::sync::CancellationToken;

#[derive(Debug)]
pub enum ToolDispatchOutcome {
    Completed,
    Failed(AgentError),
}

pub fn emit_single_step_events(step_id: u32) -> Vec<AgentEvent> {
    vec![
        AgentEvent::StepStarted { step_id },
        AgentEvent::ModelResponded { step_id },
        AgentEvent::Completed { step_id },
    ]
}

pub fn emit_tool_step_events(step_id: u32, outcome: ToolDispatchOutcome) -> Vec<AgentEvent> {
    let mut events = vec![
        AgentEvent::StepStarted { step_id },
        AgentEvent::ModelResponded { step_id },
        AgentEvent::ToolDispatched { step_id },
    ];

    match outcome {
        ToolDispatchOutcome::Completed => events.push(AgentEvent::ToolCompleted { step_id }),
        ToolDispatchOutcome::Failed(error) => {
            events.push(AgentEvent::StepFailed { step_id, error })
        }
    }

    events
}

fn emit_tool_dispatch_events(step_id: u32) -> Vec<AgentEvent> {
    let mut events = emit_tool_step_events(step_id, ToolDispatchOutcome::Completed);
    let _ = events.pop();
    events
}

fn emit_tool_failure_event(step_id: u32, error: AgentError) -> AgentEvent {
    emit_tool_step_events(step_id, ToolDispatchOutcome::Failed(error))
        .pop()
        .expect("emit_tool_step_events always emits at least one event")
}

pub struct AgentRuntime<S, T, P, Phase> {
    remaining_budget: u32,
    cancellation: CancellationToken,
    _marker: std::marker::PhantomData<(S, T, P, Phase)>,
}

impl<S, T, P, Phase> std::fmt::Debug for AgentRuntime<S, T, P, Phase> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("AgentRuntime")
            .field("remaining_budget", &self.remaining_budget)
            .field("cancelled", &self.cancellation.is_cancelled())
            .finish()
    }
}

#[derive(Debug)]
pub enum LoopTransition<S, T, P> {
    Thinking {
        runtime: AgentRuntime<S, T, P, Thinking>,
        reprompt_strategy: Option<RepromptStrategy>,
    },
    Acting(AgentRuntime<S, T, P, Acting>),
    Observing(AgentRuntime<S, T, P, Observing>),
    Completed(AgentRuntime<S, T, P, Completed>),
    Interrupted(AgentRuntime<S, T, P, Interrupted>),
}

pub type TransitionWithEvents<S, T, P> = (LoopTransition<S, T, P>, Vec<AgentEvent>);

impl<S, T, P, Phase> AgentRuntime<S, T, P, Phase> {
    pub fn validate_model_action(
        step_id: u32,
        response: wesichain_core::LlmResponse,
        allowed_tools: &[String],
    ) -> Result<ModelAction, AgentError> {
        validation::validate_model_action(step_id, response, allowed_tools)
    }

    pub fn remaining_budget(&self) -> u32 {
        self.remaining_budget
    }

    fn map_model_transport_error(_error: wesichain_core::WesichainError) -> AgentError {
        AgentError::ModelTransport
    }

    fn cancellation_is_requested(&self) -> bool {
        self.cancellation.is_cancelled()
    }

    fn transition<NextPhase>(self) -> AgentRuntime<S, T, P, NextPhase> {
        AgentRuntime {
            remaining_budget: self.remaining_budget,
            cancellation: self.cancellation,
            _marker: std::marker::PhantomData,
        }
    }

    fn consume_budget(mut self, consume: bool) -> Result<Self, AgentError> {
        if !consume {
            return Ok(self);
        }

        if self.remaining_budget == 0 {
            return Err(AgentError::BudgetExceeded);
        }

        self.remaining_budget -= 1;
        Ok(self)
    }
}

impl<S, T, P> AgentRuntime<S, T, P, Idle> {
    pub fn new() -> Self {
        Self::with_budget(u32::MAX)
    }

    pub fn with_budget(remaining_budget: u32) -> Self {
        Self::with_budget_and_cancellation(remaining_budget, CancellationToken::new())
    }

    pub fn with_cancellation(cancellation: CancellationToken) -> Self {
        Self::with_budget_and_cancellation(u32::MAX, cancellation)
    }

    pub fn with_budget_and_cancellation(
        remaining_budget: u32,
        cancellation: CancellationToken,
    ) -> Self {
        Self {
            remaining_budget,
            cancellation,
            _marker: std::marker::PhantomData,
        }
    }

    pub fn think(self) -> AgentRuntime<S, T, P, Thinking> {
        self.transition()
    }

    pub fn begin_thinking(self) -> LoopTransition<S, T, P> {
        if self.cancellation_is_requested() {
            return LoopTransition::Interrupted(self.transition());
        }

        LoopTransition::Thinking {
            runtime: self.transition(),
            reprompt_strategy: None,
        }
    }
}

impl<S, T, P> Default for AgentRuntime<S, T, P, Idle> {
    fn default() -> Self {
        Self::new()
    }
}

impl<S, T, P> AgentRuntime<S, T, P, Thinking>
where
    P: PolicyEngine,
{
    pub fn act(self) -> AgentRuntime<S, T, P, Acting> {
        self.transition()
    }

    pub fn complete(self) -> AgentRuntime<S, T, P, Completed> {
        self.transition()
    }

    pub fn interrupt(self) -> AgentRuntime<S, T, P, Interrupted> {
        self.transition()
    }

    pub fn on_model_response(
        self,
        step_id: u32,
        response: wesichain_core::LlmResponse,
        allowed_tools: &[String],
    ) -> Result<LoopTransition<S, T, P>, AgentError> {
        self.on_model_response_with_events(step_id, response, allowed_tools)
            .map(|(transition, _events)| transition)
    }

    pub fn on_model_response_with_events(
        self,
        step_id: u32,
        response: wesichain_core::LlmResponse,
        allowed_tools: &[String],
    ) -> Result<TransitionWithEvents<S, T, P>, AgentError> {
        if self.cancellation_is_requested() {
            return Ok((LoopTransition::Interrupted(self.interrupt()), Vec::new()));
        }

        match Self::validate_model_action(step_id, response, allowed_tools) {
            Ok(ModelAction::ToolCall { .. }) => Ok((
                LoopTransition::Acting(self.act()),
                emit_tool_dispatch_events(step_id),
            )),
            Ok(ModelAction::FinalAnswer { .. }) => Ok((
                LoopTransition::Completed(self.complete()),
                emit_single_step_events(step_id),
            )),
            Err(error) => self.on_model_error_with_events(error),
        }
    }

    pub fn on_model_transport_error(
        self,
        error: wesichain_core::WesichainError,
    ) -> Result<LoopTransition<S, T, P>, AgentError> {
        self.on_model_transport_error_with_events(error)
            .map(|(transition, _events)| transition)
    }

    pub fn on_model_transport_error_with_events(
        self,
        error: wesichain_core::WesichainError,
    ) -> Result<TransitionWithEvents<S, T, P>, AgentError> {
        if self.cancellation_is_requested() {
            return Ok((LoopTransition::Interrupted(self.interrupt()), Vec::new()));
        }

        self.on_model_error_with_events(Self::map_model_transport_error(error))
    }

    fn on_model_error_with_events(
        self,
        error: AgentError,
    ) -> Result<TransitionWithEvents<S, T, P>, AgentError> {
        let decision = P::on_model_error(&error);
        self.apply_policy_decision_with_events(error, decision, Vec::new())
    }

    fn apply_policy_decision_with_events(
        self,
        error: AgentError,
        decision: PolicyDecision,
        events: Vec<AgentEvent>,
    ) -> Result<TransitionWithEvents<S, T, P>, AgentError> {
        match decision {
            PolicyDecision::Fail => Err(error),
            PolicyDecision::Interrupt => {
                Ok((LoopTransition::Interrupted(self.interrupt()), events))
            }
            PolicyDecision::Retry { consume_budget } => {
                let runtime = self.consume_budget(consume_budget)?;
                Ok((
                    LoopTransition::Thinking {
                        runtime,
                        reprompt_strategy: None,
                    },
                    events,
                ))
            }
            PolicyDecision::Reprompt {
                strategy,
                consume_budget,
            } => {
                let runtime = self.consume_budget(consume_budget)?;
                Ok((
                    LoopTransition::Thinking {
                        runtime,
                        reprompt_strategy: Some(strategy),
                    },
                    events,
                ))
            }
        }
    }
}

impl<S, T, P> AgentRuntime<S, T, P, Acting>
where
    P: PolicyEngine,
{
    pub fn observe(self) -> AgentRuntime<S, T, P, Observing> {
        self.transition()
    }

    pub fn on_tool_success(self) -> LoopTransition<S, T, P> {
        if self.cancellation_is_requested() {
            return LoopTransition::Interrupted(self.interrupt());
        }

        LoopTransition::Observing(self.observe())
    }

    pub fn on_tool_success_with_events(self, step_id: u32) -> TransitionWithEvents<S, T, P> {
        match self.on_tool_success() {
            LoopTransition::Observing(runtime) => (
                LoopTransition::Observing(runtime),
                vec![AgentEvent::ToolCompleted { step_id }],
            ),
            LoopTransition::Interrupted(runtime) => (
                LoopTransition::Interrupted(runtime),
                vec![emit_tool_failure_event(
                    step_id,
                    AgentError::PolicyRuntimeViolation,
                )],
            ),
            _ => unreachable!("on_tool_success only returns Observing or Interrupted"),
        }
    }

    pub fn interrupt(self) -> AgentRuntime<S, T, P, Interrupted> {
        self.transition()
    }

    pub fn on_tool_error(self, error: AgentError) -> Result<LoopTransition<S, T, P>, AgentError> {
        self.on_tool_error_internal(None, error)
            .map(|(transition, _events)| transition)
    }

    pub fn on_tool_error_with_events(
        self,
        step_id: u32,
        error: AgentError,
    ) -> Result<TransitionWithEvents<S, T, P>, AgentError> {
        self.on_tool_error_internal(Some(step_id), error)
    }

    fn on_tool_error_internal(
        self,
        step_id: Option<u32>,
        error: AgentError,
    ) -> Result<TransitionWithEvents<S, T, P>, AgentError> {
        let decision = P::on_tool_error(&error);
        let mut events = Vec::new();
        if let Some(step_id) = step_id {
            events.push(emit_tool_failure_event(step_id, error.clone()));
        }

        match decision {
            PolicyDecision::Fail => Err(error),
            PolicyDecision::Interrupt => {
                Ok((LoopTransition::Interrupted(self.interrupt()), events))
            }
            PolicyDecision::Retry { consume_budget } => {
                let runtime = self.consume_budget(consume_budget)?;
                Ok((
                    LoopTransition::Thinking {
                        runtime: runtime.transition(),
                        reprompt_strategy: None,
                    },
                    events,
                ))
            }
            PolicyDecision::Reprompt {
                strategy,
                consume_budget,
            } => {
                let runtime = self.consume_budget(consume_budget)?;
                Ok((
                    LoopTransition::Thinking {
                        runtime: runtime.transition(),
                        reprompt_strategy: Some(strategy),
                    },
                    events,
                ))
            }
        }
    }
}

impl<S, T, P> AgentRuntime<S, T, P, Observing> {
    pub fn think(self) -> AgentRuntime<S, T, P, Thinking> {
        self.transition()
    }
}