Skip to main content

nika_engine/runtime/rig_agent_loop/
thinking.rs

1//! Extended thinking, guardrails, and confidence routing
2//!
3//! Contains: check_completion_signal, check_guardrails,
4//! determine_status, confidence routing, and run_claude_with_thinking.
5
6use std::sync::Arc;
7
8use futures::StreamExt;
9use rig::client::{CompletionClient, ProviderClient};
10use rig::completion::CompletionModel as _;
11use rig::completion::GetTokenUsage;
12use rig::message::ReasoningContent;
13use rig::providers::anthropic;
14use rig::streaming::StreamedAssistantContent;
15use serde_json;
16use tokio::time::timeout;
17
18use crate::ast::guardrails::{escalation_required, immediate_failures, run_sync_guardrails};
19use crate::error::NikaError;
20use crate::event::{AgentTurnMetadata, EventKind};
21use crate::util::STREAM_CHUNK_TIMEOUT;
22
23use super::types::{GuardrailCheckResult, RigAgentLoopResult, RigAgentStatus};
24use super::RigAgentLoop;
25
26impl RigAgentLoop {
27    /// Check if output contains explicit completion signal
28    ///
29    /// This checks for the COMPLETION_MARKER in tool results, indicating the
30    /// agent called nika:complete to signal task completion.
31    pub(crate) fn check_completion_signal(&self, output: &str) -> bool {
32        use crate::runtime::builtin::COMPLETION_MARKER;
33        output.contains(COMPLETION_MARKER)
34    }
35
36    /// Run all configured guardrails against the output
37    ///
38    /// Emits events for each guardrail result:
39    /// - `GuardrailPassed`: Guardrail check succeeded
40    /// - `GuardrailFailed`: Guardrail check failed
41    /// - `GuardrailEscalation`: Guardrail failed with `on_failure: escalate`
42    ///
43    /// Returns `GuardrailCheckResult` indicating the appropriate action:
44    /// - `AllPassed`: All guardrails passed
45    /// - `FailedRetry`: Some failed with `on_failure: retry` (default)
46    /// - `FailedEscalate`: Some failed with `on_failure: escalate`
47    /// - `FailedImmediate`: Some failed with `on_failure: fail`
48    ///
49    /// Priority: Immediate > Escalate > Retry
50    pub fn check_guardrails(&self, output: &str) -> GuardrailCheckResult {
51        if self.params.guardrails.is_empty() {
52            return GuardrailCheckResult::AllPassed;
53        }
54
55        let results = run_sync_guardrails(&self.params.guardrails, output);
56        let mut all_passed = true;
57
58        // PERF: Hoist Arc allocation outside loops to avoid
59        // per-iteration allocation overhead. Arc::clone is cheap (atomic inc).
60        let task_id: Arc<str> = Arc::from(self.task_id.as_str());
61
62        // Emit events for each result
63        for result in &results {
64            if result.passed {
65                self.event_log.emit(EventKind::GuardrailPassed {
66                    task_id: Arc::clone(&task_id),
67                    guardrail_type: result.guardrail_type.clone(),
68                    description: result.guardrail_id.clone(),
69                });
70            } else {
71                self.event_log.emit(EventKind::GuardrailFailed {
72                    task_id: Arc::clone(&task_id),
73                    guardrail_type: result.guardrail_type.clone(),
74                    description: result.guardrail_id.clone(),
75                    message: result
76                        .message
77                        .clone()
78                        .unwrap_or_else(|| "Guardrail check failed".to_string()),
79                });
80                all_passed = false;
81            }
82        }
83
84        if all_passed {
85            return GuardrailCheckResult::AllPassed;
86        }
87
88        // Check for immediate failures first (highest priority)
89        let immediate = immediate_failures(&results);
90        if !immediate.is_empty() {
91            return GuardrailCheckResult::FailedImmediate;
92        }
93
94        // Check for escalation requirements
95        let escalations = escalation_required(&results);
96        if !escalations.is_empty() {
97            // Emit escalation events for each guardrail requiring escalation
98            for result in escalations {
99                self.event_log.emit(EventKind::GuardrailEscalation {
100                    task_id: Arc::clone(&task_id),
101                    guardrail_type: result.guardrail_type.clone(),
102                    guardrail_id: result.guardrail_id.clone(),
103                    message: result
104                        .message
105                        .clone()
106                        .unwrap_or_else(|| "Guardrail requires escalation".to_string()),
107                    severity: "high".to_string(),
108                    suggested_action: Some("Review agent output and provide guidance".to_string()),
109                });
110            }
111            return GuardrailCheckResult::FailedEscalate;
112        }
113
114        // Default: retry (on_failure: retry)
115        // Collect error messages from failed guardrails for retry feedback
116        let failure_messages: Vec<String> = results
117            .iter()
118            .filter(|r| !r.passed)
119            .map(|r| {
120                r.message
121                    .clone()
122                    .unwrap_or_else(|| format!("Guardrail '{}' failed", r.guardrail_id))
123            })
124            .collect();
125        GuardrailCheckResult::FailedRetry(failure_messages)
126    }
127
128    /// Determine agent status based on output content
129    ///
130    /// Checks in order:
131    /// 1. Explicit completion via nika:complete tool
132    ///    - With confidence: compare against threshold → HighConfidence/LowConfidence
133    ///    - Without confidence: ExplicitCompletion
134    /// 2. Pattern completion (if `completion.mode: pattern`)
135    /// 3. Natural completion — but only when mode is NOT `explicit`
136    ///    When `completion.mode: explicit`, natural end-of-turn without
137    ///    `nika:complete` returns `LowConfidence(0.0)` to trigger retry.
138    pub fn determine_status(&self, output: &str) -> RigAgentStatus {
139        if self.check_completion_signal(output) {
140            // Parse the completion response to extract confidence
141            use crate::runtime::builtin::parse_completion_response;
142
143            if let Some(response) = parse_completion_response(output) {
144                // Check if confidence is provided
145                if let Some(confidence) = response.confidence {
146                    // Use apply_routing for confidence-based status
147                    return self.apply_routing(confidence);
148                }
149            }
150            // No confidence provided, treat as explicit completion
151            return RigAgentStatus::ExplicitCompletion;
152        }
153
154        // Check pattern-based completion
155        if let Some(ref completion_config) = self.params.completion {
156            if completion_config.check_pattern_match(output) {
157                return RigAgentStatus::ExplicitCompletion;
158            }
159        }
160
161        // Check if mode is explicit — if so, natural end-of-turn is NOT completion.
162        // The agent must call nika:complete. Returning LowConfidence triggers retry
163        // in the generic provider loop.
164        if let Some(ref completion_config) = self.params.completion {
165            use crate::ast::completion::CompletionMode;
166            if completion_config.mode == CompletionMode::Explicit {
167                tracing::debug!(
168                    task_id = %self.task_id,
169                    "Agent ended turn without calling nika:complete (mode: explicit)"
170                );
171                return RigAgentStatus::LowConfidence(0.0);
172            }
173        }
174
175        RigAgentStatus::NaturalCompletion
176    }
177
178    /// Get confidence threshold from completion config
179    ///
180    /// Returns the configured threshold, or 0.8 as default.
181    pub(crate) fn get_confidence_threshold(&self) -> f64 {
182        self.params
183            .effective_completion()
184            .and_then(|c| c.confidence)
185            .map(|conf| conf.threshold)
186            .unwrap_or(0.8)
187    }
188
189    /// Get low confidence configuration
190    ///
191    /// Returns the OnLowConfidenceConfig if available, or None.
192    pub(super) fn get_low_confidence_config(
193        &self,
194    ) -> Option<crate::ast::completion::OnLowConfidenceConfig> {
195        self.params
196            .effective_completion()
197            .and_then(|c| c.confidence)
198            .map(|conf| conf.on_low.clone())
199    }
200
201    /// Check if retry should be attempted for low confidence
202    ///
203    /// Returns true if:
204    /// - Status is LowConfidence
205    /// - on_low.action is Retry
206    /// - retry_count < max_retries
207    pub(super) fn should_retry(&self, status: &RigAgentStatus, retry_count: u32) -> bool {
208        if !matches!(status, RigAgentStatus::LowConfidence(_)) {
209            return false;
210        }
211
212        let Some(config) = self.get_low_confidence_config() else {
213            return false;
214        };
215
216        config.action == crate::ast::completion::LowConfidenceAction::Retry
217            && retry_count < config.max_retries
218    }
219
220    /// Get retry feedback message
221    ///
222    /// Returns the feedback message to append to prompt on retry.
223    pub(super) fn get_retry_feedback(&self, confidence: f64) -> String {
224        let config = self.get_low_confidence_config();
225        let threshold = self.get_confidence_threshold();
226
227        // Use custom feedback if configured
228        if let Some(feedback) = config.as_ref().and_then(|c| c.feedback.clone()) {
229            return format!(
230                "\n\n[RETRY: Your previous response had confidence {:.2}, below threshold {:.2}. {}]",
231                confidence, threshold, feedback
232            );
233        }
234
235        // Default feedback
236        format!(
237            "\n\n[RETRY: Your previous response had confidence {:.2}, which is below the required threshold of {:.2}. Please reconsider your response and provide a higher confidence answer.]",
238            confidence, threshold
239        )
240    }
241
242    /// Get confidence routing configuration
243    ///
244    /// Returns the ConfidenceRouting if available, or None.
245    pub(crate) fn get_confidence_routing(
246        &self,
247    ) -> Option<crate::ast::completion::ConfidenceRouting> {
248        self.params
249            .effective_completion()
250            .and_then(|c| c.confidence)
251            .and_then(|conf| conf.routing.clone())
252    }
253
254    /// Apply confidence-based routing
255    ///
256    /// Uses routing configuration to determine the appropriate status
257    /// based on confidence level. If routing is not configured, falls back
258    /// to simple threshold-based High/Low confidence.
259    pub(crate) fn apply_routing(&self, confidence: f64) -> RigAgentStatus {
260        let Some(routing) = self.get_confidence_routing() else {
261            // No routing configured, use simple threshold
262            let threshold = self.get_confidence_threshold();
263            return if confidence >= threshold {
264                RigAgentStatus::HighConfidence(confidence)
265            } else {
266                RigAgentStatus::LowConfidence(confidence)
267            };
268        };
269
270        // Determine which route applies based on confidence
271        // Check high route first (highest min value)
272        if let Some(high_min) = routing.high.min {
273            if confidence >= high_min {
274                return self.route_action_to_status(&routing.high.action, confidence);
275            }
276        }
277
278        // Check medium route (typically >= threshold)
279        if let Some(medium_min) = routing.medium.min {
280            if confidence >= medium_min {
281                return self.route_action_to_status(&routing.medium.action, confidence);
282            }
283        }
284
285        // Default to low route
286        self.route_action_to_status(&routing.low.action, confidence)
287    }
288
289    /// Convert a RouteAction to RigAgentStatus
290    pub(crate) fn route_action_to_status(
291        &self,
292        action: &crate::ast::completion::RouteAction,
293        confidence: f64,
294    ) -> RigAgentStatus {
295        use crate::ast::completion::RouteAction;
296
297        match action {
298            RouteAction::Accept => RigAgentStatus::HighConfidence(confidence),
299            RouteAction::AcceptWithFlag => RigAgentStatus::FlaggedForReview(confidence),
300            RouteAction::Retry => RigAgentStatus::LowConfidence(confidence),
301            RouteAction::Escalate => RigAgentStatus::Escalated(confidence),
302        }
303    }
304
305    /// Run the agent loop with extended thinking enabled (Claude only).
306    ///
307    /// Uses rig-core's streaming API to capture thinking blocks from Claude's
308    /// extended thinking feature. The thinking is accumulated and stored in
309    /// the AgentTurnMetadata for observability.
310    ///
311    /// # Errors
312    /// - NIKA-113: Extended thinking failed
313    /// - NIKA-110: Agent execution error
314    pub async fn run_claude_with_thinking(&mut self) -> Result<RigAgentLoopResult, NikaError> {
315        // Create Anthropic client from environment
316        let client = anthropic::Client::from_env();
317
318        // Get model name — validated by analyzer (NIKA-034)
319        let model_name =
320            Self::strip_model_prefix(self.params.model.as_deref().ok_or_else(|| {
321                NikaError::ValidationError {
322                    reason: "model field is required for LLM verbs (NIKA-034)".to_string(),
323                }
324            })?);
325        let model = client.completion_model(model_name);
326
327        // Build completion request with thinking enabled
328        // Use configurable thinking_budget from AgentParams (default: 4096)
329        let thinking_budget = self.params.effective_thinking_budget();
330
331        // Extended thinking requires additional_params (Claude-specific API feature)
332        let thinking_config = serde_json::json!({
333            "thinking": {
334                "type": "enabled",
335                "budget_tokens": thinking_budget
336            }
337        });
338
339        // Build request with native temperature method
340        // Inject skills into system prompt if configured
341        let preamble = self.inject_skills_into_prompt().await?;
342
343        // Use effective_max_tokens (required for extended thinking)
344        // Claude requires max_tokens > thinking_budget
345        let max_tokens = self
346            .params
347            .effective_max_tokens()
348            .unwrap_or((thinking_budget as u32) + 8192);
349
350        let mut request_builder = model
351            .completion_request(&self.params.prompt)
352            .preamble(preamble)
353            .max_tokens(max_tokens as u64)
354            .additional_params(thinking_config);
355
356        // Apply temperature using native rig-core method
357        if let Some(temp) = self.params.effective_temperature() {
358            request_builder = request_builder.temperature(f64::from(temp));
359        }
360
361        let request = request_builder.build();
362
363        // Emit start event
364        self.event_log.emit(EventKind::AgentTurn {
365            task_id: Arc::from(self.task_id.as_str()),
366            turn_index: 1,
367            kind: "started".to_string(),
368            metadata: None,
369        });
370
371        // Execute streaming request
372        let mut stream =
373            model
374                .stream(request)
375                .await
376                .map_err(|e| NikaError::AgentExecutionError {
377                    task_id: self.task_id.clone(),
378                    reason: format!("Streaming request failed: {}", e),
379                })?;
380
381        // Accumulate thinking, response, and token usage
382        let mut thinking_parts: Vec<String> = Vec::new();
383        let mut response_parts: Vec<String> = Vec::new();
384        let mut input_tokens: u64 = 0;
385        let mut output_tokens: u64 = 0;
386        let mut cached_input_tokens: u64 = 0;
387
388        // Per-chunk timeout to prevent hanging streams
389        loop {
390            let chunk_result = match timeout(STREAM_CHUNK_TIMEOUT, stream.next()).await {
391                Ok(Some(chunk)) => chunk,
392                Ok(None) => break, // Stream ended normally
393                Err(_elapsed) => {
394                    // Timeout - stream stalled
395                    tracing::warn!(
396                        task_id = %self.task_id,
397                        timeout_secs = STREAM_CHUNK_TIMEOUT.as_secs(),
398                        "Thinking stream timed out waiting for chunk"
399                    );
400                    return Err(NikaError::Timeout {
401                        operation: format!("thinking capture (task: {})", self.task_id),
402                        duration_ms: STREAM_CHUNK_TIMEOUT.as_millis() as u64,
403                    });
404                }
405            };
406
407            match chunk_result {
408                Ok(content) => match content {
409                    StreamedAssistantContent::Text(text) => {
410                        response_parts.push(text.text);
411                    }
412                    StreamedAssistantContent::ReasoningDelta { reasoning, .. } => {
413                        thinking_parts.push(reasoning);
414                    }
415                    StreamedAssistantContent::Reasoning(reasoning) => {
416                        // Final reasoning block - extract text from content blocks
417                        for block in reasoning.content {
418                            if let ReasoningContent::Text { text, .. } = block {
419                                thinking_parts.push(text);
420                            }
421                        }
422                    }
423                    StreamedAssistantContent::Final(final_resp) => {
424                        // Extract token usage from final response
425                        if let Some(usage) = final_resp.token_usage() {
426                            input_tokens = usage.input_tokens;
427                            output_tokens = usage.output_tokens;
428                            cached_input_tokens = usage.cached_input_tokens;
429                        }
430                    }
431                    _ => {
432                        // Tool calls and other events - handled by agent loop
433                        tracing::debug!("Streaming event: {:?}", content);
434                    }
435                },
436                Err(e) => {
437                    // Return error instead of silently swallowing - critical for debugging
438                    return Err(NikaError::ThinkingCaptureFailed {
439                        reason: format!(
440                            "Streaming chunk failed for task '{}': {}",
441                            self.task_id, e
442                        ),
443                    });
444                }
445            }
446        }
447
448        // Combine accumulated text
449        let thinking = if thinking_parts.is_empty() {
450            None
451        } else {
452            Some(thinking_parts.concat())
453        };
454        let response = response_parts.concat();
455
456        // Determine status
457        let status = self.determine_status(&response);
458
459        // Build metadata with thinking and token usage
460        let stop_reason = status.as_canonical_str();
461        let metadata = AgentTurnMetadata {
462            thinking,
463            response_text: response.clone(),
464            input_tokens,
465            output_tokens,
466            cache_read_tokens: cached_input_tokens,
467            stop_reason: stop_reason.to_string(),
468        };
469
470        // Emit completion event
471        self.event_log.emit(EventKind::AgentTurn {
472            task_id: Arc::from(self.task_id.as_str()),
473            turn_index: 1,
474            kind: stop_reason.to_string(),
475            metadata: Some(metadata),
476        });
477
478        // Check guardrails
479        let guardrail_result = self.check_guardrails(&response);
480        let guardrails_passed = guardrail_result.is_passed();
481
482        Ok(RigAgentLoopResult {
483            status: status.clone(),
484            turns: 1,
485            final_output: serde_json::json!({ "response": response }),
486            total_tokens: input_tokens + output_tokens,
487            confidence: status.confidence(),
488            retry_count: 0,
489            guardrails_passed,
490            cost_usd: crate::provider::cost::calculate_cost(
491                crate::provider::cost::ProviderKind::Claude,
492                model_name,
493                input_tokens,
494                output_tokens,
495            ),
496            partial_result: None,
497        })
498    }
499}