nika-engine 0.47.0

Nika workflow engine — embeddable runtime, provider, DAG, and binding logic
Documentation
//! Chat history management and multi-turn conversation support
//!
//! Provides methods for managing conversation history and continuing
//! multi-turn conversations with different LLM providers.

use std::sync::Arc;

use rig::agent::AgentBuilder;
use rig::client::{CompletionClient, ProviderClient};
use rig::completion::{Chat, CompletionModel};
use rig::message::Message;
use rig::providers::{anthropic, openai};
use serde_json;

use crate::error::NikaError;
use crate::event::{AgentTurnMetadata, EventKind};

use super::types::RigAgentLoopResult;
use super::RigAgentLoop;

impl RigAgentLoop {
    // =========================================================================
    // Chat History Management
    // =========================================================================

    /// Add a user/assistant turn to the conversation history
    ///
    /// Call this after each completed turn to maintain context for `chat_continue()`.
    pub fn add_to_history(&mut self, user_prompt: &str, assistant_response: &str) {
        self.history.push(Message::user(user_prompt));
        self.history.push(Message::assistant(assistant_response));
        self.turn_count += 1;
    }

    /// Add a single message to the history
    pub fn push_message(&mut self, message: Message) {
        self.history.push(message);
    }

    /// Clear all conversation history and reset turn count
    pub fn clear_history(&mut self) {
        self.history.clear();
        self.turn_count = 0;
    }

    /// Get the current history length (number of messages)
    pub fn history_len(&self) -> usize {
        self.history.len()
    }

    /// Get the number of completed turns (user + assistant exchanges).
    pub fn turn_count(&self) -> u32 {
        self.turn_count
    }

    /// Get a reference to the conversation history
    pub fn history(&self) -> &[Message] {
        &self.history
    }

    /// Create with pre-existing history
    ///
    /// Useful for resuming conversations or injecting context.
    pub fn with_history(mut self, history: Vec<Message>) -> Self {
        self.history = history;
        self
    }

    /// Continue a conversation using the accumulated history
    ///
    /// Uses rig-core's `Chat` trait for multi-turn conversations.
    /// The history is automatically updated with the user prompt and response.
    ///
    /// # Example
    /// ```rust,ignore
    /// // First turn
    /// let result1 = agent.run_claude().await?;
    /// agent.add_to_history("Initial prompt", &extract_text(&result1));
    ///
    /// // Continue conversation
    /// let result2 = agent.chat_continue("Follow-up question").await?;
    /// // History now contains both turns
    /// ```
    pub async fn chat_continue(&mut self, prompt: &str) -> Result<RigAgentLoopResult, NikaError> {
        // Use configured provider first, fallback to env-var auto-detection
        let provider = self.params.provider.as_deref();
        match provider {
            Some(name) => {
                // Resolve alias to canonical provider ID via core catalog
                let resolved = crate::core::find_provider(name).ok_or_else(|| {
                    NikaError::AgentValidationError {
                        reason: format!(
                            "Unknown provider: '{}'. Use 'claude', 'openai', 'mistral', 'groq', 'deepseek', 'gemini', or 'xai'.",
                            name
                        ),
                    }
                })?;
                match resolved.id {
                    "anthropic" => self.chat_continue_claude(prompt).await,
                    "openai" => self.chat_continue_openai(prompt).await,
                    "mistral" => self.chat_continue_mistral(prompt).await,
                    "groq" => self.chat_continue_groq(prompt).await,
                    "deepseek" => self.chat_continue_deepseek(prompt).await,
                    "gemini" => self.chat_continue_gemini(prompt).await,
                    "xai" => self.chat_continue_xai(prompt).await,
                    other => Err(NikaError::AgentValidationError {
                        reason: format!("Provider '{}' is not supported for chat_continue.", other),
                    }),
                }
            }
            None => {
                // Auto-detect: check env vars in priority order
                let has_key = |key: &str| std::env::var(key).is_ok_and(|v| !v.trim().is_empty());

                if has_key("ANTHROPIC_API_KEY") {
                    return self.chat_continue_claude(prompt).await;
                }
                if has_key("OPENAI_API_KEY") {
                    return self.chat_continue_openai(prompt).await;
                }
                if has_key("MISTRAL_API_KEY") {
                    return self.chat_continue_mistral(prompt).await;
                }
                if has_key("GROQ_API_KEY") {
                    return self.chat_continue_groq(prompt).await;
                }
                if has_key("DEEPSEEK_API_KEY") {
                    return self.chat_continue_deepseek(prompt).await;
                }
                if has_key("GEMINI_API_KEY") {
                    return self.chat_continue_gemini(prompt).await;
                }
                if has_key("XAI_API_KEY") {
                    return self.chat_continue_xai(prompt).await;
                }
                Err(NikaError::AgentValidationError {
                    reason: "chat_continue requires a configured provider or one of: ANTHROPIC_API_KEY, OPENAI_API_KEY, MISTRAL_API_KEY, GROQ_API_KEY, DEEPSEEK_API_KEY, GEMINI_API_KEY, or XAI_API_KEY".to_string(),
                })
            }
        }
    }

