use crate::typed_id::{AgentId, MessageId, SessionId, TurnId};
#[derive(Debug, Clone)]
pub struct TurnContext {
pub session_id: SessionId,
pub turn_id: TurnId,
pub input_message_id: MessageId,
pub agent_id: AgentId,
pub org_id: i64,
}
impl TurnContext {
pub fn new(
session_id: SessionId,
input_message_id: MessageId,
agent_id: AgentId,
org_id: i64,
) -> Self {
Self {
session_id,
turn_id: TurnId::new(),
input_message_id,
agent_id,
org_id,
}
}
pub fn with_turn_id(
session_id: SessionId,
turn_id: TurnId,
input_message_id: MessageId,
agent_id: AgentId,
org_id: i64,
) -> Self {
Self {
session_id,
turn_id,
input_message_id,
agent_id,
org_id,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TurnPhase {
PendingInput,
PendingReason,
PendingAct,
Completed,
}
#[derive(Debug, Clone)]
pub enum TurnAction {
ExecuteInput,
ExecuteReason,
ExecuteAct,
Complete(TurnOutcome),
}
#[derive(Debug, Clone)]
pub enum TurnOutcome {
Success {
response: String,
iterations: usize,
tool_calls_count: usize,
},
Failed {
error: String,
iterations: usize,
},
MaxIterationsReached {
response: String,
iterations: usize,
tool_calls_count: usize,
},
}
impl TurnOutcome {
pub fn is_success(&self) -> bool {
matches!(self, TurnOutcome::Success { .. })
}
pub fn response(&self) -> Option<&str> {
match self {
TurnOutcome::Success { response, .. } => Some(response),
TurnOutcome::MaxIterationsReached { response, .. } => Some(response),
TurnOutcome::Failed { .. } => None,
}
}
pub fn error(&self) -> Option<&str> {
match self {
TurnOutcome::Failed { error, .. } => Some(error),
_ => None,
}
}
pub fn iterations(&self) -> usize {
match self {
TurnOutcome::Success { iterations, .. } => *iterations,
TurnOutcome::Failed { iterations, .. } => *iterations,
TurnOutcome::MaxIterationsReached { iterations, .. } => *iterations,
}
}
}
#[derive(Debug)]
pub struct TurnStateMachine {
context: TurnContext,
phase: TurnPhase,
max_iterations: usize,
current_iteration: usize,
total_tool_calls: usize,
last_response: String,
pending_error: Option<String>,
has_pending_tool_calls: bool,
}
impl TurnStateMachine {
pub fn new(context: TurnContext, max_iterations: usize) -> Self {
Self {
context,
phase: TurnPhase::PendingInput,
max_iterations,
current_iteration: 0,
total_tool_calls: 0,
last_response: String::new(),
pending_error: None,
has_pending_tool_calls: false,
}
}
pub fn context(&self) -> &TurnContext {
&self.context
}
pub fn phase(&self) -> TurnPhase {
self.phase
}
pub fn current_iteration(&self) -> usize {
self.current_iteration
}
pub fn total_tool_calls(&self) -> usize {
self.total_tool_calls
}
pub fn next_action(&self) -> TurnAction {
match self.phase {
TurnPhase::PendingInput => TurnAction::ExecuteInput,
TurnPhase::PendingReason => TurnAction::ExecuteReason,
TurnPhase::PendingAct => TurnAction::ExecuteAct,
TurnPhase::Completed => {
if let Some(error) = &self.pending_error {
TurnAction::Complete(TurnOutcome::Failed {
error: error.clone(),
iterations: self.current_iteration,
})
} else if self.current_iteration >= self.max_iterations {
TurnAction::Complete(TurnOutcome::MaxIterationsReached {
response: self.last_response.clone(),
iterations: self.current_iteration,
tool_calls_count: self.total_tool_calls,
})
} else {
TurnAction::Complete(TurnOutcome::Success {
response: self.last_response.clone(),
iterations: self.current_iteration,
tool_calls_count: self.total_tool_calls,
})
}
}
}
}
pub fn on_input_completed(&mut self) {
debug_assert_eq!(self.phase, TurnPhase::PendingInput);
self.phase = TurnPhase::PendingReason;
}
pub fn on_reason_completed(
&mut self,
response: String,
has_tool_calls: bool,
tool_call_count: usize,
success: bool,
error: Option<String>,
has_pending_user_messages: bool,
) {
debug_assert_eq!(self.phase, TurnPhase::PendingReason);
self.current_iteration += 1;
if !response.is_empty() {
self.last_response = response;
}
if !success {
self.pending_error = error;
self.phase = TurnPhase::Completed;
return;
}
if has_tool_calls && tool_call_count > 0 {
if self.current_iteration >= self.max_iterations {
self.phase = TurnPhase::Completed;
return;
}
self.has_pending_tool_calls = true;
self.total_tool_calls += tool_call_count;
self.phase = TurnPhase::PendingAct;
} else if has_pending_user_messages {
if self.current_iteration >= self.max_iterations {
self.phase = TurnPhase::Completed;
} else {
self.phase = TurnPhase::PendingReason;
}
} else {
self.phase = TurnPhase::Completed;
}
}
pub fn on_act_completed(&mut self) {
debug_assert_eq!(self.phase, TurnPhase::PendingAct);
self.has_pending_tool_calls = false;
self.phase = TurnPhase::PendingReason;
}
pub fn is_completed(&self) -> bool {
self.phase == TurnPhase::Completed
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_context() -> TurnContext {
TurnContext::new(SessionId::new(), MessageId::new(), AgentId::new(), 0)
}
#[test]
fn test_simple_turn_no_tools() {
let mut sm = TurnStateMachine::new(test_context(), 10);
assert!(matches!(sm.next_action(), TurnAction::ExecuteInput));
sm.on_input_completed();
assert!(matches!(sm.next_action(), TurnAction::ExecuteReason));
sm.on_reason_completed("Hello!".to_string(), false, 0, true, None, false);
match sm.next_action() {
TurnAction::Complete(TurnOutcome::Success {
response,
iterations,
tool_calls_count,
}) => {
assert_eq!(response, "Hello!");
assert_eq!(iterations, 1);
assert_eq!(tool_calls_count, 0);
}
other => panic!("Expected Success, got {:?}", other),
}
}
#[test]
fn test_turn_with_one_tool_call() {
let mut sm = TurnStateMachine::new(test_context(), 10);
assert!(matches!(sm.next_action(), TurnAction::ExecuteInput));
sm.on_input_completed();
assert!(matches!(sm.next_action(), TurnAction::ExecuteReason));
sm.on_reason_completed("Let me check...".to_string(), true, 1, true, None, false);
assert!(matches!(sm.next_action(), TurnAction::ExecuteAct));
sm.on_act_completed();
assert!(matches!(sm.next_action(), TurnAction::ExecuteReason));
sm.on_reason_completed(
"Here's the result.".to_string(),
false,
0,
true,
None,
false,
);
match sm.next_action() {
TurnAction::Complete(TurnOutcome::Success {
response,
iterations,
tool_calls_count,
}) => {
assert_eq!(response, "Here's the result.");
assert_eq!(iterations, 2);
assert_eq!(tool_calls_count, 1);
}
other => panic!("Expected Success, got {:?}", other),
}
}
#[test]
fn test_max_iterations() {
let mut sm = TurnStateMachine::new(test_context(), 2);
sm.on_input_completed();
sm.on_reason_completed("Trying...".to_string(), true, 1, true, None, false);
sm.on_act_completed();
sm.on_reason_completed("Still trying...".to_string(), true, 1, true, None, false);
match sm.next_action() {
TurnAction::Complete(TurnOutcome::MaxIterationsReached { iterations, .. }) => {
assert_eq!(iterations, 2);
}
other => panic!("Expected MaxIterationsReached, got {:?}", other),
}
}
#[test]
fn test_reason_failure() {
let mut sm = TurnStateMachine::new(test_context(), 10);
sm.on_input_completed();
sm.on_reason_completed(
String::new(),
false,
0,
false,
Some("LLM error".to_string()),
false,
);
match sm.next_action() {
TurnAction::Complete(TurnOutcome::Failed { error, .. }) => {
assert_eq!(error, "LLM error");
}
other => panic!("Expected Failed, got {:?}", other),
}
}
#[test]
fn test_context_preserved() {
let context = TurnContext::new(SessionId::new(), MessageId::new(), AgentId::new(), 42);
let turn_id = context.turn_id;
let sm = TurnStateMachine::new(context, 10);
assert_eq!(sm.context().turn_id, turn_id);
assert_eq!(sm.context().org_id, 42);
}
#[test]
fn test_outcome_helpers() {
let success = TurnOutcome::Success {
response: "test".to_string(),
iterations: 1,
tool_calls_count: 0,
};
assert!(success.is_success());
assert_eq!(success.response(), Some("test"));
assert!(success.error().is_none());
let failed = TurnOutcome::Failed {
error: "oops".to_string(),
iterations: 0,
};
assert!(!failed.is_success());
assert!(failed.response().is_none());
assert_eq!(failed.error(), Some("oops"));
}
#[test]
fn test_pending_user_message_continues_turn() {
let mut sm = TurnStateMachine::new(test_context(), 10);
sm.on_input_completed();
sm.on_reason_completed("Hello!".to_string(), false, 0, true, None, true);
assert!(!sm.is_completed());
assert_eq!(sm.phase(), TurnPhase::PendingReason);
assert!(matches!(sm.next_action(), TurnAction::ExecuteReason));
sm.on_reason_completed("Got your message!".to_string(), false, 0, true, None, false);
match sm.next_action() {
TurnAction::Complete(TurnOutcome::Success {
response,
iterations,
..
}) => {
assert_eq!(response, "Got your message!");
assert_eq!(iterations, 2);
}
other => panic!("Expected Success, got {:?}", other),
}
}
#[test]
fn test_pending_messages_ignored_on_failure() {
let mut sm = TurnStateMachine::new(test_context(), 10);
sm.on_input_completed();
sm.on_reason_completed(
String::new(),
false,
0,
false,
Some("LLM error".to_string()),
true,
);
assert!(sm.is_completed());
assert!(matches!(
sm.next_action(),
TurnAction::Complete(TurnOutcome::Failed { .. })
));
}
#[test]
fn test_pending_messages_ignored_when_tool_calls() {
let mut sm = TurnStateMachine::new(test_context(), 10);
sm.on_input_completed();
sm.on_reason_completed("Working...".to_string(), true, 2, true, None, true);
assert_eq!(sm.phase(), TurnPhase::PendingAct);
}
}