herolib-ai 0.3.13

AI client with multi-provider support (Groq, OpenRouter, SambaNova) and automatic failover
Documentation
//! OpenAI-compatible types for chat completions.
//!
//! This module defines the request and response types compatible with the OpenAI API.

use serde::{Deserialize, Serialize};

/// Role of a message in the conversation.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
    /// System message setting context/behavior.
    System,
    /// User message (input).
    User,
    /// Assistant message (output from model).
    Assistant,
    /// Tool/function result.
    Tool,
}

/// A message in the conversation.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
    /// The role of the message author.
    pub role: Role,
    /// The content of the message.
    pub content: String,
    /// Optional name for the message author.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub name: Option<String>,
}

impl Message {
    /// Creates a new system message.
    pub fn system(content: impl Into<String>) -> Self {
        Self {
            role: Role::System,
            content: content.into(),
            name: None,
        }
    }

    /// Creates a new user message.
    pub fn user(content: impl Into<String>) -> Self {
        Self {
            role: Role::User,
            content: content.into(),
            name: None,
        }
    }

    /// Creates a new assistant message.
    pub fn assistant(content: impl Into<String>) -> Self {
        Self {
            role: Role::Assistant,
            content: content.into(),
            name: None,
        }
    }

    /// Sets the name of the message author.
    pub fn with_name(mut self, name: impl Into<String>) -> Self {
        self.name = Some(name.into());
        self
    }
}

/// Request body for chat completions.
#[derive(Debug, Clone, Serialize)]
pub struct ChatCompletionRequest {
    /// The model to use.
    pub model: String,
    /// The messages in the conversation.
    pub messages: Vec<Message>,
    /// Sampling temperature (0.0 to 2.0).
    #[serde(skip_serializing_if = "Option::is_none")]
    pub temperature: Option<f32>,
    /// Maximum tokens to generate.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub max_tokens: Option<u32>,
    /// Top-p sampling parameter.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub top_p: Option<f32>,
    /// Frequency penalty (-2.0 to 2.0).
    #[serde(skip_serializing_if = "Option::is_none")]
    pub frequency_penalty: Option<f32>,
    /// Presence penalty (-2.0 to 2.0).
    #[serde(skip_serializing_if = "Option::is_none")]
    pub presence_penalty: Option<f32>,
    /// Stop sequences.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub stop: Option<Vec<String>>,
    /// Whether to stream the response.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub stream: Option<bool>,
}

impl ChatCompletionRequest {
    /// Creates a new chat completion request.
    pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
        Self {
            model: model.into(),
            messages,
            temperature: None,
            max_tokens: None,
            top_p: None,
            frequency_penalty: None,
            presence_penalty: None,
            stop: None,
            stream: Some(false),
        }
    }

    /// Sets the temperature.
    pub fn with_temperature(mut self, temperature: f32) -> Self {
        self.temperature = Some(temperature);
        self
    }

    /// Sets the maximum tokens.
    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
        self.max_tokens = Some(max_tokens);
        self
    }

    /// Sets the top-p value.
    pub fn with_top_p(mut self, top_p: f32) -> Self {
        self.top_p = Some(top_p);
        self
    }

    /// Sets stop sequences.
    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
        self.stop = Some(stop);
        self
    }
}

/// Response from chat completions API.
#[derive(Debug, Clone, Deserialize)]
pub struct ChatCompletionResponse {
    /// Unique identifier for the completion.
    pub id: String,
    /// Object type (always "chat.completion").
    pub object: String,
    /// Unix timestamp of creation.
    pub created: u64,
    /// Model used for completion.
    pub model: String,
    /// List of completion choices.
    pub choices: Vec<Choice>,
    /// Token usage statistics.
    #[serde(default)]
    pub usage: Option<Usage>,
}

impl ChatCompletionResponse {
    /// Returns the content of the first choice.
    pub fn content(&self) -> Option<&str> {
        self.choices
            .first()
            .and_then(|c| c.message.as_ref())
            .map(|m| m.effective_content())
    }

    /// Returns the finish reason of the first choice.
    pub fn finish_reason(&self) -> Option<&str> {
        self.choices
            .first()
            .and_then(|c| c.finish_reason.as_deref())
    }
}

/// A completion choice.
#[derive(Debug, Clone, Deserialize)]
pub struct Choice {
    /// Index of this choice.
    pub index: u32,
    /// The generated message.
    pub message: Option<ResponseMessage>,
    /// Delta for streaming responses.
    pub delta: Option<ResponseMessage>,
    /// Reason the model stopped generating.
    pub finish_reason: Option<String>,
}

/// Message in a response.
#[derive(Debug, Clone, Deserialize)]
pub struct ResponseMessage {
    /// Role of the message.
    pub role: Option<Role>,
    /// Content of the message.
    #[serde(default)]
    pub content: String,
    /// Reasoning content (used by some models like GPT-OSS in reasoning mode).
    /// If content is empty, this field may contain the actual response.
    #[serde(default)]
    pub reasoning: Option<String>,
}

impl ResponseMessage {
    /// Get the effective content - prefers content field, falls back to reasoning.
    pub fn effective_content(&self) -> &str {
        if self.content.is_empty() {
            self.reasoning.as_deref().unwrap_or("")
        } else {
            &self.content
        }
    }
}

/// Token usage statistics.
#[derive(Debug, Clone, Deserialize, Default)]
pub struct Usage {
    /// Tokens in the prompt.
    #[serde(default)]
    pub prompt_tokens: u32,
    /// Tokens in the completion.
    #[serde(default)]
    pub completion_tokens: u32,
    /// Total tokens used.
    #[serde(default)]
    pub total_tokens: u32,
}

/// Error response from the API.
#[derive(Debug, Clone, Deserialize)]
pub struct ApiErrorResponse {
    /// Error details.
    pub error: ApiErrorDetail,
}

/// Error detail from the API.
#[derive(Debug, Clone, Deserialize)]
pub struct ApiErrorDetail {
    /// Error message.
    pub message: String,
    /// Error type.
    #[serde(rename = "type")]
    pub error_type: Option<String>,
    /// Error code.
    pub code: Option<String>,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_message_creation() {
        let system = Message::system("You are helpful");
        assert_eq!(system.role, Role::System);
        assert_eq!(system.content, "You are helpful");

        let user = Message::user("Hello").with_name("John");
        assert_eq!(user.role, Role::User);
        assert_eq!(user.name, Some("John".to_string()));
    }

    #[test]
    fn test_request_builder() {
        let request = ChatCompletionRequest::new("gpt-4", vec![Message::user("Hello")])
            .with_temperature(0.7)
            .with_max_tokens(1000);

        assert_eq!(request.model, "gpt-4");
        assert_eq!(request.temperature, Some(0.7));
        assert_eq!(request.max_tokens, Some(1000));
    }

    #[test]
    fn test_response_parsing() {
        let json = r#"{
            "id": "chatcmpl-123",
            "object": "chat.completion",
            "created": 1677652288,
            "model": "gpt-4",
            "choices": [{
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": "Hello! How can I help you?"
                },
                "finish_reason": "stop"
            }],
            "usage": {
                "prompt_tokens": 10,
                "completion_tokens": 20,
                "total_tokens": 30
            }
        }"#;

        let response: ChatCompletionResponse = serde_json::from_str(json).unwrap();
        assert_eq!(response.content(), Some("Hello! How can I help you?"));
        assert_eq!(response.finish_reason(), Some("stop"));
        assert_eq!(response.usage.as_ref().unwrap().total_tokens, 30);
    }
}