nika_engine/runtime/rig_agent_loop/
chat.rs1use std::sync::Arc;
7
8use rig::agent::AgentBuilder;
9use rig::client::{CompletionClient, ProviderClient};
10use rig::completion::{Chat, CompletionModel};
11use rig::message::Message;
12use rig::providers::{anthropic, openai};
13use serde_json;
14
15use crate::error::NikaError;
16use crate::event::{AgentTurnMetadata, EventKind};
17
18use super::types::RigAgentLoopResult;
19use super::RigAgentLoop;
20
21impl RigAgentLoop {
22 pub fn add_to_history(&mut self, user_prompt: &str, assistant_response: &str) {
30 self.history.push(Message::user(user_prompt));
31 self.history.push(Message::assistant(assistant_response));
32 self.turn_count += 1;
33 }
34
35 pub fn push_message(&mut self, message: Message) {
37 self.history.push(message);
38 }
39
40 pub fn clear_history(&mut self) {
42 self.history.clear();
43 self.turn_count = 0;
44 }
45
46 pub fn history_len(&self) -> usize {
48 self.history.len()
49 }
50
51 pub fn turn_count(&self) -> u32 {
53 self.turn_count
54 }
55
56 pub fn history(&self) -> &[Message] {
58 &self.history
59 }
60
61 pub fn with_history(mut self, history: Vec<Message>) -> Self {
65 self.history = history;
66 self
67 }
68
69 pub async fn chat_continue(&mut self, prompt: &str) -> Result<RigAgentLoopResult, NikaError> {
85 let provider = self.params.provider.as_deref();
87 match provider {
88 Some(name) => {
89 let resolved = crate::core::find_provider(name).ok_or_else(|| {
91 NikaError::AgentValidationError {
92 reason: format!(
93 "Unknown provider: '{}'. Use 'claude', 'openai', 'mistral', 'groq', 'deepseek', 'gemini', or 'xai'.",
94 name
95 ),
96 }
97 })?;
98 match resolved.id {
99 "anthropic" => self.chat_continue_claude(prompt).await,
100 "openai" => self.chat_continue_openai(prompt).await,
101 "mistral" => self.chat_continue_mistral(prompt).await,
102 "groq" => self.chat_continue_groq(prompt).await,
103 "deepseek" => self.chat_continue_deepseek(prompt).await,
104 "gemini" => self.chat_continue_gemini(prompt).await,
105 "xai" => self.chat_continue_xai(prompt).await,
106 other => Err(NikaError::AgentValidationError {
107 reason: format!("Provider '{}' is not supported for chat_continue.", other),
108 }),
109 }
110 }
111 None => {
112 let has_key = |key: &str| std::env::var(key).is_ok_and(|v| !v.trim().is_empty());
114
115 if has_key("ANTHROPIC_API_KEY") {
116 return self.chat_continue_claude(prompt).await;
117 }
118 if has_key("OPENAI_API_KEY") {
119 return self.chat_continue_openai(prompt).await;
120 }
121 if has_key("MISTRAL_API_KEY") {
122 return self.chat_continue_mistral(prompt).await;
123 }
124 if has_key("GROQ_API_KEY") {
125 return self.chat_continue_groq(prompt).await;
126 }
127 if has_key("DEEPSEEK_API_KEY") {
128 return self.chat_continue_deepseek(prompt).await;
129 }
130 if has_key("GEMINI_API_KEY") {
131 return self.chat_continue_gemini(prompt).await;
132 }
133 if has_key("XAI_API_KEY") {
134 return self.chat_continue_xai(prompt).await;
135 }
136 Err(NikaError::AgentValidationError {
137 reason: "chat_continue requires a configured provider or one of: ANTHROPIC_API_KEY, OPENAI_API_KEY, MISTRAL_API_KEY, GROQ_API_KEY, DEEPSEEK_API_KEY, GEMINI_API_KEY, or XAI_API_KEY".to_string(),
138 })
139 }
140 }
141 }
142
143 async fn chat_continue_claude(
152 &mut self,
153 prompt: &str,
154 ) -> Result<RigAgentLoopResult, NikaError> {
155 let model_name = self.resolve_model_name()?;
156 let model = anthropic::Client::from_env().completion_model(&model_name);
157 self.chat_continue_with_model(prompt, model, &model_name)
158 .await
159 }
160
161 async fn chat_continue_openai(
162 &mut self,
163 prompt: &str,
164 ) -> Result<RigAgentLoopResult, NikaError> {
165 let model_name = self.resolve_model_name()?;
166 let model = openai::Client::from_env().completion_model(&model_name);
167 self.chat_continue_with_model(prompt, model, &model_name)
168 .await
169 }
170
171 async fn chat_continue_mistral(
172 &mut self,
173 prompt: &str,
174 ) -> Result<RigAgentLoopResult, NikaError> {
175 let model_name = self.resolve_model_name()?;
176 let model = rig::providers::mistral::Client::from_env().completion_model(&model_name);
177 self.chat_continue_with_model(prompt, model, &model_name)
178 .await
179 }
180
181 async fn chat_continue_groq(&mut self, prompt: &str) -> Result<RigAgentLoopResult, NikaError> {
182 let model_name = self.resolve_model_name()?;
183 let model = rig::providers::groq::Client::from_env().completion_model(&model_name);
184 self.chat_continue_with_model(prompt, model, &model_name)
185 .await
186 }
187
188 async fn chat_continue_deepseek(
189 &mut self,
190 prompt: &str,
191 ) -> Result<RigAgentLoopResult, NikaError> {
192 let model_name = self.resolve_model_name()?;
193 let model = rig::providers::deepseek::Client::from_env().completion_model(&model_name);
194 self.chat_continue_with_model(prompt, model, &model_name)
195 .await
196 }
197
198 async fn chat_continue_gemini(
199 &mut self,
200 prompt: &str,
201 ) -> Result<RigAgentLoopResult, NikaError> {
202 let model_name = self.resolve_model_name()?;
203 let model = rig::providers::gemini::Client::from_env().completion_model(&model_name);
204 self.chat_continue_with_model(prompt, model, &model_name)
205 .await
206 }
207
208 async fn chat_continue_xai(&mut self, prompt: &str) -> Result<RigAgentLoopResult, NikaError> {
209 let model_name = self.resolve_model_name()?;
210 let model = rig::providers::xai::Client::from_env().completion_model(&model_name);
211 self.chat_continue_with_model(prompt, model, &model_name)
212 .await
213 }
214
215 fn resolve_model_name(&self) -> Result<String, NikaError> {
224 let raw = self
225 .params
226 .model
227 .as_deref()
228 .ok_or_else(|| NikaError::ValidationError {
229 reason: "model field is required for LLM verbs (NIKA-034)".to_string(),
230 })?;
231 Ok(Self::strip_model_prefix(raw).to_string())
232 }
233
234 async fn chat_continue_with_model<M: CompletionModel>(
244 &mut self,
245 prompt: &str,
246 model: M,
247 model_name: &str,
248 ) -> Result<RigAgentLoopResult, NikaError> {
249 let turn_index = self.turn_count + 1;
250
251 let preamble = self.inject_skills_into_prompt().await?;
253
254 self.event_log.emit(EventKind::AgentTurn {
256 task_id: Arc::from(self.task_id.as_str()),
257 turn_index,
258 kind: "started".to_string(),
259 metadata: None,
260 });
261
262 let effective_max_tokens = self.params.effective_max_tokens().unwrap_or(8192) as u64;
264 let mut builder = AgentBuilder::new(model)
265 .preamble(&preamble)
266 .max_tokens(effective_max_tokens);
267
268 if let Some(temp) = self.params.effective_temperature() {
269 builder = builder.temperature(f64::from(temp));
270 }
271
272 if self.params.has_explicit_tool_choice() {
273 let tool_choice = self.params.effective_tool_choice();
274 builder = builder.tool_choice(tool_choice.into());
275 }
276
277 if let Some(stop_params) = Self::stop_sequences_params(
278 self.params.provider.as_deref().unwrap_or(""),
279 &self.params.stop_sequences,
280 ) {
281 builder = builder.additional_params(stop_params);
282 }
283
284 let tools = self.tools_as_boxed();
285 let agent = builder.tools(tools).build();
286
287 let response = agent
288 .chat(prompt, self.history.clone())
289 .await
290 .map_err(|e| NikaError::AgentExecutionError {
291 task_id: self.task_id.clone(),
292 reason: e.to_string(),
293 })?;
294
295 self.history.push(Message::user(prompt));
297 self.history.push(Message::assistant(&response));
298 self.turn_count += 1;
299
300 let status = self.determine_status(&response);
302
303 let stop_reason = status.as_canonical_str();
305 let metadata = AgentTurnMetadata::text_only(&response, stop_reason);
306
307 self.event_log.emit(EventKind::AgentTurn {
308 task_id: Arc::from(self.task_id.as_str()),
309 turn_index,
310 kind: stop_reason.to_string(),
311 metadata: Some(metadata),
312 });
313
314 let guardrail_result = self.check_guardrails(&response);
316 let guardrails_passed = guardrail_result.is_passed();
317
318 let est_input = prompt.chars().count().div_ceil(4) as u64;
320 let est_output = response.chars().count().div_ceil(4) as u64;
321 let provider_kind = crate::provider::cost::ProviderKind::parse(
322 self.params.provider.as_deref().unwrap_or(""),
323 );
324 let cost = provider_kind
325 .map(|pk| crate::provider::cost::calculate_cost(pk, model_name, est_input, est_output))
326 .unwrap_or(0.0);
327
328 Ok(RigAgentLoopResult {
329 status: status.clone(),
330 turns: turn_index as usize,
331 final_output: serde_json::json!({ "response": response }),
332 total_tokens: est_input + est_output,
333 confidence: status.confidence(),
334 retry_count: 0,
335 guardrails_passed,
336 cost_usd: cost,
337 partial_result: None,
338 })
339 }
340}