use crate::provider::FinishReason;
use super::run::RunStatus;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TurnState {
CheckingPolicy,
BuildingContext,
CallingModel,
Compacting,
ProcessingResponse,
ExecutingTools,
Persisting,
}
#[derive(Debug, Clone)]
pub enum TurnOutcome {
Success,
PolicyExceeded {
reason: String,
},
ContextOverflow,
ProviderError {
message: String,
},
PipelineError {
message: String,
},
NoToolCalls,
NotToolCalls {
finish_reason: FinishReason,
},
ToolFailure {
message: String,
},
OutputTruncated,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TurnAction {
Continue {
next: TurnState,
},
BreakLoop,
CompactAndRetry,
Fail {
reason: String,
},
}
pub struct TurnTransition;
impl TurnTransition {
#[must_use]
pub fn resolve(state: TurnState, outcome: &TurnOutcome) -> TurnAction {
match (state, outcome) {
(TurnState::CheckingPolicy, TurnOutcome::PolicyExceeded { reason }) => {
TurnAction::Fail {
reason: reason.clone(),
}
}
(TurnState::CheckingPolicy, TurnOutcome::Success) => TurnAction::Continue {
next: TurnState::BuildingContext,
},
(TurnState::BuildingContext | TurnState::Compacting, TurnOutcome::Success)
| (TurnState::ProcessingResponse, TurnOutcome::OutputTruncated) => {
TurnAction::Continue {
next: TurnState::CallingModel,
}
}
(
TurnState::BuildingContext | TurnState::Compacting,
TurnOutcome::PipelineError { .. },
) => TurnAction::Fail {
reason: "context build failed".to_string(),
},
(TurnState::CallingModel, TurnOutcome::Success) => TurnAction::Continue {
next: TurnState::ProcessingResponse,
},
(TurnState::CallingModel, TurnOutcome::ContextOverflow) => TurnAction::CompactAndRetry,
(TurnState::CallingModel, TurnOutcome::ProviderError { .. }) => TurnAction::Fail {
reason: "provider error".to_string(),
},
(
TurnState::ProcessingResponse,
TurnOutcome::NoToolCalls | TurnOutcome::NotToolCalls { .. },
) => TurnAction::BreakLoop,
(TurnState::ProcessingResponse, TurnOutcome::Success) => TurnAction::Continue {
next: TurnState::ExecutingTools,
},
(TurnState::ExecutingTools, TurnOutcome::Success) => TurnAction::Continue {
next: TurnState::Persisting,
},
(TurnState::ExecutingTools, TurnOutcome::ToolFailure { .. }) => TurnAction::Fail {
reason: "tool execution failed".to_string(),
},
(TurnState::Persisting, TurnOutcome::Success) => TurnAction::Continue {
next: TurnState::CheckingPolicy,
},
(TurnState::Persisting, TurnOutcome::PipelineError { .. }) => TurnAction::Fail {
reason: "persistence failed".to_string(),
},
(_, outcome) => TurnAction::Fail {
reason: format!("unexpected outcome {outcome:?} in state {state:?}"),
},
}
}
#[must_use]
pub fn status_for(state: TurnState) -> RunStatus {
match state {
TurnState::CheckingPolicy | TurnState::Compacting | TurnState::BuildingContext => {
RunStatus::BuildingContext
}
TurnState::CallingModel | TurnState::ProcessingResponse => RunStatus::CallingModel,
TurnState::ExecutingTools => RunStatus::WaitingForTools,
TurnState::Persisting => RunStatus::Persisting,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn policy_exceeded_fails() {
let action = TurnTransition::resolve(
TurnState::CheckingPolicy,
&TurnOutcome::PolicyExceeded {
reason: "iteration".into(),
},
);
assert_eq!(
action,
TurnAction::Fail {
reason: "iteration".into()
}
);
}
#[test]
fn policy_ok_proceeds_to_context() {
let action = TurnTransition::resolve(TurnState::CheckingPolicy, &TurnOutcome::Success);
assert_eq!(
action,
TurnAction::Continue {
next: TurnState::BuildingContext
}
);
}
#[test]
fn context_overflow_compacts_and_retries() {
let action =
TurnTransition::resolve(TurnState::CallingModel, &TurnOutcome::ContextOverflow);
assert_eq!(action, TurnAction::CompactAndRetry);
}
#[test]
fn model_success_moves_to_processing() {
let action = TurnTransition::resolve(TurnState::CallingModel, &TurnOutcome::Success);
assert_eq!(
action,
TurnAction::Continue {
next: TurnState::ProcessingResponse
}
);
}
#[test]
fn provider_error_fails() {
let action = TurnTransition::resolve(
TurnState::CallingModel,
&TurnOutcome::ProviderError {
message: "boom".into(),
},
);
assert_eq!(
action,
TurnAction::Fail {
reason: "provider error".into()
}
);
}
#[test]
fn compaction_success_returns_to_model() {
let action = TurnTransition::resolve(TurnState::Compacting, &TurnOutcome::Success);
assert_eq!(
action,
TurnAction::Continue {
next: TurnState::CallingModel
}
);
}
#[test]
fn no_tool_calls_breaks_loop() {
let action =
TurnTransition::resolve(TurnState::ProcessingResponse, &TurnOutcome::NoToolCalls);
assert_eq!(action, TurnAction::BreakLoop);
}
#[test]
fn not_tool_calls_finish_breaks() {
let action = TurnTransition::resolve(
TurnState::ProcessingResponse,
&TurnOutcome::NotToolCalls {
finish_reason: FinishReason::Stop,
},
);
assert_eq!(action, TurnAction::BreakLoop);
}
#[test]
fn tool_calls_moves_to_execution() {
let action = TurnTransition::resolve(TurnState::ProcessingResponse, &TurnOutcome::Success);
assert_eq!(
action,
TurnAction::Continue {
next: TurnState::ExecutingTools
}
);
}
#[test]
fn output_truncated_continues_to_model() {
let action =
TurnTransition::resolve(TurnState::ProcessingResponse, &TurnOutcome::OutputTruncated);
assert_eq!(
action,
TurnAction::Continue {
next: TurnState::CallingModel
}
);
}
#[test]
fn persisting_success_returns_to_check() {
let action = TurnTransition::resolve(TurnState::Persisting, &TurnOutcome::Success);
assert_eq!(
action,
TurnAction::Continue {
next: TurnState::CheckingPolicy
}
);
}
#[test]
fn status_maps_correctly() {
assert_eq!(
TurnTransition::status_for(TurnState::BuildingContext),
RunStatus::BuildingContext
);
assert_eq!(
TurnTransition::status_for(TurnState::CallingModel),
RunStatus::CallingModel
);
assert_eq!(
TurnTransition::status_for(TurnState::ExecutingTools),
RunStatus::WaitingForTools
);
assert_eq!(
TurnTransition::status_for(TurnState::Persisting),
RunStatus::Persisting
);
}
}