1use bob_core::{
37 error::AgentError,
38 ports::{ContextCompactorPort, LlmPort, SessionStore, ToolPort},
39 types::{
40 AgentResponse, FinishReason, LlmRequest, Message, Role, SessionState, TokenUsage, ToolCall,
41 ToolResult, TurnPolicy,
42 },
43};
44
45#[derive(Debug)]
49pub struct Ready;
50
51#[derive(Debug)]
53pub struct AwaitingToolCall {
54 pub pending_calls: Vec<ToolCall>,
56 pub call_ids: Vec<Option<String>>,
58}
59
60#[derive(Debug)]
62pub struct Finished {
63 pub response: AgentResponse,
65}
66
67#[derive(Debug)]
77pub struct AgentRunner<S> {
78 state: S,
79 session: SessionState,
80 context: RunnerContext,
81}
82
83#[derive(Debug, Clone)]
85pub struct RunnerContext {
86 pub session_id: String,
87 pub model: String,
88 pub system_instructions: String,
89 pub policy: TurnPolicy,
90 pub steps_taken: u32,
91 pub tool_calls_made: u32,
92 pub total_usage: TokenUsage,
93 pub tool_transcript: Vec<ToolResult>,
94}
95
96pub enum AgentStepResult {
98 Finished(AgentRunner<Finished>),
100 RequiresTool(AgentRunner<AwaitingToolCall>),
102}
103
104impl std::fmt::Debug for AgentStepResult {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 match self {
107 Self::Finished(_) => f.write_str("AgentStepResult::Finished"),
108 Self::RequiresTool(_) => f.write_str("AgentStepResult::RequiresTool"),
109 }
110 }
111}
112
113impl AgentRunner<Ready> {
116 #[must_use]
118 pub fn new(
119 session_id: impl Into<String>,
120 model: impl Into<String>,
121 system_instructions: impl Into<String>,
122 policy: TurnPolicy,
123 session: SessionState,
124 ) -> Self {
125 Self {
126 state: Ready,
127 session,
128 context: RunnerContext {
129 session_id: session_id.into(),
130 model: model.into(),
131 system_instructions: system_instructions.into(),
132 policy,
133 steps_taken: 0,
134 tool_calls_made: 0,
135 total_usage: TokenUsage::default(),
136 tool_transcript: Vec::new(),
137 },
138 }
139 }
140
141 pub async fn infer(
151 mut self,
152 llm: &(impl LlmPort + ?Sized),
153 tools: &(impl ToolPort + ?Sized),
154 compactor: &(impl ContextCompactorPort + ?Sized),
155 ) -> Result<AgentStepResult, AgentError> {
156 if self.context.steps_taken >= self.context.policy.max_steps {
157 return Ok(AgentStepResult::Finished(AgentRunner {
158 state: Finished {
159 response: AgentResponse {
160 content: "Max steps exceeded.".to_string(),
161 tool_transcript: self.context.tool_transcript.clone(),
162 usage: self.context.total_usage.clone(),
163 finish_reason: FinishReason::GuardExceeded,
164 },
165 },
166 session: self.session,
167 context: self.context,
168 }));
169 }
170
171 let tool_descriptors = tools.list_tools().await.unwrap_or_default();
172 let messages = compactor.compact(&self.session).await;
173
174 let request = LlmRequest {
175 model: self.context.model.clone(),
176 messages,
177 tools: tool_descriptors,
178 output_schema: None,
179 };
180
181 let response = llm.complete(request).await?;
182
183 self.context.steps_taken += 1;
184 self.context.total_usage.prompt_tokens =
185 self.context.total_usage.prompt_tokens.saturating_add(response.usage.prompt_tokens);
186 self.context.total_usage.completion_tokens = self
187 .context
188 .total_usage
189 .completion_tokens
190 .saturating_add(response.usage.completion_tokens);
191
192 if response.tool_calls.is_empty() {
193 let assistant_msg = Message::text(Role::Assistant, response.content.clone());
194 self.session.messages.push(assistant_msg);
195
196 Ok(AgentStepResult::Finished(AgentRunner {
197 state: Finished {
198 response: AgentResponse {
199 content: response.content,
200 tool_transcript: self.context.tool_transcript.clone(),
201 usage: self.context.total_usage.clone(),
202 finish_reason: FinishReason::Stop,
203 },
204 },
205 session: self.session,
206 context: self.context,
207 }))
208 } else {
209 let call_ids: Vec<Option<String>> =
210 response.tool_calls.iter().map(|c| c.call_id.clone()).collect();
211
212 let assistant_msg =
213 Message::assistant_tool_calls(response.content, response.tool_calls.clone());
214 self.session.messages.push(assistant_msg);
215
216 Ok(AgentStepResult::RequiresTool(AgentRunner {
217 state: AwaitingToolCall { pending_calls: response.tool_calls, call_ids },
218 session: self.session,
219 context: self.context,
220 }))
221 }
222 }
223
224 pub async fn run_to_completion(
233 self,
234 llm: &(impl LlmPort + ?Sized),
235 tools: &(impl ToolPort + ?Sized),
236 compactor: &(impl ContextCompactorPort + ?Sized),
237 store: &(impl SessionStore + ?Sized),
238 ) -> Result<AgentRunner<Finished>, AgentError> {
239 let mut current = self.infer(llm, tools, compactor).await?;
240
241 loop {
242 match current {
243 AgentStepResult::Finished(runner) => {
244 store.save(&runner.context.session_id, &runner.session).await?;
245 return Ok(runner);
246 }
247 AgentStepResult::RequiresTool(runner) => {
248 let mut results = Vec::new();
249 for call in &runner.state.pending_calls {
250 match tools.call_tool(call.clone()).await {
251 Ok(result) => results.push(result),
252 Err(err) => results.push(ToolResult {
253 name: call.name.clone(),
254 output: serde_json::json!({"error": err.to_string()}),
255 is_error: true,
256 }),
257 }
258 }
259 let ready = runner.provide_tool_results(results);
260 current = ready.infer(llm, tools, compactor).await?;
261 }
262 }
263 }
264 }
265}
266
267impl AgentRunner<AwaitingToolCall> {
270 #[must_use]
272 pub fn pending_calls(&self) -> &[ToolCall] {
273 &self.state.pending_calls
274 }
275
276 #[must_use]
281 pub fn provide_tool_results(mut self, results: Vec<ToolResult>) -> AgentRunner<Ready> {
282 for (result, call_id) in results.iter().zip(self.state.call_ids.iter()) {
283 let output_str = serde_json::to_string(&result.output).unwrap_or_default();
284 self.session.messages.push(Message::tool_result(
285 result.name.clone(),
286 call_id.clone(),
287 output_str,
288 ));
289 self.context.tool_calls_made += 1;
290 self.context.tool_transcript.push(result.clone());
291 }
292
293 AgentRunner { state: Ready, session: self.session, context: self.context }
294 }
295
296 #[must_use]
298 pub fn cancel(self, reason: impl Into<String>) -> AgentRunner<Finished> {
299 AgentRunner {
300 state: Finished {
301 response: AgentResponse {
302 content: reason.into(),
303 tool_transcript: self.context.tool_transcript.clone(),
304 usage: self.context.total_usage.clone(),
305 finish_reason: FinishReason::Cancelled,
306 },
307 },
308 session: self.session,
309 context: self.context,
310 }
311 }
312}
313
314impl AgentRunner<Finished> {
317 #[must_use]
319 pub fn response(&self) -> &AgentResponse {
320 &self.state.response
321 }
322
323 #[must_use]
325 pub fn into_response(self) -> AgentResponse {
326 self.state.response
327 }
328
329 #[must_use]
331 pub fn session(&self) -> &SessionState {
332 &self.session
333 }
334
335 #[must_use]
337 pub fn context(&self) -> &RunnerContext {
338 &self.context
339 }
340}
341
342#[cfg(test)]
345mod tests {
346 use bob_core::types::ToolDescriptor;
347
348 use super::*;
349
350 struct StubLlm;
351
352 impl StubLlm {
353 fn finish_response(content: &str) -> bob_core::types::LlmResponse {
354 bob_core::types::LlmResponse {
355 content: content.to_string(),
356 usage: TokenUsage::default(),
357 finish_reason: FinishReason::Stop,
358 tool_calls: Vec::new(),
359 }
360 }
361 }
362
363 #[async_trait::async_trait]
364 impl LlmPort for StubLlm {
365 async fn complete(
366 &self,
367 _req: LlmRequest,
368 ) -> Result<bob_core::types::LlmResponse, bob_core::error::LlmError> {
369 Ok(Self::finish_response("done"))
370 }
371
372 async fn complete_stream(
373 &self,
374 _req: LlmRequest,
375 ) -> Result<bob_core::types::LlmStream, bob_core::error::LlmError> {
376 Err(bob_core::error::LlmError::Provider("not implemented".into()))
377 }
378 }
379
380 struct StubTools;
381
382 #[async_trait::async_trait]
383 impl ToolPort for StubTools {
384 async fn list_tools(&self) -> Result<Vec<ToolDescriptor>, bob_core::error::ToolError> {
385 Ok(vec![])
386 }
387
388 async fn call_tool(
389 &self,
390 call: ToolCall,
391 ) -> Result<ToolResult, bob_core::error::ToolError> {
392 Ok(ToolResult { name: call.name, output: serde_json::json!(null), is_error: false })
393 }
394 }
395
396 struct StubCompactor;
397
398 #[async_trait::async_trait]
399 impl ContextCompactorPort for StubCompactor {
400 async fn compact(&self, session: &SessionState) -> Vec<Message> {
401 session.messages.clone()
402 }
403 }
404
405 struct StubStore;
406
407 #[async_trait::async_trait]
408 impl SessionStore for StubStore {
409 async fn load(
410 &self,
411 _id: &bob_core::types::SessionId,
412 ) -> Result<Option<SessionState>, bob_core::error::StoreError> {
413 Ok(None)
414 }
415
416 async fn save(
417 &self,
418 _id: &bob_core::types::SessionId,
419 _state: &SessionState,
420 ) -> Result<(), bob_core::error::StoreError> {
421 Ok(())
422 }
423 }
424
425 #[tokio::test]
426 async fn ready_infer_to_finished() {
427 let runner = AgentRunner::new(
428 "test-session",
429 "test-model",
430 "You are a test assistant.",
431 TurnPolicy::default(),
432 SessionState::default(),
433 );
434
435 let result = runner.infer(&StubLlm, &StubTools, &StubCompactor).await;
436 assert!(result.is_ok(), "infer should succeed");
437
438 if let Ok(AgentStepResult::Finished(runner)) = result {
439 assert_eq!(runner.response().content, "done");
440 assert_eq!(runner.response().finish_reason, FinishReason::Stop);
441 } else {
442 panic!("expected Finished result");
443 }
444 }
445
446 #[tokio::test]
447 async fn run_to_completion() {
448 let runner = AgentRunner::new(
449 "test-session",
450 "test-model",
451 "You are a test assistant.",
452 TurnPolicy::default(),
453 SessionState::default(),
454 );
455
456 let result =
457 runner.run_to_completion(&StubLlm, &StubTools, &StubCompactor, &StubStore).await;
458 assert!(result.is_ok(), "run_to_completion should succeed");
459
460 let finished = result.unwrap();
461 assert_eq!(finished.response().content, "done");
462 }
463
464 #[test]
465 fn awaiting_tool_call_provide_results() {
466 let runner = AgentRunner {
467 state: AwaitingToolCall {
468 pending_calls: vec![ToolCall::new("test", serde_json::json!({}))],
469 call_ids: vec![Some("call-1".into())],
470 },
471 session: SessionState::default(),
472 context: RunnerContext {
473 session_id: "test".into(),
474 model: "test".into(),
475 system_instructions: String::new(),
476 policy: TurnPolicy::default(),
477 steps_taken: 1,
478 tool_calls_made: 0,
479 total_usage: TokenUsage::default(),
480 tool_transcript: Vec::new(),
481 },
482 };
483
484 let results = vec![ToolResult {
485 name: "test".into(),
486 output: serde_json::json!({"ok": true}),
487 is_error: false,
488 }];
489
490 let ready = runner.provide_tool_results(results);
491 assert_eq!(ready.context.tool_calls_made, 1);
492 assert_eq!(ready.session.messages.len(), 1);
493 }
494}