use crate::errors::AppError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MemorySource {
Agent,
User,
System,
Import,
Sync,
}
impl MemorySource {
pub const fn as_str(self) -> &'static str {
match self {
Self::Agent => "agent",
Self::User => "user",
Self::System => "system",
Self::Import => "import",
Self::Sync => "sync",
}
}
pub const ALL: &'static [MemorySource] = &[
Self::Agent,
Self::User,
Self::System,
Self::Import,
Self::Sync,
];
}
impl std::fmt::Display for MemorySource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl TryFrom<&str> for MemorySource {
type Error = AppError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
"agent" => Ok(Self::Agent),
"user" => Ok(Self::User),
"system" => Ok(Self::System),
"import" => Ok(Self::Import),
"sync" => Ok(Self::Sync),
other => Err(AppError::Validation(format!(
"invalid memory source: {other:?}; expected one of {}",
Self::ALL
.iter()
.map(|v| v.as_str())
.collect::<Vec<_>>()
.join(", ")
))),
}
}
}
impl TryFrom<String> for MemorySource {
type Error = AppError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}
pub fn validate_source(raw: &str) -> Result<&'static str, AppError> {
match raw {
"agent" => Ok("agent"),
"user" => Ok("user"),
"system" => Ok("system"),
"import" => Ok("import"),
"sync" => Ok("sync"),
other => Err(AppError::Validation(format!(
"invalid memory source: {other:?}; expected one of {}",
MemorySource::ALL
.iter()
.map(|v| v.as_str())
.collect::<Vec<_>>()
.join(", ")
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn as_str_returns_canonical_lowercase() {
assert_eq!(MemorySource::Agent.as_str(), "agent");
assert_eq!(MemorySource::User.as_str(), "user");
assert_eq!(MemorySource::System.as_str(), "system");
assert_eq!(MemorySource::Import.as_str(), "import");
assert_eq!(MemorySource::Sync.as_str(), "sync");
}
#[test]
fn try_from_valid_strings_succeeds() {
assert_eq!(
MemorySource::try_from("agent").unwrap(),
MemorySource::Agent
);
assert_eq!(MemorySource::try_from("user").unwrap(), MemorySource::User);
assert_eq!(
MemorySource::try_from("system").unwrap(),
MemorySource::System
);
assert_eq!(
MemorySource::try_from("import").unwrap(),
MemorySource::Import
);
assert_eq!(MemorySource::try_from("sync").unwrap(), MemorySource::Sync);
}
#[test]
fn try_from_invalid_string_returns_err() {
let err = MemorySource::try_from("enrich").unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("invalid memory source"), "got: {msg}");
assert!(msg.contains("\"enrich\""), "got: {msg}");
assert!(msg.contains("agent"), "must list agent as valid: {msg}");
}
#[test]
fn try_from_empty_string_returns_err() {
assert!(MemorySource::try_from("").is_err());
}
#[test]
fn try_from_string_owned_works() {
let src: MemorySource = String::from("agent").try_into().unwrap();
assert_eq!(src, MemorySource::Agent);
}
#[test]
fn display_matches_as_str() {
for v in MemorySource::ALL {
assert_eq!(format!("{v}"), v.as_str());
}
}
#[test]
fn serialize_round_trip_preserves_variant() {
let v = MemorySource::Import;
let json = serde_json::to_string(&v).unwrap();
assert_eq!(json, "\"import\"");
let back: MemorySource = serde_json::from_str(&json).unwrap();
assert_eq!(back, v);
}
#[test]
fn all_slice_has_exactly_five_variants() {
assert_eq!(MemorySource::ALL.len(), 5);
}
}