use std::sync::Arc;
use crate::memory::{RecallRequest, RecalledMemory, StoreMemoryRequest, StoreMemoryResponse};
use crate::{DakeraClient, Result};
pub struct ChatMemorySession {
client: Arc<DakeraClient>,
agent_id: String,
session_id: String,
}
impl ChatMemorySession {
pub async fn create(
client: Arc<DakeraClient>,
agent_id: impl Into<String>,
) -> Result<ChatMemorySession> {
let agent_id = agent_id.into();
let session = client.start_session(&agent_id).await?;
Ok(ChatMemorySession {
client,
agent_id,
session_id: session.id,
})
}
pub async fn create_with_metadata(
client: Arc<DakeraClient>,
agent_id: impl Into<String>,
metadata: serde_json::Value,
) -> Result<ChatMemorySession> {
let agent_id = agent_id.into();
let session = client
.start_session_with_metadata(&agent_id, metadata)
.await?;
Ok(ChatMemorySession {
client,
agent_id,
session_id: session.id,
})
}
pub async fn store(&self, role: &str, content: &str) -> Result<StoreMemoryResponse> {
self.store_with_opts(role, content, 0.6, &[]).await
}
pub async fn store_with_opts(
&self,
role: &str,
content: &str,
importance: f32,
extra_tags: &[&str],
) -> Result<StoreMemoryResponse> {
let mut tags: Vec<String> = extra_tags.iter().map(|&t| t.to_owned()).collect();
if !tags.iter().any(|t| t == role) {
tags.push(role.to_owned());
}
let request = StoreMemoryRequest::new(&self.agent_id, content)
.with_importance(importance)
.with_tags(tags)
.with_session(self.session_id.clone());
self.client.store_memory(request).await
}
pub async fn recall(&self, query: &str) -> Result<Vec<RecalledMemory>> {
self.recall_top_k(query, 5).await
}
pub async fn recall_top_k(&self, query: &str, top_k: usize) -> Result<Vec<RecalledMemory>> {
let request = RecallRequest::new(&self.agent_id, query).with_top_k(top_k);
let response = self.client.recall(request).await?;
Ok(response.memories)
}
pub async fn close(self) -> Result<()> {
self.client.end_session(&self.session_id, None).await?;
Ok(())
}
pub fn session_id(&self) -> &str {
&self.session_id
}
pub fn agent_id(&self) -> &str {
&self.agent_id
}
}