abu-base 0.2.0

Core data structures for chat and embeddings.
Documentation
use std::fmt::Display;

use crate::common::Role;

use super::tool::*;
use serde::{Deserialize, Serialize};
use strum::{EnumMessage, EnumVariantNames};

#[derive(Debug, Clone, Serialize, EnumVariantNames, EnumMessage)]
#[serde(rename_all = "snake_case", tag = "role")]
pub enum ChatMessage {
    /// A message from a system.
    System(SystemMessage),
    /// A message from a human.
    User(UserMessage),
    /// A message from the assistant.
    Assistant(AssistantMessage),
    /// A message from a tool.
    Tool(ToolMessage),
}

impl Into<ChatMessage> for SystemMessage {
    fn into(self) -> ChatMessage {
        ChatMessage::System(self)
    }
}

impl Into<ChatMessage> for UserMessage {
    fn into(self) -> ChatMessage {
        ChatMessage::User(self)
    }
}

impl Into<ChatMessage> for AssistantMessage {
    fn into(self) -> ChatMessage {
        ChatMessage::Assistant(self)
    }
}

impl Into<ChatMessage> for ToolMessage {
    fn into(self) -> ChatMessage {
        ChatMessage::Tool(self)
    }
}

#[derive(Debug, Clone, Serialize)]
pub struct SystemMessage {
    /// The contents of the system message.
    pub content: String,
    /// An optional name for the participant. Provides the model information to differentiate between participants of the same role.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub name: Option<String>,
}

#[derive(Debug, Clone, Serialize)]
pub struct UserMessage {
    /// The contents of the user message.
    pub content: String,
    /// An optional name for the participant. Provides the model information to differentiate between participants of the same role.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub name: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AssistantMessage {
    /// The contents of the system message.
    #[serde(default)]
    pub content: String,
    /// An optional name for the participant. Provides the model information to differentiate between participants of the same role.
    #[serde(skip_serializing_if = "Option::is_none", default)]
    pub name: Option<String>,    
    /// The tool calls generated by the model, such as function calls
    #[serde(skip_serializing_if = "Vec::is_empty", default)]
    pub tool_calls: Vec<ToolCall>,
}

#[derive(Debug, Clone, Serialize)]
pub struct ToolMessage {
    pub content: String,
    pub tool_call_id: String,
}

impl ChatMessage {
    pub fn role(&self) -> Role {
        match self {
            Self::Assistant(_) => Role::Assistant,
            Self::User(_) => Role::User,
            Self::Tool(_) => Role::Tool,
            Self::System(_) => Role::System,
        }
    }

    pub fn content(&self) -> &str {
        match self {
            Self::Assistant(m) => &m.content,
            Self::User(m) => &m.content,
            Self::Tool(m) => &m.content,
            Self::System(m) => &m.content,
        }
    }

    pub fn system(content: impl Into<String>) -> Self {
        Self::System(SystemMessage {
            content: content.into(),
            name: None,
        })
    }

    pub fn user(content: impl Into<String>) -> Self {
        Self::User(UserMessage {
            content: content.into(),
            name: None,
        })
    }

    pub fn assistant(content: impl Into<String>, tool_calls: impl Into<Vec<ToolCall>>) -> Self {
        Self::Assistant(AssistantMessage {
            content: content.into(),
            name: None,
            tool_calls: tool_calls.into(),
        })
    }

    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
        Self::Tool(ToolMessage {
            content: content.into(),
            tool_call_id: tool_call_id.into()
        })
    }

    pub fn is_system(&self) -> bool {
        matches!(self, Self::System(_))
    }

    pub fn is_user(&self) -> bool {
        matches!(self, Self::User(_))
    }

    pub fn is_assistant(&self) -> bool {
        matches!(self, Self::Assistant(_))
    }

    pub fn is_tool(&self) -> bool {
        matches!(self, Self::Tool(_))
    }

    pub fn as_system(&self) -> Option<&SystemMessage> {
        if let Self::System(msg) = self {
            Some(msg)
        } else {
            None
        }
    }

    pub fn as_user(&self) -> Option<&UserMessage> {
        if let Self::User(msg) = self {
            Some(msg)
        } else {
            None
        }
    }

    pub fn as_assistant(&self) -> Option<&AssistantMessage> {
        if let Self::Assistant(msg) = self {
            Some(msg)
        } else {
            None
        }
    }

    pub fn as_tool(&self) -> Option<&ToolMessage> {
        if let Self::Tool(msg) = self {
            Some(msg)
        } else {
            None
        }
    }
}

impl Display for Role {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Role::System => write!(f, "system"),
            Role::User => write!(f, "user"),
            Role::Assistant => write!(f, "assiatant"),
            Role::Tool => write!(f, "tool"),
        }
    }
}