use std::sync::Arc;
use rig::agent::AgentBuilder;
use rig::client::{CompletionClient, ProviderClient};
use rig::completion::{Chat, CompletionModel};
use rig::message::Message;
use rig::providers::{anthropic, openai};
use serde_json;
use crate::error::NikaError;
use crate::event::{AgentTurnMetadata, EventKind};
use super::types::RigAgentLoopResult;
use super::RigAgentLoop;
impl RigAgentLoop {
pub fn add_to_history(&mut self, user_prompt: &str, assistant_response: &str) {
self.history.push(Message::user(user_prompt));
self.history.push(Message::assistant(assistant_response));
self.turn_count += 1;
}
pub fn push_message(&mut self, message: Message) {
self.history.push(message);
}
pub fn clear_history(&mut self) {
self.history.clear();
self.turn_count = 0;
}
pub fn history_len(&self) -> usize {
self.history.len()
}
pub fn turn_count(&self) -> u32 {
self.turn_count
}
pub fn history(&self) -> &[Message] {
&self.history
}
pub fn with_history(mut self, history: Vec<Message>) -> Self {
self.history = history;
self
}
pub async fn chat_continue(&mut self, prompt: &str) -> Result<RigAgentLoopResult, NikaError> {
let provider = self.params.provider.as_deref();
match provider {
Some(name) => {
let resolved = crate::core::find_provider(name).ok_or_else(|| {
NikaError::AgentValidationError {
reason: format!(
"Unknown provider: '{}'. Use 'claude', 'openai', 'mistral', 'groq', 'deepseek', 'gemini', or 'xai'.",
name
),
}
})?;
match resolved.id {
"anthropic" => self.chat_continue_claude(prompt).await,
"openai" => self.chat_continue_openai(prompt).await,
"mistral" => self.chat_continue_mistral(prompt).await,
"groq" => self.chat_continue_groq(prompt).await,
"deepseek" => self.chat_continue_deepseek(prompt).await,
"gemini" => self.chat_continue_gemini(prompt).await,
"xai" => self.chat_continue_xai(prompt).await,
other => Err(NikaError::AgentValidationError {
reason: format!("Provider '{}' is not supported for chat_continue.", other),
}),
}
}
None => {
let has_key = |key: &str| std::env::var(key).is_ok_and(|v| !v.trim().is_empty());
if has_key("ANTHROPIC_API_KEY") {
return self.chat_continue_claude(prompt).await;
}
if has_key("OPENAI_API_KEY") {
return self.chat_continue_openai(prompt).await;
}
if has_key("MISTRAL_API_KEY") {
return self.chat_continue_mistral(prompt).await;
}
if has_key("GROQ_API_KEY") {
return self.chat_continue_groq(prompt).await;
}
if has_key("DEEPSEEK_API_KEY") {
return self.chat_continue_deepseek(prompt).await;
}
if has_key("GEMINI_API_KEY") {
return self.chat_continue_gemini(prompt).await;
}
if has_key("XAI_API_KEY") {
return self.chat_continue_xai(prompt).await;
}
Err(NikaError::AgentValidationError {
reason: "chat_continue requires a configured provider or one of: ANTHROPIC_API_KEY, OPENAI_API_KEY, MISTRAL_API_KEY, GROQ_API_KEY, DEEPSEEK_API_KEY, GEMINI_API_KEY, or XAI_API_KEY".to_string(),
})
}
}
}
async fn chat_continue_claude(
&mut self,
prompt: &str,
) -> Result<RigAgentLoopResult, NikaError> {
let model_name = self.resolve_model_name()?;
let model = anthropic::Client::from_env().completion_model(&model_name);
self.chat_continue_with_model(prompt, model, &model_name)
.await
}
async fn chat_continue_openai(
&mut self,
prompt: &str,
) -> Result<RigAgentLoopResult, NikaError> {
let model_name = self.resolve_model_name()?;
let model = openai::Client::from_env().completion_model(&model_name);
self.chat_continue_with_model(prompt, model, &model_name)
.await
}
async fn chat_continue_mistral(
&mut self,
prompt: &str,
) -> Result<RigAgentLoopResult, NikaError> {
let model_name = self.resolve_model_name()?;
let model = rig::providers::mistral::Client::from_env().completion_model(&model_name);
self.chat_continue_with_model(prompt, model, &model_name)
.await
}
async fn chat_continue_groq(&mut self, prompt: &str) -> Result<RigAgentLoopResult, NikaError> {
let model_name = self.resolve_model_name()?;
let model = rig::providers::groq::Client::from_env().completion_model(&model_name);
self.chat_continue_with_model(prompt, model, &model_name)
.await
}
async fn chat_continue_deepseek(
&mut self,
prompt: &str,
) -> Result<RigAgentLoopResult, NikaError> {
let model_name = self.resolve_model_name()?;
let model = rig::providers::deepseek::Client::from_env().completion_model(&model_name);
self.chat_continue_with_model(prompt, model, &model_name)
.await
}
async fn chat_continue_gemini(
&mut self,
prompt: &str,
) -> Result<RigAgentLoopResult, NikaError> {
let model_name = self.resolve_model_name()?;
let model = rig::providers::gemini::Client::from_env().completion_model(&model_name);
self.chat_continue_with_model(prompt, model, &model_name)
.await
}
async fn chat_continue_xai(&mut self, prompt: &str) -> Result<RigAgentLoopResult, NikaError> {
let model_name = self.resolve_model_name()?;
let model = rig::providers::xai::Client::from_env().completion_model(&model_name);
self.chat_continue_with_model(prompt, model, &model_name)
.await
}
fn resolve_model_name(&self) -> Result<String, NikaError> {
let raw = self
.params
.model
.as_deref()
.ok_or_else(|| NikaError::ValidationError {
reason: "model field is required for LLM verbs (NIKA-034)".to_string(),
})?;
Ok(Self::strip_model_prefix(raw).to_string())
}
async fn chat_continue_with_model<M: CompletionModel>(
&mut self,
prompt: &str,
model: M,
model_name: &str,
) -> Result<RigAgentLoopResult, NikaError> {
let turn_index = self.turn_count + 1;
let preamble = self.inject_skills_into_prompt().await?;
self.event_log.emit(EventKind::AgentTurn {
task_id: Arc::from(self.task_id.as_str()),
turn_index,
kind: "started".to_string(),
metadata: None,
});
let effective_max_tokens = self.params.effective_max_tokens().unwrap_or(8192) as u64;
let mut builder = AgentBuilder::new(model)
.preamble(&preamble)
.max_tokens(effective_max_tokens);
if let Some(temp) = self.params.effective_temperature() {
builder = builder.temperature(f64::from(temp));
}
if self.params.has_explicit_tool_choice() {
let tool_choice = self.params.effective_tool_choice();
builder = builder.tool_choice(tool_choice.into());
}
if let Some(stop_params) = Self::stop_sequences_params(
self.params.provider.as_deref().unwrap_or(""),
&self.params.stop_sequences,
) {
builder = builder.additional_params(stop_params);
}
let tools = self.tools_as_boxed();
let agent = builder.tools(tools).build();
let response = agent
.chat(prompt, self.history.clone())
.await
.map_err(|e| NikaError::AgentExecutionError {
task_id: self.task_id.clone(),
reason: e.to_string(),
})?;
self.history.push(Message::user(prompt));
self.history.push(Message::assistant(&response));
self.turn_count += 1;
let status = self.determine_status(&response);
let stop_reason = status.as_canonical_str();
let metadata = AgentTurnMetadata::text_only(&response, stop_reason);
self.event_log.emit(EventKind::AgentTurn {
task_id: Arc::from(self.task_id.as_str()),
turn_index,
kind: stop_reason.to_string(),
metadata: Some(metadata),
});
let guardrail_result = self.check_guardrails(&response);
let guardrails_passed = guardrail_result.is_passed();
let est_input = prompt.chars().count().div_ceil(4) as u64;
let est_output = response.chars().count().div_ceil(4) as u64;
let provider_kind = crate::provider::cost::ProviderKind::parse(
self.params.provider.as_deref().unwrap_or(""),
);
let cost = provider_kind
.map(|pk| crate::provider::cost::calculate_cost(pk, model_name, est_input, est_output))
.unwrap_or(0.0);
Ok(RigAgentLoopResult {
status: status.clone(),
turns: turn_index as usize,
final_output: serde_json::json!({ "response": response }),
total_tokens: est_input + est_output,
confidence: status.confidence(),
retry_count: 0,
guardrails_passed,
cost_usd: cost,
partial_result: None,
})
}
}