use crate::typed_id::{AgentId, MessageId, SessionId, TurnId};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct TurnContext {
pub session_id: SessionId,
pub turn_id: TurnId,
pub input_message_id: MessageId,
pub agent_id: AgentId,
pub org_id: i64,
}
impl TurnContext {
pub fn new(
session_id: SessionId,
input_message_id: MessageId,
agent_id: AgentId,
org_id: i64,
) -> Self {
Self {
session_id,
turn_id: TurnId::new(),
input_message_id,
agent_id,
org_id,
}
}
pub fn with_turn_id(
session_id: SessionId,
turn_id: TurnId,
input_message_id: MessageId,
agent_id: AgentId,
org_id: i64,
) -> Self {
Self {
session_id,
turn_id,
input_message_id,
agent_id,
org_id,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TurnPhase {
PendingInput,
PendingReason,
PendingAct,
Completed,
}
#[derive(Debug, Clone)]
pub enum TurnAction {
ExecuteInput,
ExecuteReason,
ExecuteAct,
Complete(TurnOutcome),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SealReason {
NoProgress,
Budget,
}
impl SealReason {
pub fn as_str(&self) -> &'static str {
match self {
SealReason::NoProgress => "no_progress",
SealReason::Budget => "budget",
}
}
pub fn from_str_lossy(s: &str) -> Self {
match s {
"budget" => SealReason::Budget,
_ => SealReason::NoProgress,
}
}
}
impl std::fmt::Display for SealReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ProgressToken(pub u64);
impl ProgressToken {
pub const ZERO: ProgressToken = ProgressToken(0);
pub fn from_event_sequence(highest_sequence: i64) -> Self {
ProgressToken((highest_sequence.max(-1) + 1) as u64)
}
pub fn advanced_from(&self, prev: ProgressToken) -> bool {
self.0 > prev.0
}
}
pub const DEFAULT_NO_PROGRESS_SEAL_THRESHOLD: u32 = 3;
pub fn no_progress_seal_threshold_from_env() -> u32 {
std::env::var("DURABLE_NO_PROGRESS_SEAL_THRESHOLD")
.ok()
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(DEFAULT_NO_PROGRESS_SEAL_THRESHOLD)
.max(1)
}
#[derive(Debug, Clone)]
pub enum TurnOutcome {
Success {
response: String,
iterations: usize,
tool_calls_count: usize,
},
Failed {
error: String,
iterations: usize,
},
MaxIterationsReached {
response: String,
iterations: usize,
tool_calls_count: usize,
},
Sealed {
reason: SealReason,
response: String,
iterations: usize,
tool_calls_count: usize,
},
}
impl TurnOutcome {
pub fn is_success(&self) -> bool {
matches!(self, TurnOutcome::Success { .. })
}
pub fn is_sealed(&self) -> bool {
matches!(self, TurnOutcome::Sealed { .. })
}
pub fn seal_reason(&self) -> Option<SealReason> {
match self {
TurnOutcome::Sealed { reason, .. } => Some(*reason),
_ => None,
}
}
pub fn response(&self) -> Option<&str> {
match self {
TurnOutcome::Success { response, .. } => Some(response),
TurnOutcome::MaxIterationsReached { response, .. } => Some(response),
TurnOutcome::Sealed { response, .. } => Some(response),
TurnOutcome::Failed { .. } => None,
}
}
pub fn error(&self) -> Option<&str> {
match self {
TurnOutcome::Failed { error, .. } => Some(error),
_ => None,
}
}
pub fn iterations(&self) -> usize {
match self {
TurnOutcome::Success { iterations, .. } => *iterations,
TurnOutcome::Failed { iterations, .. } => *iterations,
TurnOutcome::MaxIterationsReached { iterations, .. } => *iterations,
TurnOutcome::Sealed { iterations, .. } => *iterations,
}
}
}
#[derive(Debug)]
pub struct TurnStateMachine {
context: TurnContext,
phase: TurnPhase,
max_iterations: usize,
current_iteration: usize,
total_tool_calls: usize,
last_response: String,
pending_error: Option<String>,
has_pending_tool_calls: bool,
pending_seal: Option<SealReason>,
}
impl TurnStateMachine {
pub fn new(context: TurnContext, max_iterations: usize) -> Self {
Self {
context,
phase: TurnPhase::PendingInput,
max_iterations,
current_iteration: 0,
total_tool_calls: 0,
last_response: String::new(),
pending_error: None,
has_pending_tool_calls: false,
pending_seal: None,
}
}
pub fn context(&self) -> &TurnContext {
&self.context
}
pub fn phase(&self) -> TurnPhase {
self.phase
}
pub fn current_iteration(&self) -> usize {
self.current_iteration
}
pub fn total_tool_calls(&self) -> usize {
self.total_tool_calls
}
pub fn next_action(&self) -> TurnAction {
match self.phase {
TurnPhase::PendingInput => TurnAction::ExecuteInput,
TurnPhase::PendingReason => TurnAction::ExecuteReason,
TurnPhase::PendingAct => TurnAction::ExecuteAct,
TurnPhase::Completed => {
if let Some(reason) = self.pending_seal {
return TurnAction::Complete(TurnOutcome::Sealed {
reason,
response: self.last_response.clone(),
iterations: self.current_iteration,
tool_calls_count: self.total_tool_calls,
});
}
if let Some(error) = &self.pending_error {
TurnAction::Complete(TurnOutcome::Failed {
error: error.clone(),
iterations: self.current_iteration,
})
} else if self.current_iteration >= self.max_iterations {
TurnAction::Complete(TurnOutcome::MaxIterationsReached {
response: self.last_response.clone(),
iterations: self.current_iteration,
tool_calls_count: self.total_tool_calls,
})
} else {
TurnAction::Complete(TurnOutcome::Success {
response: self.last_response.clone(),
iterations: self.current_iteration,
tool_calls_count: self.total_tool_calls,
})
}
}
}
}
pub fn on_input_completed(&mut self) {
debug_assert_eq!(self.phase, TurnPhase::PendingInput);
self.phase = TurnPhase::PendingReason;
}
pub fn on_reason_completed(
&mut self,
response: String,
has_tool_calls: bool,
tool_call_count: usize,
success: bool,
error: Option<String>,
has_pending_user_messages: bool,
) {
debug_assert_eq!(self.phase, TurnPhase::PendingReason);
self.current_iteration += 1;
if !response.is_empty() {
self.last_response = response;
}
if !success {
self.pending_error = error;
self.phase = TurnPhase::Completed;
return;
}
if has_tool_calls && tool_call_count > 0 {
if self.current_iteration >= self.max_iterations {
self.phase = TurnPhase::Completed;
return;
}
self.has_pending_tool_calls = true;
self.total_tool_calls += tool_call_count;
self.phase = TurnPhase::PendingAct;
} else if has_pending_user_messages {
if self.current_iteration >= self.max_iterations {
self.phase = TurnPhase::Completed;
} else {
self.phase = TurnPhase::PendingReason;
}
} else {
self.phase = TurnPhase::Completed;
}
}
pub fn on_act_completed(&mut self) {
debug_assert_eq!(self.phase, TurnPhase::PendingAct);
self.has_pending_tool_calls = false;
self.phase = TurnPhase::PendingReason;
}
pub fn seal(&mut self, reason: SealReason) {
if self.pending_seal.is_none() {
self.pending_seal = Some(reason);
}
self.phase = TurnPhase::Completed;
}
pub fn is_completed(&self) -> bool {
self.phase == TurnPhase::Completed
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_context() -> TurnContext {
TurnContext::new(SessionId::new(), MessageId::new(), AgentId::new(), 0)
}
#[test]
fn test_simple_turn_no_tools() {
let mut sm = TurnStateMachine::new(test_context(), 10);
assert!(matches!(sm.next_action(), TurnAction::ExecuteInput));
sm.on_input_completed();
assert!(matches!(sm.next_action(), TurnAction::ExecuteReason));
sm.on_reason_completed("Hello!".to_string(), false, 0, true, None, false);
match sm.next_action() {
TurnAction::Complete(TurnOutcome::Success {
response,
iterations,
tool_calls_count,
}) => {
assert_eq!(response, "Hello!");
assert_eq!(iterations, 1);
assert_eq!(tool_calls_count, 0);
}
other => panic!("Expected Success, got {:?}", other),
}
}
#[test]
fn test_turn_with_one_tool_call() {
let mut sm = TurnStateMachine::new(test_context(), 10);
assert!(matches!(sm.next_action(), TurnAction::ExecuteInput));
sm.on_input_completed();
assert!(matches!(sm.next_action(), TurnAction::ExecuteReason));
sm.on_reason_completed("Let me check...".to_string(), true, 1, true, None, false);
assert!(matches!(sm.next_action(), TurnAction::ExecuteAct));
sm.on_act_completed();
assert!(matches!(sm.next_action(), TurnAction::ExecuteReason));
sm.on_reason_completed(
"Here's the result.".to_string(),
false,
0,
true,
None,
false,
);
match sm.next_action() {
TurnAction::Complete(TurnOutcome::Success {
response,
iterations,
tool_calls_count,
}) => {
assert_eq!(response, "Here's the result.");
assert_eq!(iterations, 2);
assert_eq!(tool_calls_count, 1);
}
other => panic!("Expected Success, got {:?}", other),
}
}
#[test]
fn test_max_iterations() {
let mut sm = TurnStateMachine::new(test_context(), 2);
sm.on_input_completed();
sm.on_reason_completed("Trying...".to_string(), true, 1, true, None, false);
sm.on_act_completed();
sm.on_reason_completed("Still trying...".to_string(), true, 1, true, None, false);
match sm.next_action() {
TurnAction::Complete(TurnOutcome::MaxIterationsReached { iterations, .. }) => {
assert_eq!(iterations, 2);
}
other => panic!("Expected MaxIterationsReached, got {:?}", other),
}
}
#[test]
fn test_reason_failure() {
let mut sm = TurnStateMachine::new(test_context(), 10);
sm.on_input_completed();
sm.on_reason_completed(
String::new(),
false,
0,
false,
Some("LLM error".to_string()),
false,
);
match sm.next_action() {
TurnAction::Complete(TurnOutcome::Failed { error, .. }) => {
assert_eq!(error, "LLM error");
}
other => panic!("Expected Failed, got {:?}", other),
}
}
#[test]
fn test_context_preserved() {
let context = TurnContext::new(SessionId::new(), MessageId::new(), AgentId::new(), 42);
let turn_id = context.turn_id;
let sm = TurnStateMachine::new(context, 10);
assert_eq!(sm.context().turn_id, turn_id);
assert_eq!(sm.context().org_id, 42);
}
#[test]
fn test_outcome_helpers() {
let success = TurnOutcome::Success {
response: "test".to_string(),
iterations: 1,
tool_calls_count: 0,
};
assert!(success.is_success());
assert_eq!(success.response(), Some("test"));
assert!(success.error().is_none());
let failed = TurnOutcome::Failed {
error: "oops".to_string(),
iterations: 0,
};
assert!(!failed.is_success());
assert!(failed.response().is_none());
assert_eq!(failed.error(), Some("oops"));
}
#[test]
fn test_pending_user_message_continues_turn() {
let mut sm = TurnStateMachine::new(test_context(), 10);
sm.on_input_completed();
sm.on_reason_completed("Hello!".to_string(), false, 0, true, None, true);
assert!(!sm.is_completed());
assert_eq!(sm.phase(), TurnPhase::PendingReason);
assert!(matches!(sm.next_action(), TurnAction::ExecuteReason));
sm.on_reason_completed("Got your message!".to_string(), false, 0, true, None, false);
match sm.next_action() {
TurnAction::Complete(TurnOutcome::Success {
response,
iterations,
..
}) => {
assert_eq!(response, "Got your message!");
assert_eq!(iterations, 2);
}
other => panic!("Expected Success, got {:?}", other),
}
}
#[test]
fn test_pending_messages_ignored_on_failure() {
let mut sm = TurnStateMachine::new(test_context(), 10);
sm.on_input_completed();
sm.on_reason_completed(
String::new(),
false,
0,
false,
Some("LLM error".to_string()),
true,
);
assert!(sm.is_completed());
assert!(matches!(
sm.next_action(),
TurnAction::Complete(TurnOutcome::Failed { .. })
));
}
#[test]
fn test_progress_token_monotonicity() {
let t0 = ProgressToken::from_event_sequence(-1); let t1 = ProgressToken::from_event_sequence(0);
let t2 = ProgressToken::from_event_sequence(5);
assert_eq!(t0, ProgressToken::ZERO);
assert!(t1.advanced_from(t0));
assert!(t2.advanced_from(t1));
let t2_again = ProgressToken::from_event_sequence(5);
assert!(!t2_again.advanced_from(t2));
assert!(!t2.advanced_from(t2_again));
assert!(t2 > t1 && t1 > t0);
}
#[test]
fn test_no_progress_counter_logic() {
let threshold = 3u32;
let mut recorded = ProgressToken::ZERO;
let mut no_progress = 0u32;
let step =
|recorded: &mut ProgressToken, no_progress: &mut u32, observed: ProgressToken| {
if observed.advanced_from(*recorded) {
*recorded = observed;
*no_progress = 0;
} else {
*no_progress += 1;
}
*no_progress
};
assert_eq!(
step(&mut recorded, &mut no_progress, ProgressToken::ZERO),
1
);
assert_eq!(
step(&mut recorded, &mut no_progress, ProgressToken::ZERO),
2
);
assert_eq!(
step(
&mut recorded,
&mut no_progress,
ProgressToken::from_event_sequence(2)
),
0
);
let stuck = ProgressToken::from_event_sequence(2);
assert_eq!(step(&mut recorded, &mut no_progress, stuck), 1);
assert_eq!(step(&mut recorded, &mut no_progress, stuck), 2);
let count = step(&mut recorded, &mut no_progress, stuck);
assert_eq!(count, 3);
assert!(count >= threshold, "should seal once threshold reached");
}
#[test]
fn test_seal_threshold_env_never_zero() {
assert_eq!(DEFAULT_NO_PROGRESS_SEAL_THRESHOLD, 3);
}
#[test]
fn test_budget_seal_outcome() {
let mut sm = TurnStateMachine::new(test_context(), 10);
sm.on_input_completed();
sm.on_reason_completed("Working...".to_string(), true, 1, true, None, false);
sm.on_act_completed();
sm.seal(SealReason::Budget);
assert!(sm.is_completed());
match sm.next_action() {
TurnAction::Complete(TurnOutcome::Sealed {
reason, iterations, ..
}) => {
assert_eq!(reason, SealReason::Budget);
assert_eq!(iterations, 1);
}
other => panic!("Expected Sealed, got {:?}", other),
}
}
#[test]
fn test_seal_takes_precedence_over_error() {
let mut sm = TurnStateMachine::new(test_context(), 10);
sm.on_input_completed();
sm.on_reason_completed(
String::new(),
false,
0,
false,
Some("LLM error".to_string()),
false,
);
sm.seal(SealReason::NoProgress);
match sm.next_action() {
TurnAction::Complete(TurnOutcome::Sealed { reason, .. }) => {
assert_eq!(reason, SealReason::NoProgress);
}
other => panic!("Expected Sealed, got {:?}", other),
}
}
#[test]
fn test_seal_reason_wire_roundtrip() {
assert_eq!(SealReason::NoProgress.as_str(), "no_progress");
assert_eq!(SealReason::Budget.as_str(), "budget");
assert_eq!(SealReason::from_str_lossy("budget"), SealReason::Budget);
assert_eq!(
SealReason::from_str_lossy("no_progress"),
SealReason::NoProgress
);
assert_eq!(
SealReason::from_str_lossy("future_reason"),
SealReason::NoProgress
);
}
#[test]
fn test_outcome_sealed_helpers() {
let sealed = TurnOutcome::Sealed {
reason: SealReason::Budget,
response: "partial".to_string(),
iterations: 2,
tool_calls_count: 1,
};
assert!(sealed.is_sealed());
assert!(!sealed.is_success());
assert_eq!(sealed.seal_reason(), Some(SealReason::Budget));
assert_eq!(sealed.response(), Some("partial"));
assert!(sealed.error().is_none());
assert_eq!(sealed.iterations(), 2);
}
#[test]
fn test_pending_messages_ignored_when_tool_calls() {
let mut sm = TurnStateMachine::new(test_context(), 10);
sm.on_input_completed();
sm.on_reason_completed("Working...".to_string(), true, 2, true, None, true);
assert_eq!(sm.phase(), TurnPhase::PendingAct);
}
}