grok_api 0.1.71

Rust client library for the Grok AI API (xAI)
Documentation
//! Common types used across the API

use serde::{Deserialize, Serialize};

/// Role of a message in a conversation
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
    /// System message (instructions for the model)
    System,

    /// User message
    User,

    /// Assistant/model response
    Assistant,

    /// Tool/function response
    Tool,
}

impl Role {
    /// Convert role to string
    pub fn as_str(&self) -> &'static str {
        match self {
            Role::System => "system",
            Role::User => "user",
            Role::Assistant => "assistant",
            Role::Tool => "tool",
        }
    }
}

impl std::fmt::Display for Role {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.as_str())
    }
}

impl From<Role> for String {
    fn from(role: Role) -> Self {
        role.as_str().to_string()
    }
}

/// Reason why a completion finished
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
    /// Natural completion
    Stop,

    /// Maximum tokens reached
    Length,

    /// Tool/function call
    ToolCalls,

    /// Content filtered
    ContentFilter,

    /// Other/unknown reason
    #[serde(other)]
    Other,
}

impl StopReason {
    /// Check if completion finished naturally
    pub fn is_complete(&self) -> bool {
        matches!(self, StopReason::Stop)
    }

    /// Check if more tokens were needed
    pub fn needs_more_tokens(&self) -> bool {
        matches!(self, StopReason::Length)
    }

    /// Check if tool calls were made
    pub fn has_tool_calls(&self) -> bool {
        matches!(self, StopReason::ToolCalls)
    }
}

impl std::fmt::Display for StopReason {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            StopReason::Stop => write!(f, "stop"),
            StopReason::Length => write!(f, "length"),
            StopReason::ToolCalls => write!(f, "tool_calls"),
            StopReason::ContentFilter => write!(f, "content_filter"),
            StopReason::Other => write!(f, "other"),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_role_conversion() {
        assert_eq!(Role::System.as_str(), "system");
        assert_eq!(Role::User.as_str(), "user");
        assert_eq!(Role::Assistant.as_str(), "assistant");

        let role_string: String = Role::User.into();
        assert_eq!(role_string, "user");
    }

    #[test]
    fn test_stop_reason_checks() {
        assert!(StopReason::Stop.is_complete());
        assert!(!StopReason::Length.is_complete());

        assert!(StopReason::Length.needs_more_tokens());
        assert!(!StopReason::Stop.needs_more_tokens());

        assert!(StopReason::ToolCalls.has_tool_calls());
        assert!(!StopReason::Stop.has_tool_calls());
    }

    #[test]
    fn test_role_serialization() {
        let json = serde_json::to_string(&Role::System).unwrap();
        assert_eq!(json, r#""system""#);

        let role: Role = serde_json::from_str(r#""user""#).unwrap();
        assert_eq!(role, Role::User);
    }
}