use super::StoreError;
use crate::agents::memory::MemoryTurn;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use potato_type::prompt::MessageNum;
use potato_util::create_uuid7;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt::Debug;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredMemoryTurn {
pub id: String,
pub session_id: String,
pub app_name: String,
pub user_id: String,
pub invocation_id: String,
pub user: MessageNum,
pub assistant: MessageNum,
pub event_data: Option<Value>,
pub created_at: DateTime<Utc>,
}
impl StoredMemoryTurn {
pub fn new(
session_id: &str,
app_name: &str,
user_id: &str,
invocation_id: &str,
user: MessageNum,
assistant: MessageNum,
) -> Self {
Self {
id: create_uuid7(),
session_id: session_id.to_string(),
app_name: app_name.to_string(),
user_id: user_id.to_string(),
invocation_id: invocation_id.to_string(),
user,
assistant,
event_data: None,
created_at: Utc::now(),
}
}
pub fn with_event_data(mut self, data: Value) -> Self {
self.event_data = Some(data);
self
}
pub fn into_memory_turn(self) -> MemoryTurn {
MemoryTurn {
user: self.user,
assistant: self.assistant,
}
}
}
#[async_trait]
pub trait MemoryStore: Send + Sync + Debug {
async fn load_turns(
&self,
app_name: &str,
user_id: &str,
session_id: &str,
) -> Result<Vec<StoredMemoryTurn>, StoreError>;
async fn save_turn(&self, turn: &StoredMemoryTurn) -> Result<(), StoreError>;
async fn clear(
&self,
app_name: &str,
user_id: &str,
session_id: &str,
) -> Result<(), StoreError>;
async fn count(
&self,
app_name: &str,
user_id: &str,
session_id: &str,
) -> Result<usize, StoreError>;
}