Skip to main content

bob_runtime/
typestate.rs

1//! # Typestate Agent Runner
2//!
3//! Exposes the agent turn FSM as an explicit typestate machine, enabling:
4//!
5//! - **Step-by-step execution**: developers can manually advance the agent
6//! - **Human-in-the-loop**: pause before tool execution for approval
7//! - **Distributed suspend/resume**: serialize state between steps
8//! - **Compile-time safety**: invalid state transitions are rejected at build time
9//!
10//! ## Example
11//!
12//! ```rust,ignore
13//! use bob_runtime::typestate::*;
14//!
15//! // High-level: run to completion
16//! let result = AgentRunner::new(context, llm, tools)
17//!     .run_to_completion()
18//!     .await?;
19//!
20//! // Low-level: step-by-step control
21//! let runner = AgentRunner::new(context, llm, tools);
22//! let step = runner.infer().await?;
23//! match step {
24//!     AgentStepResult::Finished(runner) => {
25//!         println!("done: {}", runner.response().content);
26//!     }
27//!     AgentStepResult::RequiresTool(runner) => {
28//!         // Human-in-the-loop: ask for approval
29//!         let results = execute_tools_with_approval(runner.pending_calls()).await;
30//!         let runner = runner.provide_tool_results(results);
31//!         // Continue to next step...
32//!     }
33//! }
34//! ```
35
36use bob_core::{
37    error::AgentError,
38    ports::{ContextCompactorPort, LlmPort, SessionStore, ToolPort},
39    types::{
40        AgentResponse, FinishReason, LlmRequest, Message, Role, SessionState, TokenUsage, ToolCall,
41        ToolResult, TurnPolicy,
42    },
43};
44
45// ── State Markers ────────────────────────────────────────────────────
46
47/// Agent is ready to perform LLM inference.
48#[derive(Debug)]
49pub struct Ready;
50
51/// Agent is waiting for tool call results before continuing.
52#[derive(Debug)]
53pub struct AwaitingToolCall {
54    /// Pending tool calls requested by the LLM.
55    pub pending_calls: Vec<ToolCall>,
56    /// Native tool call IDs (if using provider-native tool calling).
57    pub call_ids: Vec<Option<String>>,
58}
59
60/// Agent has finished execution with a final response.
61#[derive(Debug)]
62pub struct Finished {
63    /// The final agent response.
64    pub response: AgentResponse,
65}
66
67// ── Agent Runner ─────────────────────────────────────────────────────
68
69/// Typestate-parameterized agent runner.
70///
71/// The type parameter `S` encodes the current FSM state:
72/// - [`Ready`] — can call [`infer`](AgentRunner<Ready>::infer)
73/// - [`AwaitingToolCall`] — can call
74///   [`provide_tool_results`](AgentRunner<AwaitingToolCall>::provide_tool_results)
75/// - [`Finished`] — can access the final [`response`](AgentRunner<Finished>::response)
76#[derive(Debug)]
77pub struct AgentRunner<S> {
78    state: S,
79    session: SessionState,
80    context: RunnerContext,
81}
82
83/// Immutable execution context carried through the typestate machine.
84#[derive(Debug, Clone)]
85pub struct RunnerContext {
86    pub session_id: String,
87    pub model: String,
88    pub system_instructions: String,
89    pub policy: TurnPolicy,
90    pub steps_taken: u32,
91    pub tool_calls_made: u32,
92    pub total_usage: TokenUsage,
93    pub tool_transcript: Vec<ToolResult>,
94}
95
96/// Result of an inference step — either finished or needs tools.
97pub enum AgentStepResult {
98    /// Agent produced a final response.
99    Finished(AgentRunner<Finished>),
100    /// Agent requested tool calls that need to be executed.
101    RequiresTool(AgentRunner<AwaitingToolCall>),
102}
103
104impl std::fmt::Debug for AgentStepResult {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        match self {
107            Self::Finished(_) => f.write_str("AgentStepResult::Finished"),
108            Self::RequiresTool(_) => f.write_str("AgentStepResult::RequiresTool"),
109        }
110    }
111}
112
113// ── Ready State ──────────────────────────────────────────────────────
114
115impl AgentRunner<Ready> {
116    /// Create a new runner in the `Ready` state.
117    #[must_use]
118    pub fn new(
119        session_id: impl Into<String>,
120        model: impl Into<String>,
121        system_instructions: impl Into<String>,
122        policy: TurnPolicy,
123        session: SessionState,
124    ) -> Self {
125        Self {
126            state: Ready,
127            session,
128            context: RunnerContext {
129                session_id: session_id.into(),
130                model: model.into(),
131                system_instructions: system_instructions.into(),
132                policy,
133                steps_taken: 0,
134                tool_calls_made: 0,
135                total_usage: TokenUsage::default(),
136                tool_transcript: Vec::new(),
137            },
138        }
139    }
140
141    /// Perform LLM inference and transition to the next state.
142    ///
143    /// Returns [`AgentStepResult::Finished`] if the LLM produces a final
144    /// response, or [`AgentStepResult::RequiresTool`] if tool calls are
145    /// needed.
146    ///
147    /// # Errors
148    ///
149    /// Returns an error if the LLM call fails or policy limits are exceeded.
150    pub async fn infer(
151        mut self,
152        llm: &(impl LlmPort + ?Sized),
153        tools: &(impl ToolPort + ?Sized),
154        compactor: &(impl ContextCompactorPort + ?Sized),
155    ) -> Result<AgentStepResult, AgentError> {
156        if self.context.steps_taken >= self.context.policy.max_steps {
157            return Ok(AgentStepResult::Finished(AgentRunner {
158                state: Finished {
159                    response: AgentResponse {
160                        content: "Max steps exceeded.".to_string(),
161                        tool_transcript: self.context.tool_transcript.clone(),
162                        usage: self.context.total_usage.clone(),
163                        finish_reason: FinishReason::GuardExceeded,
164                    },
165                },
166                session: self.session,
167                context: self.context,
168            }));
169        }
170
171        let tool_descriptors = tools.list_tools().await.unwrap_or_default();
172        let messages = compactor.compact(&self.session).await;
173
174        let request = LlmRequest {
175            model: self.context.model.clone(),
176            messages,
177            tools: tool_descriptors,
178            output_schema: None,
179        };
180
181        let response = llm.complete(request).await?;
182
183        self.context.steps_taken += 1;
184        self.context.total_usage.prompt_tokens =
185            self.context.total_usage.prompt_tokens.saturating_add(response.usage.prompt_tokens);
186        self.context.total_usage.completion_tokens = self
187            .context
188            .total_usage
189            .completion_tokens
190            .saturating_add(response.usage.completion_tokens);
191
192        if response.tool_calls.is_empty() {
193            let assistant_msg = Message::text(Role::Assistant, response.content.clone());
194            self.session.messages.push(assistant_msg);
195
196            Ok(AgentStepResult::Finished(AgentRunner {
197                state: Finished {
198                    response: AgentResponse {
199                        content: response.content,
200                        tool_transcript: self.context.tool_transcript.clone(),
201                        usage: self.context.total_usage.clone(),
202                        finish_reason: FinishReason::Stop,
203                    },
204                },
205                session: self.session,
206                context: self.context,
207            }))
208        } else {
209            let call_ids: Vec<Option<String>> =
210                response.tool_calls.iter().map(|c| c.call_id.clone()).collect();
211
212            let assistant_msg =
213                Message::assistant_tool_calls(response.content, response.tool_calls.clone());
214            self.session.messages.push(assistant_msg);
215
216            Ok(AgentStepResult::RequiresTool(AgentRunner {
217                state: AwaitingToolCall { pending_calls: response.tool_calls, call_ids },
218                session: self.session,
219                context: self.context,
220            }))
221        }
222    }
223
224    /// Run the agent loop until completion (high-level convenience method).
225    ///
226    /// This is equivalent to calling [`infer`](Self::infer) in a loop until
227    /// a [`Finished`] state is reached.
228    ///
229    /// # Errors
230    ///
231    /// Returns an error if any LLM call fails.
232    pub async fn run_to_completion(
233        self,
234        llm: &(impl LlmPort + ?Sized),
235        tools: &(impl ToolPort + ?Sized),
236        compactor: &(impl ContextCompactorPort + ?Sized),
237        store: &(impl SessionStore + ?Sized),
238    ) -> Result<AgentRunner<Finished>, AgentError> {
239        let mut current = self.infer(llm, tools, compactor).await?;
240
241        loop {
242            match current {
243                AgentStepResult::Finished(runner) => {
244                    store.save(&runner.context.session_id, &runner.session).await?;
245                    return Ok(runner);
246                }
247                AgentStepResult::RequiresTool(runner) => {
248                    let mut results = Vec::new();
249                    for call in &runner.state.pending_calls {
250                        match tools.call_tool(call.clone()).await {
251                            Ok(result) => results.push(result),
252                            Err(err) => results.push(ToolResult {
253                                name: call.name.clone(),
254                                output: serde_json::json!({"error": err.to_string()}),
255                                is_error: true,
256                            }),
257                        }
258                    }
259                    let ready = runner.provide_tool_results(results);
260                    current = ready.infer(llm, tools, compactor).await?;
261                }
262            }
263        }
264    }
265}
266
267// ── AwaitingToolCall State ───────────────────────────────────────────
268
269impl AgentRunner<AwaitingToolCall> {
270    /// Get the pending tool calls that need to be executed.
271    #[must_use]
272    pub fn pending_calls(&self) -> &[ToolCall] {
273        &self.state.pending_calls
274    }
275
276    /// Provide tool execution results and transition back to `Ready`.
277    ///
278    /// The results must correspond 1:1 with the pending calls in the
279    /// same order.
280    #[must_use]
281    pub fn provide_tool_results(mut self, results: Vec<ToolResult>) -> AgentRunner<Ready> {
282        for (result, call_id) in results.iter().zip(self.state.call_ids.iter()) {
283            let output_str = serde_json::to_string(&result.output).unwrap_or_default();
284            self.session.messages.push(Message::tool_result(
285                result.name.clone(),
286                call_id.clone(),
287                output_str,
288            ));
289            self.context.tool_calls_made += 1;
290            self.context.tool_transcript.push(result.clone());
291        }
292
293        AgentRunner { state: Ready, session: self.session, context: self.context }
294    }
295
296    /// Cancel the pending tool calls and transition to `Finished`.
297    #[must_use]
298    pub fn cancel(self, reason: impl Into<String>) -> AgentRunner<Finished> {
299        AgentRunner {
300            state: Finished {
301                response: AgentResponse {
302                    content: reason.into(),
303                    tool_transcript: self.context.tool_transcript.clone(),
304                    usage: self.context.total_usage.clone(),
305                    finish_reason: FinishReason::Cancelled,
306                },
307            },
308            session: self.session,
309            context: self.context,
310        }
311    }
312}
313
314// ── Finished State ───────────────────────────────────────────────────
315
316impl AgentRunner<Finished> {
317    /// Get a reference to the final response.
318    #[must_use]
319    pub fn response(&self) -> &AgentResponse {
320        &self.state.response
321    }
322
323    /// Consume the runner and return the final response.
324    #[must_use]
325    pub fn into_response(self) -> AgentResponse {
326        self.state.response
327    }
328
329    /// Get the final session state (for persistence).
330    #[must_use]
331    pub fn session(&self) -> &SessionState {
332        &self.session
333    }
334
335    /// Get the execution context with usage stats.
336    #[must_use]
337    pub fn context(&self) -> &RunnerContext {
338        &self.context
339    }
340}
341
342// ── Tests ────────────────────────────────────────────────────────────
343
344#[cfg(test)]
345mod tests {
346    use bob_core::types::ToolDescriptor;
347
348    use super::*;
349
350    struct StubLlm;
351
352    impl StubLlm {
353        fn finish_response(content: &str) -> bob_core::types::LlmResponse {
354            bob_core::types::LlmResponse {
355                content: content.to_string(),
356                usage: TokenUsage::default(),
357                finish_reason: FinishReason::Stop,
358                tool_calls: Vec::new(),
359            }
360        }
361    }
362
363    #[async_trait::async_trait]
364    impl LlmPort for StubLlm {
365        async fn complete(
366            &self,
367            _req: LlmRequest,
368        ) -> Result<bob_core::types::LlmResponse, bob_core::error::LlmError> {
369            Ok(Self::finish_response("done"))
370        }
371
372        async fn complete_stream(
373            &self,
374            _req: LlmRequest,
375        ) -> Result<bob_core::types::LlmStream, bob_core::error::LlmError> {
376            Err(bob_core::error::LlmError::Provider("not implemented".into()))
377        }
378    }
379
380    struct StubTools;
381
382    #[async_trait::async_trait]
383    impl ToolPort for StubTools {
384        async fn list_tools(&self) -> Result<Vec<ToolDescriptor>, bob_core::error::ToolError> {
385            Ok(vec![])
386        }
387
388        async fn call_tool(
389            &self,
390            call: ToolCall,
391        ) -> Result<ToolResult, bob_core::error::ToolError> {
392            Ok(ToolResult { name: call.name, output: serde_json::json!(null), is_error: false })
393        }
394    }
395
396    struct StubCompactor;
397
398    #[async_trait::async_trait]
399    impl ContextCompactorPort for StubCompactor {
400        async fn compact(&self, session: &SessionState) -> Vec<Message> {
401            session.messages.clone()
402        }
403    }
404
405    struct StubStore;
406
407    #[async_trait::async_trait]
408    impl SessionStore for StubStore {
409        async fn load(
410            &self,
411            _id: &bob_core::types::SessionId,
412        ) -> Result<Option<SessionState>, bob_core::error::StoreError> {
413            Ok(None)
414        }
415
416        async fn save(
417            &self,
418            _id: &bob_core::types::SessionId,
419            _state: &SessionState,
420        ) -> Result<(), bob_core::error::StoreError> {
421            Ok(())
422        }
423    }
424
425    #[tokio::test]
426    async fn ready_infer_to_finished() {
427        let runner = AgentRunner::new(
428            "test-session",
429            "test-model",
430            "You are a test assistant.",
431            TurnPolicy::default(),
432            SessionState::default(),
433        );
434
435        let result = runner.infer(&StubLlm, &StubTools, &StubCompactor).await;
436        assert!(result.is_ok(), "infer should succeed");
437
438        if let Ok(AgentStepResult::Finished(runner)) = result {
439            assert_eq!(runner.response().content, "done");
440            assert_eq!(runner.response().finish_reason, FinishReason::Stop);
441        } else {
442            panic!("expected Finished result");
443        }
444    }
445
446    #[tokio::test]
447    async fn run_to_completion() {
448        let runner = AgentRunner::new(
449            "test-session",
450            "test-model",
451            "You are a test assistant.",
452            TurnPolicy::default(),
453            SessionState::default(),
454        );
455
456        let result =
457            runner.run_to_completion(&StubLlm, &StubTools, &StubCompactor, &StubStore).await;
458        assert!(result.is_ok(), "run_to_completion should succeed");
459
460        let finished = result.unwrap();
461        assert_eq!(finished.response().content, "done");
462    }
463
464    #[test]
465    fn awaiting_tool_call_provide_results() {
466        let runner = AgentRunner {
467            state: AwaitingToolCall {
468                pending_calls: vec![ToolCall::new("test", serde_json::json!({}))],
469                call_ids: vec![Some("call-1".into())],
470            },
471            session: SessionState::default(),
472            context: RunnerContext {
473                session_id: "test".into(),
474                model: "test".into(),
475                system_instructions: String::new(),
476                policy: TurnPolicy::default(),
477                steps_taken: 1,
478                tool_calls_made: 0,
479                total_usage: TokenUsage::default(),
480                tool_transcript: Vec::new(),
481            },
482        };
483
484        let results = vec![ToolResult {
485            name: "test".into(),
486            output: serde_json::json!({"ok": true}),
487            is_error: false,
488        }];
489
490        let ready = runner.provide_tool_results(results);
491        assert_eq!(ready.context.tool_calls_made, 1);
492        assert_eq!(ready.session.messages.len(), 1);
493    }
494}