use std::fmt;
use std::hash::{Hash, Hasher};
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::CupelError;
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
#[error("invalid context kind: {0:?}")]
pub struct ParseContextKindError(String);
#[derive(Debug, Clone)]
pub struct ContextKind(String);
impl ContextKind {
pub const MESSAGE: &str = "Message";
pub const DOCUMENT: &str = "Document";
pub const TOOL_OUTPUT: &str = "ToolOutput";
pub const MEMORY: &str = "Memory";
pub const SYSTEM_PROMPT: &str = "SystemPrompt";
pub fn new(value: impl Into<String>) -> Result<Self, CupelError> {
let s = value.into();
if s.trim().is_empty() {
return Err(CupelError::EmptyKind);
}
Ok(Self(s))
}
pub(crate) fn from_static(value: &str) -> Self {
Self(value.to_owned())
}
#[must_use]
pub fn message() -> Self {
Self::from_static(Self::MESSAGE)
}
#[must_use]
pub fn system_prompt() -> Self {
Self::from_static(Self::SYSTEM_PROMPT)
}
#[must_use]
pub fn document() -> Self {
Self::from_static(Self::DOCUMENT)
}
#[must_use]
pub fn tool_output() -> Self {
Self::from_static(Self::TOOL_OUTPUT)
}
#[must_use]
pub fn memory() -> Self {
Self::from_static(Self::MEMORY)
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl Default for ContextKind {
fn default() -> Self {
Self(Self::MESSAGE.to_owned())
}
}
impl PartialEq for ContextKind {
fn eq(&self, other: &Self) -> bool {
self.0.eq_ignore_ascii_case(&other.0)
}
}
impl Eq for ContextKind {}
impl TryFrom<&str> for ContextKind {
type Error = ParseContextKindError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
Self::new(value).map_err(|_| ParseContextKindError(value.to_owned()))
}
}
impl Hash for ContextKind {
fn hash<H: Hasher>(&self, state: &mut H) {
for byte in self.0.bytes() {
state.write_u8(byte.to_ascii_lowercase());
}
}
}
impl fmt::Display for ContextKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
#[cfg(feature = "serde")]
impl Serialize for ContextKind {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.0)
}
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for ContextKind {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
ContextKind::new(s).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn factory_message_returns_message() {
assert_eq!(ContextKind::message().as_str(), "Message");
}
#[test]
fn factory_system_prompt_returns_system_prompt() {
assert_eq!(ContextKind::system_prompt().as_str(), "SystemPrompt");
}
#[test]
fn factory_document_returns_document() {
assert_eq!(ContextKind::document().as_str(), "Document");
}
#[test]
fn factory_tool_output_returns_tool_output() {
assert_eq!(ContextKind::tool_output().as_str(), "ToolOutput");
}
#[test]
fn factory_memory_returns_memory() {
assert_eq!(ContextKind::memory().as_str(), "Memory");
}
#[test]
fn factory_equals_new() {
assert_eq!(ContextKind::message(), ContextKind::new("message").unwrap());
}
#[test]
fn try_from_valid_string() {
let kind = ContextKind::try_from("Custom").unwrap();
assert_eq!(kind.as_str(), "Custom");
}
#[test]
fn try_from_empty_string_fails() {
let err = ContextKind::try_from("").unwrap_err();
assert_eq!(err, ParseContextKindError(String::new()));
}
#[test]
fn try_from_whitespace_only_fails() {
assert!(ContextKind::try_from(" ").is_err());
}
#[test]
fn parse_context_kind_error_display() {
let err = ParseContextKindError("".to_owned());
assert!(err.to_string().contains("invalid context kind"));
}
}