1use std::sync::Arc;
8
9use rig::client::{CompletionClient, ProviderClient};
10use rig::providers::{anthropic, openai};
11use serde_json;
12
13use crate::error::NikaError;
14use crate::event::{AgentTurnMetadata, EventKind};
15
16use crate::ast::limits::LimitType;
17
18use super::types::{RigAgentLoopResult, RigAgentStatus};
19use super::RigAgentLoop;
20
21impl RigAgentLoop {
22 pub async fn run_mock(&self) -> Result<RigAgentLoopResult, NikaError> {
26 self.event_log.emit(EventKind::AgentTurn {
28 task_id: Arc::from(self.task_id.as_str()),
29 turn_index: 1,
30 kind: "started".to_string(),
31 metadata: None,
32 });
33
34 let response_text = "Mock response from rig agent".to_string();
36 let final_output = serde_json::json!({
37 "response": &response_text,
38 "completed": true
39 });
40
41 let status = self.determine_status(&final_output.to_string());
43
44 let stop_reason = status.as_canonical_str();
46 let metadata = AgentTurnMetadata {
47 thinking: None, response_text: response_text.clone(),
49 input_tokens: 50,
50 output_tokens: 50,
51 cache_read_tokens: 0,
52 stop_reason: stop_reason.to_string(),
53 };
54
55 self.event_log.emit(EventKind::AgentTurn {
57 task_id: Arc::from(self.task_id.as_str()),
58 turn_index: 1,
59 kind: stop_reason.to_string(),
60 metadata: Some(metadata),
61 });
62
63 let guardrail_result = self.check_guardrails(&response_text);
65 let guardrails_passed = guardrail_result.is_passed();
66
67 Ok(RigAgentLoopResult {
68 status: status.clone(),
69 turns: 1,
70 final_output,
71 total_tokens: 100, confidence: status.confidence(),
73 retry_count: 0,
74 guardrails_passed,
75 cost_usd: 0.0,
76 partial_result: None,
77 })
78 }
79
80 pub async fn run_claude(&mut self) -> Result<RigAgentLoopResult, NikaError> {
102 if self.params.extended_thinking == Some(true) {
104 return self.run_claude_with_thinking().await;
105 }
106
107 let client = anthropic::Client::from_env();
109
110 let raw_model = self
112 .params
113 .model
114 .clone()
115 .ok_or_else(|| NikaError::ValidationError {
116 reason: "model field is required for LLM verbs (NIKA-034)".to_string(),
117 })?;
118 let model_name = Self::strip_model_prefix(&raw_model).to_string();
119 let model = client.completion_model(&model_name);
120
121 let tools = self.tools_as_boxed();
123
124 let max_turns = self.params.max_turns.unwrap_or(10) as usize;
126 let max_retries = self
127 .get_low_confidence_config()
128 .map(|c| c.max_retries)
129 .unwrap_or(2);
130 let base_prompt = self.params.prompt.clone();
131
132 let mut retry_count: u32 = 0;
133 let mut current_prompt = base_prompt.clone();
134 let mut total_input_tokens: u64 = 0;
135 let mut total_output_tokens: u64 = 0;
136 let mut total_cached_input_tokens: u64 = 0;
137
138 self.event_log.emit(EventKind::AgentTurn {
140 task_id: Arc::from(self.task_id.as_str()),
141 turn_index: 1,
142 kind: "started".to_string(),
143 metadata: None,
144 });
145
146 let mut result = self
148 .stream_with_tools(model.clone(), ¤t_prompt, tools, max_turns)
149 .await?;
150
151 total_input_tokens += result.input_tokens;
152 total_output_tokens += result.output_tokens;
153 total_cached_input_tokens += result.cached_input_tokens;
154
155 let cost = crate::provider::cost::calculate_cost_with_cache(
157 crate::provider::cost::ProviderKind::Claude,
158 &model_name,
159 result.input_tokens,
160 result.output_tokens,
161 result.cached_input_tokens,
162 );
163 self.limit_tracker
164 .record_turn(result.input_tokens, result.output_tokens, cost);
165
166 if let Some(exceeded) = self.limit_tracker.check_limits() {
168 let status = match exceeded.limit_type {
169 LimitType::Turns => RigAgentStatus::MaxTurnsReached,
170 LimitType::Tokens => RigAgentStatus::TokenBudgetExceeded,
171 LimitType::Cost => RigAgentStatus::CostLimitReached,
172 LimitType::Duration => RigAgentStatus::DurationLimitReached,
173 };
174 tracing::warn!(
175 task_id = %self.task_id,
176 limit = %exceeded.limit_type,
177 current = exceeded.current,
178 maximum = exceeded.maximum,
179 "Claude agent limit exceeded after first turn"
180 );
181 return Ok(RigAgentLoopResult {
182 status,
183 turns: 1,
184 final_output: serde_json::json!({ "response": result.response }),
185 total_tokens: total_input_tokens + total_output_tokens,
186 confidence: None,
187 retry_count: 0,
188 guardrails_passed: true,
189 cost_usd: self.limit_tracker.cost_usd(),
190 partial_result: None,
191 });
192 }
193
194 let mut status = self.determine_status(&result.response);
195
196 while self.should_retry(&status, retry_count) {
198 retry_count += 1;
199
200 if let Some(exceeded) = self.limit_tracker.check_limits() {
202 let limit_status = match exceeded.limit_type {
203 LimitType::Turns => RigAgentStatus::MaxTurnsReached,
204 LimitType::Tokens => RigAgentStatus::TokenBudgetExceeded,
205 LimitType::Cost => RigAgentStatus::CostLimitReached,
206 LimitType::Duration => RigAgentStatus::DurationLimitReached,
207 };
208 tracing::warn!(
209 task_id = %self.task_id,
210 limit = %exceeded.limit_type,
211 retry = retry_count,
212 "Claude agent limit exceeded during retry loop"
213 );
214 status = limit_status;
215 break;
216 }
217
218 let confidence = match &status {
220 RigAgentStatus::LowConfidence(c) => *c,
221 _ => 0.0,
222 };
223
224 self.event_log.emit(EventKind::AgentTurn {
226 task_id: Arc::from(self.task_id.as_str()),
227 turn_index: retry_count + 1,
228 kind: format!("retry_{}", retry_count),
229 metadata: Some(AgentTurnMetadata {
230 thinking: None,
231 response_text: format!(
232 "Low confidence ({:.2}), retrying ({}/{})",
233 confidence, retry_count, max_retries
234 ),
235 input_tokens: 0,
236 output_tokens: 0,
237 cache_read_tokens: 0,
238 stop_reason: "low_confidence_retry".to_string(),
239 }),
240 });
241
242 current_prompt = format!(
244 "{}\n\n{}\n\nPrevious response:\n{}",
245 base_prompt,
246 self.get_retry_feedback(confidence),
247 result.response
248 );
249
250 result = self
252 .stream_with_tools(model.clone(), ¤t_prompt, vec![], max_turns)
253 .await?;
254
255 total_input_tokens += result.input_tokens;
256 total_output_tokens += result.output_tokens;
257 total_cached_input_tokens += result.cached_input_tokens;
258
259 let retry_cost = crate::provider::cost::calculate_cost_with_cache(
261 crate::provider::cost::ProviderKind::Claude,
262 &model_name,
263 result.input_tokens,
264 result.output_tokens,
265 result.cached_input_tokens,
266 );
267 self.limit_tracker
268 .record_turn(result.input_tokens, result.output_tokens, retry_cost);
269
270 status = self.determine_status(&result.response);
271 }
272
273 let stop_reason = status.as_canonical_str();
275 let metadata = AgentTurnMetadata {
276 thinking: result.thinking,
277 response_text: result.response.clone(),
278 input_tokens: total_input_tokens,
279 output_tokens: total_output_tokens,
280 cache_read_tokens: total_cached_input_tokens,
281 stop_reason: stop_reason.to_string(),
282 };
283
284 self.event_log.emit(EventKind::AgentTurn {
286 task_id: Arc::from(self.task_id.as_str()),
287 turn_index: retry_count + 1,
288 kind: stop_reason.to_string(),
289 metadata: Some(metadata),
290 });
291
292 let max_guardrail_retries: u32 = 2;
294 let mut guardrail_retry_count: u32 = 0;
295 let mut guardrail_result = self.check_guardrails(&result.response);
296
297 while guardrail_result.should_retry() && guardrail_retry_count < max_guardrail_retries {
298 guardrail_retry_count += 1;
299
300 if let Some(exceeded) = self.limit_tracker.check_limits() {
302 tracing::warn!(
303 task_id = %self.task_id,
304 limit = %exceeded.limit_type,
305 guardrail_retry = guardrail_retry_count,
306 "Claude agent limit exceeded during guardrail retry loop"
307 );
308 break;
309 }
310
311 let feedback = guardrail_result.failure_messages().join("; ");
313 tracing::info!(
314 task_id = %self.task_id,
315 guardrail_retry = guardrail_retry_count,
316 max = max_guardrail_retries,
317 feedback = %feedback,
318 "Retrying Claude due to guardrail failure"
319 );
320
321 self.event_log.emit(EventKind::AgentTurn {
323 task_id: Arc::from(self.task_id.as_str()),
324 turn_index: retry_count + guardrail_retry_count + 1,
325 kind: format!("guardrail_retry_{}", guardrail_retry_count),
326 metadata: Some(AgentTurnMetadata {
327 thinking: None,
328 response_text: format!(
329 "Guardrail validation failed, retrying ({}/{}): {}",
330 guardrail_retry_count, max_guardrail_retries, feedback
331 ),
332 input_tokens: 0,
333 output_tokens: 0,
334 cache_read_tokens: 0,
335 stop_reason: "guardrail_retry".to_string(),
336 }),
337 });
338
339 current_prompt = format!(
341 "{}\n\n[GUARDRAIL RETRY {}/{}] Your previous output failed quality validation:\n{}\n\nPlease fix these issues and try again.\n\nPrevious response:\n{}",
342 base_prompt,
343 guardrail_retry_count,
344 max_guardrail_retries,
345 feedback,
346 result.response
347 );
348
349 result = self
351 .stream_with_tools(model.clone(), ¤t_prompt, vec![], max_turns)
352 .await?;
353
354 total_input_tokens += result.input_tokens;
355 total_output_tokens += result.output_tokens;
356 total_cached_input_tokens += result.cached_input_tokens;
357
358 let gr_cost = crate::provider::cost::calculate_cost_with_cache(
360 crate::provider::cost::ProviderKind::Claude,
361 &model_name,
362 result.input_tokens,
363 result.output_tokens,
364 result.cached_input_tokens,
365 );
366 self.limit_tracker
367 .record_turn(result.input_tokens, result.output_tokens, gr_cost);
368
369 status = self.determine_status(&result.response);
371 guardrail_result = self.check_guardrails(&result.response);
372 }
373
374 if guardrail_result.should_retry() {
376 tracing::warn!(
377 task_id = %self.task_id,
378 retries = guardrail_retry_count,
379 "Claude guardrail retries exhausted, accepting output with guardrails_passed=false"
380 );
381 }
382
383 let guardrails_passed = guardrail_result.is_passed();
384
385 let status = if guardrail_result.should_fail() {
387 RigAgentStatus::Failed
388 } else if guardrail_result.should_escalate() {
389 RigAgentStatus::Escalated(status.confidence().unwrap_or(0.0))
390 } else {
391 status
392 };
393
394 let total_retries = retry_count + guardrail_retry_count;
395
396 let total_cost = crate::provider::cost::calculate_cost_with_cache(
398 crate::provider::cost::ProviderKind::Claude,
399 &model_name,
400 total_input_tokens,
401 total_output_tokens,
402 total_cached_input_tokens,
403 );
404 self.event_log.emit(EventKind::ProviderResponded {
405 task_id: Arc::from(self.task_id.as_str()),
406 request_id: None,
407 input_tokens: total_input_tokens,
408 output_tokens: total_output_tokens,
409 cache_read_tokens: total_cached_input_tokens,
410 ttft_ms: None,
411 finish_reason: stop_reason.to_string(),
412 cost_usd: if total_cost.is_finite() {
413 total_cost
414 } else {
415 0.0
416 },
417 });
418
419 Ok(RigAgentLoopResult {
420 status: status.clone(),
421 turns: (total_retries + 1) as usize,
422 final_output: serde_json::json!({ "response": result.response }),
423 total_tokens: total_input_tokens + total_output_tokens,
424 confidence: status.confidence(),
425 retry_count: total_retries,
426 guardrails_passed,
427 cost_usd: self.limit_tracker.cost_usd(),
428 partial_result: None,
429 })
430 }
431
432 pub async fn run_openai(&mut self) -> Result<RigAgentLoopResult, NikaError> {
444 let client = openai::Client::from_env();
446
447 let raw_model = self
449 .params
450 .model
451 .clone()
452 .ok_or_else(|| NikaError::ValidationError {
453 reason: "model field is required for LLM verbs (NIKA-034)".to_string(),
454 })?;
455 let model_name = Self::strip_model_prefix(&raw_model).to_string();
456 let model = client.completion_model(&model_name);
457
458 let tools = self.tools_as_boxed();
460
461 let max_turns = self.params.max_turns.unwrap_or(10) as usize;
463 let max_retries = self
464 .get_low_confidence_config()
465 .map(|c| c.max_retries)
466 .unwrap_or(2);
467 let base_prompt = self.params.prompt.clone();
468
469 let mut retry_count: u32 = 0;
470 let mut current_prompt = base_prompt.clone();
471 let mut total_input_tokens: u64 = 0;
472 let mut total_output_tokens: u64 = 0;
473 let mut total_cached_input_tokens: u64 = 0;
474
475 self.event_log.emit(EventKind::AgentTurn {
477 task_id: Arc::from(self.task_id.as_str()),
478 turn_index: 1,
479 kind: "started".to_string(),
480 metadata: None,
481 });
482
483 let mut result = self
485 .stream_with_tools(model.clone(), ¤t_prompt, tools, max_turns)
486 .await?;
487
488 total_input_tokens += result.input_tokens;
489 total_output_tokens += result.output_tokens;
490 total_cached_input_tokens += result.cached_input_tokens;
491
492 let cost = crate::provider::cost::calculate_cost_with_cache(
494 crate::provider::cost::ProviderKind::OpenAI,
495 &model_name,
496 result.input_tokens,
497 result.output_tokens,
498 result.cached_input_tokens,
499 );
500 self.limit_tracker
501 .record_turn(result.input_tokens, result.output_tokens, cost);
502
503 if let Some(exceeded) = self.limit_tracker.check_limits() {
505 let status = match exceeded.limit_type {
506 LimitType::Turns => RigAgentStatus::MaxTurnsReached,
507 LimitType::Tokens => RigAgentStatus::TokenBudgetExceeded,
508 LimitType::Cost => RigAgentStatus::CostLimitReached,
509 LimitType::Duration => RigAgentStatus::DurationLimitReached,
510 };
511 tracing::warn!(
512 task_id = %self.task_id,
513 limit = %exceeded.limit_type,
514 current = exceeded.current,
515 maximum = exceeded.maximum,
516 "OpenAI agent limit exceeded after first turn"
517 );
518 return Ok(RigAgentLoopResult {
519 status,
520 turns: 1,
521 final_output: serde_json::json!({ "response": result.response }),
522 total_tokens: total_input_tokens + total_output_tokens,
523 confidence: None,
524 retry_count: 0,
525 guardrails_passed: true,
526 cost_usd: self.limit_tracker.cost_usd(),
527 partial_result: None,
528 });
529 }
530
531 let mut status = self.determine_status(&result.response);
532
533 while self.should_retry(&status, retry_count) {
535 retry_count += 1;
536
537 if let Some(exceeded) = self.limit_tracker.check_limits() {
539 let limit_status = match exceeded.limit_type {
540 LimitType::Turns => RigAgentStatus::MaxTurnsReached,
541 LimitType::Tokens => RigAgentStatus::TokenBudgetExceeded,
542 LimitType::Cost => RigAgentStatus::CostLimitReached,
543 LimitType::Duration => RigAgentStatus::DurationLimitReached,
544 };
545 tracing::warn!(
546 task_id = %self.task_id,
547 limit = %exceeded.limit_type,
548 retry = retry_count,
549 "OpenAI agent limit exceeded during retry loop"
550 );
551 status = limit_status;
552 break;
553 }
554
555 let confidence = match &status {
557 RigAgentStatus::LowConfidence(c) => *c,
558 _ => 0.0,
559 };
560
561 self.event_log.emit(EventKind::AgentTurn {
563 task_id: Arc::from(self.task_id.as_str()),
564 turn_index: retry_count + 1,
565 kind: format!("retry_{}", retry_count),
566 metadata: Some(AgentTurnMetadata {
567 thinking: None,
568 response_text: format!(
569 "Low confidence ({:.2}), retrying ({}/{})",
570 confidence, retry_count, max_retries
571 ),
572 input_tokens: 0,
573 output_tokens: 0,
574 cache_read_tokens: 0,
575 stop_reason: "low_confidence_retry".to_string(),
576 }),
577 });
578
579 current_prompt = format!(
581 "{}\n\n{}\n\nPrevious response:\n{}",
582 base_prompt,
583 self.get_retry_feedback(confidence),
584 result.response
585 );
586
587 result = self
589 .stream_with_tools(model.clone(), ¤t_prompt, vec![], max_turns)
590 .await?;
591
592 total_input_tokens += result.input_tokens;
593 total_output_tokens += result.output_tokens;
594 total_cached_input_tokens += result.cached_input_tokens;
595
596 let retry_cost = crate::provider::cost::calculate_cost_with_cache(
598 crate::provider::cost::ProviderKind::OpenAI,
599 &model_name,
600 result.input_tokens,
601 result.output_tokens,
602 result.cached_input_tokens,
603 );
604 self.limit_tracker
605 .record_turn(result.input_tokens, result.output_tokens, retry_cost);
606
607 status = self.determine_status(&result.response);
608 }
609
610 let stop_reason = status.as_canonical_str();
612 let metadata = AgentTurnMetadata {
613 thinking: result.thinking,
614 response_text: result.response.clone(),
615 input_tokens: total_input_tokens,
616 output_tokens: total_output_tokens,
617 cache_read_tokens: total_cached_input_tokens,
618 stop_reason: stop_reason.to_string(),
619 };
620
621 self.event_log.emit(EventKind::AgentTurn {
622 task_id: Arc::from(self.task_id.as_str()),
623 turn_index: retry_count + 1,
624 kind: stop_reason.to_string(),
625 metadata: Some(metadata),
626 });
627
628 let max_guardrail_retries: u32 = 2;
630 let mut guardrail_retry_count: u32 = 0;
631 let mut guardrail_result = self.check_guardrails(&result.response);
632
633 while guardrail_result.should_retry() && guardrail_retry_count < max_guardrail_retries {
634 guardrail_retry_count += 1;
635
636 if let Some(exceeded) = self.limit_tracker.check_limits() {
638 tracing::warn!(
639 task_id = %self.task_id,
640 limit = %exceeded.limit_type,
641 guardrail_retry = guardrail_retry_count,
642 "OpenAI agent limit exceeded during guardrail retry loop"
643 );
644 break;
645 }
646
647 let feedback = guardrail_result.failure_messages().join("; ");
649 tracing::info!(
650 task_id = %self.task_id,
651 guardrail_retry = guardrail_retry_count,
652 max = max_guardrail_retries,
653 feedback = %feedback,
654 "Retrying OpenAI due to guardrail failure"
655 );
656
657 self.event_log.emit(EventKind::AgentTurn {
659 task_id: Arc::from(self.task_id.as_str()),
660 turn_index: retry_count + guardrail_retry_count + 1,
661 kind: format!("guardrail_retry_{}", guardrail_retry_count),
662 metadata: Some(AgentTurnMetadata {
663 thinking: None,
664 response_text: format!(
665 "Guardrail validation failed, retrying ({}/{}): {}",
666 guardrail_retry_count, max_guardrail_retries, feedback
667 ),
668 input_tokens: 0,
669 output_tokens: 0,
670 cache_read_tokens: 0,
671 stop_reason: "guardrail_retry".to_string(),
672 }),
673 });
674
675 current_prompt = format!(
677 "{}\n\n[GUARDRAIL RETRY {}/{}] Your previous output failed quality validation:\n{}\n\nPlease fix these issues and try again.\n\nPrevious response:\n{}",
678 base_prompt,
679 guardrail_retry_count,
680 max_guardrail_retries,
681 feedback,
682 result.response
683 );
684
685 result = self
687 .stream_with_tools(model.clone(), ¤t_prompt, vec![], max_turns)
688 .await?;
689
690 total_input_tokens += result.input_tokens;
691 total_output_tokens += result.output_tokens;
692 total_cached_input_tokens += result.cached_input_tokens;
693
694 let gr_cost = crate::provider::cost::calculate_cost_with_cache(
696 crate::provider::cost::ProviderKind::OpenAI,
697 &model_name,
698 result.input_tokens,
699 result.output_tokens,
700 result.cached_input_tokens,
701 );
702 self.limit_tracker
703 .record_turn(result.input_tokens, result.output_tokens, gr_cost);
704
705 status = self.determine_status(&result.response);
707 guardrail_result = self.check_guardrails(&result.response);
708 }
709
710 if guardrail_result.should_retry() {
712 tracing::warn!(
713 task_id = %self.task_id,
714 retries = guardrail_retry_count,
715 "OpenAI guardrail retries exhausted, accepting output with guardrails_passed=false"
716 );
717 }
718
719 let guardrails_passed = guardrail_result.is_passed();
720
721 let status = if guardrail_result.should_fail() {
723 RigAgentStatus::Failed
724 } else if guardrail_result.should_escalate() {
725 RigAgentStatus::Escalated(status.confidence().unwrap_or(0.0))
726 } else {
727 status
728 };
729
730 let total_retries = retry_count + guardrail_retry_count;
731
732 let total_cost = crate::provider::cost::calculate_cost_with_cache(
734 crate::provider::cost::ProviderKind::OpenAI,
735 &model_name,
736 total_input_tokens,
737 total_output_tokens,
738 total_cached_input_tokens,
739 );
740 self.event_log.emit(EventKind::ProviderResponded {
741 task_id: Arc::from(self.task_id.as_str()),
742 request_id: None,
743 input_tokens: total_input_tokens,
744 output_tokens: total_output_tokens,
745 cache_read_tokens: total_cached_input_tokens,
746 ttft_ms: None,
747 finish_reason: stop_reason.to_string(),
748 cost_usd: if total_cost.is_finite() {
749 total_cost
750 } else {
751 0.0
752 },
753 });
754
755 Ok(RigAgentLoopResult {
756 status: status.clone(),
757 turns: (total_retries + 1) as usize,
758 final_output: serde_json::json!({ "response": result.response }),
759 total_tokens: total_input_tokens + total_output_tokens,
760 confidence: status.confidence(),
761 retry_count: total_retries,
762 guardrails_passed,
763 cost_usd: self.limit_tracker.cost_usd(),
764 partial_result: None,
765 })
766 }
767
768 pub async fn run_auto(&mut self) -> Result<RigAgentLoopResult, NikaError> {
782 if let Some(ref provider_name) = self.params.provider {
784 let resolved = crate::core::find_provider(provider_name).ok_or_else(|| {
785 NikaError::AgentValidationError {
786 reason: format!(
787 "Unknown provider: '{}'. Use 'claude', 'openai', 'mistral', 'groq', 'deepseek', 'gemini', or 'xai'.",
788 provider_name
789 ),
790 }
791 })?;
792 return match resolved.id {
793 "anthropic" => self.run_claude().await,
794 "openai" => self.run_openai().await,
795 "mistral" => self.run_mistral().await,
796 "groq" => self.run_groq().await,
797 "deepseek" => self.run_deepseek().await,
798 "gemini" => self.run_gemini().await,
799 "xai" => self.run_xai().await,
800 "native" => Err(NikaError::AgentValidationError {
801 reason: "Provider 'native' is not supported for agent: tasks. Native inference (mistral.rs) is only available for infer: tasks. Use a cloud provider (claude, openai, mistral, groq, deepseek, gemini, xai) for agent tasks.".to_string(),
802 }),
803 _ => Err(NikaError::AgentValidationError {
804 reason: format!("Provider '{}' is not supported for agent: tasks.", resolved.id),
805 }),
806 };
807 }
808
809 use crate::core::providers::{ProviderCategory, KNOWN_PROVIDERS};
811 for p in KNOWN_PROVIDERS.iter() {
812 if p.category == ProviderCategory::Llm && p.has_env_key() {
813 return match p.id {
814 "anthropic" => self.run_claude().await,
815 "openai" => self.run_openai().await,
816 "mistral" => self.run_mistral().await,
817 "groq" => self.run_groq().await,
818 "deepseek" => self.run_deepseek().await,
819 "gemini" => self.run_gemini().await,
820 "xai" => self.run_xai().await,
821 _ => continue,
822 };
823 }
824 }
825
826 Err(NikaError::AgentValidationError {
827 reason: "No API key found. Set one of: ANTHROPIC_API_KEY, OPENAI_API_KEY, MISTRAL_API_KEY, GROQ_API_KEY, DEEPSEEK_API_KEY, GEMINI_API_KEY, or XAI_API_KEY.".to_string(),
828 })
829 }
830
831 pub async fn run_mistral(&mut self) -> Result<RigAgentLoopResult, NikaError> {
837 let model_name = self
838 .params
839 .model
840 .clone()
841 .unwrap_or_else(|| rig::providers::mistral::MISTRAL_LARGE.to_string());
842 let client = rig::providers::mistral::Client::from_env();
843 self.run_generic_provider_impl(
844 client,
845 &model_name,
846 Some(crate::provider::cost::ProviderKind::Mistral),
847 )
848 .await
849 }
850
851 pub async fn run_groq(&mut self) -> Result<RigAgentLoopResult, NikaError> {
853 let model_name = self
854 .params
855 .model
856 .clone()
857 .unwrap_or_else(|| "llama-3.3-70b-versatile".to_string());
858 let client = rig::providers::groq::Client::from_env();
859 self.run_generic_provider_impl(
860 client,
861 &model_name,
862 Some(crate::provider::cost::ProviderKind::Groq),
863 )
864 .await
865 }
866
867 pub async fn run_deepseek(&mut self) -> Result<RigAgentLoopResult, NikaError> {
869 let model_name = self
870 .params
871 .model
872 .clone()
873 .unwrap_or_else(|| "deepseek-chat".to_string());
874 let client = rig::providers::deepseek::Client::from_env();
875 self.run_generic_provider_impl(
876 client,
877 &model_name,
878 Some(crate::provider::cost::ProviderKind::DeepSeek),
879 )
880 .await
881 }
882
883 pub async fn run_gemini(&mut self) -> Result<RigAgentLoopResult, NikaError> {
885 let model_name = self
886 .params
887 .model
888 .clone()
889 .unwrap_or_else(|| "gemini-2.0-flash".to_string());
890 let client = rig::providers::gemini::Client::from_env();
891 self.run_generic_provider_impl(
892 client,
893 &model_name,
894 Some(crate::provider::cost::ProviderKind::Gemini),
895 )
896 .await
897 }
898
899 pub async fn run_xai(&mut self) -> Result<RigAgentLoopResult, NikaError> {
901 let model_name = self
902 .params
903 .model
904 .clone()
905 .unwrap_or_else(|| "grok-3-fast".to_string());
906 let client = rig::providers::xai::Client::from_env();
907 self.run_generic_provider_impl(
908 client,
909 &model_name,
910 Some(crate::provider::cost::ProviderKind::XAi),
911 )
912 .await
913 }
914
915 async fn run_generic_provider_impl<C>(
920 &mut self,
921 client: C,
922 model_name: &str,
923 provider_kind: Option<crate::provider::cost::ProviderKind>,
924 ) -> Result<RigAgentLoopResult, NikaError>
925 where
926 C: CompletionClient,
927 C::CompletionModel: Clone + 'static,
928 <C::CompletionModel as rig::completion::CompletionModel>::Response: Send,
929 {
930 let model_name = Self::strip_model_prefix(model_name);
931 let model = client.completion_model(model_name);
932
933 let tools = self.tools_as_boxed();
935 let max_turns = self.params.max_turns.unwrap_or(10) as usize;
936 let base_prompt = self.params.prompt.clone();
937
938 let max_retries = self
940 .get_low_confidence_config()
941 .map(|c| c.max_retries)
942 .unwrap_or(2);
943
944 let mut retry_count: u32 = 0;
945 let mut current_prompt = base_prompt.clone();
946 let mut total_input_tokens: u64 = 0;
947 let mut total_output_tokens: u64 = 0;
948 let mut total_cached_input_tokens: u64 = 0;
949
950 self.event_log.emit(EventKind::AgentTurn {
952 task_id: Arc::from(self.task_id.as_str()),
953 turn_index: 1,
954 kind: "started".to_string(),
955 metadata: None,
956 });
957
958 let mut result = self
960 .stream_with_tools(model.clone(), ¤t_prompt, tools, max_turns)
961 .await?;
962
963 total_input_tokens += result.input_tokens;
964 total_output_tokens += result.output_tokens;
965 total_cached_input_tokens += result.cached_input_tokens;
966
967 let turn_cost = provider_kind
969 .map(|pk| {
970 crate::provider::cost::calculate_cost_with_cache(
971 pk,
972 model_name,
973 result.input_tokens,
974 result.output_tokens,
975 result.cached_input_tokens,
976 )
977 })
978 .unwrap_or(0.0);
979 self.limit_tracker
980 .record_turn(result.input_tokens, result.output_tokens, turn_cost);
981
982 if let Some(exceeded) = self.limit_tracker.check_limits() {
984 let status = match exceeded.limit_type {
985 LimitType::Turns => RigAgentStatus::MaxTurnsReached,
986 LimitType::Tokens => RigAgentStatus::TokenBudgetExceeded,
987 LimitType::Cost => RigAgentStatus::CostLimitReached,
988 LimitType::Duration => RigAgentStatus::DurationLimitReached,
989 };
990 tracing::warn!(
991 task_id = %self.task_id,
992 limit = %exceeded.limit_type,
993 current = exceeded.current,
994 maximum = exceeded.maximum,
995 "Agent limit exceeded after first turn"
996 );
997 return Ok(RigAgentLoopResult {
998 status,
999 turns: 1,
1000 final_output: serde_json::json!({ "response": result.response }),
1001 total_tokens: total_input_tokens + total_output_tokens,
1002 confidence: None,
1003 retry_count: 0,
1004 guardrails_passed: true,
1005 cost_usd: self.limit_tracker.cost_usd(),
1006 partial_result: None,
1007 });
1008 }
1009
1010 let mut status = self.determine_status(&result.response);
1011
1012 while self.should_retry(&status, retry_count) {
1014 retry_count += 1;
1015
1016 if let Some(exceeded) = self.limit_tracker.check_limits() {
1018 let limit_status = match exceeded.limit_type {
1019 LimitType::Turns => RigAgentStatus::MaxTurnsReached,
1020 LimitType::Tokens => RigAgentStatus::TokenBudgetExceeded,
1021 LimitType::Cost => RigAgentStatus::CostLimitReached,
1022 LimitType::Duration => RigAgentStatus::DurationLimitReached,
1023 };
1024 tracing::warn!(
1025 task_id = %self.task_id,
1026 limit = %exceeded.limit_type,
1027 retry = retry_count,
1028 "Agent limit exceeded during retry loop"
1029 );
1030 status = limit_status;
1031 break;
1032 }
1033
1034 let confidence = match &status {
1036 RigAgentStatus::LowConfidence(c) => *c,
1037 _ => 0.0,
1038 };
1039
1040 self.event_log.emit(EventKind::AgentTurn {
1042 task_id: Arc::from(self.task_id.as_str()),
1043 turn_index: retry_count + 1,
1044 kind: format!("retry_{}", retry_count),
1045 metadata: Some(AgentTurnMetadata {
1046 thinking: None,
1047 response_text: format!(
1048 "Low confidence ({:.2}), retrying ({}/{})",
1049 confidence, retry_count, max_retries
1050 ),
1051 input_tokens: 0,
1052 output_tokens: 0,
1053 cache_read_tokens: 0,
1054 stop_reason: "low_confidence_retry".to_string(),
1055 }),
1056 });
1057
1058 current_prompt = format!(
1060 "{}\n\n{}\n\nPrevious response:\n{}",
1061 base_prompt,
1062 self.get_retry_feedback(confidence),
1063 result.response
1064 );
1065
1066 result = self
1069 .stream_with_tools(model.clone(), ¤t_prompt, vec![], max_turns)
1070 .await?;
1071
1072 total_input_tokens += result.input_tokens;
1073 total_output_tokens += result.output_tokens;
1074 total_cached_input_tokens += result.cached_input_tokens;
1075
1076 let retry_cost = provider_kind
1078 .map(|pk| {
1079 crate::provider::cost::calculate_cost_with_cache(
1080 pk,
1081 model_name,
1082 result.input_tokens,
1083 result.output_tokens,
1084 result.cached_input_tokens,
1085 )
1086 })
1087 .unwrap_or(0.0);
1088 self.limit_tracker
1089 .record_turn(result.input_tokens, result.output_tokens, retry_cost);
1090
1091 status = self.determine_status(&result.response);
1092 }
1093
1094 let stop_reason = status.as_canonical_str();
1096 let metadata = AgentTurnMetadata {
1097 thinking: result.thinking,
1098 response_text: result.response.clone(),
1099 input_tokens: total_input_tokens,
1100 output_tokens: total_output_tokens,
1101 cache_read_tokens: total_cached_input_tokens,
1102 stop_reason: stop_reason.to_string(),
1103 };
1104
1105 self.event_log.emit(EventKind::AgentTurn {
1106 task_id: Arc::from(self.task_id.as_str()),
1107 turn_index: retry_count + 1,
1108 kind: stop_reason.to_string(),
1109 metadata: Some(metadata),
1110 });
1111
1112 let max_guardrail_retries: u32 = 2;
1114 let mut guardrail_retry_count: u32 = 0;
1115 let mut guardrail_result = self.check_guardrails(&result.response);
1116
1117 while guardrail_result.should_retry() && guardrail_retry_count < max_guardrail_retries {
1118 guardrail_retry_count += 1;
1119
1120 if let Some(exceeded) = self.limit_tracker.check_limits() {
1122 tracing::warn!(
1123 task_id = %self.task_id,
1124 limit = %exceeded.limit_type,
1125 guardrail_retry = guardrail_retry_count,
1126 "Agent limit exceeded during guardrail retry loop"
1127 );
1128 break;
1129 }
1130
1131 let feedback = guardrail_result.failure_messages().join("; ");
1133 tracing::info!(
1134 task_id = %self.task_id,
1135 guardrail_retry = guardrail_retry_count,
1136 max = max_guardrail_retries,
1137 feedback = %feedback,
1138 "Retrying due to guardrail failure"
1139 );
1140
1141 self.event_log.emit(EventKind::AgentTurn {
1143 task_id: Arc::from(self.task_id.as_str()),
1144 turn_index: retry_count + guardrail_retry_count + 1,
1145 kind: format!("guardrail_retry_{}", guardrail_retry_count),
1146 metadata: Some(AgentTurnMetadata {
1147 thinking: None,
1148 response_text: format!(
1149 "Guardrail validation failed, retrying ({}/{}): {}",
1150 guardrail_retry_count, max_guardrail_retries, feedback
1151 ),
1152 input_tokens: 0,
1153 output_tokens: 0,
1154 cache_read_tokens: 0,
1155 stop_reason: "guardrail_retry".to_string(),
1156 }),
1157 });
1158
1159 current_prompt = format!(
1161 "{}\n\n[GUARDRAIL RETRY {}/{}] Your previous output failed quality validation:\n{}\n\nPlease fix these issues and try again.\n\nPrevious response:\n{}",
1162 base_prompt,
1163 guardrail_retry_count,
1164 max_guardrail_retries,
1165 feedback,
1166 result.response
1167 );
1168
1169 result = self
1171 .stream_with_tools(model.clone(), ¤t_prompt, vec![], max_turns)
1172 .await?;
1173
1174 total_input_tokens += result.input_tokens;
1175 total_output_tokens += result.output_tokens;
1176 total_cached_input_tokens += result.cached_input_tokens;
1177
1178 let gr_cost = provider_kind
1180 .map(|pk| {
1181 crate::provider::cost::calculate_cost_with_cache(
1182 pk,
1183 model_name,
1184 result.input_tokens,
1185 result.output_tokens,
1186 result.cached_input_tokens,
1187 )
1188 })
1189 .unwrap_or(0.0);
1190 self.limit_tracker
1191 .record_turn(result.input_tokens, result.output_tokens, gr_cost);
1192
1193 status = self.determine_status(&result.response);
1195 guardrail_result = self.check_guardrails(&result.response);
1196 }
1197
1198 if guardrail_result.should_retry() {
1201 tracing::warn!(
1202 task_id = %self.task_id,
1203 retries = guardrail_retry_count,
1204 "Guardrail retries exhausted, accepting output with guardrails_passed=false"
1205 );
1206 }
1207
1208 let guardrails_passed = guardrail_result.is_passed();
1209
1210 let status = if guardrail_result.should_fail() {
1212 RigAgentStatus::Failed
1213 } else if guardrail_result.should_escalate() {
1214 RigAgentStatus::Escalated(status.confidence().unwrap_or(0.0))
1215 } else {
1216 status
1217 };
1218
1219 let total_retries = retry_count + guardrail_retry_count;
1220
1221 let total_cost = provider_kind
1223 .map(|pk| {
1224 crate::provider::cost::calculate_cost_with_cache(
1225 pk,
1226 model_name,
1227 total_input_tokens,
1228 total_output_tokens,
1229 total_cached_input_tokens,
1230 )
1231 })
1232 .unwrap_or(0.0);
1233 self.event_log.emit(EventKind::ProviderResponded {
1234 task_id: Arc::from(self.task_id.as_str()),
1235 request_id: None,
1236 input_tokens: total_input_tokens,
1237 output_tokens: total_output_tokens,
1238 cache_read_tokens: total_cached_input_tokens,
1239 ttft_ms: None,
1240 finish_reason: stop_reason.to_string(),
1241 cost_usd: if total_cost.is_finite() {
1242 total_cost
1243 } else {
1244 0.0
1245 },
1246 });
1247
1248 Ok(RigAgentLoopResult {
1249 status: status.clone(),
1250 turns: (total_retries + 1) as usize,
1251 final_output: serde_json::json!({ "response": result.response }),
1252 total_tokens: total_input_tokens + total_output_tokens,
1253 confidence: status.confidence(),
1254 retry_count: total_retries,
1255 guardrails_passed,
1256 cost_usd: self.limit_tracker.cost_usd(),
1257 partial_result: None,
1258 })
1259 }
1260}