    // =========================================================================
    // Provider-specific chat_continue wrappers
    // =========================================================================
    //
    // Each wrapper creates a provider client and delegates to the generic
    // `chat_continue_with_model()`. Adding a new provider requires only
    // a 4-line wrapper + a match arm in `chat_continue()`.

    async fn chat_continue_claude(
        &mut self,
        prompt: &str,
    ) -> Result<RigAgentLoopResult, NikaError> {
        let model_name = self.resolve_model_name()?;
        let model = anthropic::Client::from_env().completion_model(&model_name);
        self.chat_continue_with_model(prompt, model, &model_name)
            .await
    }

    async fn chat_continue_openai(
        &mut self,
        prompt: &str,
    ) -> Result<RigAgentLoopResult, NikaError> {
        let model_name = self.resolve_model_name()?;
        let model = openai::Client::from_env().completion_model(&model_name);
        self.chat_continue_with_model(prompt, model, &model_name)
            .await
    }

    async fn chat_continue_mistral(
        &mut self,
        prompt: &str,
    ) -> Result<RigAgentLoopResult, NikaError> {
        let model_name = self.resolve_model_name()?;
        let model = rig::providers::mistral::Client::from_env().completion_model(&model_name);
        self.chat_continue_with_model(prompt, model, &model_name)
            .await
    }

    async fn chat_continue_groq(&mut self, prompt: &str) -> Result<RigAgentLoopResult, NikaError> {
        let model_name = self.resolve_model_name()?;
        let model = rig::providers::groq::Client::from_env().completion_model(&model_name);
        self.chat_continue_with_model(prompt, model, &model_name)
            .await
    }

    async fn chat_continue_deepseek(
        &mut self,
        prompt: &str,
    ) -> Result<RigAgentLoopResult, NikaError> {
        let model_name = self.resolve_model_name()?;
        let model = rig::providers::deepseek::Client::from_env().completion_model(&model_name);
        self.chat_continue_with_model(prompt, model, &model_name)
            .await
    }

    async fn chat_continue_gemini(
        &mut self,
        prompt: &str,
    ) -> Result<RigAgentLoopResult, NikaError> {
        let model_name = self.resolve_model_name()?;
        let model = rig::providers::gemini::Client::from_env().completion_model(&model_name);
        self.chat_continue_with_model(prompt, model, &model_name)
            .await
    }

    async fn chat_continue_xai(&mut self, prompt: &str) -> Result<RigAgentLoopResult, NikaError> {
        let model_name = self.resolve_model_name()?;
        let model = rig::providers::xai::Client::from_env().completion_model(&model_name);
        self.chat_continue_with_model(prompt, model, &model_name)
            .await
    }

    // =========================================================================
    // Shared Implementation
    // =========================================================================

