use std::fmt;
use std::hash::{Hash, Hasher};
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::CupelError;
#[derive(Debug, Clone)]
pub struct ContextSource(String);
impl ContextSource {
pub const CHAT: &str = "Chat";
pub const TOOL: &str = "Tool";
pub const RAG: &str = "Rag";
pub fn new(value: impl Into<String>) -> Result<Self, CupelError> {
let s = value.into();
if s.trim().is_empty() {
return Err(CupelError::EmptySource);
}
Ok(Self(s))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl Default for ContextSource {
fn default() -> Self {
Self(Self::CHAT.to_owned())
}
}
impl PartialEq for ContextSource {
fn eq(&self, other: &Self) -> bool {
self.0.eq_ignore_ascii_case(&other.0)
}
}
impl Eq for ContextSource {}
impl Hash for ContextSource {
fn hash<H: Hasher>(&self, state: &mut H) {
for byte in self.0.bytes() {
state.write_u8(byte.to_ascii_lowercase());
}
}
}
impl fmt::Display for ContextSource {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
#[cfg(feature = "serde")]
impl Serialize for ContextSource {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.0)
}
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for ContextSource {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
ContextSource::new(s).map_err(serde::de::Error::custom)
}
}