1use bob_core::{
53 error::AgentError,
54 ports::{ContextCompactorPort, LlmPort, SessionStore, ToolPort},
55 types::{
56 AgentResponse, FinishReason, LlmRequest, Message, Role, SessionState, TokenUsage, ToolCall,
57 ToolResult, TurnPolicy,
58 },
59};
60
61#[derive(Debug)]
68pub struct Ready;
69
70#[derive(Debug)]
78pub struct AwaitingToolCall {
79 pub pending_calls: Vec<ToolCall>,
81 pub call_ids: Vec<Option<String>>,
83}
84
85#[derive(Debug)]
95pub struct AwaitingApproval {
96 pub pending_calls: Vec<ToolCall>,
98 pub call_ids: Vec<Option<String>>,
100 pub reason: String,
102}
103
104#[derive(Debug)]
109pub struct Finished {
110 pub response: AgentResponse,
112}
113
114#[derive(Debug)]
124pub struct AgentRunner<S> {
125 state: S,
126 session: SessionState,
127 context: RunnerContext,
128}
129
130#[derive(Debug, Clone)]
132pub struct RunnerContext {
133 pub session_id: String,
134 pub model: String,
135 pub system_instructions: String,
136 pub policy: TurnPolicy,
137 pub steps_taken: u32,
138 pub tool_calls_made: u32,
139 pub total_usage: TokenUsage,
140 pub tool_transcript: Vec<ToolResult>,
141}
142
143pub enum AgentStepResult {
145 Finished(AgentRunner<Finished>),
147 RequiresTool(AgentRunner<AwaitingToolCall>),
149}
150
151impl std::fmt::Debug for AgentStepResult {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 match self {
154 Self::Finished(_) => f.write_str("AgentStepResult::Finished"),
155 Self::RequiresTool(_) => f.write_str("AgentStepResult::RequiresTool"),
156 }
157 }
158}
159
160impl AgentRunner<Ready> {
163 #[must_use]
165 pub fn new(
166 session_id: impl Into<String>,
167 model: impl Into<String>,
168 system_instructions: impl Into<String>,
169 policy: TurnPolicy,
170 session: SessionState,
171 ) -> Self {
172 Self {
173 state: Ready,
174 session,
175 context: RunnerContext {
176 session_id: session_id.into(),
177 model: model.into(),
178 system_instructions: system_instructions.into(),
179 policy,
180 steps_taken: 0,
181 tool_calls_made: 0,
182 total_usage: TokenUsage::default(),
183 tool_transcript: Vec::new(),
184 },
185 }
186 }
187
188 pub async fn infer(
198 mut self,
199 llm: &(impl LlmPort + ?Sized),
200 tools: &(impl ToolPort + ?Sized),
201 compactor: &(impl ContextCompactorPort + ?Sized),
202 ) -> Result<AgentStepResult, AgentError> {
203 if self.context.steps_taken >= self.context.policy.max_steps {
204 return Ok(AgentStepResult::Finished(AgentRunner {
205 state: Finished {
206 response: AgentResponse {
207 content: "Max steps exceeded.".to_string(),
208 tool_transcript: self.context.tool_transcript.clone(),
209 usage: self.context.total_usage.clone(),
210 finish_reason: FinishReason::GuardExceeded,
211 },
212 },
213 session: self.session,
214 context: self.context,
215 }));
216 }
217
218 let tool_descriptors = tools.list_tools().await.unwrap_or_default();
219 let messages = compactor.compact(&self.session).await;
220
221 let request = LlmRequest {
222 model: self.context.model.clone(),
223 messages,
224 tools: tool_descriptors,
225 output_schema: None,
226 };
227
228 let response = llm.complete(request).await?;
229
230 self.context.steps_taken += 1;
231 self.context.total_usage.prompt_tokens =
232 self.context.total_usage.prompt_tokens.saturating_add(response.usage.prompt_tokens);
233 self.context.total_usage.completion_tokens = self
234 .context
235 .total_usage
236 .completion_tokens
237 .saturating_add(response.usage.completion_tokens);
238
239 if response.tool_calls.is_empty() {
240 let assistant_msg = Message::text(Role::Assistant, response.content.clone());
241 self.session.messages.push(assistant_msg);
242
243 Ok(AgentStepResult::Finished(AgentRunner {
244 state: Finished {
245 response: AgentResponse {
246 content: response.content,
247 tool_transcript: self.context.tool_transcript.clone(),
248 usage: self.context.total_usage.clone(),
249 finish_reason: FinishReason::Stop,
250 },
251 },
252 session: self.session,
253 context: self.context,
254 }))
255 } else {
256 let call_ids: Vec<Option<String>> =
257 response.tool_calls.iter().map(|c| c.call_id.clone()).collect();
258
259 let assistant_msg =
260 Message::assistant_tool_calls(response.content, response.tool_calls.clone());
261 self.session.messages.push(assistant_msg);
262
263 Ok(AgentStepResult::RequiresTool(AgentRunner {
264 state: AwaitingToolCall { pending_calls: response.tool_calls, call_ids },
265 session: self.session,
266 context: self.context,
267 }))
268 }
269 }
270
271 pub async fn run_to_completion(
280 self,
281 llm: &(impl LlmPort + ?Sized),
282 tools: &(impl ToolPort + ?Sized),
283 compactor: &(impl ContextCompactorPort + ?Sized),
284 store: &(impl SessionStore + ?Sized),
285 ) -> Result<AgentRunner<Finished>, AgentError> {
286 let mut current = self.infer(llm, tools, compactor).await?;
287
288 loop {
289 match current {
290 AgentStepResult::Finished(runner) => {
291 store.save(&runner.context.session_id, &runner.session).await?;
292 return Ok(runner);
293 }
294 AgentStepResult::RequiresTool(runner) => {
295 let mut results = Vec::new();
296 for call in &runner.state.pending_calls {
297 match tools.call_tool(call.clone()).await {
298 Ok(result) => results.push(result),
299 Err(err) => results.push(ToolResult {
300 name: call.name.clone(),
301 output: serde_json::json!({"error": err.to_string()}),
302 is_error: true,
303 }),
304 }
305 }
306 let ready = runner.provide_tool_results(results);
307 current = ready.infer(llm, tools, compactor).await?;
308 }
309 }
310 }
311 }
312}
313
314impl AgentRunner<AwaitingToolCall> {
317 #[must_use]
319 pub fn pending_calls(&self) -> &[ToolCall] {
320 &self.state.pending_calls
321 }
322
323 #[must_use]
328 pub fn provide_tool_results(mut self, results: Vec<ToolResult>) -> AgentRunner<Ready> {
329 for (result, call_id) in results.iter().zip(self.state.call_ids.iter()) {
330 let output_str = serde_json::to_string(&result.output).unwrap_or_default();
331 self.session.messages.push(Message::tool_result(
332 result.name.clone(),
333 call_id.clone(),
334 output_str,
335 ));
336 self.context.tool_calls_made += 1;
337 self.context.tool_transcript.push(result.clone());
338 }
339
340 AgentRunner { state: Ready, session: self.session, context: self.context }
341 }
342
343 #[must_use]
345 pub fn cancel(self, reason: impl Into<String>) -> AgentRunner<Finished> {
346 AgentRunner {
347 state: Finished {
348 response: AgentResponse {
349 content: reason.into(),
350 tool_transcript: self.context.tool_transcript.clone(),
351 usage: self.context.total_usage.clone(),
352 finish_reason: FinishReason::Cancelled,
353 },
354 },
355 session: self.session,
356 context: self.context,
357 }
358 }
359
360 #[must_use]
371 pub fn require_approval(self, reason: impl Into<String>) -> AgentRunner<AwaitingApproval> {
372 AgentRunner {
373 state: AwaitingApproval {
374 pending_calls: self.state.pending_calls,
375 call_ids: self.state.call_ids,
376 reason: reason.into(),
377 },
378 session: self.session,
379 context: self.context,
380 }
381 }
382}
383
384impl AgentRunner<AwaitingApproval> {
387 #[must_use]
389 pub fn pending_calls(&self) -> &[ToolCall] {
390 &self.state.pending_calls
391 }
392
393 #[must_use]
395 pub fn approval_reason(&self) -> &str {
396 &self.state.reason
397 }
398
399 #[must_use]
404 pub fn approve(self) -> AgentRunner<AwaitingToolCall> {
405 AgentRunner {
406 state: AwaitingToolCall {
407 pending_calls: self.state.pending_calls,
408 call_ids: self.state.call_ids,
409 },
410 session: self.session,
411 context: self.context,
412 }
413 }
414
415 #[must_use]
420 pub fn deny(self, reason: impl Into<String>) -> AgentRunner<Finished> {
421 AgentRunner {
422 state: Finished {
423 response: AgentResponse {
424 content: reason.into(),
425 tool_transcript: self.context.tool_transcript.clone(),
426 usage: self.context.total_usage.clone(),
427 finish_reason: FinishReason::Cancelled,
428 },
429 },
430 session: self.session,
431 context: self.context,
432 }
433 }
434}
435
436impl AgentRunner<Finished> {
439 #[must_use]
441 pub fn response(&self) -> &AgentResponse {
442 &self.state.response
443 }
444
445 #[must_use]
447 pub fn into_response(self) -> AgentResponse {
448 self.state.response
449 }
450
451 #[must_use]
453 pub fn session(&self) -> &SessionState {
454 &self.session
455 }
456
457 #[must_use]
459 pub fn context(&self) -> &RunnerContext {
460 &self.context
461 }
462}
463
464#[cfg(test)]
467mod tests {
468 use bob_core::types::ToolDescriptor;
469
470 use super::*;
471
472 struct StubLlm;
473
474 impl StubLlm {
475 fn finish_response(content: &str) -> bob_core::types::LlmResponse {
476 bob_core::types::LlmResponse {
477 content: content.to_string(),
478 usage: TokenUsage::default(),
479 finish_reason: FinishReason::Stop,
480 tool_calls: Vec::new(),
481 }
482 }
483 }
484
485 #[async_trait::async_trait]
486 impl LlmPort for StubLlm {
487 async fn complete(
488 &self,
489 _req: LlmRequest,
490 ) -> Result<bob_core::types::LlmResponse, bob_core::error::LlmError> {
491 Ok(Self::finish_response("done"))
492 }
493
494 async fn complete_stream(
495 &self,
496 _req: LlmRequest,
497 ) -> Result<bob_core::types::LlmStream, bob_core::error::LlmError> {
498 Err(bob_core::error::LlmError::Provider("not implemented".into()))
499 }
500 }
501
502 struct StubTools;
503
504 #[async_trait::async_trait]
505 impl ToolPort for StubTools {
506 async fn list_tools(&self) -> Result<Vec<ToolDescriptor>, bob_core::error::ToolError> {
507 Ok(vec![])
508 }
509
510 async fn call_tool(
511 &self,
512 call: ToolCall,
513 ) -> Result<ToolResult, bob_core::error::ToolError> {
514 Ok(ToolResult { name: call.name, output: serde_json::json!(null), is_error: false })
515 }
516 }
517
518 struct StubCompactor;
519
520 #[async_trait::async_trait]
521 impl ContextCompactorPort for StubCompactor {
522 async fn compact(&self, session: &SessionState) -> Vec<Message> {
523 session.messages.clone()
524 }
525 }
526
527 struct StubStore;
528
529 #[async_trait::async_trait]
530 impl SessionStore for StubStore {
531 async fn load(
532 &self,
533 _id: &bob_core::types::SessionId,
534 ) -> Result<Option<SessionState>, bob_core::error::StoreError> {
535 Ok(None)
536 }
537
538 async fn save(
539 &self,
540 _id: &bob_core::types::SessionId,
541 _state: &SessionState,
542 ) -> Result<(), bob_core::error::StoreError> {
543 Ok(())
544 }
545 }
546
547 #[tokio::test]
548 async fn ready_infer_to_finished() {
549 let runner = AgentRunner::new(
550 "test-session",
551 "test-model",
552 "You are a test assistant.",
553 TurnPolicy::default(),
554 SessionState::default(),
555 );
556
557 let result = runner.infer(&StubLlm, &StubTools, &StubCompactor).await;
558 assert!(result.is_ok(), "infer should succeed");
559
560 if let Ok(AgentStepResult::Finished(runner)) = result {
561 assert_eq!(runner.response().content, "done");
562 assert_eq!(runner.response().finish_reason, FinishReason::Stop);
563 } else {
564 panic!("expected Finished result");
565 }
566 }
567
568 #[tokio::test]
569 async fn run_to_completion() {
570 let runner = AgentRunner::new(
571 "test-session",
572 "test-model",
573 "You are a test assistant.",
574 TurnPolicy::default(),
575 SessionState::default(),
576 );
577
578 let result =
579 runner.run_to_completion(&StubLlm, &StubTools, &StubCompactor, &StubStore).await;
580 assert!(result.is_ok(), "run_to_completion should succeed");
581
582 let finished = result.unwrap();
583 assert_eq!(finished.response().content, "done");
584 }
585
586 #[test]
587 fn awaiting_tool_call_provide_results() {
588 let runner = AgentRunner {
589 state: AwaitingToolCall {
590 pending_calls: vec![ToolCall::new("test", serde_json::json!({}))],
591 call_ids: vec![Some("call-1".into())],
592 },
593 session: SessionState::default(),
594 context: RunnerContext {
595 session_id: "test".into(),
596 model: "test".into(),
597 system_instructions: String::new(),
598 policy: TurnPolicy::default(),
599 steps_taken: 1,
600 tool_calls_made: 0,
601 total_usage: TokenUsage::default(),
602 tool_transcript: Vec::new(),
603 },
604 };
605
606 let results = vec![ToolResult {
607 name: "test".into(),
608 output: serde_json::json!({"ok": true}),
609 is_error: false,
610 }];
611
612 let ready = runner.provide_tool_results(results);
613 assert_eq!(ready.context.tool_calls_made, 1);
614 assert_eq!(ready.session.messages.len(), 1);
615 }
616}