    /// Extract and validate the model name from params.
    ///
    /// Returns an owned String to avoid holding an immutable borrow on `self`
    /// across the `&mut self` call to `chat_continue_with_model`.
    fn resolve_model_name(&self) -> Result<String, NikaError> {
        let raw = self
            .params
            .model
            .as_deref()
            .ok_or_else(|| NikaError::ValidationError {
                reason: "model field is required for LLM verbs (NIKA-034)".to_string(),
            })?;
        Ok(Self::strip_model_prefix(raw).to_string())
    }

    /// Generic chat continuation — all provider wrappers delegate here.
    ///
    /// Builds a rig Agent from the given CompletionModel, runs the chat with
    /// accumulated history, updates history, emits telemetry events, estimates
    /// token costs, and checks guardrails.
    ///
    /// **Note:** Token tracking uses char-based estimation (Chat trait returns
    /// only String, no usage metadata). Use `run_claude()` / `run_openai()` etc.
    /// for single-turn requests with full streaming token tracking.
    async fn chat_continue_with_model<M: CompletionModel>(
        &mut self,
        prompt: &str,
        model: M,
        model_name: &str,
    ) -> Result<RigAgentLoopResult, NikaError> {
        let turn_index = self.turn_count + 1;

        // Inject skills into system prompt if configured
        let preamble = self.inject_skills_into_prompt().await?;

        // Emit start event
        self.event_log.emit(EventKind::AgentTurn {
            task_id: Arc::from(self.task_id.as_str()),
            turn_index,
            kind: "started".to_string(),
            metadata: None,
        });

        // Build agent with full config
        let effective_max_tokens = self.params.effective_max_tokens().unwrap_or(8192) as u64;
        let mut builder = AgentBuilder::new(model)
            .preamble(&preamble)
            .max_tokens(effective_max_tokens);

        if let Some(temp) = self.params.effective_temperature() {
            builder = builder.temperature(f64::from(temp));
        }

        if self.params.has_explicit_tool_choice() {
            let tool_choice = self.params.effective_tool_choice();
            builder = builder.tool_choice(tool_choice.into());
        }

        if let Some(stop_params) = Self::stop_sequences_params(
            self.params.provider.as_deref().unwrap_or(""),
            &self.params.stop_sequences,
        ) {
            builder = builder.additional_params(stop_params);
        }

        let tools = self.tools_as_boxed();
        let agent = builder.tools(tools).build();

        let response = agent
            .chat(prompt, self.history.clone())
            .await
            .map_err(|e| NikaError::AgentExecutionError {
                task_id: self.task_id.clone(),
                reason: e.to_string(),
            })?;

        // Update history and increment turn count
        self.history.push(Message::user(prompt));
        self.history.push(Message::assistant(&response));
        self.turn_count += 1;

        // Determine status
        let status = self.determine_status(&response);

        // Emit completion
        let stop_reason = status.as_canonical_str();
        let metadata = AgentTurnMetadata::text_only(&response, stop_reason);

        self.event_log.emit(EventKind::AgentTurn {
            task_id: Arc::from(self.task_id.as_str()),
            turn_index,
            kind: stop_reason.to_string(),
            metadata: Some(metadata),
        });

        // Check guardrails
        let guardrail_result = self.check_guardrails(&response);
        let guardrails_passed = guardrail_result.is_passed();

        // Estimate tokens for cost tracking (Chat trait returns only String, no metadata)
        let est_input = prompt.chars().count().div_ceil(4) as u64;
        let est_output = response.chars().count().div_ceil(4) as u64;
        let provider_kind = crate::provider::cost::ProviderKind::parse(
            self.params.provider.as_deref().unwrap_or(""),
        );
        let cost = provider_kind
            .map(|pk| crate::provider::cost::calculate_cost(pk, model_name, est_input, est_output))
            .unwrap_or(0.0);

        Ok(RigAgentLoopResult {
            status: status.clone(),
            turns: turn_index as usize,
            final_output: serde_json::json!({ "response": response }),
            total_tokens: est_input + est_output,
            confidence: status.confidence(),
            retry_count: 0,
            guardrails_passed,
            cost_usd: cost,
            partial_result: None,
        })
    }
}