Skip to main content

nika_engine/runtime/rig_agent_loop/
chat.rs

1//! Chat history management and multi-turn conversation support
2//!
3//! Provides methods for managing conversation history and continuing
4//! multi-turn conversations with different LLM providers.
5
6use std::sync::Arc;
7
8use rig::agent::AgentBuilder;
9use rig::client::{CompletionClient, ProviderClient};
10use rig::completion::{Chat, CompletionModel};
11use rig::message::Message;
12use rig::providers::{anthropic, openai};
13use serde_json;
14
15use crate::error::NikaError;
16use crate::event::{AgentTurnMetadata, EventKind};
17
18use super::types::RigAgentLoopResult;
19use super::RigAgentLoop;
20
21impl RigAgentLoop {
22    // =========================================================================
23    // Chat History Management
24    // =========================================================================
25
26    /// Add a user/assistant turn to the conversation history
27    ///
28    /// Call this after each completed turn to maintain context for `chat_continue()`.
29    pub fn add_to_history(&mut self, user_prompt: &str, assistant_response: &str) {
30        self.history.push(Message::user(user_prompt));
31        self.history.push(Message::assistant(assistant_response));
32        self.turn_count += 1;
33    }
34
35    /// Add a single message to the history
36    pub fn push_message(&mut self, message: Message) {
37        self.history.push(message);
38    }
39
40    /// Clear all conversation history and reset turn count
41    pub fn clear_history(&mut self) {
42        self.history.clear();
43        self.turn_count = 0;
44    }
45
46    /// Get the current history length (number of messages)
47    pub fn history_len(&self) -> usize {
48        self.history.len()
49    }
50
51    /// Get the number of completed turns (user + assistant exchanges).
52    pub fn turn_count(&self) -> u32 {
53        self.turn_count
54    }
55
56    /// Get a reference to the conversation history
57    pub fn history(&self) -> &[Message] {
58        &self.history
59    }
60
61    /// Create with pre-existing history
62    ///
63    /// Useful for resuming conversations or injecting context.
64    pub fn with_history(mut self, history: Vec<Message>) -> Self {
65        self.history = history;
66        self
67    }
68
69    /// Continue a conversation using the accumulated history
70    ///
71    /// Uses rig-core's `Chat` trait for multi-turn conversations.
72    /// The history is automatically updated with the user prompt and response.
73    ///
74    /// # Example
75    /// ```rust,ignore
76    /// // First turn
77    /// let result1 = agent.run_claude().await?;
78    /// agent.add_to_history("Initial prompt", &extract_text(&result1));
79    ///
80    /// // Continue conversation
81    /// let result2 = agent.chat_continue("Follow-up question").await?;
82    /// // History now contains both turns
83    /// ```
84    pub async fn chat_continue(&mut self, prompt: &str) -> Result<RigAgentLoopResult, NikaError> {
85        // Use configured provider first, fallback to env-var auto-detection
86        let provider = self.params.provider.as_deref();
87        match provider {
88            Some(name) => {
89                // Resolve alias to canonical provider ID via core catalog
90                let resolved = crate::core::find_provider(name).ok_or_else(|| {
91                    NikaError::AgentValidationError {
92                        reason: format!(
93                            "Unknown provider: '{}'. Use 'claude', 'openai', 'mistral', 'groq', 'deepseek', 'gemini', or 'xai'.",
94                            name
95                        ),
96                    }
97                })?;
98                match resolved.id {
99                    "anthropic" => self.chat_continue_claude(prompt).await,
100                    "openai" => self.chat_continue_openai(prompt).await,
101                    "mistral" => self.chat_continue_mistral(prompt).await,
102                    "groq" => self.chat_continue_groq(prompt).await,
103                    "deepseek" => self.chat_continue_deepseek(prompt).await,
104                    "gemini" => self.chat_continue_gemini(prompt).await,
105                    "xai" => self.chat_continue_xai(prompt).await,
106                    other => Err(NikaError::AgentValidationError {
107                        reason: format!("Provider '{}' is not supported for chat_continue.", other),
108                    }),
109                }
110            }
111            None => {
112                // Auto-detect: check env vars in priority order
113                let has_key = |key: &str| std::env::var(key).is_ok_and(|v| !v.trim().is_empty());
114
115                if has_key("ANTHROPIC_API_KEY") {
116                    return self.chat_continue_claude(prompt).await;
117                }
118                if has_key("OPENAI_API_KEY") {
119                    return self.chat_continue_openai(prompt).await;
120                }
121                if has_key("MISTRAL_API_KEY") {
122                    return self.chat_continue_mistral(prompt).await;
123                }
124                if has_key("GROQ_API_KEY") {
125                    return self.chat_continue_groq(prompt).await;
126                }
127                if has_key("DEEPSEEK_API_KEY") {
128                    return self.chat_continue_deepseek(prompt).await;
129                }
130                if has_key("GEMINI_API_KEY") {
131                    return self.chat_continue_gemini(prompt).await;
132                }
133                if has_key("XAI_API_KEY") {
134                    return self.chat_continue_xai(prompt).await;
135                }
136                Err(NikaError::AgentValidationError {
137                    reason: "chat_continue requires a configured provider or one of: ANTHROPIC_API_KEY, OPENAI_API_KEY, MISTRAL_API_KEY, GROQ_API_KEY, DEEPSEEK_API_KEY, GEMINI_API_KEY, or XAI_API_KEY".to_string(),
138                })
139            }
140        }
141    }
142
143    // =========================================================================
144    // Provider-specific chat_continue wrappers
145    // =========================================================================
146    //
147    // Each wrapper creates a provider client and delegates to the generic
148    // `chat_continue_with_model()`. Adding a new provider requires only
149    // a 4-line wrapper + a match arm in `chat_continue()`.
150
151    async fn chat_continue_claude(
152        &mut self,
153        prompt: &str,
154    ) -> Result<RigAgentLoopResult, NikaError> {
155        let model_name = self.resolve_model_name()?;
156        let model = anthropic::Client::from_env().completion_model(&model_name);
157        self.chat_continue_with_model(prompt, model, &model_name)
158            .await
159    }
160
161    async fn chat_continue_openai(
162        &mut self,
163        prompt: &str,
164    ) -> Result<RigAgentLoopResult, NikaError> {
165        let model_name = self.resolve_model_name()?;
166        let model = openai::Client::from_env().completion_model(&model_name);
167        self.chat_continue_with_model(prompt, model, &model_name)
168            .await
169    }
170
171    async fn chat_continue_mistral(
172        &mut self,
173        prompt: &str,
174    ) -> Result<RigAgentLoopResult, NikaError> {
175        let model_name = self.resolve_model_name()?;
176        let model = rig::providers::mistral::Client::from_env().completion_model(&model_name);
177        self.chat_continue_with_model(prompt, model, &model_name)
178            .await
179    }
180
181    async fn chat_continue_groq(&mut self, prompt: &str) -> Result<RigAgentLoopResult, NikaError> {
182        let model_name = self.resolve_model_name()?;
183        let model = rig::providers::groq::Client::from_env().completion_model(&model_name);
184        self.chat_continue_with_model(prompt, model, &model_name)
185            .await
186    }
187
188    async fn chat_continue_deepseek(
189        &mut self,
190        prompt: &str,
191    ) -> Result<RigAgentLoopResult, NikaError> {
192        let model_name = self.resolve_model_name()?;
193        let model = rig::providers::deepseek::Client::from_env().completion_model(&model_name);
194        self.chat_continue_with_model(prompt, model, &model_name)
195            .await
196    }
197
198    async fn chat_continue_gemini(
199        &mut self,
200        prompt: &str,
201    ) -> Result<RigAgentLoopResult, NikaError> {
202        let model_name = self.resolve_model_name()?;
203        let model = rig::providers::gemini::Client::from_env().completion_model(&model_name);
204        self.chat_continue_with_model(prompt, model, &model_name)
205            .await
206    }
207
208    async fn chat_continue_xai(&mut self, prompt: &str) -> Result<RigAgentLoopResult, NikaError> {
209        let model_name = self.resolve_model_name()?;
210        let model = rig::providers::xai::Client::from_env().completion_model(&model_name);
211        self.chat_continue_with_model(prompt, model, &model_name)
212            .await
213    }
214
215    // =========================================================================
216    // Shared Implementation
217    // =========================================================================
218
219    /// Extract and validate the model name from params.
220    ///
221    /// Returns an owned String to avoid holding an immutable borrow on `self`
222    /// across the `&mut self` call to `chat_continue_with_model`.
223    fn resolve_model_name(&self) -> Result<String, NikaError> {
224        let raw = self
225            .params
226            .model
227            .as_deref()
228            .ok_or_else(|| NikaError::ValidationError {
229                reason: "model field is required for LLM verbs (NIKA-034)".to_string(),
230            })?;
231        Ok(Self::strip_model_prefix(raw).to_string())
232    }
233
234    /// Generic chat continuation — all provider wrappers delegate here.
235    ///
236    /// Builds a rig Agent from the given CompletionModel, runs the chat with
237    /// accumulated history, updates history, emits telemetry events, estimates
238    /// token costs, and checks guardrails.
239    ///
240    /// **Note:** Token tracking uses char-based estimation (Chat trait returns
241    /// only String, no usage metadata). Use `run_claude()` / `run_openai()` etc.
242    /// for single-turn requests with full streaming token tracking.
243    async fn chat_continue_with_model<M: CompletionModel>(
244        &mut self,
245        prompt: &str,
246        model: M,
247        model_name: &str,
248    ) -> Result<RigAgentLoopResult, NikaError> {
249        let turn_index = self.turn_count + 1;
250
251        // Inject skills into system prompt if configured
252        let preamble = self.inject_skills_into_prompt().await?;
253
254        // Emit start event
255        self.event_log.emit(EventKind::AgentTurn {
256            task_id: Arc::from(self.task_id.as_str()),
257            turn_index,
258            kind: "started".to_string(),
259            metadata: None,
260        });
261
262        // Build agent with full config
263        let effective_max_tokens = self.params.effective_max_tokens().unwrap_or(8192) as u64;
264        let mut builder = AgentBuilder::new(model)
265            .preamble(&preamble)
266            .max_tokens(effective_max_tokens);
267
268        if let Some(temp) = self.params.effective_temperature() {
269            builder = builder.temperature(f64::from(temp));
270        }
271
272        if self.params.has_explicit_tool_choice() {
273            let tool_choice = self.params.effective_tool_choice();
274            builder = builder.tool_choice(tool_choice.into());
275        }
276
277        if let Some(stop_params) = Self::stop_sequences_params(
278            self.params.provider.as_deref().unwrap_or(""),
279            &self.params.stop_sequences,
280        ) {
281            builder = builder.additional_params(stop_params);
282        }
283
284        let tools = self.tools_as_boxed();
285        let agent = builder.tools(tools).build();
286
287        let response = agent
288            .chat(prompt, self.history.clone())
289            .await
290            .map_err(|e| NikaError::AgentExecutionError {
291                task_id: self.task_id.clone(),
292                reason: e.to_string(),
293            })?;
294
295        // Update history and increment turn count
296        self.history.push(Message::user(prompt));
297        self.history.push(Message::assistant(&response));
298        self.turn_count += 1;
299
300        // Determine status
301        let status = self.determine_status(&response);
302
303        // Emit completion
304        let stop_reason = status.as_canonical_str();
305        let metadata = AgentTurnMetadata::text_only(&response, stop_reason);
306
307        self.event_log.emit(EventKind::AgentTurn {
308            task_id: Arc::from(self.task_id.as_str()),
309            turn_index,
310            kind: stop_reason.to_string(),
311            metadata: Some(metadata),
312        });
313
314        // Check guardrails
315        let guardrail_result = self.check_guardrails(&response);
316        let guardrails_passed = guardrail_result.is_passed();
317
318        // Estimate tokens for cost tracking (Chat trait returns only String, no metadata)
319        let est_input = prompt.chars().count().div_ceil(4) as u64;
320        let est_output = response.chars().count().div_ceil(4) as u64;
321        let provider_kind = crate::provider::cost::ProviderKind::parse(
322            self.params.provider.as_deref().unwrap_or(""),
323        );
324        let cost = provider_kind
325            .map(|pk| crate::provider::cost::calculate_cost(pk, model_name, est_input, est_output))
326            .unwrap_or(0.0);
327
328        Ok(RigAgentLoopResult {
329            status: status.clone(),
330            turns: turn_index as usize,
331            final_output: serde_json::json!({ "response": response }),
332            total_tokens: est_input + est_output,
333            confidence: status.confidence(),
334            retry_count: 0,
335            guardrails_passed,
336            cost_usd: cost,
337            partial_result: None,
338        })
339    }
340}