1use std::sync::Arc;
43
44use bob_core::{
45 error::AgentError,
46 normalize_tool_list,
47 ports::{
48 ApprovalPort, ArtifactStorePort, CostMeterPort, EventSink, LlmPort, SessionStore,
49 ToolPolicyPort, ToolPort, TurnCheckpointStorePort,
50 },
51 types::{
52 AgentAction, AgentEvent, AgentEventStream, AgentRequest, AgentResponse, AgentRunResult,
53 AgentStreamEvent, ApprovalContext, ApprovalDecision, ArtifactRecord, FinishReason,
54 GuardReason, Message, Role, TokenUsage, ToolCall, ToolResult, TurnCheckpoint, TurnPolicy,
55 },
56};
57use futures_util::StreamExt;
58use tokio::time::Instant;
59use tokio_stream::wrappers::UnboundedReceiverStream;
60
61#[derive(Debug)]
64pub struct LoopGuard {
65 policy: TurnPolicy,
66 steps: u32,
67 tool_calls: u32,
68 consecutive_errors: u32,
69 start: Instant,
70}
71
72impl LoopGuard {
73 #[must_use]
75 pub fn new(policy: TurnPolicy) -> Self {
76 Self { policy, steps: 0, tool_calls: 0, consecutive_errors: 0, start: Instant::now() }
77 }
78
79 #[must_use]
81 pub fn can_continue(&self) -> bool {
82 self.steps < self.policy.max_steps
83 && self.tool_calls < self.policy.max_tool_calls
84 && self.consecutive_errors < self.policy.max_consecutive_errors
85 && !self.timed_out()
86 }
87
88 pub fn record_step(&mut self) {
90 self.steps += 1;
91 }
92
93 pub fn record_tool_call(&mut self) {
95 self.tool_calls += 1;
96 }
97
98 pub fn record_error(&mut self) {
100 self.consecutive_errors += 1;
101 }
102
103 pub fn reset_errors(&mut self) {
105 self.consecutive_errors = 0;
106 }
107
108 #[must_use]
112 pub fn reason(&self) -> GuardReason {
113 if self.steps >= self.policy.max_steps {
114 GuardReason::MaxSteps
115 } else if self.tool_calls >= self.policy.max_tool_calls {
116 GuardReason::MaxToolCalls
117 } else if self.consecutive_errors >= self.policy.max_consecutive_errors {
118 GuardReason::MaxConsecutiveErrors
119 } else if self.timed_out() {
120 GuardReason::TurnTimeout
121 } else {
122 GuardReason::Cancelled
124 }
125 }
126
127 #[must_use]
129 pub fn timed_out(&self) -> bool {
130 self.start.elapsed().as_millis() >= u128::from(self.policy.turn_timeout_ms)
131 }
132}
133
134const DEFAULT_SYSTEM_INSTRUCTIONS: &str = "\
137You are a helpful AI assistant. \
138Think step by step before answering. \
139When you need external information, use the available tools.";
140
141fn resolve_system_instructions(req: &AgentRequest) -> String {
142 let Some(skills_prompt) = req.context.system_prompt.as_deref() else {
143 return DEFAULT_SYSTEM_INSTRUCTIONS.to_string();
144 };
145
146 if skills_prompt.trim().is_empty() {
147 DEFAULT_SYSTEM_INSTRUCTIONS.to_string()
148 } else {
149 format!("{DEFAULT_SYSTEM_INSTRUCTIONS}\n\n{skills_prompt}")
150 }
151}
152
153fn resolve_selected_skills(req: &AgentRequest) -> Vec<String> {
154 req.context.selected_skills.clone()
155}
156
157#[derive(Debug, Clone, Default)]
158struct ToolCallPolicy {
159 deny_tools: Vec<String>,
160 allow_tools: Option<Vec<String>>,
161}
162
163fn resolve_tool_call_policy(req: &AgentRequest) -> ToolCallPolicy {
164 let deny_tools =
165 normalize_tool_list(req.context.tool_policy.deny_tools.iter().map(String::as_str));
166 let allow_tools = req
167 .context
168 .tool_policy
169 .allow_tools
170 .as_ref()
171 .map(|tools| normalize_tool_list(tools.iter().map(String::as_str)));
172 ToolCallPolicy { deny_tools, allow_tools }
173}
174
175fn prompt_options_for_mode(
176 dispatch_mode: crate::DispatchMode,
177 llm: &dyn LlmPort,
178) -> crate::prompt::PromptBuildOptions {
179 match dispatch_mode {
180 crate::DispatchMode::PromptGuided => crate::prompt::PromptBuildOptions::default(),
181 crate::DispatchMode::NativePreferred => {
182 if llm.capabilities().native_tool_calling {
183 crate::prompt::PromptBuildOptions {
184 include_action_schema: false,
185 include_tool_schema: false,
186 }
187 } else {
188 crate::prompt::PromptBuildOptions::default()
189 }
190 }
191 }
192}
193
194fn parse_action_for_mode(
195 dispatch_mode: crate::DispatchMode,
196 llm: &dyn LlmPort,
197 response: &bob_core::types::LlmResponse,
198) -> Result<AgentAction, crate::action::ActionParseError> {
199 match dispatch_mode {
200 crate::DispatchMode::PromptGuided => crate::action::parse_action(&response.content),
201 crate::DispatchMode::NativePreferred => {
202 if llm.capabilities().native_tool_calling {
203 if let Some(tool_call) = response.tool_calls.first() {
204 return Ok(AgentAction::ToolCall {
205 name: tool_call.name.clone(),
206 arguments: tool_call.arguments.clone(),
207 });
208 }
209 }
210 crate::action::parse_action(&response.content)
211 }
212 }
213}
214
215async fn execute_tool_call(
216 tools: &dyn ToolPort,
217 guard: &mut LoopGuard,
218 tool_call: ToolCall,
219 policy: &ToolCallPolicy,
220 tool_policy_port: &dyn ToolPolicyPort,
221 approval_port: &dyn ApprovalPort,
222 approval_context: &ApprovalContext,
223 timeout_ms: u64,
224) -> ToolResult {
225 if !tool_policy_port.is_tool_allowed(
226 &tool_call.name,
227 &policy.deny_tools,
228 policy.allow_tools.as_deref(),
229 ) {
230 guard.record_error();
231 return ToolResult {
232 name: tool_call.name.clone(),
233 output: serde_json::json!({
234 "error": format!("tool '{}' denied by policy", tool_call.name)
235 }),
236 is_error: true,
237 };
238 }
239
240 match approval_port.approve_tool_call(&tool_call, approval_context).await {
241 Ok(ApprovalDecision::Approved) => {}
242 Ok(ApprovalDecision::Denied { reason }) => {
243 guard.record_error();
244 return ToolResult {
245 name: tool_call.name.clone(),
246 output: serde_json::json!({"error": reason}),
247 is_error: true,
248 };
249 }
250 Err(err) => {
251 guard.record_error();
252 return ToolResult {
253 name: tool_call.name.clone(),
254 output: serde_json::json!({"error": err.to_string()}),
255 is_error: true,
256 };
257 }
258 }
259
260 match tokio::time::timeout(
261 std::time::Duration::from_millis(timeout_ms),
262 tools.call_tool(tool_call.clone()),
263 )
264 .await
265 {
266 Ok(Ok(result)) => {
267 guard.reset_errors();
268 result
269 }
270 Ok(Err(err)) => {
271 guard.record_error();
272 ToolResult {
273 name: tool_call.name,
274 output: serde_json::json!({"error": err.to_string()}),
275 is_error: true,
276 }
277 }
278 Err(_) => {
279 guard.record_error();
280 ToolResult {
281 name: tool_call.name,
282 output: serde_json::json!({"error": "tool call timed out"}),
283 is_error: true,
284 }
285 }
286 }
287}
288
289pub async fn run_turn(
296 llm: &dyn LlmPort,
297 tools: &dyn ToolPort,
298 store: &dyn SessionStore,
299 events: &dyn EventSink,
300 req: AgentRequest,
301 policy: &TurnPolicy,
302 default_model: &str,
303) -> Result<AgentRunResult, AgentError> {
304 let tool_policy = crate::DefaultToolPolicyPort;
305 let approval = crate::AllowAllApprovalPort;
306 let checkpoint_store = crate::NoOpCheckpointStorePort;
307 let artifact_store = crate::NoOpArtifactStorePort;
308 let cost_meter = crate::NoOpCostMeterPort;
309 run_turn_with_extensions(
310 llm,
311 tools,
312 store,
313 events,
314 req,
315 policy,
316 default_model,
317 &tool_policy,
318 &approval,
319 crate::DispatchMode::NativePreferred,
320 &checkpoint_store,
321 &artifact_store,
322 &cost_meter,
323 )
324 .await
325}
326
327#[allow(dead_code)]
329pub(crate) async fn run_turn_with_controls(
330 llm: &dyn LlmPort,
331 tools: &dyn ToolPort,
332 store: &dyn SessionStore,
333 events: &dyn EventSink,
334 req: AgentRequest,
335 policy: &TurnPolicy,
336 default_model: &str,
337 tool_policy_port: &dyn ToolPolicyPort,
338 approval_port: &dyn ApprovalPort,
339) -> Result<AgentRunResult, AgentError> {
340 let checkpoint_store = crate::NoOpCheckpointStorePort;
341 let artifact_store = crate::NoOpArtifactStorePort;
342 let cost_meter = crate::NoOpCostMeterPort;
343 run_turn_with_extensions(
344 llm,
345 tools,
346 store,
347 events,
348 req,
349 policy,
350 default_model,
351 tool_policy_port,
352 approval_port,
353 crate::DispatchMode::PromptGuided,
354 &checkpoint_store,
355 &artifact_store,
356 &cost_meter,
357 )
358 .await
359}
360
361pub(crate) async fn run_turn_with_extensions(
363 llm: &dyn LlmPort,
364 tools: &dyn ToolPort,
365 store: &dyn SessionStore,
366 events: &dyn EventSink,
367 req: AgentRequest,
368 policy: &TurnPolicy,
369 default_model: &str,
370 tool_policy_port: &dyn ToolPolicyPort,
371 approval_port: &dyn ApprovalPort,
372 dispatch_mode: crate::DispatchMode,
373 checkpoint_store: &dyn TurnCheckpointStorePort,
374 artifact_store: &dyn ArtifactStorePort,
375 cost_meter: &dyn CostMeterPort,
376) -> Result<AgentRunResult, AgentError> {
377 let model = req.model.as_deref().unwrap_or(default_model);
378 let cancel_token = req.cancel_token.clone();
379 let system_instructions = resolve_system_instructions(&req);
380 let selected_skills = resolve_selected_skills(&req);
381 let tool_call_policy = resolve_tool_call_policy(&req);
382
383 let mut session = store.load(&req.session_id).await?.unwrap_or_default();
384 let tool_descriptors = tools.list_tools().await?;
385 let mut guard = LoopGuard::new(policy.clone());
386
387 events.emit(AgentEvent::TurnStarted { session_id: req.session_id.clone() });
388 if !selected_skills.is_empty() {
389 events.emit(AgentEvent::SkillsSelected { skill_names: selected_skills.clone() });
390 }
391
392 session.messages.push(Message { role: Role::User, content: req.input.clone() });
393
394 let mut tool_transcript: Vec<ToolResult> = Vec::new();
395 let mut total_usage = TokenUsage::default();
396 let mut consecutive_parse_failures: u32 = 0;
397
398 loop {
399 if let Some(ref token) = cancel_token
400 && token.is_cancelled()
401 {
402 return finish_turn(
403 store,
404 events,
405 &req.session_id,
406 &session,
407 FinishResult {
408 content: "Turn cancelled.",
409 tool_transcript,
410 usage: total_usage,
411 finish_reason: FinishReason::Cancelled,
412 },
413 )
414 .await;
415 }
416
417 cost_meter.check_budget(&req.session_id).await?;
418
419 if !guard.can_continue() {
420 let reason = guard.reason();
421 let msg = format!("Turn stopped: {reason:?}");
422 return finish_turn(
423 store,
424 events,
425 &req.session_id,
426 &session,
427 FinishResult {
428 content: &msg,
429 tool_transcript,
430 usage: total_usage,
431 finish_reason: FinishReason::GuardExceeded,
432 },
433 )
434 .await;
435 }
436
437 let llm_request = crate::prompt::build_llm_request_with_options(
438 model,
439 &session,
440 &tool_descriptors,
441 &system_instructions,
442 prompt_options_for_mode(dispatch_mode, llm),
443 );
444
445 events.emit(AgentEvent::LlmCallStarted { model: model.to_string() });
446
447 let llm_response = if let Some(ref token) = cancel_token {
448 tokio::select! {
449 result = llm.complete(llm_request) => result?,
450 () = token.cancelled() => {
451 return finish_turn(
452 store, events, &req.session_id, &session,
453 FinishResult { content: "Turn cancelled.", tool_transcript, usage: total_usage, finish_reason: FinishReason::Cancelled },
454 ).await;
455 }
456 }
457 } else {
458 llm.complete(llm_request).await?
459 };
460
461 guard.record_step();
462 total_usage.prompt_tokens += llm_response.usage.prompt_tokens;
463 total_usage.completion_tokens += llm_response.usage.completion_tokens;
464 cost_meter.record_llm_usage(&req.session_id, model, &llm_response.usage).await?;
465
466 events.emit(AgentEvent::LlmCallCompleted { usage: llm_response.usage.clone() });
467
468 session
469 .messages
470 .push(Message { role: Role::Assistant, content: llm_response.content.clone() });
471
472 let _ = checkpoint_store
473 .save_checkpoint(&TurnCheckpoint {
474 session_id: req.session_id.clone(),
475 step: guard.steps,
476 tool_calls: guard.tool_calls,
477 usage: total_usage.clone(),
478 })
479 .await;
480
481 match parse_action_for_mode(dispatch_mode, llm, &llm_response) {
482 Ok(action) => {
483 consecutive_parse_failures = 0;
484 match action {
485 AgentAction::Final { content } => {
486 return finish_turn(
487 store,
488 events,
489 &req.session_id,
490 &session,
491 FinishResult {
492 content: &content,
493 tool_transcript,
494 usage: total_usage,
495 finish_reason: FinishReason::Stop,
496 },
497 )
498 .await;
499 }
500 AgentAction::AskUser { question } => {
501 return finish_turn(
502 store,
503 events,
504 &req.session_id,
505 &session,
506 FinishResult {
507 content: &question,
508 tool_transcript,
509 usage: total_usage,
510 finish_reason: FinishReason::Stop,
511 },
512 )
513 .await;
514 }
515 AgentAction::ToolCall { name, arguments } => {
516 events.emit(AgentEvent::ToolCallStarted { name: name.clone() });
517 let approval_context = ApprovalContext {
518 session_id: req.session_id.clone(),
519 turn_step: guard.steps.max(1),
520 selected_skills: selected_skills.clone(),
521 };
522
523 let tool_result = execute_tool_call(
524 tools,
525 &mut guard,
526 ToolCall { name: name.clone(), arguments },
527 &tool_call_policy,
528 tool_policy_port,
529 approval_port,
530 &approval_context,
531 policy.tool_timeout_ms,
532 )
533 .await;
534
535 guard.record_tool_call();
536 let _ = cost_meter.record_tool_result(&req.session_id, &tool_result).await;
537
538 let is_error = tool_result.is_error;
539 events.emit(AgentEvent::ToolCallCompleted { name: name.clone(), is_error });
540
541 let output_str =
542 serde_json::to_string(&tool_result.output).unwrap_or_default();
543 session.messages.push(Message { role: Role::Tool, content: output_str });
544
545 let _ = artifact_store
546 .put(ArtifactRecord {
547 session_id: req.session_id.clone(),
548 kind: "tool_result".to_string(),
549 name: name.clone(),
550 content: tool_result.output.clone(),
551 })
552 .await;
553
554 tool_transcript.push(tool_result);
555 }
556 }
557 }
558 Err(_parse_err) => {
559 consecutive_parse_failures += 1;
560 if consecutive_parse_failures >= 2 {
561 let _ = store.save(&req.session_id, &session).await;
562 return Err(AgentError::Internal(
563 "LLM produced invalid JSON after re-prompt".into(),
564 ));
565 }
566 session.messages.push(Message {
567 role: Role::User,
568 content: "Your response was not valid JSON. \
569 Please respond with exactly one JSON object \
570 matching the required schema."
571 .into(),
572 });
573 }
574 }
575 }
576}
577
578struct FinishResult<'a> {
580 content: &'a str,
581 tool_transcript: Vec<ToolResult>,
582 usage: TokenUsage,
583 finish_reason: FinishReason,
584}
585
586async fn finish_turn(
588 store: &dyn SessionStore,
589 events: &dyn EventSink,
590 session_id: &bob_core::types::SessionId,
591 session: &bob_core::types::SessionState,
592 result: FinishResult<'_>,
593) -> Result<AgentRunResult, AgentError> {
594 store.save(session_id, session).await?;
595 events.emit(AgentEvent::TurnCompleted { finish_reason: result.finish_reason });
596 Ok(AgentRunResult::Finished(AgentResponse {
597 content: result.content.to_string(),
598 tool_transcript: result.tool_transcript,
599 usage: result.usage,
600 finish_reason: result.finish_reason,
601 }))
602}
603
604pub async fn run_turn_stream(
606 llm: Arc<dyn LlmPort>,
607 tools: Arc<dyn ToolPort>,
608 store: Arc<dyn SessionStore>,
609 events: Arc<dyn EventSink>,
610 req: AgentRequest,
611 policy: TurnPolicy,
612 default_model: String,
613) -> Result<AgentEventStream, AgentError> {
614 let tool_policy: Arc<dyn ToolPolicyPort> = Arc::new(crate::DefaultToolPolicyPort);
615 let approval: Arc<dyn ApprovalPort> = Arc::new(crate::AllowAllApprovalPort);
616 let checkpoint_store: Arc<dyn TurnCheckpointStorePort> =
617 Arc::new(crate::NoOpCheckpointStorePort);
618 let artifact_store: Arc<dyn ArtifactStorePort> = Arc::new(crate::NoOpArtifactStorePort);
619 let cost_meter: Arc<dyn CostMeterPort> = Arc::new(crate::NoOpCostMeterPort);
620 run_turn_stream_with_controls(
621 llm,
622 tools,
623 store,
624 events,
625 req,
626 policy,
627 default_model,
628 tool_policy,
629 approval,
630 crate::DispatchMode::NativePreferred,
631 checkpoint_store,
632 artifact_store,
633 cost_meter,
634 )
635 .await
636}
637
638pub(crate) async fn run_turn_stream_with_controls(
640 llm: Arc<dyn LlmPort>,
641 tools: Arc<dyn ToolPort>,
642 store: Arc<dyn SessionStore>,
643 events: Arc<dyn EventSink>,
644 req: AgentRequest,
645 policy: TurnPolicy,
646 default_model: String,
647 tool_policy: Arc<dyn ToolPolicyPort>,
648 approval: Arc<dyn ApprovalPort>,
649 dispatch_mode: crate::DispatchMode,
650 checkpoint_store: Arc<dyn TurnCheckpointStorePort>,
651 artifact_store: Arc<dyn ArtifactStorePort>,
652 cost_meter: Arc<dyn CostMeterPort>,
653) -> Result<AgentEventStream, AgentError> {
654 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<AgentStreamEvent>();
655 let config = StreamRunConfig {
656 policy,
657 default_model,
658 tool_policy,
659 approval,
660 dispatch_mode,
661 checkpoint_store,
662 artifact_store,
663 cost_meter,
664 };
665
666 tokio::spawn(async move {
667 if let Err(err) = run_turn_stream_inner(llm, tools, store, events, req, &config, &tx).await
668 {
669 let _ = tx.send(AgentStreamEvent::Error { error: err.to_string() });
670 }
671 });
672
673 Ok(Box::pin(UnboundedReceiverStream::new(rx)))
674}
675
676struct StreamRunConfig {
677 policy: TurnPolicy,
678 default_model: String,
679 tool_policy: Arc<dyn ToolPolicyPort>,
680 approval: Arc<dyn ApprovalPort>,
681 dispatch_mode: crate::DispatchMode,
682 checkpoint_store: Arc<dyn TurnCheckpointStorePort>,
683 artifact_store: Arc<dyn ArtifactStorePort>,
684 cost_meter: Arc<dyn CostMeterPort>,
685}
686
687async fn run_turn_stream_inner(
688 llm: Arc<dyn LlmPort>,
689 tools: Arc<dyn ToolPort>,
690 store: Arc<dyn SessionStore>,
691 events: Arc<dyn EventSink>,
692 req: AgentRequest,
693 config: &StreamRunConfig,
694 tx: &tokio::sync::mpsc::UnboundedSender<AgentStreamEvent>,
695) -> Result<(), AgentError> {
696 let model = req.model.as_deref().unwrap_or(&config.default_model);
697 let cancel_token = req.cancel_token.clone();
698 let system_instructions = resolve_system_instructions(&req);
699 let selected_skills = resolve_selected_skills(&req);
700 let tool_call_policy = resolve_tool_call_policy(&req);
701
702 let mut session = store.load(&req.session_id).await?.unwrap_or_default();
703 let tool_descriptors = tools.list_tools().await?;
704 let mut guard = LoopGuard::new(config.policy.clone());
705 let mut total_usage = TokenUsage::default();
706 let mut consecutive_parse_failures: u32 = 0;
707 let mut next_call_id: u64 = 0;
708
709 events.emit(AgentEvent::TurnStarted { session_id: req.session_id.clone() });
710 if !selected_skills.is_empty() {
711 events.emit(AgentEvent::SkillsSelected { skill_names: selected_skills.clone() });
712 }
713 session.messages.push(Message { role: Role::User, content: req.input.clone() });
714
715 loop {
716 if let Some(ref token) = cancel_token
717 && token.is_cancelled()
718 {
719 events.emit(AgentEvent::Error { error: "turn cancelled".to_string() });
720 events.emit(AgentEvent::TurnCompleted { finish_reason: FinishReason::Cancelled });
721 store.save(&req.session_id, &session).await?;
722 let _ = tx.send(AgentStreamEvent::Error { error: "turn cancelled".to_string() });
723 let _ = tx.send(AgentStreamEvent::Finished { usage: total_usage.clone() });
724 return Ok(());
725 }
726
727 config.cost_meter.check_budget(&req.session_id).await?;
728
729 if !guard.can_continue() {
730 let reason = guard.reason();
731 let msg = format!("Turn stopped: {reason:?}");
732 events.emit(AgentEvent::Error { error: msg.clone() });
733 events.emit(AgentEvent::TurnCompleted { finish_reason: FinishReason::GuardExceeded });
734 store.save(&req.session_id, &session).await?;
735 let _ = tx.send(AgentStreamEvent::Error { error: msg });
736 let _ = tx.send(AgentStreamEvent::Finished { usage: total_usage.clone() });
737 return Ok(());
738 }
739
740 let llm_request = crate::prompt::build_llm_request_with_options(
741 model,
742 &session,
743 &tool_descriptors,
744 &system_instructions,
745 prompt_options_for_mode(config.dispatch_mode, llm.as_ref()),
746 );
747 events.emit(AgentEvent::LlmCallStarted { model: model.to_string() });
748
749 let mut assistant_content = String::new();
750 let mut llm_usage = TokenUsage::default();
751 let mut llm_tool_calls: Vec<ToolCall> = Vec::new();
752 let mut fallback_to_complete = false;
753
754 match llm.complete_stream(llm_request.clone()).await {
755 Ok(mut stream) => {
756 while let Some(item) = stream.next().await {
757 match item {
758 Ok(bob_core::types::LlmStreamChunk::TextDelta(delta)) => {
759 assistant_content.push_str(&delta);
760 let _ = tx.send(AgentStreamEvent::TextDelta { content: delta });
761 }
762 Ok(bob_core::types::LlmStreamChunk::Done { usage }) => {
763 llm_usage = usage;
764 }
765 Err(err) => {
766 events.emit(AgentEvent::Error { error: err.to_string() });
767 return Err(AgentError::Llm(err));
768 }
769 }
770 }
771 }
772 Err(err) => {
773 fallback_to_complete = true;
774 events.emit(AgentEvent::Error { error: err.to_string() });
775 }
776 }
777
778 if fallback_to_complete {
780 let llm_response = llm.complete(llm_request).await?;
781 assistant_content = llm_response.content.clone();
782 llm_usage = llm_response.usage;
783 llm_tool_calls = llm_response.tool_calls;
784 let _ = tx.send(AgentStreamEvent::TextDelta { content: llm_response.content });
785 }
786
787 guard.record_step();
788 total_usage.prompt_tokens += llm_usage.prompt_tokens;
789 total_usage.completion_tokens += llm_usage.completion_tokens;
790 config.cost_meter.record_llm_usage(&req.session_id, model, &llm_usage).await?;
791 events.emit(AgentEvent::LlmCallCompleted { usage: llm_usage.clone() });
792 session
793 .messages
794 .push(Message { role: Role::Assistant, content: assistant_content.clone() });
795
796 let _ = config
797 .checkpoint_store
798 .save_checkpoint(&TurnCheckpoint {
799 session_id: req.session_id.clone(),
800 step: guard.steps,
801 tool_calls: guard.tool_calls,
802 usage: total_usage.clone(),
803 })
804 .await;
805
806 let response_for_dispatch = bob_core::types::LlmResponse {
807 content: assistant_content.clone(),
808 usage: llm_usage.clone(),
809 finish_reason: FinishReason::Stop,
810 tool_calls: llm_tool_calls,
811 };
812
813 if let Ok(action) =
814 parse_action_for_mode(config.dispatch_mode, llm.as_ref(), &response_for_dispatch)
815 {
816 consecutive_parse_failures = 0;
817 match action {
818 AgentAction::Final { .. } | AgentAction::AskUser { .. } => {
819 store.save(&req.session_id, &session).await?;
820 events.emit(AgentEvent::TurnCompleted { finish_reason: FinishReason::Stop });
821 let _ = tx.send(AgentStreamEvent::Finished { usage: total_usage.clone() });
822 return Ok(());
823 }
824 AgentAction::ToolCall { name, arguments } => {
825 events.emit(AgentEvent::ToolCallStarted { name: name.clone() });
826 next_call_id += 1;
827 let call_id = format!("call-{next_call_id}");
828 let _ = tx.send(AgentStreamEvent::ToolCallStarted {
829 name: name.clone(),
830 call_id: call_id.clone(),
831 });
832 let approval_context = ApprovalContext {
833 session_id: req.session_id.clone(),
834 turn_step: guard.steps.max(1),
835 selected_skills: selected_skills.clone(),
836 };
837
838 let tool_result = execute_tool_call(
839 tools.as_ref(),
840 &mut guard,
841 ToolCall { name: name.clone(), arguments },
842 &tool_call_policy,
843 config.tool_policy.as_ref(),
844 config.approval.as_ref(),
845 &approval_context,
846 config.policy.tool_timeout_ms,
847 )
848 .await;
849
850 guard.record_tool_call();
851 let _ =
852 config.cost_meter.record_tool_result(&req.session_id, &tool_result).await;
853 let is_error = tool_result.is_error;
854 events.emit(AgentEvent::ToolCallCompleted { name: name.clone(), is_error });
855 let _ = tx.send(AgentStreamEvent::ToolCallCompleted {
856 call_id,
857 result: tool_result.clone(),
858 });
859
860 let output_str = serde_json::to_string(&tool_result.output).unwrap_or_default();
861 session.messages.push(Message { role: Role::Tool, content: output_str });
862 let _ = config
863 .artifact_store
864 .put(ArtifactRecord {
865 session_id: req.session_id.clone(),
866 kind: "tool_result".to_string(),
867 name: name.clone(),
868 content: tool_result.output.clone(),
869 })
870 .await;
871 }
872 }
873 } else {
874 consecutive_parse_failures += 1;
875 if consecutive_parse_failures >= 2 {
876 store.save(&req.session_id, &session).await?;
877 events.emit(AgentEvent::Error {
878 error: "LLM produced invalid JSON after re-prompt".to_string(),
879 });
880 return Err(AgentError::Internal(
881 "LLM produced invalid JSON after re-prompt".into(),
882 ));
883 }
884 session.messages.push(Message {
885 role: Role::User,
886 content: "Your response was not valid JSON. \
887 Please respond with exactly one JSON object \
888 matching the required schema."
889 .into(),
890 });
891 }
892 }
893}
894
895#[cfg(test)]
896mod tests {
897 use super::*;
898
899 fn test_policy() -> TurnPolicy {
901 TurnPolicy {
902 max_steps: 3,
903 max_tool_calls: 2,
904 max_consecutive_errors: 2,
905 turn_timeout_ms: 100,
906 tool_timeout_ms: 50,
907 }
908 }
909
910 #[test]
911 fn trips_on_max_steps() {
912 let mut guard = LoopGuard::new(test_policy());
913 assert!(guard.can_continue());
914
915 for _ in 0..3 {
916 guard.record_step();
917 }
918
919 assert!(!guard.can_continue(), "guard should trip after reaching max_steps");
920 assert_eq!(guard.reason(), GuardReason::MaxSteps);
921 }
922
923 #[test]
924 fn trips_on_max_tool_calls() {
925 let mut guard = LoopGuard::new(test_policy());
926 assert!(guard.can_continue());
927
928 for _ in 0..2 {
929 guard.record_tool_call();
930 }
931
932 assert!(!guard.can_continue(), "guard should trip after reaching max_tool_calls");
933 assert_eq!(guard.reason(), GuardReason::MaxToolCalls);
934 }
935
936 #[test]
937 fn trips_on_max_consecutive_errors() {
938 let mut guard = LoopGuard::new(test_policy());
939 assert!(guard.can_continue());
940
941 for _ in 0..2 {
942 guard.record_error();
943 }
944
945 assert!(!guard.can_continue(), "guard should trip after reaching max_consecutive_errors");
946 assert_eq!(guard.reason(), GuardReason::MaxConsecutiveErrors);
947 }
948
949 #[tokio::test]
950 async fn trips_on_timeout() {
951 let guard = LoopGuard::new(test_policy());
952 assert!(guard.can_continue());
953 assert!(!guard.timed_out());
954
955 tokio::time::sleep(std::time::Duration::from_millis(150)).await;
957
958 assert!(!guard.can_continue(), "guard should trip after timeout");
959 assert!(guard.timed_out());
960 assert_eq!(guard.reason(), GuardReason::TurnTimeout);
961 }
962
963 #[test]
964 fn reset_errors_clears_counter() {
965 let mut guard = LoopGuard::new(test_policy());
966
967 guard.record_error();
968 guard.reset_errors();
969
970 guard.record_error();
972 assert!(guard.can_continue(), "single error after reset should not trip guard");
973 }
974
975 use std::{
978 collections::{HashMap, VecDeque},
979 sync::{Arc, Mutex},
980 };
981
982 use bob_core::{
983 error::{CostError, LlmError, StoreError, ToolError},
984 ports::{
985 ApprovalPort, ArtifactStorePort, CostMeterPort, EventSink, LlmPort, SessionStore,
986 ToolPolicyPort, ToolPort, TurnCheckpointStorePort,
987 },
988 types::{
989 AgentEvent, AgentRequest, AgentRunResult, AgentStreamEvent, ApprovalContext,
990 ApprovalDecision, ArtifactRecord, CancelToken, LlmRequest, LlmResponse, LlmStream,
991 LlmStreamChunk, SessionId, SessionState, ToolCall, ToolDescriptor, ToolResult,
992 ToolSource, TurnCheckpoint,
993 },
994 };
995 use futures_util::StreamExt;
996
997 struct SequentialLlm {
1001 responses: Mutex<VecDeque<Result<LlmResponse, LlmError>>>,
1002 }
1003
1004 impl SequentialLlm {
1005 fn from_contents(contents: Vec<&str>) -> Self {
1006 let responses = contents
1007 .into_iter()
1008 .map(|c| {
1009 Ok(LlmResponse {
1010 content: c.to_string(),
1011 usage: TokenUsage::default(),
1012 finish_reason: FinishReason::Stop,
1013 tool_calls: Vec::new(),
1014 })
1015 })
1016 .collect();
1017 Self { responses: Mutex::new(responses) }
1018 }
1019 }
1020
1021 #[async_trait::async_trait]
1022 impl LlmPort for SequentialLlm {
1023 async fn complete(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
1024 let mut q = self.responses.lock().unwrap_or_else(|p| p.into_inner());
1025 q.pop_front().unwrap_or_else(|| {
1026 Ok(LlmResponse {
1027 content: r#"{"type": "final", "content": "fallback"}"#.to_string(),
1028 usage: TokenUsage::default(),
1029 finish_reason: FinishReason::Stop,
1030 tool_calls: Vec::new(),
1031 })
1032 })
1033 }
1034
1035 async fn complete_stream(&self, _req: LlmRequest) -> Result<LlmStream, LlmError> {
1036 Err(LlmError::Provider("not implemented".into()))
1037 }
1038 }
1039
1040 struct MockToolPort {
1042 tools: Vec<ToolDescriptor>,
1043 call_results: Mutex<VecDeque<Result<ToolResult, ToolError>>>,
1044 }
1045
1046 impl MockToolPort {
1047 fn empty() -> Self {
1048 Self { tools: vec![], call_results: Mutex::new(VecDeque::new()) }
1049 }
1050
1051 fn with_tool_and_results(
1052 tool_name: &str,
1053 results: Vec<Result<ToolResult, ToolError>>,
1054 ) -> Self {
1055 Self {
1056 tools: vec![ToolDescriptor {
1057 id: tool_name.to_string(),
1058 description: format!("{tool_name} tool"),
1059 input_schema: serde_json::json!({"type": "object"}),
1060 source: ToolSource::Local,
1061 }],
1062 call_results: Mutex::new(results.into()),
1063 }
1064 }
1065 }
1066
1067 #[async_trait::async_trait]
1068 impl ToolPort for MockToolPort {
1069 async fn list_tools(&self) -> Result<Vec<ToolDescriptor>, ToolError> {
1070 Ok(self.tools.clone())
1071 }
1072
1073 async fn call_tool(&self, call: ToolCall) -> Result<ToolResult, ToolError> {
1074 let mut q = self.call_results.lock().unwrap_or_else(|p| p.into_inner());
1075 q.pop_front().unwrap_or_else(|| {
1076 Ok(ToolResult {
1077 name: call.name,
1078 output: serde_json::json!({"result": "default"}),
1079 is_error: false,
1080 })
1081 })
1082 }
1083 }
1084
1085 struct NoCallToolPort {
1086 tools: Vec<ToolDescriptor>,
1087 }
1088
1089 #[async_trait::async_trait]
1090 impl ToolPort for NoCallToolPort {
1091 async fn list_tools(&self) -> Result<Vec<ToolDescriptor>, ToolError> {
1092 Ok(self.tools.clone())
1093 }
1094
1095 async fn call_tool(&self, _call: ToolCall) -> Result<ToolResult, ToolError> {
1096 Err(ToolError::Execution(
1097 "tool call should be blocked by policy before execution".to_string(),
1098 ))
1099 }
1100 }
1101
1102 struct AllowAllPolicyPort;
1103
1104 impl ToolPolicyPort for AllowAllPolicyPort {
1105 fn is_tool_allowed(
1106 &self,
1107 _tool: &str,
1108 _deny_tools: &[String],
1109 _allow_tools: Option<&[String]>,
1110 ) -> bool {
1111 true
1112 }
1113 }
1114
1115 struct DenySearchPolicyPort;
1116
1117 impl ToolPolicyPort for DenySearchPolicyPort {
1118 fn is_tool_allowed(
1119 &self,
1120 tool: &str,
1121 _deny_tools: &[String],
1122 _allow_tools: Option<&[String]>,
1123 ) -> bool {
1124 tool != "search"
1125 }
1126 }
1127
1128 struct AlwaysApprovePort;
1129
1130 #[async_trait::async_trait]
1131 impl ApprovalPort for AlwaysApprovePort {
1132 async fn approve_tool_call(
1133 &self,
1134 _call: &ToolCall,
1135 _context: &ApprovalContext,
1136 ) -> Result<ApprovalDecision, ToolError> {
1137 Ok(ApprovalDecision::Approved)
1138 }
1139 }
1140
1141 struct AlwaysDenyApprovalPort;
1142
1143 #[async_trait::async_trait]
1144 impl ApprovalPort for AlwaysDenyApprovalPort {
1145 async fn approve_tool_call(
1146 &self,
1147 _call: &ToolCall,
1148 _context: &ApprovalContext,
1149 ) -> Result<ApprovalDecision, ToolError> {
1150 Ok(ApprovalDecision::Denied {
1151 reason: "approval policy rejected tool call".to_string(),
1152 })
1153 }
1154 }
1155
1156 struct CountingCheckpointPort {
1157 saved: Mutex<Vec<TurnCheckpoint>>,
1158 }
1159
1160 impl CountingCheckpointPort {
1161 fn new() -> Self {
1162 Self { saved: Mutex::new(Vec::new()) }
1163 }
1164 }
1165
1166 #[async_trait::async_trait]
1167 impl TurnCheckpointStorePort for CountingCheckpointPort {
1168 async fn save_checkpoint(&self, checkpoint: &TurnCheckpoint) -> Result<(), StoreError> {
1169 self.saved.lock().unwrap_or_else(|p| p.into_inner()).push(checkpoint.clone());
1170 Ok(())
1171 }
1172
1173 async fn load_latest(
1174 &self,
1175 _session_id: &SessionId,
1176 ) -> Result<Option<TurnCheckpoint>, StoreError> {
1177 Ok(None)
1178 }
1179 }
1180
1181 struct NoopArtifactStore;
1182
1183 #[async_trait::async_trait]
1184 impl ArtifactStorePort for NoopArtifactStore {
1185 async fn put(&self, _artifact: ArtifactRecord) -> Result<(), StoreError> {
1186 Ok(())
1187 }
1188
1189 async fn list_by_session(
1190 &self,
1191 _session_id: &SessionId,
1192 ) -> Result<Vec<ArtifactRecord>, StoreError> {
1193 Ok(Vec::new())
1194 }
1195 }
1196
1197 struct CountingCostMeter {
1198 llm_calls: Mutex<u32>,
1199 }
1200
1201 impl CountingCostMeter {
1202 fn new() -> Self {
1203 Self { llm_calls: Mutex::new(0) }
1204 }
1205 }
1206
1207 #[async_trait::async_trait]
1208 impl CostMeterPort for CountingCostMeter {
1209 async fn check_budget(&self, _session_id: &SessionId) -> Result<(), CostError> {
1210 Ok(())
1211 }
1212
1213 async fn record_llm_usage(
1214 &self,
1215 _session_id: &SessionId,
1216 _model: &str,
1217 _usage: &TokenUsage,
1218 ) -> Result<(), CostError> {
1219 let mut count = self.llm_calls.lock().unwrap_or_else(|p| p.into_inner());
1220 *count += 1;
1221 Ok(())
1222 }
1223
1224 async fn record_tool_result(
1225 &self,
1226 _session_id: &SessionId,
1227 _tool_result: &ToolResult,
1228 ) -> Result<(), CostError> {
1229 Ok(())
1230 }
1231 }
1232
1233 struct MemoryStore {
1234 data: Mutex<HashMap<SessionId, SessionState>>,
1235 }
1236
1237 impl MemoryStore {
1238 fn new() -> Self {
1239 Self { data: Mutex::new(HashMap::new()) }
1240 }
1241 }
1242
1243 struct FailingSaveStore;
1244
1245 #[async_trait::async_trait]
1246 impl SessionStore for FailingSaveStore {
1247 async fn load(&self, _id: &SessionId) -> Result<Option<SessionState>, StoreError> {
1248 Ok(None)
1249 }
1250
1251 async fn save(&self, _id: &SessionId, _state: &SessionState) -> Result<(), StoreError> {
1252 Err(StoreError::Backend("simulated save failure".into()))
1253 }
1254 }
1255
1256 #[async_trait::async_trait]
1257 impl SessionStore for MemoryStore {
1258 async fn load(&self, id: &SessionId) -> Result<Option<SessionState>, StoreError> {
1259 let map = self.data.lock().unwrap_or_else(|p| p.into_inner());
1260 Ok(map.get(id).cloned())
1261 }
1262
1263 async fn save(&self, id: &SessionId, state: &SessionState) -> Result<(), StoreError> {
1264 let mut map = self.data.lock().unwrap_or_else(|p| p.into_inner());
1265 map.insert(id.clone(), state.clone());
1266 Ok(())
1267 }
1268 }
1269
1270 struct CollectingSink {
1271 events: Mutex<Vec<AgentEvent>>,
1272 }
1273
1274 impl CollectingSink {
1275 fn new() -> Self {
1276 Self { events: Mutex::new(Vec::new()) }
1277 }
1278
1279 fn event_count(&self) -> usize {
1280 self.events.lock().unwrap_or_else(|p| p.into_inner()).len()
1281 }
1282
1283 fn all_events(&self) -> Vec<AgentEvent> {
1284 self.events.lock().unwrap_or_else(|p| p.into_inner()).clone()
1285 }
1286 }
1287
1288 impl EventSink for CollectingSink {
1289 fn emit(&self, event: AgentEvent) {
1290 self.events.lock().unwrap_or_else(|p| p.into_inner()).push(event);
1291 }
1292 }
1293
1294 fn make_request(input: &str) -> AgentRequest {
1295 AgentRequest {
1296 input: input.into(),
1297 session_id: "test-session".into(),
1298 model: None,
1299 context: bob_core::types::RequestContext::default(),
1300 cancel_token: None,
1301 }
1302 }
1303
1304 fn generous_policy() -> TurnPolicy {
1305 TurnPolicy {
1306 max_steps: 20,
1307 max_tool_calls: 10,
1308 max_consecutive_errors: 3,
1309 turn_timeout_ms: 30_000,
1310 tool_timeout_ms: 5_000,
1311 }
1312 }
1313
1314 struct StreamLlm {
1315 chunks: Mutex<VecDeque<Result<LlmStreamChunk, LlmError>>>,
1316 }
1317
1318 impl StreamLlm {
1319 fn new(chunks: Vec<Result<LlmStreamChunk, LlmError>>) -> Self {
1320 Self { chunks: Mutex::new(chunks.into()) }
1321 }
1322 }
1323
1324 #[async_trait::async_trait]
1325 impl LlmPort for StreamLlm {
1326 async fn complete(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
1327 Err(LlmError::Provider("complete() should not be called in stream test".into()))
1328 }
1329
1330 async fn complete_stream(&self, _req: LlmRequest) -> Result<LlmStream, LlmError> {
1331 let mut chunks = self.chunks.lock().unwrap_or_else(|p| p.into_inner());
1332 let items: Vec<Result<LlmStreamChunk, LlmError>> = chunks.drain(..).collect();
1333 Ok(Box::pin(futures_util::stream::iter(items)))
1334 }
1335 }
1336
1337 struct InspectingLlm {
1338 expected_substring: String,
1339 }
1340
1341 #[async_trait::async_trait]
1342 impl LlmPort for InspectingLlm {
1343 async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
1344 let system = req
1345 .messages
1346 .iter()
1347 .find(|m| m.role == Role::System)
1348 .map(|m| m.content.clone())
1349 .unwrap_or_default();
1350 if !system.contains(&self.expected_substring) {
1351 return Err(LlmError::Provider(format!(
1352 "expected system prompt to include '{}', got: {}",
1353 self.expected_substring, system
1354 )));
1355 }
1356 Ok(LlmResponse {
1357 content: r#"{"type": "final", "content": "ok"}"#.to_string(),
1358 usage: TokenUsage::default(),
1359 finish_reason: FinishReason::Stop,
1360 tool_calls: Vec::new(),
1361 })
1362 }
1363
1364 async fn complete_stream(&self, _req: LlmRequest) -> Result<LlmStream, LlmError> {
1365 Err(LlmError::Provider("not used".into()))
1366 }
1367 }
1368
1369 #[tokio::test]
1372 async fn tc01_simple_final_response() {
1373 let llm =
1374 SequentialLlm::from_contents(vec![r#"{"type": "final", "content": "Hello there!"}"#]);
1375 let tools = MockToolPort::empty();
1376 let store = MemoryStore::new();
1377 let sink = CollectingSink::new();
1378
1379 let result = run_turn(
1380 &llm,
1381 &tools,
1382 &store,
1383 &sink,
1384 make_request("Hi"),
1385 &generous_policy(),
1386 "test-model",
1387 )
1388 .await;
1389
1390 assert!(
1391 matches!(&result, Ok(AgentRunResult::Finished(_))),
1392 "expected Finished, got {result:?}"
1393 );
1394 let resp = match result {
1395 Ok(AgentRunResult::Finished(r)) => r,
1396 _ => return,
1397 };
1398
1399 assert_eq!(resp.content, "Hello there!");
1400 assert_eq!(resp.finish_reason, FinishReason::Stop);
1401 assert!(resp.tool_transcript.is_empty());
1402 assert!(sink.event_count() >= 3, "should emit TurnStarted, LlmCall*, TurnCompleted");
1403 }
1404
1405 #[tokio::test]
1408 async fn tc02_tool_call_then_final() {
1409 let llm = SequentialLlm::from_contents(vec![
1410 r#"{"type": "tool_call", "name": "search", "arguments": {"q": "rust"}}"#,
1411 r#"{"type": "final", "content": "Found results."}"#,
1412 ]);
1413 let tools = MockToolPort::with_tool_and_results(
1414 "search",
1415 vec![Ok(ToolResult {
1416 name: "search".into(),
1417 output: serde_json::json!({"hits": 42}),
1418 is_error: false,
1419 })],
1420 );
1421 let store = MemoryStore::new();
1422 let sink = CollectingSink::new();
1423
1424 let result = run_turn(
1425 &llm,
1426 &tools,
1427 &store,
1428 &sink,
1429 make_request("Search for rust"),
1430 &generous_policy(),
1431 "test-model",
1432 )
1433 .await;
1434
1435 assert!(
1436 matches!(&result, Ok(AgentRunResult::Finished(_))),
1437 "expected Finished, got {result:?}"
1438 );
1439 let resp = match result {
1440 Ok(AgentRunResult::Finished(r)) => r,
1441 _ => return,
1442 };
1443
1444 assert_eq!(resp.content, "Found results.");
1445 assert_eq!(resp.finish_reason, FinishReason::Stop);
1446 assert_eq!(resp.tool_transcript.len(), 1);
1447 assert_eq!(resp.tool_transcript[0].name, "search");
1448 assert!(!resp.tool_transcript[0].is_error);
1449 }
1450
1451 #[tokio::test]
1454 async fn tc03_parse_error_reprompt_success() {
1455 let llm = SequentialLlm::from_contents(vec![
1456 "This is not JSON at all.",
1457 r#"{"type": "final", "content": "Recovered"}"#,
1458 ]);
1459 let tools = MockToolPort::empty();
1460 let store = MemoryStore::new();
1461 let sink = CollectingSink::new();
1462
1463 let result = run_turn(
1464 &llm,
1465 &tools,
1466 &store,
1467 &sink,
1468 make_request("Hi"),
1469 &generous_policy(),
1470 "test-model",
1471 )
1472 .await;
1473
1474 assert!(
1475 matches!(&result, Ok(AgentRunResult::Finished(_))),
1476 "expected Finished after re-prompt, got {result:?}"
1477 );
1478 let resp = match result {
1479 Ok(AgentRunResult::Finished(r)) => r,
1480 _ => return,
1481 };
1482
1483 assert_eq!(resp.content, "Recovered");
1484 assert_eq!(resp.finish_reason, FinishReason::Stop);
1485 }
1486
1487 #[tokio::test]
1490 async fn tc04_double_parse_error() {
1491 let llm = SequentialLlm::from_contents(vec!["not json 1", "not json 2"]);
1492 let tools = MockToolPort::empty();
1493 let store = MemoryStore::new();
1494 let sink = CollectingSink::new();
1495
1496 let result = run_turn(
1497 &llm,
1498 &tools,
1499 &store,
1500 &sink,
1501 make_request("Hi"),
1502 &generous_policy(),
1503 "test-model",
1504 )
1505 .await;
1506
1507 assert!(result.is_err(), "should return error after two parse failures");
1508 let msg = match result {
1509 Err(err) => err.to_string(),
1510 Ok(value) => format!("unexpected success: {value:?}"),
1511 };
1512 assert!(msg.contains("invalid JSON"), "error message = {msg}");
1513 }
1514
1515 #[tokio::test]
1518 async fn tc05_max_steps_exhaustion() {
1519 let llm = SequentialLlm::from_contents(vec![
1521 r#"{"type": "tool_call", "name": "t1", "arguments": {}}"#,
1522 r#"{"type": "tool_call", "name": "t1", "arguments": {}}"#,
1523 r#"{"type": "tool_call", "name": "t1", "arguments": {}}"#,
1524 r#"{"type": "tool_call", "name": "t1", "arguments": {}}"#,
1525 ]);
1526 let tools = MockToolPort::with_tool_and_results(
1527 "t1",
1528 vec![
1529 Ok(ToolResult {
1530 name: "t1".into(),
1531 output: serde_json::json!(null),
1532 is_error: false,
1533 }),
1534 Ok(ToolResult {
1535 name: "t1".into(),
1536 output: serde_json::json!(null),
1537 is_error: false,
1538 }),
1539 Ok(ToolResult {
1540 name: "t1".into(),
1541 output: serde_json::json!(null),
1542 is_error: false,
1543 }),
1544 ],
1545 );
1546 let store = MemoryStore::new();
1547 let sink = CollectingSink::new();
1548
1549 let policy = TurnPolicy {
1550 max_steps: 2,
1551 max_tool_calls: 10,
1552 max_consecutive_errors: 5,
1553 turn_timeout_ms: 30_000,
1554 tool_timeout_ms: 5_000,
1555 };
1556
1557 let result =
1558 run_turn(&llm, &tools, &store, &sink, make_request("do work"), &policy, "test-model")
1559 .await;
1560
1561 assert!(
1562 matches!(&result, Ok(AgentRunResult::Finished(_))),
1563 "expected Finished with GuardExceeded, got {result:?}"
1564 );
1565 let resp = match result {
1566 Ok(AgentRunResult::Finished(r)) => r,
1567 _ => return,
1568 };
1569
1570 assert_eq!(resp.finish_reason, FinishReason::GuardExceeded);
1571 assert!(resp.content.contains("MaxSteps"), "content = {}", resp.content);
1572 }
1573
1574 #[tokio::test]
1577 async fn tc06_cancellation() {
1578 let llm = SequentialLlm::from_contents(vec![
1579 r#"{"type": "final", "content": "should not reach"}"#,
1580 ]);
1581 let tools = MockToolPort::empty();
1582 let store = MemoryStore::new();
1583 let sink = CollectingSink::new();
1584
1585 let token = CancelToken::new();
1586 token.cancel();
1588
1589 let mut req = make_request("Hi");
1590 req.cancel_token = Some(token);
1591
1592 let result =
1593 run_turn(&llm, &tools, &store, &sink, req, &generous_policy(), "test-model").await;
1594
1595 assert!(
1596 matches!(&result, Ok(AgentRunResult::Finished(_))),
1597 "expected Finished with Cancelled, got {result:?}"
1598 );
1599 let resp = match result {
1600 Ok(AgentRunResult::Finished(r)) => r,
1601 _ => return,
1602 };
1603
1604 assert_eq!(resp.finish_reason, FinishReason::Cancelled);
1605 }
1606
1607 #[tokio::test]
1610 async fn tc07_tool_error_then_final() {
1611 let llm = SequentialLlm::from_contents(vec![
1612 r#"{"type": "tool_call", "name": "flaky_tool", "arguments": {}}"#,
1613 r#"{"type": "final", "content": "Recovered from tool error."}"#,
1614 ]);
1615 let tools = MockToolPort::with_tool_and_results(
1616 "flaky_tool",
1617 vec![Err(ToolError::Execution("connection refused".into()))],
1618 );
1619 let store = MemoryStore::new();
1620 let sink = CollectingSink::new();
1621
1622 let result = run_turn(
1623 &llm,
1624 &tools,
1625 &store,
1626 &sink,
1627 make_request("call flaky"),
1628 &generous_policy(),
1629 "test-model",
1630 )
1631 .await;
1632
1633 assert!(
1634 matches!(&result, Ok(AgentRunResult::Finished(_))),
1635 "expected Finished, got {result:?}"
1636 );
1637 let resp = match result {
1638 Ok(AgentRunResult::Finished(r)) => r,
1639 _ => return,
1640 };
1641
1642 assert_eq!(resp.content, "Recovered from tool error.");
1643 assert_eq!(resp.tool_transcript.len(), 1);
1644 assert!(resp.tool_transcript[0].is_error);
1645 }
1646
1647 #[tokio::test]
1648 async fn tc08_save_failure_is_propagated() {
1649 let llm = SequentialLlm::from_contents(vec![r#"{"type": "final", "content": "done"}"#]);
1650 let tools = MockToolPort::empty();
1651 let store = FailingSaveStore;
1652 let sink = CollectingSink::new();
1653
1654 let result = run_turn(
1655 &llm,
1656 &tools,
1657 &store,
1658 &sink,
1659 make_request("hello"),
1660 &generous_policy(),
1661 "test-model",
1662 )
1663 .await;
1664
1665 assert!(matches!(result, Err(AgentError::Store(_))), "expected Store error to be returned");
1666 }
1667
1668 #[tokio::test]
1669 async fn tc09_stream_turn_emits_text_and_finished() {
1670 let llm: Arc<dyn LlmPort> = Arc::new(StreamLlm::new(vec![
1671 Ok(LlmStreamChunk::TextDelta("{\"type\":\"final\",\"content\":\"he".into())),
1672 Ok(LlmStreamChunk::TextDelta("llo\"}".into())),
1673 Ok(LlmStreamChunk::Done {
1674 usage: TokenUsage { prompt_tokens: 3, completion_tokens: 4 },
1675 }),
1676 ]));
1677 let tools: Arc<dyn ToolPort> = Arc::new(MockToolPort::empty());
1678 let store: Arc<dyn SessionStore> = Arc::new(MemoryStore::new());
1679 let sink: Arc<dyn EventSink> = Arc::new(CollectingSink::new());
1680
1681 let stream_result = run_turn_stream(
1682 llm,
1683 tools,
1684 store,
1685 sink,
1686 make_request("hello"),
1687 generous_policy(),
1688 "test-model".to_string(),
1689 )
1690 .await;
1691 assert!(stream_result.is_ok(), "run_turn_stream should produce a stream");
1692 let mut stream = match stream_result {
1693 Ok(stream) => stream,
1694 Err(_) => return,
1695 };
1696
1697 let mut saw_text = false;
1698 let mut saw_finished = false;
1699 while let Some(event) = stream.next().await {
1700 match event {
1701 AgentStreamEvent::TextDelta { content } => {
1702 saw_text = saw_text || !content.is_empty();
1703 }
1704 AgentStreamEvent::Finished { usage } => {
1705 saw_finished = true;
1706 assert_eq!(usage.prompt_tokens, 3);
1707 assert_eq!(usage.completion_tokens, 4);
1708 }
1709 AgentStreamEvent::ToolCallStarted { .. }
1710 | AgentStreamEvent::ToolCallCompleted { .. }
1711 | AgentStreamEvent::Error { .. } => {}
1712 }
1713 }
1714
1715 assert!(saw_text, "expected at least one text delta");
1716 assert!(saw_finished, "expected a finished event");
1717 }
1718
1719 #[tokio::test]
1720 async fn tc10_skills_prompt_context_is_injected() {
1721 let llm = InspectingLlm { expected_substring: "Skill: rust-review".to_string() };
1722 let tools = MockToolPort::empty();
1723 let store = MemoryStore::new();
1724 let sink = CollectingSink::new();
1725
1726 let mut req = make_request("review this code");
1727 req.context.system_prompt = Some("Skill: rust-review\nUse strict checks.".to_string());
1728
1729 let result =
1730 run_turn(&llm, &tools, &store, &sink, req, &generous_policy(), "test-model").await;
1731
1732 assert!(result.is_ok(), "run should succeed when skills prompt is injected");
1733 }
1734
1735 #[tokio::test]
1736 async fn tc11_selected_skills_context_emits_event() {
1737 let llm =
1738 SequentialLlm::from_contents(vec![r#"{"type": "final", "content": "looks good"}"#]);
1739 let tools = MockToolPort::empty();
1740 let store = MemoryStore::new();
1741 let sink = CollectingSink::new();
1742
1743 let mut req = make_request("review code");
1744 req.context.selected_skills = vec!["rust-review".to_string(), "security-audit".to_string()];
1745
1746 let result =
1747 run_turn(&llm, &tools, &store, &sink, req, &generous_policy(), "test-model").await;
1748 assert!(result.is_ok(), "run should succeed");
1749
1750 let events = sink.all_events();
1751 assert!(
1752 events.iter().any(|event| matches!(
1753 event,
1754 AgentEvent::SkillsSelected { skill_names }
1755 if skill_names == &vec!["rust-review".to_string(), "security-audit".to_string()]
1756 )),
1757 "skills.selected event should be emitted with context skill names"
1758 );
1759 }
1760
1761 #[tokio::test]
1762 async fn tc12_policy_deny_tool_blocks_execution() {
1763 let llm = SequentialLlm::from_contents(vec![
1764 r#"{"type": "tool_call", "name": "search", "arguments": {"q": "rust"}}"#,
1765 r#"{"type": "final", "content": "done"}"#,
1766 ]);
1767 let tools = NoCallToolPort {
1768 tools: vec![ToolDescriptor {
1769 id: "search".to_string(),
1770 description: "search tool".to_string(),
1771 input_schema: serde_json::json!({"type":"object"}),
1772 source: ToolSource::Local,
1773 }],
1774 };
1775 let store = MemoryStore::new();
1776 let sink = CollectingSink::new();
1777
1778 let mut req = make_request("search rust");
1779 req.context.tool_policy.deny_tools =
1780 vec!["search".to_string(), "local/shell_exec".to_string()];
1781
1782 let result =
1783 run_turn(&llm, &tools, &store, &sink, req, &generous_policy(), "test-model").await;
1784 assert!(
1785 matches!(&result, Ok(AgentRunResult::Finished(_))),
1786 "expected finished response, got {result:?}"
1787 );
1788 let resp = match result {
1789 Ok(AgentRunResult::Finished(r)) => r,
1790 _ => return,
1791 };
1792
1793 assert_eq!(resp.finish_reason, FinishReason::Stop);
1794 assert_eq!(resp.tool_transcript.len(), 1);
1795 assert!(resp.tool_transcript[0].is_error);
1796 assert!(
1797 resp.tool_transcript[0].output.to_string().contains("denied"),
1798 "tool error should explain policy denial"
1799 );
1800 }
1801
1802 #[tokio::test]
1803 async fn tc13_approval_denied_blocks_execution() {
1804 let llm = SequentialLlm::from_contents(vec![
1805 r#"{"type": "tool_call", "name": "search", "arguments": {"q": "rust"}}"#,
1806 r#"{"type": "final", "content": "done"}"#,
1807 ]);
1808 let tools = NoCallToolPort {
1809 tools: vec![ToolDescriptor {
1810 id: "search".to_string(),
1811 description: "search tool".to_string(),
1812 input_schema: serde_json::json!({"type":"object"}),
1813 source: ToolSource::Local,
1814 }],
1815 };
1816 let store = MemoryStore::new();
1817 let sink = CollectingSink::new();
1818 let req = make_request("search rust");
1819 let tool_policy = AllowAllPolicyPort;
1820 let approval = AlwaysDenyApprovalPort;
1821
1822 let result = run_turn_with_controls(
1823 &llm,
1824 &tools,
1825 &store,
1826 &sink,
1827 req,
1828 &generous_policy(),
1829 "test-model",
1830 &tool_policy,
1831 &approval,
1832 )
1833 .await;
1834 assert!(matches!(&result, Ok(AgentRunResult::Finished(_))), "unexpected result {result:?}");
1835 let resp = match result {
1836 Ok(AgentRunResult::Finished(r)) => r,
1837 _ => return,
1838 };
1839
1840 assert_eq!(resp.tool_transcript.len(), 1);
1841 assert!(resp.tool_transcript[0].is_error);
1842 assert!(
1843 resp.tool_transcript[0].output.to_string().contains("approval policy rejected"),
1844 "tool error should explain approval denial"
1845 );
1846 }
1847
1848 #[tokio::test]
1849 async fn tc14_custom_policy_port_blocks_execution() {
1850 let llm = SequentialLlm::from_contents(vec![
1851 r#"{"type": "tool_call", "name": "search", "arguments": {"q": "rust"}}"#,
1852 r#"{"type": "final", "content": "done"}"#,
1853 ]);
1854 let tools = NoCallToolPort {
1855 tools: vec![ToolDescriptor {
1856 id: "search".to_string(),
1857 description: "search tool".to_string(),
1858 input_schema: serde_json::json!({"type":"object"}),
1859 source: ToolSource::Local,
1860 }],
1861 };
1862 let store = MemoryStore::new();
1863 let sink = CollectingSink::new();
1864 let req = make_request("search rust");
1865 let tool_policy = DenySearchPolicyPort;
1866 let approval = AlwaysApprovePort;
1867
1868 let result = run_turn_with_controls(
1869 &llm,
1870 &tools,
1871 &store,
1872 &sink,
1873 req,
1874 &generous_policy(),
1875 "test-model",
1876 &tool_policy,
1877 &approval,
1878 )
1879 .await;
1880 assert!(matches!(&result, Ok(AgentRunResult::Finished(_))), "unexpected result {result:?}");
1881 let resp = match result {
1882 Ok(AgentRunResult::Finished(r)) => r,
1883 _ => return,
1884 };
1885
1886 assert_eq!(resp.tool_transcript.len(), 1);
1887 assert!(resp.tool_transcript[0].is_error);
1888 assert!(
1889 resp.tool_transcript[0].output.to_string().contains("denied"),
1890 "tool error should explain policy denial"
1891 );
1892 }
1893
1894 #[tokio::test]
1895 async fn tc15_native_dispatch_mode_uses_llm_tool_calls() {
1896 struct NativeToolLlm {
1897 responses: Mutex<VecDeque<LlmResponse>>,
1898 }
1899
1900 #[async_trait::async_trait]
1901 impl LlmPort for NativeToolLlm {
1902 fn capabilities(&self) -> bob_core::types::LlmCapabilities {
1903 bob_core::types::LlmCapabilities { native_tool_calling: true, streaming: true }
1904 }
1905
1906 async fn complete(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
1907 let mut q = self.responses.lock().unwrap_or_else(|p| p.into_inner());
1908 Ok(q.pop_front().unwrap_or(LlmResponse {
1909 content: r#"{"type":"final","content":"fallback"}"#.to_string(),
1910 usage: TokenUsage::default(),
1911 finish_reason: FinishReason::Stop,
1912 tool_calls: Vec::new(),
1913 }))
1914 }
1915
1916 async fn complete_stream(&self, _req: LlmRequest) -> Result<LlmStream, LlmError> {
1917 Err(LlmError::Provider("not used".into()))
1918 }
1919 }
1920
1921 let llm = NativeToolLlm {
1922 responses: Mutex::new(VecDeque::from(vec![
1923 LlmResponse {
1924 content: "ignored".to_string(),
1925 usage: TokenUsage::default(),
1926 finish_reason: FinishReason::Stop,
1927 tool_calls: vec![ToolCall {
1928 name: "search".to_string(),
1929 arguments: serde_json::json!({"q":"rust"}),
1930 }],
1931 },
1932 LlmResponse {
1933 content: r#"{"type":"final","content":"done"}"#.to_string(),
1934 usage: TokenUsage::default(),
1935 finish_reason: FinishReason::Stop,
1936 tool_calls: Vec::new(),
1937 },
1938 ])),
1939 };
1940 let tools = MockToolPort::with_tool_and_results(
1941 "search",
1942 vec![Ok(ToolResult {
1943 name: "search".to_string(),
1944 output: serde_json::json!({"hits": 2}),
1945 is_error: false,
1946 })],
1947 );
1948 let store = MemoryStore::new();
1949 let sink = CollectingSink::new();
1950 let checkpoint = CountingCheckpointPort::new();
1951 let artifacts = NoopArtifactStore;
1952 let cost = CountingCostMeter::new();
1953 let policy = AllowAllPolicyPort;
1954 let approval = AlwaysApprovePort;
1955
1956 let result = run_turn_with_extensions(
1957 &llm,
1958 &tools,
1959 &store,
1960 &sink,
1961 make_request("search rust"),
1962 &generous_policy(),
1963 "test-model",
1964 &policy,
1965 &approval,
1966 crate::DispatchMode::NativePreferred,
1967 &checkpoint,
1968 &artifacts,
1969 &cost,
1970 )
1971 .await;
1972
1973 assert!(matches!(&result, Ok(AgentRunResult::Finished(_))), "unexpected {result:?}");
1974 let resp = match result {
1975 Ok(AgentRunResult::Finished(r)) => r,
1976 _ => return,
1977 };
1978 assert_eq!(resp.tool_transcript.len(), 1);
1979 assert_eq!(resp.tool_transcript[0].name, "search");
1980 }
1981
1982 #[tokio::test]
1983 async fn tc16_checkpoint_and_cost_ports_are_invoked() {
1984 let llm = SequentialLlm::from_contents(vec![r#"{"type": "final", "content": "ok"}"#]);
1985 let tools = MockToolPort::empty();
1986 let store = MemoryStore::new();
1987 let sink = CollectingSink::new();
1988 let checkpoint = CountingCheckpointPort::new();
1989 let artifacts = NoopArtifactStore;
1990 let cost = CountingCostMeter::new();
1991 let policy = AllowAllPolicyPort;
1992 let approval = AlwaysApprovePort;
1993
1994 let result = run_turn_with_extensions(
1995 &llm,
1996 &tools,
1997 &store,
1998 &sink,
1999 make_request("hello"),
2000 &generous_policy(),
2001 "test-model",
2002 &policy,
2003 &approval,
2004 crate::DispatchMode::PromptGuided,
2005 &checkpoint,
2006 &artifacts,
2007 &cost,
2008 )
2009 .await;
2010 assert!(result.is_ok(), "turn should succeed");
2011 let checkpoints = checkpoint.saved.lock().unwrap_or_else(|p| p.into_inner()).len();
2012 let llm_calls = *cost.llm_calls.lock().unwrap_or_else(|p| p.into_inner());
2013 assert!(checkpoints >= 1, "checkpoint port should be invoked at least once");
2014 assert!(llm_calls >= 1, "cost meter should record llm usage");
2015 }
2016}