claude_agent/context/
provider.rs1use std::path::PathBuf;
7
8use async_trait::async_trait;
9
10use super::{ContextResult, MemoryContent, MemoryLoader};
11
12#[async_trait]
19pub trait MemoryProvider: Send + Sync {
20 fn name(&self) -> &str;
22
23 async fn load(&self) -> ContextResult<MemoryContent>;
25}
26
27#[derive(Debug, Clone, Default)]
37pub struct InMemoryProvider {
38 pub claude_md: Vec<String>,
40 pub local_md: Vec<String>,
42}
43
44impl InMemoryProvider {
45 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn with_claude_md(mut self, content: impl Into<String>) -> Self {
52 self.claude_md.push(content.into());
53 self
54 }
55
56 pub fn with_local_md(mut self, content: impl Into<String>) -> Self {
58 self.local_md.push(content.into());
59 self
60 }
61}
62
63#[async_trait]
64impl MemoryProvider for InMemoryProvider {
65 fn name(&self) -> &str {
66 "in-memory"
67 }
68
69 async fn load(&self) -> ContextResult<MemoryContent> {
70 Ok(MemoryContent {
71 claude_md: self.claude_md.clone(),
72 local_md: self.local_md.clone(),
73 rule_indices: Vec::new(),
74 })
75 }
76}
77
78#[derive(Debug, Clone)]
90pub struct FileMemoryProvider {
91 pub path: PathBuf,
93}
94
95impl FileMemoryProvider {
96 pub fn new(path: impl Into<PathBuf>) -> Self {
98 Self { path: path.into() }
99 }
100}
101
102#[async_trait]
103impl MemoryProvider for FileMemoryProvider {
104 fn name(&self) -> &str {
105 "file"
106 }
107
108 async fn load(&self) -> ContextResult<MemoryContent> {
109 let loader = MemoryLoader::new();
110 loader.load(&self.path).await
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[tokio::test]
119 async fn test_in_memory_provider() {
120 let provider = InMemoryProvider::new()
121 .with_claude_md("# Project Rules")
122 .with_claude_md("Use async/await.");
123
124 let content = provider.load().await.unwrap();
125 assert_eq!(content.claude_md.len(), 2);
126 assert!(content.claude_md[0].contains("Project Rules"));
127 }
128
129 #[tokio::test]
130 async fn test_in_memory_provider_with_local() {
131 let provider = InMemoryProvider::new()
132 .with_claude_md("Shared rules")
133 .with_local_md("Local settings");
134
135 let content = provider.load().await.unwrap();
136 assert_eq!(content.claude_md.len(), 1);
137 assert_eq!(content.local_md.len(), 1);
138 }
139
140 #[tokio::test]
141 async fn test_empty_provider() {
142 let provider = InMemoryProvider::new();
143 let content = provider.load().await.unwrap();
144 assert!(content.is_empty());
145 }
146}