Skip to main content

bob_runtime/
typestate.rs

1//! # Typestate Agent Runner
2//!
3//! Exposes the agent turn FSM as an explicit typestate machine, enabling:
4//!
5//! - **Step-by-step execution**: developers can manually advance the agent
6//! - **Human-in-the-loop**: pause before tool execution for approval
7//! - **Distributed suspend/resume**: serialize state between steps
8//! - **Compile-time safety**: invalid state transitions are rejected at build time
9//!
10//! ## State Machine
11//!
12//! ```text
13//! Ready ──infer()──▶ AwaitingToolCall
14//!   │                     │
15//!   │                     ├──provide_tool_results()──▶ Ready
16//!   │                     │
17//!   │                     └──require_approval()──▶ AwaitingApproval
18//!   │                                                  │
19//!   │                                                  ├──approve()──▶ AwaitingToolCall
20//!   │                                                  │
21//!   │                                                  └──deny()──▶ Finished
22//!   │
23//!   └──infer()──▶ Finished (when no tool calls)
24//! ```
25//!
26//! ## Example
27//!
28//! ```rust,ignore
29//! use bob_runtime::typestate::*;
30//!
31//! // High-level: run to completion
32//! let result = AgentRunner::new(context, llm, tools)
33//!     .run_to_completion()
34//!     .await?;
35//!
36//! // Low-level: step-by-step control
37//! let runner = AgentRunner::new(context, llm, tools);
38//! let step = runner.infer().await?;
39//! match step {
40//!     AgentStepResult::Finished(runner) => {
41//!         println!("done: {}", runner.response().content);
42//!     }
43//!     AgentStepResult::RequiresTool(runner) => {
44//!         // Human-in-the-loop: ask for approval
45//!         let results = execute_tools_with_approval(runner.pending_calls()).await;
46//!         let runner = runner.provide_tool_results(results);
47//!         // Continue to next step...
48//!     }
49//! }
50//! ```
51
52use bob_core::{
53    error::AgentError,
54    ports::{ContextCompactorPort, LlmPort, SessionStore, ToolPort},
55    types::{
56        AgentResponse, FinishReason, LlmRequest, Message, Role, SessionState, TokenUsage, ToolCall,
57        ToolResult, TurnPolicy,
58    },
59};
60
61// ── State Markers ────────────────────────────────────────────────────
62
63/// Agent is ready to perform LLM inference.
64///
65/// This is the initial state and the state after tool results are provided.
66/// Only valid transitions: `infer()` → `AwaitingToolCall` or `Finished`.
67#[derive(Debug)]
68pub struct Ready;
69
70/// Agent has requested tool calls and is waiting for results.
71///
72/// This state holds the pending tool calls that need to be executed.
73/// Valid transitions:
74/// - `provide_tool_results()` → `Ready` (continue execution)
75/// - `require_approval()` → `AwaitingApproval` (human-in-the-loop)
76/// - `cancel()` → `Finished` (abort execution)
77#[derive(Debug)]
78pub struct AwaitingToolCall {
79    /// Pending tool calls requested by the LLM.
80    pub pending_calls: Vec<ToolCall>,
81    /// Native tool call IDs (if using provider-native tool calling).
82    pub call_ids: Vec<Option<String>>,
83}
84
85/// Agent is waiting for human approval before executing tool calls.
86///
87/// This state is entered when `require_approval()` is called on
88/// `AwaitingToolCall`. It provides compile-time safety by preventing
89/// direct tool execution without explicit approval.
90///
91/// Valid transitions:
92/// - `approve()` → `AwaitingToolCall` (proceed with execution)
93/// - `deny()` → `Finished` (abort with reason)
94#[derive(Debug)]
95pub struct AwaitingApproval {
96    /// Pending tool calls awaiting approval.
97    pub pending_calls: Vec<ToolCall>,
98    /// Native tool call IDs.
99    pub call_ids: Vec<Option<String>>,
100    /// Reason for requiring approval (for audit trail).
101    pub reason: String,
102}
103
104/// Agent has finished execution with a final response.
105///
106/// Terminal state — no further transitions are possible.
107/// Access the final response via `response()` or `into_response()`.
108#[derive(Debug)]
109pub struct Finished {
110    /// The final agent response.
111    pub response: AgentResponse,
112}
113
114// ── Agent Runner ─────────────────────────────────────────────────────
115
116/// Typestate-parameterized agent runner.
117///
118/// The type parameter `S` encodes the current FSM state:
119/// - [`Ready`] — can call [`infer`](AgentRunner<Ready>::infer)
120/// - [`AwaitingToolCall`] — can call
121///   [`provide_tool_results`](AgentRunner<AwaitingToolCall>::provide_tool_results)
122/// - [`Finished`] — can access the final [`response`](AgentRunner<Finished>::response)
123#[derive(Debug)]
124pub struct AgentRunner<S> {
125    state: S,
126    session: SessionState,
127    context: RunnerContext,
128}
129
130/// Immutable execution context carried through the typestate machine.
131#[derive(Debug, Clone)]
132pub struct RunnerContext {
133    pub session_id: String,
134    pub model: String,
135    pub system_instructions: String,
136    pub policy: TurnPolicy,
137    pub steps_taken: u32,
138    pub tool_calls_made: u32,
139    pub total_usage: TokenUsage,
140    pub tool_transcript: Vec<ToolResult>,
141}
142
143/// Result of an inference step — either finished or needs tools.
144pub enum AgentStepResult {
145    /// Agent produced a final response.
146    Finished(AgentRunner<Finished>),
147    /// Agent requested tool calls that need to be executed.
148    RequiresTool(AgentRunner<AwaitingToolCall>),
149}
150
151impl std::fmt::Debug for AgentStepResult {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        match self {
154            Self::Finished(_) => f.write_str("AgentStepResult::Finished"),
155            Self::RequiresTool(_) => f.write_str("AgentStepResult::RequiresTool"),
156        }
157    }
158}
159
160// ── Ready State ──────────────────────────────────────────────────────
161
162impl AgentRunner<Ready> {
163    /// Create a new runner in the `Ready` state.
164    #[must_use]
165    pub fn new(
166        session_id: impl Into<String>,
167        model: impl Into<String>,
168        system_instructions: impl Into<String>,
169        policy: TurnPolicy,
170        session: SessionState,
171    ) -> Self {
172        Self {
173            state: Ready,
174            session,
175            context: RunnerContext {
176                session_id: session_id.into(),
177                model: model.into(),
178                system_instructions: system_instructions.into(),
179                policy,
180                steps_taken: 0,
181                tool_calls_made: 0,
182                total_usage: TokenUsage::default(),
183                tool_transcript: Vec::new(),
184            },
185        }
186    }
187
188    /// Perform LLM inference and transition to the next state.
189    ///
190    /// Returns [`AgentStepResult::Finished`] if the LLM produces a final
191    /// response, or [`AgentStepResult::RequiresTool`] if tool calls are
192    /// needed.
193    ///
194    /// # Errors
195    ///
196    /// Returns an error if the LLM call fails or policy limits are exceeded.
197    pub async fn infer(
198        mut self,
199        llm: &(impl LlmPort + ?Sized),
200        tools: &(impl ToolPort + ?Sized),
201        compactor: &(impl ContextCompactorPort + ?Sized),
202    ) -> Result<AgentStepResult, AgentError> {
203        if self.context.steps_taken >= self.context.policy.max_steps {
204            return Ok(AgentStepResult::Finished(AgentRunner {
205                state: Finished {
206                    response: AgentResponse {
207                        content: "Max steps exceeded.".to_string(),
208                        tool_transcript: self.context.tool_transcript.clone(),
209                        usage: self.context.total_usage.clone(),
210                        finish_reason: FinishReason::GuardExceeded,
211                    },
212                },
213                session: self.session,
214                context: self.context,
215            }));
216        }
217
218        let tool_descriptors = tools.list_tools().await.unwrap_or_default();
219        let messages = compactor.compact(&self.session).await;
220
221        let request = LlmRequest {
222            model: self.context.model.clone(),
223            messages,
224            tools: tool_descriptors,
225            output_schema: None,
226        };
227
228        let response = llm.complete(request).await?;
229
230        self.context.steps_taken += 1;
231        self.context.total_usage.prompt_tokens =
232            self.context.total_usage.prompt_tokens.saturating_add(response.usage.prompt_tokens);
233        self.context.total_usage.completion_tokens = self
234            .context
235            .total_usage
236            .completion_tokens
237            .saturating_add(response.usage.completion_tokens);
238
239        if response.tool_calls.is_empty() {
240            let assistant_msg = Message::text(Role::Assistant, response.content.clone());
241            self.session.messages.push(assistant_msg);
242
243            Ok(AgentStepResult::Finished(AgentRunner {
244                state: Finished {
245                    response: AgentResponse {
246                        content: response.content,
247                        tool_transcript: self.context.tool_transcript.clone(),
248                        usage: self.context.total_usage.clone(),
249                        finish_reason: FinishReason::Stop,
250                    },
251                },
252                session: self.session,
253                context: self.context,
254            }))
255        } else {
256            let call_ids: Vec<Option<String>> =
257                response.tool_calls.iter().map(|c| c.call_id.clone()).collect();
258
259            let assistant_msg =
260                Message::assistant_tool_calls(response.content, response.tool_calls.clone());
261            self.session.messages.push(assistant_msg);
262
263            Ok(AgentStepResult::RequiresTool(AgentRunner {
264                state: AwaitingToolCall { pending_calls: response.tool_calls, call_ids },
265                session: self.session,
266                context: self.context,
267            }))
268        }
269    }
270
271    /// Run the agent loop until completion (high-level convenience method).
272    ///
273    /// This is equivalent to calling [`infer`](Self::infer) in a loop until
274    /// a [`Finished`] state is reached.
275    ///
276    /// # Errors
277    ///
278    /// Returns an error if any LLM call fails.
279    pub async fn run_to_completion(
280        self,
281        llm: &(impl LlmPort + ?Sized),
282        tools: &(impl ToolPort + ?Sized),
283        compactor: &(impl ContextCompactorPort + ?Sized),
284        store: &(impl SessionStore + ?Sized),
285    ) -> Result<AgentRunner<Finished>, AgentError> {
286        let mut current = self.infer(llm, tools, compactor).await?;
287
288        loop {
289            match current {
290                AgentStepResult::Finished(runner) => {
291                    store.save(&runner.context.session_id, &runner.session).await?;
292                    return Ok(runner);
293                }
294                AgentStepResult::RequiresTool(runner) => {
295                    let mut results = Vec::new();
296                    for call in &runner.state.pending_calls {
297                        match tools.call_tool(call.clone()).await {
298                            Ok(result) => results.push(result),
299                            Err(err) => results.push(ToolResult {
300                                name: call.name.clone(),
301                                output: serde_json::json!({"error": err.to_string()}),
302                                is_error: true,
303                            }),
304                        }
305                    }
306                    let ready = runner.provide_tool_results(results);
307                    current = ready.infer(llm, tools, compactor).await?;
308                }
309            }
310        }
311    }
312}
313
314// ── AwaitingToolCall State ───────────────────────────────────────────
315
316impl AgentRunner<AwaitingToolCall> {
317    /// Get the pending tool calls that need to be executed.
318    #[must_use]
319    pub fn pending_calls(&self) -> &[ToolCall] {
320        &self.state.pending_calls
321    }
322
323    /// Provide tool execution results and transition back to `Ready`.
324    ///
325    /// The results must correspond 1:1 with the pending calls in the
326    /// same order.
327    #[must_use]
328    pub fn provide_tool_results(mut self, results: Vec<ToolResult>) -> AgentRunner<Ready> {
329        for (result, call_id) in results.iter().zip(self.state.call_ids.iter()) {
330            let output_str = serde_json::to_string(&result.output).unwrap_or_default();
331            self.session.messages.push(Message::tool_result(
332                result.name.clone(),
333                call_id.clone(),
334                output_str,
335            ));
336            self.context.tool_calls_made += 1;
337            self.context.tool_transcript.push(result.clone());
338        }
339
340        AgentRunner { state: Ready, session: self.session, context: self.context }
341    }
342
343    /// Cancel the pending tool calls and transition to `Finished`.
344    #[must_use]
345    pub fn cancel(self, reason: impl Into<String>) -> AgentRunner<Finished> {
346        AgentRunner {
347            state: Finished {
348                response: AgentResponse {
349                    content: reason.into(),
350                    tool_transcript: self.context.tool_transcript.clone(),
351                    usage: self.context.total_usage.clone(),
352                    finish_reason: FinishReason::Cancelled,
353                },
354            },
355            session: self.session,
356            context: self.context,
357        }
358    }
359
360    /// Require human approval before executing tool calls.
361    ///
362    /// Transitions to `AwaitingApproval` state, which prevents
363    /// direct tool execution. Callers must explicitly `approve()`
364    /// or `deny()` to proceed.
365    ///
366    /// This provides compile-time safety: `AwaitingToolCall` can
367    /// directly execute tools via `provide_tool_results()`, but
368    /// `AwaitingApproval` cannot — it must go through the approval
369    /// gate first.
370    #[must_use]
371    pub fn require_approval(self, reason: impl Into<String>) -> AgentRunner<AwaitingApproval> {
372        AgentRunner {
373            state: AwaitingApproval {
374                pending_calls: self.state.pending_calls,
375                call_ids: self.state.call_ids,
376                reason: reason.into(),
377            },
378            session: self.session,
379            context: self.context,
380        }
381    }
382}
383
384// ── AwaitingApproval State ───────────────────────────────────────────
385
386impl AgentRunner<AwaitingApproval> {
387    /// Get the pending tool calls awaiting approval.
388    #[must_use]
389    pub fn pending_calls(&self) -> &[ToolCall] {
390        &self.state.pending_calls
391    }
392
393    /// Get the reason approval was required.
394    #[must_use]
395    pub fn approval_reason(&self) -> &str {
396        &self.state.reason
397    }
398
399    /// Approve the tool calls and transition back to `AwaitingToolCall`.
400    ///
401    /// After approval, call `provide_tool_results()` with the execution
402    /// results to continue the agent loop.
403    #[must_use]
404    pub fn approve(self) -> AgentRunner<AwaitingToolCall> {
405        AgentRunner {
406            state: AwaitingToolCall {
407                pending_calls: self.state.pending_calls,
408                call_ids: self.state.call_ids,
409            },
410            session: self.session,
411            context: self.context,
412        }
413    }
414
415    /// Deny the tool calls and transition to `Finished`.
416    ///
417    /// The provided reason will be included in the final response
418    /// for the user/audit trail.
419    #[must_use]
420    pub fn deny(self, reason: impl Into<String>) -> AgentRunner<Finished> {
421        AgentRunner {
422            state: Finished {
423                response: AgentResponse {
424                    content: reason.into(),
425                    tool_transcript: self.context.tool_transcript.clone(),
426                    usage: self.context.total_usage.clone(),
427                    finish_reason: FinishReason::Cancelled,
428                },
429            },
430            session: self.session,
431            context: self.context,
432        }
433    }
434}
435
436// ── Finished State ───────────────────────────────────────────────────
437
438impl AgentRunner<Finished> {
439    /// Get a reference to the final response.
440    #[must_use]
441    pub fn response(&self) -> &AgentResponse {
442        &self.state.response
443    }
444
445    /// Consume the runner and return the final response.
446    #[must_use]
447    pub fn into_response(self) -> AgentResponse {
448        self.state.response
449    }
450
451    /// Get the final session state (for persistence).
452    #[must_use]
453    pub fn session(&self) -> &SessionState {
454        &self.session
455    }
456
457    /// Get the execution context with usage stats.
458    #[must_use]
459    pub fn context(&self) -> &RunnerContext {
460        &self.context
461    }
462}
463
464// ── Tests ────────────────────────────────────────────────────────────
465
466#[cfg(test)]
467mod tests {
468    use bob_core::types::ToolDescriptor;
469
470    use super::*;
471
472    struct StubLlm;
473
474    impl StubLlm {
475        fn finish_response(content: &str) -> bob_core::types::LlmResponse {
476            bob_core::types::LlmResponse {
477                content: content.to_string(),
478                usage: TokenUsage::default(),
479                finish_reason: FinishReason::Stop,
480                tool_calls: Vec::new(),
481            }
482        }
483    }
484
485    #[async_trait::async_trait]
486    impl LlmPort for StubLlm {
487        async fn complete(
488            &self,
489            _req: LlmRequest,
490        ) -> Result<bob_core::types::LlmResponse, bob_core::error::LlmError> {
491            Ok(Self::finish_response("done"))
492        }
493
494        async fn complete_stream(
495            &self,
496            _req: LlmRequest,
497        ) -> Result<bob_core::types::LlmStream, bob_core::error::LlmError> {
498            Err(bob_core::error::LlmError::Provider("not implemented".into()))
499        }
500    }
501
502    struct StubTools;
503
504    #[async_trait::async_trait]
505    impl ToolPort for StubTools {
506        async fn list_tools(&self) -> Result<Vec<ToolDescriptor>, bob_core::error::ToolError> {
507            Ok(vec![])
508        }
509
510        async fn call_tool(
511            &self,
512            call: ToolCall,
513        ) -> Result<ToolResult, bob_core::error::ToolError> {
514            Ok(ToolResult { name: call.name, output: serde_json::json!(null), is_error: false })
515        }
516    }
517
518    struct StubCompactor;
519
520    #[async_trait::async_trait]
521    impl ContextCompactorPort for StubCompactor {
522        async fn compact(&self, session: &SessionState) -> Vec<Message> {
523            session.messages.clone()
524        }
525    }
526
527    struct StubStore;
528
529    #[async_trait::async_trait]
530    impl SessionStore for StubStore {
531        async fn load(
532            &self,
533            _id: &bob_core::types::SessionId,
534        ) -> Result<Option<SessionState>, bob_core::error::StoreError> {
535            Ok(None)
536        }
537
538        async fn save(
539            &self,
540            _id: &bob_core::types::SessionId,
541            _state: &SessionState,
542        ) -> Result<(), bob_core::error::StoreError> {
543            Ok(())
544        }
545    }
546
547    #[tokio::test]
548    async fn ready_infer_to_finished() {
549        let runner = AgentRunner::new(
550            "test-session",
551            "test-model",
552            "You are a test assistant.",
553            TurnPolicy::default(),
554            SessionState::default(),
555        );
556
557        let result = runner.infer(&StubLlm, &StubTools, &StubCompactor).await;
558        assert!(result.is_ok(), "infer should succeed");
559
560        if let Ok(AgentStepResult::Finished(runner)) = result {
561            assert_eq!(runner.response().content, "done");
562            assert_eq!(runner.response().finish_reason, FinishReason::Stop);
563        } else {
564            panic!("expected Finished result");
565        }
566    }
567
568    #[tokio::test]
569    async fn run_to_completion() {
570        let runner = AgentRunner::new(
571            "test-session",
572            "test-model",
573            "You are a test assistant.",
574            TurnPolicy::default(),
575            SessionState::default(),
576        );
577
578        let result =
579            runner.run_to_completion(&StubLlm, &StubTools, &StubCompactor, &StubStore).await;
580        assert!(result.is_ok(), "run_to_completion should succeed");
581
582        let finished = result.unwrap();
583        assert_eq!(finished.response().content, "done");
584    }
585
586    #[test]
587    fn awaiting_tool_call_provide_results() {
588        let runner = AgentRunner {
589            state: AwaitingToolCall {
590                pending_calls: vec![ToolCall::new("test", serde_json::json!({}))],
591                call_ids: vec![Some("call-1".into())],
592            },
593            session: SessionState::default(),
594            context: RunnerContext {
595                session_id: "test".into(),
596                model: "test".into(),
597                system_instructions: String::new(),
598                policy: TurnPolicy::default(),
599                steps_taken: 1,
600                tool_calls_made: 0,
601                total_usage: TokenUsage::default(),
602                tool_transcript: Vec::new(),
603            },
604        };
605
606        let results = vec![ToolResult {
607            name: "test".into(),
608            output: serde_json::json!({"ok": true}),
609            is_error: false,
610        }];
611
612        let ready = runner.provide_tool_results(results);
613        assert_eq!(ready.context.tool_calls_made, 1);
614        assert_eq!(ready.session.messages.len(), 1);
615    }
616}