nika_engine/runtime/rig_agent_loop/
thinking.rs1use std::sync::Arc;
7
8use futures::StreamExt;
9use rig::client::{CompletionClient, ProviderClient};
10use rig::completion::CompletionModel as _;
11use rig::completion::GetTokenUsage;
12use rig::message::ReasoningContent;
13use rig::providers::anthropic;
14use rig::streaming::StreamedAssistantContent;
15use serde_json;
16use tokio::time::timeout;
17
18use crate::ast::guardrails::{escalation_required, immediate_failures, run_sync_guardrails};
19use crate::error::NikaError;
20use crate::event::{AgentTurnMetadata, EventKind};
21use crate::util::STREAM_CHUNK_TIMEOUT;
22
23use super::types::{GuardrailCheckResult, RigAgentLoopResult, RigAgentStatus};
24use super::RigAgentLoop;
25
26impl RigAgentLoop {
27 pub(crate) fn check_completion_signal(&self, output: &str) -> bool {
32 use crate::runtime::builtin::COMPLETION_MARKER;
33 output.contains(COMPLETION_MARKER)
34 }
35
36 pub fn check_guardrails(&self, output: &str) -> GuardrailCheckResult {
51 if self.params.guardrails.is_empty() {
52 return GuardrailCheckResult::AllPassed;
53 }
54
55 let results = run_sync_guardrails(&self.params.guardrails, output);
56 let mut all_passed = true;
57
58 let task_id: Arc<str> = Arc::from(self.task_id.as_str());
61
62 for result in &results {
64 if result.passed {
65 self.event_log.emit(EventKind::GuardrailPassed {
66 task_id: Arc::clone(&task_id),
67 guardrail_type: result.guardrail_type.clone(),
68 description: result.guardrail_id.clone(),
69 });
70 } else {
71 self.event_log.emit(EventKind::GuardrailFailed {
72 task_id: Arc::clone(&task_id),
73 guardrail_type: result.guardrail_type.clone(),
74 description: result.guardrail_id.clone(),
75 message: result
76 .message
77 .clone()
78 .unwrap_or_else(|| "Guardrail check failed".to_string()),
79 });
80 all_passed = false;
81 }
82 }
83
84 if all_passed {
85 return GuardrailCheckResult::AllPassed;
86 }
87
88 let immediate = immediate_failures(&results);
90 if !immediate.is_empty() {
91 return GuardrailCheckResult::FailedImmediate;
92 }
93
94 let escalations = escalation_required(&results);
96 if !escalations.is_empty() {
97 for result in escalations {
99 self.event_log.emit(EventKind::GuardrailEscalation {
100 task_id: Arc::clone(&task_id),
101 guardrail_type: result.guardrail_type.clone(),
102 guardrail_id: result.guardrail_id.clone(),
103 message: result
104 .message
105 .clone()
106 .unwrap_or_else(|| "Guardrail requires escalation".to_string()),
107 severity: "high".to_string(),
108 suggested_action: Some("Review agent output and provide guidance".to_string()),
109 });
110 }
111 return GuardrailCheckResult::FailedEscalate;
112 }
113
114 let failure_messages: Vec<String> = results
117 .iter()
118 .filter(|r| !r.passed)
119 .map(|r| {
120 r.message
121 .clone()
122 .unwrap_or_else(|| format!("Guardrail '{}' failed", r.guardrail_id))
123 })
124 .collect();
125 GuardrailCheckResult::FailedRetry(failure_messages)
126 }
127
128 pub fn determine_status(&self, output: &str) -> RigAgentStatus {
139 if self.check_completion_signal(output) {
140 use crate::runtime::builtin::parse_completion_response;
142
143 if let Some(response) = parse_completion_response(output) {
144 if let Some(confidence) = response.confidence {
146 return self.apply_routing(confidence);
148 }
149 }
150 return RigAgentStatus::ExplicitCompletion;
152 }
153
154 if let Some(ref completion_config) = self.params.completion {
156 if completion_config.check_pattern_match(output) {
157 return RigAgentStatus::ExplicitCompletion;
158 }
159 }
160
161 if let Some(ref completion_config) = self.params.completion {
165 use crate::ast::completion::CompletionMode;
166 if completion_config.mode == CompletionMode::Explicit {
167 tracing::debug!(
168 task_id = %self.task_id,
169 "Agent ended turn without calling nika:complete (mode: explicit)"
170 );
171 return RigAgentStatus::LowConfidence(0.0);
172 }
173 }
174
175 RigAgentStatus::NaturalCompletion
176 }
177
178 pub(crate) fn get_confidence_threshold(&self) -> f64 {
182 self.params
183 .effective_completion()
184 .and_then(|c| c.confidence)
185 .map(|conf| conf.threshold)
186 .unwrap_or(0.8)
187 }
188
189 pub(super) fn get_low_confidence_config(
193 &self,
194 ) -> Option<crate::ast::completion::OnLowConfidenceConfig> {
195 self.params
196 .effective_completion()
197 .and_then(|c| c.confidence)
198 .map(|conf| conf.on_low.clone())
199 }
200
201 pub(super) fn should_retry(&self, status: &RigAgentStatus, retry_count: u32) -> bool {
208 if !matches!(status, RigAgentStatus::LowConfidence(_)) {
209 return false;
210 }
211
212 let Some(config) = self.get_low_confidence_config() else {
213 return false;
214 };
215
216 config.action == crate::ast::completion::LowConfidenceAction::Retry
217 && retry_count < config.max_retries
218 }
219
220 pub(super) fn get_retry_feedback(&self, confidence: f64) -> String {
224 let config = self.get_low_confidence_config();
225 let threshold = self.get_confidence_threshold();
226
227 if let Some(feedback) = config.as_ref().and_then(|c| c.feedback.clone()) {
229 return format!(
230 "\n\n[RETRY: Your previous response had confidence {:.2}, below threshold {:.2}. {}]",
231 confidence, threshold, feedback
232 );
233 }
234
235 format!(
237 "\n\n[RETRY: Your previous response had confidence {:.2}, which is below the required threshold of {:.2}. Please reconsider your response and provide a higher confidence answer.]",
238 confidence, threshold
239 )
240 }
241
242 pub(crate) fn get_confidence_routing(
246 &self,
247 ) -> Option<crate::ast::completion::ConfidenceRouting> {
248 self.params
249 .effective_completion()
250 .and_then(|c| c.confidence)
251 .and_then(|conf| conf.routing.clone())
252 }
253
254 pub(crate) fn apply_routing(&self, confidence: f64) -> RigAgentStatus {
260 let Some(routing) = self.get_confidence_routing() else {
261 let threshold = self.get_confidence_threshold();
263 return if confidence >= threshold {
264 RigAgentStatus::HighConfidence(confidence)
265 } else {
266 RigAgentStatus::LowConfidence(confidence)
267 };
268 };
269
270 if let Some(high_min) = routing.high.min {
273 if confidence >= high_min {
274 return self.route_action_to_status(&routing.high.action, confidence);
275 }
276 }
277
278 if let Some(medium_min) = routing.medium.min {
280 if confidence >= medium_min {
281 return self.route_action_to_status(&routing.medium.action, confidence);
282 }
283 }
284
285 self.route_action_to_status(&routing.low.action, confidence)
287 }
288
289 pub(crate) fn route_action_to_status(
291 &self,
292 action: &crate::ast::completion::RouteAction,
293 confidence: f64,
294 ) -> RigAgentStatus {
295 use crate::ast::completion::RouteAction;
296
297 match action {
298 RouteAction::Accept => RigAgentStatus::HighConfidence(confidence),
299 RouteAction::AcceptWithFlag => RigAgentStatus::FlaggedForReview(confidence),
300 RouteAction::Retry => RigAgentStatus::LowConfidence(confidence),
301 RouteAction::Escalate => RigAgentStatus::Escalated(confidence),
302 }
303 }
304
305 pub async fn run_claude_with_thinking(&mut self) -> Result<RigAgentLoopResult, NikaError> {
315 let client = anthropic::Client::from_env();
317
318 let model_name =
320 Self::strip_model_prefix(self.params.model.as_deref().ok_or_else(|| {
321 NikaError::ValidationError {
322 reason: "model field is required for LLM verbs (NIKA-034)".to_string(),
323 }
324 })?);
325 let model = client.completion_model(model_name);
326
327 let thinking_budget = self.params.effective_thinking_budget();
330
331 let thinking_config = serde_json::json!({
333 "thinking": {
334 "type": "enabled",
335 "budget_tokens": thinking_budget
336 }
337 });
338
339 let preamble = self.inject_skills_into_prompt().await?;
342
343 let max_tokens = self
346 .params
347 .effective_max_tokens()
348 .unwrap_or((thinking_budget as u32) + 8192);
349
350 let mut request_builder = model
351 .completion_request(&self.params.prompt)
352 .preamble(preamble)
353 .max_tokens(max_tokens as u64)
354 .additional_params(thinking_config);
355
356 if let Some(temp) = self.params.effective_temperature() {
358 request_builder = request_builder.temperature(f64::from(temp));
359 }
360
361 let request = request_builder.build();
362
363 self.event_log.emit(EventKind::AgentTurn {
365 task_id: Arc::from(self.task_id.as_str()),
366 turn_index: 1,
367 kind: "started".to_string(),
368 metadata: None,
369 });
370
371 let mut stream =
373 model
374 .stream(request)
375 .await
376 .map_err(|e| NikaError::AgentExecutionError {
377 task_id: self.task_id.clone(),
378 reason: format!("Streaming request failed: {}", e),
379 })?;
380
381 let mut thinking_parts: Vec<String> = Vec::new();
383 let mut response_parts: Vec<String> = Vec::new();
384 let mut input_tokens: u64 = 0;
385 let mut output_tokens: u64 = 0;
386 let mut cached_input_tokens: u64 = 0;
387
388 loop {
390 let chunk_result = match timeout(STREAM_CHUNK_TIMEOUT, stream.next()).await {
391 Ok(Some(chunk)) => chunk,
392 Ok(None) => break, Err(_elapsed) => {
394 tracing::warn!(
396 task_id = %self.task_id,
397 timeout_secs = STREAM_CHUNK_TIMEOUT.as_secs(),
398 "Thinking stream timed out waiting for chunk"
399 );
400 return Err(NikaError::Timeout {
401 operation: format!("thinking capture (task: {})", self.task_id),
402 duration_ms: STREAM_CHUNK_TIMEOUT.as_millis() as u64,
403 });
404 }
405 };
406
407 match chunk_result {
408 Ok(content) => match content {
409 StreamedAssistantContent::Text(text) => {
410 response_parts.push(text.text);
411 }
412 StreamedAssistantContent::ReasoningDelta { reasoning, .. } => {
413 thinking_parts.push(reasoning);
414 }
415 StreamedAssistantContent::Reasoning(reasoning) => {
416 for block in reasoning.content {
418 if let ReasoningContent::Text { text, .. } = block {
419 thinking_parts.push(text);
420 }
421 }
422 }
423 StreamedAssistantContent::Final(final_resp) => {
424 if let Some(usage) = final_resp.token_usage() {
426 input_tokens = usage.input_tokens;
427 output_tokens = usage.output_tokens;
428 cached_input_tokens = usage.cached_input_tokens;
429 }
430 }
431 _ => {
432 tracing::debug!("Streaming event: {:?}", content);
434 }
435 },
436 Err(e) => {
437 return Err(NikaError::ThinkingCaptureFailed {
439 reason: format!(
440 "Streaming chunk failed for task '{}': {}",
441 self.task_id, e
442 ),
443 });
444 }
445 }
446 }
447
448 let thinking = if thinking_parts.is_empty() {
450 None
451 } else {
452 Some(thinking_parts.concat())
453 };
454 let response = response_parts.concat();
455
456 let status = self.determine_status(&response);
458
459 let stop_reason = status.as_canonical_str();
461 let metadata = AgentTurnMetadata {
462 thinking,
463 response_text: response.clone(),
464 input_tokens,
465 output_tokens,
466 cache_read_tokens: cached_input_tokens,
467 stop_reason: stop_reason.to_string(),
468 };
469
470 self.event_log.emit(EventKind::AgentTurn {
472 task_id: Arc::from(self.task_id.as_str()),
473 turn_index: 1,
474 kind: stop_reason.to_string(),
475 metadata: Some(metadata),
476 });
477
478 let guardrail_result = self.check_guardrails(&response);
480 let guardrails_passed = guardrail_result.is_passed();
481
482 Ok(RigAgentLoopResult {
483 status: status.clone(),
484 turns: 1,
485 final_output: serde_json::json!({ "response": response }),
486 total_tokens: input_tokens + output_tokens,
487 confidence: status.confidence(),
488 retry_count: 0,
489 guardrails_passed,
490 cost_usd: crate::provider::cost::calculate_cost(
491 crate::provider::cost::ProviderKind::Claude,
492 model_name,
493 input_tokens,
494 output_tokens,
495 ),
496 partial_result: None,
497 })
498 }
499}