use std::path::PathBuf;
use async_trait::async_trait;
use super::{ContextResult, MemoryContent, MemoryLoader};
#[async_trait]
pub trait MemoryProvider: Send + Sync {
fn name(&self) -> &str;
async fn load(&self) -> ContextResult<MemoryContent>;
}
#[derive(Debug, Clone, Default)]
pub struct MemoryContextProvider {
pub claude_md: Vec<String>,
pub local_md: Vec<String>,
}
impl MemoryContextProvider {
pub fn new() -> Self {
Self::default()
}
pub fn claude_md(mut self, content: impl Into<String>) -> Self {
self.claude_md.push(content.into());
self
}
pub fn local_md(mut self, content: impl Into<String>) -> Self {
self.local_md.push(content.into());
self
}
}
#[async_trait]
impl MemoryProvider for MemoryContextProvider {
fn name(&self) -> &str {
"in-memory"
}
async fn load(&self) -> ContextResult<MemoryContent> {
Ok(MemoryContent {
claude_md: self.claude_md.clone(),
local_md: self.local_md.clone(),
rule_indices: Vec::new(),
})
}
}
#[derive(Debug, Clone)]
pub struct FileMemoryProvider {
pub path: PathBuf,
}
impl FileMemoryProvider {
pub fn new(path: impl Into<PathBuf>) -> Self {
Self { path: path.into() }
}
}
#[async_trait]
impl MemoryProvider for FileMemoryProvider {
fn name(&self) -> &str {
"file"
}
async fn load(&self) -> ContextResult<MemoryContent> {
let loader = MemoryLoader::new();
loader.load(&self.path).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_in_memory_provider() {
let provider = MemoryContextProvider::new()
.claude_md("# Project Rules")
.claude_md("Use async/await.");
let content = provider.load().await.unwrap();
assert_eq!(content.claude_md.len(), 2);
assert!(content.claude_md[0].contains("Project Rules"));
}
#[tokio::test]
async fn test_in_memory_provider_with_local() {
let provider = MemoryContextProvider::new()
.claude_md("Shared rules")
.local_md("Local settings");
let content = provider.load().await.unwrap();
assert_eq!(content.claude_md.len(), 1);
assert_eq!(content.local_md.len(), 1);
}
#[tokio::test]
async fn test_empty_provider() {
let provider = MemoryContextProvider::new();
let content = provider.load().await.unwrap();
assert!(content.is_empty());
}
}