Skip to main content

nika_engine/runtime/rig_agent_loop/
providers.rs

1//! Provider-specific execution methods
2//!
3//! Contains: run_mock, run_claude, run_openai, run_auto,
4//! run_mistral, run_groq, run_deepseek, run_gemini, run_xai,
5//! and the generic provider implementation with retry logic.
6
7use std::sync::Arc;
8
9use rig::client::{CompletionClient, ProviderClient};
10use rig::providers::{anthropic, openai};
11use serde_json;
12
13use crate::error::NikaError;
14use crate::event::{AgentTurnMetadata, EventKind};
15
16use crate::ast::limits::LimitType;
17
18use super::types::{RigAgentLoopResult, RigAgentStatus};
19use super::RigAgentLoop;
20
21impl RigAgentLoop {
22    /// Run the agent loop with a mock provider (for testing)
23    ///
24    /// This method simulates agent execution without making real API calls.
25    pub async fn run_mock(&self) -> Result<RigAgentLoopResult, NikaError> {
26        // Emit start event (no metadata for "started")
27        self.event_log.emit(EventKind::AgentTurn {
28            task_id: Arc::from(self.task_id.as_str()),
29            turn_index: 1,
30            kind: "started".to_string(),
31            metadata: None,
32        });
33
34        // For mock execution, we simulate a single turn with natural completion
35        let response_text = "Mock response from rig agent".to_string();
36        let final_output = serde_json::json!({
37            "response": &response_text,
38            "completed": true
39        });
40
41        // Check stop conditions
42        let status = self.determine_status(&final_output.to_string());
43
44        // Build metadata for completion event
45        let stop_reason = status.as_canonical_str();
46        let metadata = AgentTurnMetadata {
47            thinking: None, // Mock mode doesn't have thinking
48            response_text: response_text.clone(),
49            input_tokens: 50,
50            output_tokens: 50,
51            cache_read_tokens: 0,
52            stop_reason: stop_reason.to_string(),
53        };
54
55        // Emit completion event with metadata
56        self.event_log.emit(EventKind::AgentTurn {
57            task_id: Arc::from(self.task_id.as_str()),
58            turn_index: 1,
59            kind: stop_reason.to_string(),
60            metadata: Some(metadata),
61        });
62
63        // Check guardrails
64        let guardrail_result = self.check_guardrails(&response_text);
65        let guardrails_passed = guardrail_result.is_passed();
66
67        Ok(RigAgentLoopResult {
68            status: status.clone(),
69            turns: 1,
70            final_output,
71            total_tokens: 100, // Mock token count
72            confidence: status.confidence(),
73            retry_count: 0,
74            guardrails_passed,
75            cost_usd: 0.0,
76            partial_result: None,
77        })
78    }
79
80    /// Run the agent loop with the real Claude provider
81    ///
82    /// This method uses rig-core's AgentBuilder for actual execution.
83    /// Requires ANTHROPIC_API_KEY environment variable to be set.
84    ///
85    /// Includes confidence retry loop and guardrail retry loop, matching
86    /// the generic provider path behavior.
87    ///
88    /// # Note
89    /// This method takes `&mut self` because tools are consumed (moved to rig's AgentBuilder).
90    /// The agent loop is designed for single-use execution.
91    ///
92    /// ## Extended Thinking
93    /// When `extended_thinking: true` is set in AgentParams, this method uses
94    /// the streaming API to capture Claude's reasoning process. The thinking
95    /// is stored in `AgentTurnMetadata.thinking` for observability.
96    ///
97    /// ## Token Tracking
98    /// - Without tools: Uses streaming API for accurate token tracking
99    /// - With tools: Falls back to agent.prompt() (tokens will be 0)
100    /// - With extended_thinking: Uses dedicated streaming path
101    pub async fn run_claude(&mut self) -> Result<RigAgentLoopResult, NikaError> {
102        // Check if extended thinking is enabled
103        if self.params.extended_thinking == Some(true) {
104            return self.run_claude_with_thinking().await;
105        }
106
107        // Create Anthropic client from environment
108        let client = anthropic::Client::from_env();
109
110        // Get model name — validated by analyzer (NIKA-034)
111        let raw_model = self
112            .params
113            .model
114            .clone()
115            .ok_or_else(|| NikaError::ValidationError {
116                reason: "model field is required for LLM verbs (NIKA-034)".to_string(),
117            })?;
118        let model_name = Self::strip_model_prefix(&raw_model).to_string();
119        let model = client.completion_model(&model_name);
120
121        // Take ownership of tools (they'll be consumed by the builder)
122        let tools = self.tools_as_boxed();
123
124        // Get max_turns and retry config
125        let max_turns = self.params.max_turns.unwrap_or(10) as usize;
126        let max_retries = self
127            .get_low_confidence_config()
128            .map(|c| c.max_retries)
129            .unwrap_or(2);
130        let base_prompt = self.params.prompt.clone();
131
132        let mut retry_count: u32 = 0;
133        let mut current_prompt = base_prompt.clone();
134        let mut total_input_tokens: u64 = 0;
135        let mut total_output_tokens: u64 = 0;
136        let mut total_cached_input_tokens: u64 = 0;
137
138        // Emit start event (no metadata for "started")
139        self.event_log.emit(EventKind::AgentTurn {
140            task_id: Arc::from(self.task_id.as_str()),
141            turn_index: 1,
142            kind: "started".to_string(),
143            metadata: None,
144        });
145
146        // First attempt with tools
147        let mut result = self
148            .stream_with_tools(model.clone(), &current_prompt, tools, max_turns)
149            .await?;
150
151        total_input_tokens += result.input_tokens;
152        total_output_tokens += result.output_tokens;
153        total_cached_input_tokens += result.cached_input_tokens;
154
155        // Record turn in limit tracker
156        let cost = crate::provider::cost::calculate_cost_with_cache(
157            crate::provider::cost::ProviderKind::Claude,
158            &model_name,
159            result.input_tokens,
160            result.output_tokens,
161            result.cached_input_tokens,
162        );
163        self.limit_tracker
164            .record_turn(result.input_tokens, result.output_tokens, cost);
165
166        // Check limits after first turn
167        if let Some(exceeded) = self.limit_tracker.check_limits() {
168            let status = match exceeded.limit_type {
169                LimitType::Turns => RigAgentStatus::MaxTurnsReached,
170                LimitType::Tokens => RigAgentStatus::TokenBudgetExceeded,
171                LimitType::Cost => RigAgentStatus::CostLimitReached,
172                LimitType::Duration => RigAgentStatus::DurationLimitReached,
173            };
174            tracing::warn!(
175                task_id = %self.task_id,
176                limit = %exceeded.limit_type,
177                current = exceeded.current,
178                maximum = exceeded.maximum,
179                "Claude agent limit exceeded after first turn"
180            );
181            return Ok(RigAgentLoopResult {
182                status,
183                turns: 1,
184                final_output: serde_json::json!({ "response": result.response }),
185                total_tokens: total_input_tokens + total_output_tokens,
186                confidence: None,
187                retry_count: 0,
188                guardrails_passed: true,
189                cost_usd: self.limit_tracker.cost_usd(),
190                partial_result: None,
191            });
192        }
193
194        let mut status = self.determine_status(&result.response);
195
196        // Confidence retry loop (matches generic provider path)
197        while self.should_retry(&status, retry_count) {
198            retry_count += 1;
199
200            // Check limits before starting a retry
201            if let Some(exceeded) = self.limit_tracker.check_limits() {
202                let limit_status = match exceeded.limit_type {
203                    LimitType::Turns => RigAgentStatus::MaxTurnsReached,
204                    LimitType::Tokens => RigAgentStatus::TokenBudgetExceeded,
205                    LimitType::Cost => RigAgentStatus::CostLimitReached,
206                    LimitType::Duration => RigAgentStatus::DurationLimitReached,
207                };
208                tracing::warn!(
209                    task_id = %self.task_id,
210                    limit = %exceeded.limit_type,
211                    retry = retry_count,
212                    "Claude agent limit exceeded during retry loop"
213                );
214                status = limit_status;
215                break;
216            }
217
218            // Get confidence from status for feedback message
219            let confidence = match &status {
220                RigAgentStatus::LowConfidence(c) => *c,
221                _ => 0.0,
222            };
223
224            // Emit retry event
225            self.event_log.emit(EventKind::AgentTurn {
226                task_id: Arc::from(self.task_id.as_str()),
227                turn_index: retry_count + 1,
228                kind: format!("retry_{}", retry_count),
229                metadata: Some(AgentTurnMetadata {
230                    thinking: None,
231                    response_text: format!(
232                        "Low confidence ({:.2}), retrying ({}/{})",
233                        confidence, retry_count, max_retries
234                    ),
235                    input_tokens: 0,
236                    output_tokens: 0,
237                    cache_read_tokens: 0,
238                    stop_reason: "low_confidence_retry".to_string(),
239                }),
240            });
241
242            // Append feedback to prompt for retry
243            current_prompt = format!(
244                "{}\n\n{}\n\nPrevious response:\n{}",
245                base_prompt,
246                self.get_retry_feedback(confidence),
247                result.response
248            );
249
250            // Retry without tools (agent has already gathered context)
251            result = self
252                .stream_with_tools(model.clone(), &current_prompt, vec![], max_turns)
253                .await?;
254
255            total_input_tokens += result.input_tokens;
256            total_output_tokens += result.output_tokens;
257            total_cached_input_tokens += result.cached_input_tokens;
258
259            // Record retry turn in limit tracker
260            let retry_cost = crate::provider::cost::calculate_cost_with_cache(
261                crate::provider::cost::ProviderKind::Claude,
262                &model_name,
263                result.input_tokens,
264                result.output_tokens,
265                result.cached_input_tokens,
266            );
267            self.limit_tracker
268                .record_turn(result.input_tokens, result.output_tokens, retry_cost);
269
270            status = self.determine_status(&result.response);
271        }
272
273        // Build metadata WITH token tracking
274        let stop_reason = status.as_canonical_str();
275        let metadata = AgentTurnMetadata {
276            thinking: result.thinking,
277            response_text: result.response.clone(),
278            input_tokens: total_input_tokens,
279            output_tokens: total_output_tokens,
280            cache_read_tokens: total_cached_input_tokens,
281            stop_reason: stop_reason.to_string(),
282        };
283
284        // Emit completion event
285        self.event_log.emit(EventKind::AgentTurn {
286            task_id: Arc::from(self.task_id.as_str()),
287            turn_index: retry_count + 1,
288            kind: stop_reason.to_string(),
289            metadata: Some(metadata),
290        });
291
292        // Guardrail retry loop (matches generic provider path)
293        let max_guardrail_retries: u32 = 2;
294        let mut guardrail_retry_count: u32 = 0;
295        let mut guardrail_result = self.check_guardrails(&result.response);
296
297        while guardrail_result.should_retry() && guardrail_retry_count < max_guardrail_retries {
298            guardrail_retry_count += 1;
299
300            // Check limits before starting a guardrail retry
301            if let Some(exceeded) = self.limit_tracker.check_limits() {
302                tracing::warn!(
303                    task_id = %self.task_id,
304                    limit = %exceeded.limit_type,
305                    guardrail_retry = guardrail_retry_count,
306                    "Claude agent limit exceeded during guardrail retry loop"
307                );
308                break;
309            }
310
311            // Build feedback from guardrail failure messages
312            let feedback = guardrail_result.failure_messages().join("; ");
313            tracing::info!(
314                task_id = %self.task_id,
315                guardrail_retry = guardrail_retry_count,
316                max = max_guardrail_retries,
317                feedback = %feedback,
318                "Retrying Claude due to guardrail failure"
319            );
320
321            // Emit guardrail retry event
322            self.event_log.emit(EventKind::AgentTurn {
323                task_id: Arc::from(self.task_id.as_str()),
324                turn_index: retry_count + guardrail_retry_count + 1,
325                kind: format!("guardrail_retry_{}", guardrail_retry_count),
326                metadata: Some(AgentTurnMetadata {
327                    thinking: None,
328                    response_text: format!(
329                        "Guardrail validation failed, retrying ({}/{}): {}",
330                        guardrail_retry_count, max_guardrail_retries, feedback
331                    ),
332                    input_tokens: 0,
333                    output_tokens: 0,
334                    cache_read_tokens: 0,
335                    stop_reason: "guardrail_retry".to_string(),
336                }),
337            });
338
339            // Append guardrail feedback to prompt
340            current_prompt = format!(
341                "{}\n\n[GUARDRAIL RETRY {}/{}] Your previous output failed quality validation:\n{}\n\nPlease fix these issues and try again.\n\nPrevious response:\n{}",
342                base_prompt,
343                guardrail_retry_count,
344                max_guardrail_retries,
345                feedback,
346                result.response
347            );
348
349            // Re-run without tools (agent already has context)
350            result = self
351                .stream_with_tools(model.clone(), &current_prompt, vec![], max_turns)
352                .await?;
353
354            total_input_tokens += result.input_tokens;
355            total_output_tokens += result.output_tokens;
356            total_cached_input_tokens += result.cached_input_tokens;
357
358            // Record guardrail retry turn in limit tracker
359            let gr_cost = crate::provider::cost::calculate_cost_with_cache(
360                crate::provider::cost::ProviderKind::Claude,
361                &model_name,
362                result.input_tokens,
363                result.output_tokens,
364                result.cached_input_tokens,
365            );
366            self.limit_tracker
367                .record_turn(result.input_tokens, result.output_tokens, gr_cost);
368
369            // Re-determine status and re-check guardrails
370            status = self.determine_status(&result.response);
371            guardrail_result = self.check_guardrails(&result.response);
372        }
373
374        // After guardrail retries exhausted, if still failing with retry -> accept anyway
375        if guardrail_result.should_retry() {
376            tracing::warn!(
377                task_id = %self.task_id,
378                retries = guardrail_retry_count,
379                "Claude guardrail retries exhausted, accepting output with guardrails_passed=false"
380            );
381        }
382
383        let guardrails_passed = guardrail_result.is_passed();
384
385        // Override status when guardrails fail with terminal actions
386        let status = if guardrail_result.should_fail() {
387            RigAgentStatus::Failed
388        } else if guardrail_result.should_escalate() {
389            RigAgentStatus::Escalated(status.confidence().unwrap_or(0.0))
390        } else {
391            status
392        };
393
394        let total_retries = retry_count + guardrail_retry_count;
395
396        // Emit ProviderResponded so runner cost summary includes agent work
397        let total_cost = crate::provider::cost::calculate_cost_with_cache(
398            crate::provider::cost::ProviderKind::Claude,
399            &model_name,
400            total_input_tokens,
401            total_output_tokens,
402            total_cached_input_tokens,
403        );
404        self.event_log.emit(EventKind::ProviderResponded {
405            task_id: Arc::from(self.task_id.as_str()),
406            request_id: None,
407            input_tokens: total_input_tokens,
408            output_tokens: total_output_tokens,
409            cache_read_tokens: total_cached_input_tokens,
410            ttft_ms: None,
411            finish_reason: stop_reason.to_string(),
412            cost_usd: if total_cost.is_finite() {
413                total_cost
414            } else {
415                0.0
416            },
417        });
418
419        Ok(RigAgentLoopResult {
420            status: status.clone(),
421            turns: (total_retries + 1) as usize,
422            final_output: serde_json::json!({ "response": result.response }),
423            total_tokens: total_input_tokens + total_output_tokens,
424            confidence: status.confidence(),
425            retry_count: total_retries,
426            guardrails_passed,
427            cost_usd: self.limit_tracker.cost_usd(),
428            partial_result: None,
429        })
430    }
431
432    /// Run the agent loop with the OpenAI provider
433    ///
434    /// This method uses rig-core's OpenAI client for actual execution.
435    /// Requires OPENAI_API_KEY environment variable to be set.
436    ///
437    /// Includes confidence retry loop and guardrail retry loop, matching
438    /// the generic provider path behavior.
439    ///
440    /// # Note
441    /// This method takes `&mut self` because tools are consumed (moved to rig's AgentBuilder).
442    /// The agent loop is designed for single-use execution.
443    pub async fn run_openai(&mut self) -> Result<RigAgentLoopResult, NikaError> {
444        // Create OpenAI client from environment
445        let client = openai::Client::from_env();
446
447        // Get model name — validated by analyzer (NIKA-034)
448        let raw_model = self
449            .params
450            .model
451            .clone()
452            .ok_or_else(|| NikaError::ValidationError {
453                reason: "model field is required for LLM verbs (NIKA-034)".to_string(),
454            })?;
455        let model_name = Self::strip_model_prefix(&raw_model).to_string();
456        let model = client.completion_model(&model_name);
457
458        // Take ownership of tools (they'll be consumed by the builder)
459        let tools = self.tools_as_boxed();
460
461        // Get max_turns and retry config
462        let max_turns = self.params.max_turns.unwrap_or(10) as usize;
463        let max_retries = self
464            .get_low_confidence_config()
465            .map(|c| c.max_retries)
466            .unwrap_or(2);
467        let base_prompt = self.params.prompt.clone();
468
469        let mut retry_count: u32 = 0;
470        let mut current_prompt = base_prompt.clone();
471        let mut total_input_tokens: u64 = 0;
472        let mut total_output_tokens: u64 = 0;
473        let mut total_cached_input_tokens: u64 = 0;
474
475        // Emit start event (no metadata for "started")
476        self.event_log.emit(EventKind::AgentTurn {
477            task_id: Arc::from(self.task_id.as_str()),
478            turn_index: 1,
479            kind: "started".to_string(),
480            metadata: None,
481        });
482
483        // First attempt with tools
484        let mut result = self
485            .stream_with_tools(model.clone(), &current_prompt, tools, max_turns)
486            .await?;
487
488        total_input_tokens += result.input_tokens;
489        total_output_tokens += result.output_tokens;
490        total_cached_input_tokens += result.cached_input_tokens;
491
492        // Record turn in limit tracker
493        let cost = crate::provider::cost::calculate_cost_with_cache(
494            crate::provider::cost::ProviderKind::OpenAI,
495            &model_name,
496            result.input_tokens,
497            result.output_tokens,
498            result.cached_input_tokens,
499        );
500        self.limit_tracker
501            .record_turn(result.input_tokens, result.output_tokens, cost);
502
503        // Check limits after first turn
504        if let Some(exceeded) = self.limit_tracker.check_limits() {
505            let status = match exceeded.limit_type {
506                LimitType::Turns => RigAgentStatus::MaxTurnsReached,
507                LimitType::Tokens => RigAgentStatus::TokenBudgetExceeded,
508                LimitType::Cost => RigAgentStatus::CostLimitReached,
509                LimitType::Duration => RigAgentStatus::DurationLimitReached,
510            };
511            tracing::warn!(
512                task_id = %self.task_id,
513                limit = %exceeded.limit_type,
514                current = exceeded.current,
515                maximum = exceeded.maximum,
516                "OpenAI agent limit exceeded after first turn"
517            );
518            return Ok(RigAgentLoopResult {
519                status,
520                turns: 1,
521                final_output: serde_json::json!({ "response": result.response }),
522                total_tokens: total_input_tokens + total_output_tokens,
523                confidence: None,
524                retry_count: 0,
525                guardrails_passed: true,
526                cost_usd: self.limit_tracker.cost_usd(),
527                partial_result: None,
528            });
529        }
530
531        let mut status = self.determine_status(&result.response);
532
533        // Confidence retry loop (matches generic provider path)
534        while self.should_retry(&status, retry_count) {
535            retry_count += 1;
536
537            // Check limits before starting a retry
538            if let Some(exceeded) = self.limit_tracker.check_limits() {
539                let limit_status = match exceeded.limit_type {
540                    LimitType::Turns => RigAgentStatus::MaxTurnsReached,
541                    LimitType::Tokens => RigAgentStatus::TokenBudgetExceeded,
542                    LimitType::Cost => RigAgentStatus::CostLimitReached,
543                    LimitType::Duration => RigAgentStatus::DurationLimitReached,
544                };
545                tracing::warn!(
546                    task_id = %self.task_id,
547                    limit = %exceeded.limit_type,
548                    retry = retry_count,
549                    "OpenAI agent limit exceeded during retry loop"
550                );
551                status = limit_status;
552                break;
553            }
554
555            // Get confidence from status for feedback message
556            let confidence = match &status {
557                RigAgentStatus::LowConfidence(c) => *c,
558                _ => 0.0,
559            };
560
561            // Emit retry event
562            self.event_log.emit(EventKind::AgentTurn {
563                task_id: Arc::from(self.task_id.as_str()),
564                turn_index: retry_count + 1,
565                kind: format!("retry_{}", retry_count),
566                metadata: Some(AgentTurnMetadata {
567                    thinking: None,
568                    response_text: format!(
569                        "Low confidence ({:.2}), retrying ({}/{})",
570                        confidence, retry_count, max_retries
571                    ),
572                    input_tokens: 0,
573                    output_tokens: 0,
574                    cache_read_tokens: 0,
575                    stop_reason: "low_confidence_retry".to_string(),
576                }),
577            });
578
579            // Append feedback to prompt for retry
580            current_prompt = format!(
581                "{}\n\n{}\n\nPrevious response:\n{}",
582                base_prompt,
583                self.get_retry_feedback(confidence),
584                result.response
585            );
586
587            // Retry without tools (agent has already gathered context)
588            result = self
589                .stream_with_tools(model.clone(), &current_prompt, vec![], max_turns)
590                .await?;
591
592            total_input_tokens += result.input_tokens;
593            total_output_tokens += result.output_tokens;
594            total_cached_input_tokens += result.cached_input_tokens;
595
596            // Record retry turn in limit tracker
597            let retry_cost = crate::provider::cost::calculate_cost_with_cache(
598                crate::provider::cost::ProviderKind::OpenAI,
599                &model_name,
600                result.input_tokens,
601                result.output_tokens,
602                result.cached_input_tokens,
603            );
604            self.limit_tracker
605                .record_turn(result.input_tokens, result.output_tokens, retry_cost);
606
607            status = self.determine_status(&result.response);
608        }
609
610        // Build metadata WITH token tracking
611        let stop_reason = status.as_canonical_str();
612        let metadata = AgentTurnMetadata {
613            thinking: result.thinking,
614            response_text: result.response.clone(),
615            input_tokens: total_input_tokens,
616            output_tokens: total_output_tokens,
617            cache_read_tokens: total_cached_input_tokens,
618            stop_reason: stop_reason.to_string(),
619        };
620
621        self.event_log.emit(EventKind::AgentTurn {
622            task_id: Arc::from(self.task_id.as_str()),
623            turn_index: retry_count + 1,
624            kind: stop_reason.to_string(),
625            metadata: Some(metadata),
626        });
627
628        // Guardrail retry loop (matches generic provider path)
629        let max_guardrail_retries: u32 = 2;
630        let mut guardrail_retry_count: u32 = 0;
631        let mut guardrail_result = self.check_guardrails(&result.response);
632
633        while guardrail_result.should_retry() && guardrail_retry_count < max_guardrail_retries {
634            guardrail_retry_count += 1;
635
636            // Check limits before starting a guardrail retry
637            if let Some(exceeded) = self.limit_tracker.check_limits() {
638                tracing::warn!(
639                    task_id = %self.task_id,
640                    limit = %exceeded.limit_type,
641                    guardrail_retry = guardrail_retry_count,
642                    "OpenAI agent limit exceeded during guardrail retry loop"
643                );
644                break;
645            }
646
647            // Build feedback from guardrail failure messages
648            let feedback = guardrail_result.failure_messages().join("; ");
649            tracing::info!(
650                task_id = %self.task_id,
651                guardrail_retry = guardrail_retry_count,
652                max = max_guardrail_retries,
653                feedback = %feedback,
654                "Retrying OpenAI due to guardrail failure"
655            );
656
657            // Emit guardrail retry event
658            self.event_log.emit(EventKind::AgentTurn {
659                task_id: Arc::from(self.task_id.as_str()),
660                turn_index: retry_count + guardrail_retry_count + 1,
661                kind: format!("guardrail_retry_{}", guardrail_retry_count),
662                metadata: Some(AgentTurnMetadata {
663                    thinking: None,
664                    response_text: format!(
665                        "Guardrail validation failed, retrying ({}/{}): {}",
666                        guardrail_retry_count, max_guardrail_retries, feedback
667                    ),
668                    input_tokens: 0,
669                    output_tokens: 0,
670                    cache_read_tokens: 0,
671                    stop_reason: "guardrail_retry".to_string(),
672                }),
673            });
674
675            // Append guardrail feedback to prompt
676            current_prompt = format!(
677                "{}\n\n[GUARDRAIL RETRY {}/{}] Your previous output failed quality validation:\n{}\n\nPlease fix these issues and try again.\n\nPrevious response:\n{}",
678                base_prompt,
679                guardrail_retry_count,
680                max_guardrail_retries,
681                feedback,
682                result.response
683            );
684
685            // Re-run without tools (agent already has context)
686            result = self
687                .stream_with_tools(model.clone(), &current_prompt, vec![], max_turns)
688                .await?;
689
690            total_input_tokens += result.input_tokens;
691            total_output_tokens += result.output_tokens;
692            total_cached_input_tokens += result.cached_input_tokens;
693
694            // Record guardrail retry turn in limit tracker
695            let gr_cost = crate::provider::cost::calculate_cost_with_cache(
696                crate::provider::cost::ProviderKind::OpenAI,
697                &model_name,
698                result.input_tokens,
699                result.output_tokens,
700                result.cached_input_tokens,
701            );
702            self.limit_tracker
703                .record_turn(result.input_tokens, result.output_tokens, gr_cost);
704
705            // Re-determine status and re-check guardrails
706            status = self.determine_status(&result.response);
707            guardrail_result = self.check_guardrails(&result.response);
708        }
709
710        // After guardrail retries exhausted, if still failing with retry -> accept anyway
711        if guardrail_result.should_retry() {
712            tracing::warn!(
713                task_id = %self.task_id,
714                retries = guardrail_retry_count,
715                "OpenAI guardrail retries exhausted, accepting output with guardrails_passed=false"
716            );
717        }
718
719        let guardrails_passed = guardrail_result.is_passed();
720
721        // Override status when guardrails fail with terminal actions
722        let status = if guardrail_result.should_fail() {
723            RigAgentStatus::Failed
724        } else if guardrail_result.should_escalate() {
725            RigAgentStatus::Escalated(status.confidence().unwrap_or(0.0))
726        } else {
727            status
728        };
729
730        let total_retries = retry_count + guardrail_retry_count;
731
732        // Emit ProviderResponded so runner cost summary includes agent work
733        let total_cost = crate::provider::cost::calculate_cost_with_cache(
734            crate::provider::cost::ProviderKind::OpenAI,
735            &model_name,
736            total_input_tokens,
737            total_output_tokens,
738            total_cached_input_tokens,
739        );
740        self.event_log.emit(EventKind::ProviderResponded {
741            task_id: Arc::from(self.task_id.as_str()),
742            request_id: None,
743            input_tokens: total_input_tokens,
744            output_tokens: total_output_tokens,
745            cache_read_tokens: total_cached_input_tokens,
746            ttft_ms: None,
747            finish_reason: stop_reason.to_string(),
748            cost_usd: if total_cost.is_finite() {
749                total_cost
750            } else {
751                0.0
752            },
753        });
754
755        Ok(RigAgentLoopResult {
756            status: status.clone(),
757            turns: (total_retries + 1) as usize,
758            final_output: serde_json::json!({ "response": result.response }),
759            total_tokens: total_input_tokens + total_output_tokens,
760            confidence: status.confidence(),
761            retry_count: total_retries,
762            guardrails_passed,
763            cost_usd: self.limit_tracker.cost_usd(),
764            partial_result: None,
765        })
766    }
767
768    /// Run the agent loop with the best available provider
769    ///
770    /// Provider selection order:
771    /// 1. Check AgentParams.provider field
772    /// 2. Check ANTHROPIC_API_KEY env var → use Claude
773    /// 3. Check OPENAI_API_KEY env var → use OpenAI
774    /// 4. Check MISTRAL_API_KEY env var → use Mistral
775    /// 5. Check GROQ_API_KEY env var → use Groq
776    /// 6. Check DEEPSEEK_API_KEY env var → use DeepSeek
777    /// 7. Error if no provider available
778    ///
779    /// # Note
780    /// This is the recommended method for production use.
781    pub async fn run_auto(&mut self) -> Result<RigAgentLoopResult, NikaError> {
782        // Check explicit provider from params
783        if let Some(ref provider_name) = self.params.provider {
784            let resolved = crate::core::find_provider(provider_name).ok_or_else(|| {
785                NikaError::AgentValidationError {
786                    reason: format!(
787                        "Unknown provider: '{}'. Use 'claude', 'openai', 'mistral', 'groq', 'deepseek', 'gemini', or 'xai'.",
788                        provider_name
789                    ),
790                }
791            })?;
792            return match resolved.id {
793                "anthropic" => self.run_claude().await,
794                "openai" => self.run_openai().await,
795                "mistral" => self.run_mistral().await,
796                "groq" => self.run_groq().await,
797                "deepseek" => self.run_deepseek().await,
798                "gemini" => self.run_gemini().await,
799                "xai" => self.run_xai().await,
800                "native" => Err(NikaError::AgentValidationError {
801                    reason: "Provider 'native' is not supported for agent: tasks. Native inference (mistral.rs) is only available for infer: tasks. Use a cloud provider (claude, openai, mistral, groq, deepseek, gemini, xai) for agent tasks.".to_string(),
802                }),
803                _ => Err(NikaError::AgentValidationError {
804                    reason: format!("Provider '{}' is not supported for agent: tasks.", resolved.id),
805                }),
806            };
807        }
808
809        // Auto-detect: iterate KNOWN_PROVIDERS in priority order (LLM category only)
810        use crate::core::providers::{ProviderCategory, KNOWN_PROVIDERS};
811        for p in KNOWN_PROVIDERS.iter() {
812            if p.category == ProviderCategory::Llm && p.has_env_key() {
813                return match p.id {
814                    "anthropic" => self.run_claude().await,
815                    "openai" => self.run_openai().await,
816                    "mistral" => self.run_mistral().await,
817                    "groq" => self.run_groq().await,
818                    "deepseek" => self.run_deepseek().await,
819                    "gemini" => self.run_gemini().await,
820                    "xai" => self.run_xai().await,
821                    _ => continue,
822                };
823            }
824        }
825
826        Err(NikaError::AgentValidationError {
827            reason: "No API key found. Set one of: ANTHROPIC_API_KEY, OPENAI_API_KEY, MISTRAL_API_KEY, GROQ_API_KEY, DEEPSEEK_API_KEY, GEMINI_API_KEY, or XAI_API_KEY.".to_string(),
828        })
829    }
830
831    // =========================================================================
832    // Additional Provider Methods
833    // =========================================================================
834
835    /// Run with Mistral provider (requires MISTRAL_API_KEY)
836    pub async fn run_mistral(&mut self) -> Result<RigAgentLoopResult, NikaError> {
837        let model_name = self
838            .params
839            .model
840            .clone()
841            .unwrap_or_else(|| rig::providers::mistral::MISTRAL_LARGE.to_string());
842        let client = rig::providers::mistral::Client::from_env();
843        self.run_generic_provider_impl(
844            client,
845            &model_name,
846            Some(crate::provider::cost::ProviderKind::Mistral),
847        )
848        .await
849    }
850
851    /// Run with Groq provider (requires GROQ_API_KEY)
852    pub async fn run_groq(&mut self) -> Result<RigAgentLoopResult, NikaError> {
853        let model_name = self
854            .params
855            .model
856            .clone()
857            .unwrap_or_else(|| "llama-3.3-70b-versatile".to_string());
858        let client = rig::providers::groq::Client::from_env();
859        self.run_generic_provider_impl(
860            client,
861            &model_name,
862            Some(crate::provider::cost::ProviderKind::Groq),
863        )
864        .await
865    }
866
867    /// Run with DeepSeek provider (requires DEEPSEEK_API_KEY)
868    pub async fn run_deepseek(&mut self) -> Result<RigAgentLoopResult, NikaError> {
869        let model_name = self
870            .params
871            .model
872            .clone()
873            .unwrap_or_else(|| "deepseek-chat".to_string());
874        let client = rig::providers::deepseek::Client::from_env();
875        self.run_generic_provider_impl(
876            client,
877            &model_name,
878            Some(crate::provider::cost::ProviderKind::DeepSeek),
879        )
880        .await
881    }
882
883    /// Run with Gemini provider (requires GEMINI_API_KEY)
884    pub async fn run_gemini(&mut self) -> Result<RigAgentLoopResult, NikaError> {
885        let model_name = self
886            .params
887            .model
888            .clone()
889            .unwrap_or_else(|| "gemini-2.0-flash".to_string());
890        let client = rig::providers::gemini::Client::from_env();
891        self.run_generic_provider_impl(
892            client,
893            &model_name,
894            Some(crate::provider::cost::ProviderKind::Gemini),
895        )
896        .await
897    }
898
899    /// Run with xAI provider (requires XAI_API_KEY)
900    pub async fn run_xai(&mut self) -> Result<RigAgentLoopResult, NikaError> {
901        let model_name = self
902            .params
903            .model
904            .clone()
905            .unwrap_or_else(|| "grok-3-fast".to_string());
906        let client = rig::providers::xai::Client::from_env();
907        self.run_generic_provider_impl(
908            client,
909            &model_name,
910            Some(crate::provider::cost::ProviderKind::XAi),
911        )
912        .await
913    }
914
915    /// Generic provider runner implementation
916    ///
917    /// Uses rig-core's unified ProviderClient + CompletionClient interface.
918    /// Includes retry logic for low confidence responses.
919    async fn run_generic_provider_impl<C>(
920        &mut self,
921        client: C,
922        model_name: &str,
923        provider_kind: Option<crate::provider::cost::ProviderKind>,
924    ) -> Result<RigAgentLoopResult, NikaError>
925    where
926        C: CompletionClient,
927        C::CompletionModel: Clone + 'static,
928        <C::CompletionModel as rig::completion::CompletionModel>::Response: Send,
929    {
930        let model_name = Self::strip_model_prefix(model_name);
931        let model = client.completion_model(model_name);
932
933        // Take ownership of tools for first attempt
934        let tools = self.tools_as_boxed();
935        let max_turns = self.params.max_turns.unwrap_or(10) as usize;
936        let base_prompt = self.params.prompt.clone();
937
938        // Get max retries from config (default: 2)
939        let max_retries = self
940            .get_low_confidence_config()
941            .map(|c| c.max_retries)
942            .unwrap_or(2);
943
944        let mut retry_count: u32 = 0;
945        let mut current_prompt = base_prompt.clone();
946        let mut total_input_tokens: u64 = 0;
947        let mut total_output_tokens: u64 = 0;
948        let mut total_cached_input_tokens: u64 = 0;
949
950        // Emit start event
951        self.event_log.emit(EventKind::AgentTurn {
952            task_id: Arc::from(self.task_id.as_str()),
953            turn_index: 1,
954            kind: "started".to_string(),
955            metadata: None,
956        });
957
958        // First attempt with tools
959        let mut result = self
960            .stream_with_tools(model.clone(), &current_prompt, tools, max_turns)
961            .await?;
962
963        total_input_tokens += result.input_tokens;
964        total_output_tokens += result.output_tokens;
965        total_cached_input_tokens += result.cached_input_tokens;
966
967        // Record turn in limit tracker
968        let turn_cost = provider_kind
969            .map(|pk| {
970                crate::provider::cost::calculate_cost_with_cache(
971                    pk,
972                    model_name,
973                    result.input_tokens,
974                    result.output_tokens,
975                    result.cached_input_tokens,
976                )
977            })
978            .unwrap_or(0.0);
979        self.limit_tracker
980            .record_turn(result.input_tokens, result.output_tokens, turn_cost);
981
982        // Check limits after first turn
983        if let Some(exceeded) = self.limit_tracker.check_limits() {
984            let status = match exceeded.limit_type {
985                LimitType::Turns => RigAgentStatus::MaxTurnsReached,
986                LimitType::Tokens => RigAgentStatus::TokenBudgetExceeded,
987                LimitType::Cost => RigAgentStatus::CostLimitReached,
988                LimitType::Duration => RigAgentStatus::DurationLimitReached,
989            };
990            tracing::warn!(
991                task_id = %self.task_id,
992                limit = %exceeded.limit_type,
993                current = exceeded.current,
994                maximum = exceeded.maximum,
995                "Agent limit exceeded after first turn"
996            );
997            return Ok(RigAgentLoopResult {
998                status,
999                turns: 1,
1000                final_output: serde_json::json!({ "response": result.response }),
1001                total_tokens: total_input_tokens + total_output_tokens,
1002                confidence: None,
1003                retry_count: 0,
1004                guardrails_passed: true,
1005                cost_usd: self.limit_tracker.cost_usd(),
1006                partial_result: None,
1007            });
1008        }
1009
1010        let mut status = self.determine_status(&result.response);
1011
1012        // Retry loop for low confidence
1013        while self.should_retry(&status, retry_count) {
1014            retry_count += 1;
1015
1016            // Check limits before starting a retry
1017            if let Some(exceeded) = self.limit_tracker.check_limits() {
1018                let limit_status = match exceeded.limit_type {
1019                    LimitType::Turns => RigAgentStatus::MaxTurnsReached,
1020                    LimitType::Tokens => RigAgentStatus::TokenBudgetExceeded,
1021                    LimitType::Cost => RigAgentStatus::CostLimitReached,
1022                    LimitType::Duration => RigAgentStatus::DurationLimitReached,
1023                };
1024                tracing::warn!(
1025                    task_id = %self.task_id,
1026                    limit = %exceeded.limit_type,
1027                    retry = retry_count,
1028                    "Agent limit exceeded during retry loop"
1029                );
1030                status = limit_status;
1031                break;
1032            }
1033
1034            // Get confidence from status for feedback message
1035            let confidence = match &status {
1036                RigAgentStatus::LowConfidence(c) => *c,
1037                _ => 0.0,
1038            };
1039
1040            // Emit retry event
1041            self.event_log.emit(EventKind::AgentTurn {
1042                task_id: Arc::from(self.task_id.as_str()),
1043                turn_index: retry_count + 1,
1044                kind: format!("retry_{}", retry_count),
1045                metadata: Some(AgentTurnMetadata {
1046                    thinking: None,
1047                    response_text: format!(
1048                        "Low confidence ({:.2}), retrying ({}/{})",
1049                        confidence, retry_count, max_retries
1050                    ),
1051                    input_tokens: 0,
1052                    output_tokens: 0,
1053                    cache_read_tokens: 0,
1054                    stop_reason: "low_confidence_retry".to_string(),
1055                }),
1056            });
1057
1058            // Append feedback to prompt for retry
1059            current_prompt = format!(
1060                "{}\n\n{}\n\nPrevious response:\n{}",
1061                base_prompt,
1062                self.get_retry_feedback(confidence),
1063                result.response
1064            );
1065
1066            // Retry without tools (agent has already gathered context)
1067            // Using empty tools vec for retry attempts
1068            result = self
1069                .stream_with_tools(model.clone(), &current_prompt, vec![], max_turns)
1070                .await?;
1071
1072            total_input_tokens += result.input_tokens;
1073            total_output_tokens += result.output_tokens;
1074            total_cached_input_tokens += result.cached_input_tokens;
1075
1076            // Record retry turn in limit tracker
1077            let retry_cost = provider_kind
1078                .map(|pk| {
1079                    crate::provider::cost::calculate_cost_with_cache(
1080                        pk,
1081                        model_name,
1082                        result.input_tokens,
1083                        result.output_tokens,
1084                        result.cached_input_tokens,
1085                    )
1086                })
1087                .unwrap_or(0.0);
1088            self.limit_tracker
1089                .record_turn(result.input_tokens, result.output_tokens, retry_cost);
1090
1091            status = self.determine_status(&result.response);
1092        }
1093
1094        // Build metadata WITH token tracking
1095        let stop_reason = status.as_canonical_str();
1096        let metadata = AgentTurnMetadata {
1097            thinking: result.thinking,
1098            response_text: result.response.clone(),
1099            input_tokens: total_input_tokens,
1100            output_tokens: total_output_tokens,
1101            cache_read_tokens: total_cached_input_tokens,
1102            stop_reason: stop_reason.to_string(),
1103        };
1104
1105        self.event_log.emit(EventKind::AgentTurn {
1106            task_id: Arc::from(self.task_id.as_str()),
1107            turn_index: retry_count + 1,
1108            kind: stop_reason.to_string(),
1109            metadata: Some(metadata),
1110        });
1111
1112        // Check guardrails with retry loop for `on_failure: retry`
1113        let max_guardrail_retries: u32 = 2;
1114        let mut guardrail_retry_count: u32 = 0;
1115        let mut guardrail_result = self.check_guardrails(&result.response);
1116
1117        while guardrail_result.should_retry() && guardrail_retry_count < max_guardrail_retries {
1118            guardrail_retry_count += 1;
1119
1120            // Check limits before starting a guardrail retry
1121            if let Some(exceeded) = self.limit_tracker.check_limits() {
1122                tracing::warn!(
1123                    task_id = %self.task_id,
1124                    limit = %exceeded.limit_type,
1125                    guardrail_retry = guardrail_retry_count,
1126                    "Agent limit exceeded during guardrail retry loop"
1127                );
1128                break;
1129            }
1130
1131            // Build feedback from guardrail failure messages
1132            let feedback = guardrail_result.failure_messages().join("; ");
1133            tracing::info!(
1134                task_id = %self.task_id,
1135                guardrail_retry = guardrail_retry_count,
1136                max = max_guardrail_retries,
1137                feedback = %feedback,
1138                "Retrying due to guardrail failure"
1139            );
1140
1141            // Emit guardrail retry event
1142            self.event_log.emit(EventKind::AgentTurn {
1143                task_id: Arc::from(self.task_id.as_str()),
1144                turn_index: retry_count + guardrail_retry_count + 1,
1145                kind: format!("guardrail_retry_{}", guardrail_retry_count),
1146                metadata: Some(AgentTurnMetadata {
1147                    thinking: None,
1148                    response_text: format!(
1149                        "Guardrail validation failed, retrying ({}/{}): {}",
1150                        guardrail_retry_count, max_guardrail_retries, feedback
1151                    ),
1152                    input_tokens: 0,
1153                    output_tokens: 0,
1154                    cache_read_tokens: 0,
1155                    stop_reason: "guardrail_retry".to_string(),
1156                }),
1157            });
1158
1159            // Append guardrail feedback to prompt
1160            current_prompt = format!(
1161                "{}\n\n[GUARDRAIL RETRY {}/{}] Your previous output failed quality validation:\n{}\n\nPlease fix these issues and try again.\n\nPrevious response:\n{}",
1162                base_prompt,
1163                guardrail_retry_count,
1164                max_guardrail_retries,
1165                feedback,
1166                result.response
1167            );
1168
1169            // Re-run without tools (agent already has context)
1170            result = self
1171                .stream_with_tools(model.clone(), &current_prompt, vec![], max_turns)
1172                .await?;
1173
1174            total_input_tokens += result.input_tokens;
1175            total_output_tokens += result.output_tokens;
1176            total_cached_input_tokens += result.cached_input_tokens;
1177
1178            // Record guardrail retry turn in limit tracker
1179            let gr_cost = provider_kind
1180                .map(|pk| {
1181                    crate::provider::cost::calculate_cost_with_cache(
1182                        pk,
1183                        model_name,
1184                        result.input_tokens,
1185                        result.output_tokens,
1186                        result.cached_input_tokens,
1187                    )
1188                })
1189                .unwrap_or(0.0);
1190            self.limit_tracker
1191                .record_turn(result.input_tokens, result.output_tokens, gr_cost);
1192
1193            // Re-determine status and re-check guardrails
1194            status = self.determine_status(&result.response);
1195            guardrail_result = self.check_guardrails(&result.response);
1196        }
1197
1198        // After guardrail retries exhausted, if still failing with retry -> accept anyway
1199        // (don't block forever, the guardrails_passed flag will indicate the failure)
1200        if guardrail_result.should_retry() {
1201            tracing::warn!(
1202                task_id = %self.task_id,
1203                retries = guardrail_retry_count,
1204                "Guardrail retries exhausted, accepting output with guardrails_passed=false"
1205            );
1206        }
1207
1208        let guardrails_passed = guardrail_result.is_passed();
1209
1210        // Override status when guardrails fail with terminal actions
1211        let status = if guardrail_result.should_fail() {
1212            RigAgentStatus::Failed
1213        } else if guardrail_result.should_escalate() {
1214            RigAgentStatus::Escalated(status.confidence().unwrap_or(0.0))
1215        } else {
1216            status
1217        };
1218
1219        let total_retries = retry_count + guardrail_retry_count;
1220
1221        // Emit ProviderResponded so runner cost summary includes agent work
1222        let total_cost = provider_kind
1223            .map(|pk| {
1224                crate::provider::cost::calculate_cost_with_cache(
1225                    pk,
1226                    model_name,
1227                    total_input_tokens,
1228                    total_output_tokens,
1229                    total_cached_input_tokens,
1230                )
1231            })
1232            .unwrap_or(0.0);
1233        self.event_log.emit(EventKind::ProviderResponded {
1234            task_id: Arc::from(self.task_id.as_str()),
1235            request_id: None,
1236            input_tokens: total_input_tokens,
1237            output_tokens: total_output_tokens,
1238            cache_read_tokens: total_cached_input_tokens,
1239            ttft_ms: None,
1240            finish_reason: stop_reason.to_string(),
1241            cost_usd: if total_cost.is_finite() {
1242                total_cost
1243            } else {
1244                0.0
1245            },
1246        });
1247
1248        Ok(RigAgentLoopResult {
1249            status: status.clone(),
1250            turns: (total_retries + 1) as usize,
1251            final_output: serde_json::json!({ "response": result.response }),
1252            total_tokens: total_input_tokens + total_output_tokens,
1253            confidence: status.confidence(),
1254            retry_count: total_retries,
1255            guardrails_passed,
1256            cost_usd: self.limit_tracker.cost_usd(),
1257            partial_result: None,
1258        })
1259    }
1260}