use std::collections::HashMap;
use crate::errors::AgentResult;
use crate::runtime::context::AuthContext;
use super::{ContentSource, MemoryContent, MemoryService};
#[derive(Debug, Clone)]
pub struct CompletedConversation {
pub context_id: String,
pub messages: Vec<CompletedMessage>,
}
#[derive(Debug, Clone)]
pub struct CompletedMessage {
pub message_id: String,
pub role: String,
pub text: String,
pub timestamp: Option<String>,
}
impl CompletedConversation {
#[must_use]
pub fn into_memory_contents(self) -> Vec<MemoryContent> {
self.messages
.into_iter()
.map(|msg| MemoryContent {
text: msg.text,
source: ContentSource::PastConversation {
context_id: self.context_id.clone(),
message_id: msg.message_id,
role: msg.role,
},
metadata: msg
.timestamp
.map(|ts| std::iter::once(("timestamp".to_string(), ts.into())).collect())
.unwrap_or_default(),
})
.collect()
}
}
#[derive(Debug, Clone)]
pub struct Document {
pub id: String,
pub name: String,
pub content: String,
pub metadata: HashMap<String, serde_json::Value>,
}
impl Document {
pub fn new(id: impl Into<String>, name: impl Into<String>, content: impl Into<String>) -> Self {
Self {
id: id.into(),
name: name.into(),
content: content.into(),
metadata: HashMap::new(),
}
}
#[must_use]
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
#[must_use]
pub fn into_memory_contents(self, chunk_size: usize) -> Vec<MemoryContent> {
let chunks = chunk_text(&self.content, chunk_size);
let total_chunks = chunks.len();
chunks
.into_iter()
.enumerate()
.map(|(i, text)| MemoryContent {
text,
source: ContentSource::Document {
document_id: self.id.clone(),
name: self.name.clone(),
chunk_index: i,
total_chunks,
},
metadata: self.metadata.clone(),
})
.collect()
}
}
#[must_use]
pub fn chunk_text(text: &str, chunk_size: usize) -> Vec<String> {
let mut chunks = Vec::new();
let mut current = String::new();
for sentence in text.split_inclusive(['.', '!', '?']) {
if current.len() + sentence.len() > chunk_size && !current.is_empty() {
chunks.push(std::mem::take(&mut current));
}
current.push_str(sentence);
}
if !current.is_empty() {
chunks.push(current);
}
if chunks.is_empty() && !text.is_empty() {
chunks.push(text.to_string());
}
chunks
}
#[cfg_attr(
all(target_os = "wasi", target_env = "p1"),
async_trait::async_trait(?Send)
)]
#[cfg_attr(
not(all(target_os = "wasi", target_env = "p1")),
async_trait::async_trait
)]
pub trait MemoryServiceConversationExt: MemoryService {
async fn add_conversation(
&self,
auth_ctx: &AuthContext,
conversation: CompletedConversation,
) -> AgentResult<Vec<String>> {
let contents = conversation.into_memory_contents();
self.add_batch(auth_ctx, contents).await
}
}
impl<T: MemoryService + ?Sized> MemoryServiceConversationExt for T {}
#[cfg_attr(
all(target_os = "wasi", target_env = "p1"),
async_trait::async_trait(?Send)
)]
#[cfg_attr(
not(all(target_os = "wasi", target_env = "p1")),
async_trait::async_trait
)]
pub trait MemoryServiceDocumentExt: MemoryService {
async fn add_document(
&self,
auth_ctx: &AuthContext,
document: Document,
chunk_size: Option<usize>,
) -> AgentResult<Vec<String>> {
let document_id = document.id.clone();
let contents = document.into_memory_contents(chunk_size.unwrap_or(1000));
let new_chunk_count = contents.len();
let ids = self.add_batch(auth_ctx, contents).await?;
self.delete_stale_chunks(auth_ctx, &document_id, new_chunk_count)
.await?;
Ok(ids)
}
async fn delete_stale_chunks(
&self,
auth_ctx: &AuthContext,
document_id: &str,
start_index: usize,
) -> AgentResult<usize> {
const MAX_STALE_CHUNKS: usize = 1000;
const CONSECUTIVE_MISS_THRESHOLD: usize = 10;
let mut deleted = 0;
let mut consecutive_misses = 0;
for i in start_index..(start_index + MAX_STALE_CHUNKS) {
let chunk_id = format!("doc:{document_id}:chunk-{i}");
if self.delete(auth_ctx, &chunk_id).await? {
deleted += 1;
consecutive_misses = 0;
} else {
consecutive_misses += 1;
if consecutive_misses >= CONSECUTIVE_MISS_THRESHOLD {
break;
}
}
}
Ok(deleted)
}
async fn delete_document(
&self,
auth_ctx: &AuthContext,
document_id: &str,
) -> AgentResult<usize> {
const MAX_CHUNKS: usize = 10000;
const CONSECUTIVE_MISS_THRESHOLD: usize = 10;
let mut deleted = 0;
let mut consecutive_misses = 0;
for i in 0..MAX_CHUNKS {
let chunk_id = format!("doc:{document_id}:chunk-{i}");
if self.delete(auth_ctx, &chunk_id).await? {
deleted += 1;
consecutive_misses = 0;
} else {
consecutive_misses += 1;
if consecutive_misses >= CONSECUTIVE_MISS_THRESHOLD {
break;
}
}
}
Ok(deleted)
}
}
impl<T: MemoryService + ?Sized> MemoryServiceDocumentExt for T {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chunk_text_splits_by_sentences() {
let text = "First sentence. Second sentence. Third sentence.";
let chunks = chunk_text(text, 30);
assert_eq!(chunks.len(), 3);
assert!(chunks[0].ends_with('.'));
assert!(chunks[1].ends_with('.'));
assert!(chunks[2].ends_with('.'));
}
#[test]
fn chunk_text_handles_no_sentences() {
let text = "No sentence delimiters here";
let chunks = chunk_text(text, 10);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0], text);
}
#[test]
fn chunk_text_handles_empty() {
let chunks = chunk_text("", 100);
assert!(chunks.is_empty());
}
#[test]
fn document_into_memory_contents() {
let doc = Document::new("test-doc", "Test Document", "First. Second. Third.");
let contents = doc.into_memory_contents(10);
assert!(!contents.is_empty());
for content in &contents {
if let ContentSource::Document {
document_id, name, ..
} = &content.source
{
assert_eq!(document_id, "test-doc");
assert_eq!(name, "Test Document");
} else {
panic!("Expected Document source");
}
}
}
#[test]
fn completed_conversation_into_memory_contents() {
let conv = CompletedConversation {
context_id: "ctx-123".to_string(),
messages: vec![
CompletedMessage {
message_id: "msg-1".to_string(),
role: "user".to_string(),
text: "Hello".to_string(),
timestamp: Some("2024-01-01T00:00:00Z".to_string()),
},
CompletedMessage {
message_id: "msg-2".to_string(),
role: "agent".to_string(),
text: "Hi there!".to_string(),
timestamp: None,
},
],
};
let contents = conv.into_memory_contents();
assert_eq!(contents.len(), 2);
if let ContentSource::PastConversation {
context_id,
message_id,
role,
} = &contents[0].source
{
assert_eq!(context_id, "ctx-123");
assert_eq!(message_id, "msg-1");
assert_eq!(role, "user");
} else {
panic!("Expected PastConversation source");
}
}
}
#[cfg(all(test, feature = "runtime"))]
mod async_tests {
use super::*;
use crate::runtime::memory::{InMemoryMemoryService, SearchOptions};
fn test_auth() -> AuthContext {
AuthContext {
app_name: "test-app".to_string(),
user_name: "test-user".to_string(),
}
}
#[tokio::test]
async fn add_document_removes_stale_chunks_on_reingest() {
let memory = InMemoryMemoryService::new();
let auth = test_auth();
let doc_v1 = Document::new(
"shrinking-doc",
"Shrinking Document",
"First sentence is here. Second sentence is here. Third sentence is here. Fourth sentence is here.",
);
let ids_v1 = memory.add_document(&auth, doc_v1, Some(30)).await.unwrap();
assert!(
ids_v1.len() >= 3,
"Should have at least 3 chunks, got {} with IDs: {:?}",
ids_v1.len(),
ids_v1
);
let original_chunk_count = ids_v1.len();
let results_v1 = memory
.search(
&auth,
"sentence",
SearchOptions::documents_only().with_min_score(0.1),
)
.await
.unwrap();
assert_eq!(
results_v1.len(),
original_chunk_count,
"All chunks should contain 'sentence'"
);
let doc_v2 = Document::new("shrinking-doc", "Shrinking Document", "Only one chunk now.");
let ids_v2 = memory
.add_document(&auth, doc_v2, Some(1000))
.await
.unwrap();
assert_eq!(ids_v2.len(), 1, "Should have only 1 chunk now");
let results_v2 = memory
.search(
&auth,
"sentence",
SearchOptions::documents_only().with_min_score(0.1),
)
.await
.unwrap();
assert!(
results_v2.is_empty(),
"Old content should be gone after re-ingest, but found {} results",
results_v2.len()
);
let results_new = memory
.search(
&auth,
"Only one chunk",
SearchOptions::documents_only().with_min_score(0.1),
)
.await
.unwrap();
assert_eq!(results_new.len(), 1, "New content should be searchable");
}
#[tokio::test]
async fn delete_document_stops_early_for_missing_doc() {
let memory = InMemoryMemoryService::new();
let auth = test_auth();
let deleted = memory
.delete_document(&auth, "nonexistent-document")
.await
.unwrap();
assert_eq!(deleted, 0, "Should not delete anything");
}
#[tokio::test]
async fn delete_document_handles_normal_case() {
let memory = InMemoryMemoryService::new();
let auth = test_auth();
let doc = Document::new(
"to-delete",
"Document to Delete",
"First part. Second part. Third part.",
);
let ids = memory.add_document(&auth, doc, Some(15)).await.unwrap();
let chunk_count = ids.len();
assert!(chunk_count >= 2, "Should have multiple chunks");
let before = memory
.search(&auth, "part", SearchOptions::documents_only())
.await
.unwrap();
assert_eq!(before.len(), chunk_count);
let deleted = memory.delete_document(&auth, "to-delete").await.unwrap();
assert_eq!(deleted, chunk_count, "Should delete all chunks");
let after = memory
.search(&auth, "part", SearchOptions::documents_only())
.await
.unwrap();
assert!(after.is_empty(), "All chunks should be deleted");
}
}