herolib-ai 0.3.13

AI client with multi-provider support (Groq, OpenRouter, SambaNova) and automatic failover
Documentation
//! Prompt builder with verification support.
//!
//! This module provides a fluent API for building prompts and executing them
//! with optional verification and retry logic.

use crate::client::AiClient;
use crate::error::{AiError, AiResult};
use crate::model::Model;
use crate::types::{ChatCompletionResponse, Message};

/// Function type for verifying AI responses.
///
/// Returns `Ok(())` if the response is valid, or `Err(feedback)` with feedback
/// to be added to the prompt for retry.
pub type VerifyFn = Box<dyn Fn(&str) -> Result<(), String> + Send + Sync>;

/// Builder for constructing and executing AI prompts.
///
/// Supports:
/// - System and user messages
/// - Model selection with automatic provider failover
/// - Optional verification with retry logic
/// - Temperature and max tokens configuration
pub struct PromptBuilder<'a> {
    /// The AI client to use.
    client: &'a AiClient,
    /// The model to use.
    model: Model,
    /// System message (context/instructions).
    system_message: Option<String>,
    /// User messages.
    user_messages: Vec<String>,
    /// Conversation history.
    history: Vec<Message>,
    /// Temperature for sampling.
    temperature: Option<f32>,
    /// Maximum tokens to generate.
    max_tokens: Option<u32>,
    /// Verification function.
    verify_fn: Option<VerifyFn>,
    /// Maximum retry attempts for verification.
    max_retries: usize,
}

impl<'a> PromptBuilder<'a> {
    /// Creates a new PromptBuilder with the given client.
    pub fn new(client: &'a AiClient) -> Self {
        Self {
            client,
            model: Model::default_general(),
            system_message: None,
            user_messages: Vec::new(),
            history: Vec::new(),
            temperature: None,
            max_tokens: None,
            verify_fn: None,
            max_retries: 3,
        }
    }

    /// Sets the model to use.
    pub fn model(mut self, model: Model) -> Self {
        self.model = model;
        self
    }

    /// Sets the system message (context/instructions).
    pub fn system(mut self, message: impl Into<String>) -> Self {
        self.system_message = Some(message.into());
        self
    }

    /// Adds a user message.
    pub fn user(mut self, message: impl Into<String>) -> Self {
        self.user_messages.push(message.into());
        self
    }

    /// Sets the prompt (alias for user message).
    pub fn prompt(self, message: impl Into<String>) -> Self {
        self.user(message)
    }

    /// Adds conversation history.
    pub fn with_history(mut self, history: Vec<Message>) -> Self {
        self.history = history;
        self
    }

    /// Sets the temperature for sampling.
    pub fn temperature(mut self, temperature: f32) -> Self {
        self.temperature = Some(temperature);
        self
    }

    /// Sets the maximum tokens to generate.
    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
        self.max_tokens = Some(max_tokens);
        self
    }

    /// Sets a verification function.
    ///
    /// The function receives the response content and should return:
    /// - `Ok(())` if the response is valid
    /// - `Err(feedback)` with feedback to add to the prompt for retry
    pub fn verify<F>(mut self, verify_fn: F) -> Self
    where
        F: Fn(&str) -> Result<(), String> + Send + Sync + 'static,
    {
        self.verify_fn = Some(Box::new(verify_fn));
        self
    }

    /// Sets the maximum number of retry attempts for verification.
    pub fn max_retries(mut self, max_retries: usize) -> Self {
        self.max_retries = max_retries;
        self
    }

    /// Builds the messages for the request.
    fn build_messages(&self) -> Vec<Message> {
        let mut messages = Vec::new();

        // Add system message
        if let Some(system) = &self.system_message {
            messages.push(Message::system(system));
        }

        // Add history
        messages.extend(self.history.clone());

        // Add user messages
        for user_msg in &self.user_messages {
            messages.push(Message::user(user_msg));
        }

        messages
    }

    /// Executes the prompt and returns the full response.
    pub fn execute(&self) -> AiResult<ChatCompletionResponse> {
        let messages = self.build_messages();

        if messages.is_empty() {
            return Err(AiError::InvalidRequest("No messages provided".to_string()));
        }

        self.client
            .chat_with_options(self.model, messages, self.temperature, self.max_tokens)
    }

    /// Executes the prompt and returns just the content string.
    pub fn execute_content(&self) -> AiResult<String> {
        let response = self.execute()?;
        response
            .content()
            .map(|s| s.to_string())
            .ok_or_else(|| AiError::ParseError("No content in response".to_string()))
    }

    /// Executes the prompt with verification and retry logic.
    ///
    /// If a verification function is set, it will:
    /// 1. Execute the prompt
    /// 2. Verify the response
    /// 3. If verification fails, add the feedback to the conversation and retry
    /// 4. Repeat until verification passes or max retries exceeded
    pub fn execute_verified(&self) -> AiResult<String> {
        let Some(verify_fn) = &self.verify_fn else {
            // No verification, just execute normally
            return self.execute_content();
        };

        let mut messages = self.build_messages();
        let mut last_error = String::new();

        for attempt in 0..=self.max_retries {
            // Execute the request
            let response = self.client.chat_with_options(
                self.model,
                messages.clone(),
                self.temperature,
                self.max_tokens,
            )?;

            let content = response
                .content()
                .ok_or_else(|| AiError::ParseError("No content in response".to_string()))?;

            // Verify the response
            match verify_fn(content) {
                Ok(()) => return Ok(content.to_string()),
                Err(feedback) => {
                    last_error = feedback.clone();

                    if attempt < self.max_retries {
                        // Add the assistant's response and feedback to the conversation
                        messages.push(Message::assistant(content));
                        messages.push(Message::user(format!(
                            "The previous response was not acceptable. Please fix the following issue and try again:\n\n{}",
                            feedback
                        )));
                    }
                }
            }
        }

        Err(AiError::VerificationFailed {
            retries: self.max_retries,
            message: last_error,
        })
    }
}

/// Extension trait for AiClient to create prompt builders.
pub trait PromptBuilderExt {
    /// Creates a new prompt builder.
    fn prompt(&self) -> PromptBuilder<'_>;
}

impl PromptBuilderExt for AiClient {
    fn prompt(&self) -> PromptBuilder<'_> {
        PromptBuilder::new(self)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_prompt_builder_messages() {
        let client = AiClient::new();
        let builder = client
            .prompt()
            .system("You are a helpful assistant")
            .user("Hello")
            .user("How are you?");

        let messages = builder.build_messages();
        assert_eq!(messages.len(), 3);
        assert_eq!(messages[0].content, "You are a helpful assistant");
        assert_eq!(messages[1].content, "Hello");
        assert_eq!(messages[2].content, "How are you?");
    }

    #[test]
    fn test_prompt_builder_with_history() {
        let client = AiClient::new();
        let history = vec![
            Message::user("Previous question"),
            Message::assistant("Previous answer"),
        ];

        let builder = client
            .prompt()
            .system("System")
            .with_history(history)
            .user("New question");

        let messages = builder.build_messages();
        assert_eq!(messages.len(), 4);
        assert_eq!(messages[1].content, "Previous question");
        assert_eq!(messages[2].content, "Previous answer");
        assert_eq!(messages[3].content, "New question");
    }

    #[test]
    fn test_empty_messages_error() {
        let client = AiClient::new();
        let builder = client.prompt();

        let result = builder.execute();
        assert!(matches!(result, Err(AiError::InvalidRequest(_))));
    }
}