claude_agent/context/
level.rs1use std::path::{Path, PathBuf};
4
5use async_trait::async_trait;
6
7use super::{ContextResult, MemoryContent, MemoryLoader, MemoryProvider, RuleIndex};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
10pub enum MemoryLevel {
11 Enterprise = 0,
12 User = 1,
13 Project = 2,
14 Local = 3,
15}
16
17impl MemoryLevel {
18 pub fn all() -> &'static [MemoryLevel] {
19 &[
20 MemoryLevel::Enterprise,
21 MemoryLevel::User,
22 MemoryLevel::Project,
23 MemoryLevel::Local,
24 ]
25 }
26}
27
28#[derive(Debug, Default)]
29pub struct LeveledMemoryProvider {
30 enterprise: LevelContent,
31 user: LevelContent,
32 project: LevelContent,
33 local: LevelContent,
34}
35
36#[derive(Debug, Default)]
37struct LevelContent {
38 path: Option<PathBuf>,
39 content: Vec<String>,
40 rules: Vec<RuleIndex>,
41}
42
43impl LeveledMemoryProvider {
44 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn from_project(project_dir: impl AsRef<Path>) -> Self {
49 let dir = project_dir.as_ref();
50 Self {
51 project: LevelContent {
52 path: Some(dir.to_path_buf()),
53 ..Default::default()
54 },
55 local: LevelContent {
56 path: Some(dir.to_path_buf()),
57 ..Default::default()
58 },
59 ..Default::default()
60 }
61 }
62
63 pub fn with_user(mut self) -> Self {
64 if let Some(home) = crate::common::home_dir() {
65 self.user.path = Some(home.join(".claude"));
66 }
67 self
68 }
69
70 pub fn with_enterprise(mut self) -> Self {
71 #[cfg(target_os = "macos")]
72 {
73 let path = PathBuf::from("/Library/Application Support/ClaudeCode");
74 if path.exists() {
75 self.enterprise.path = Some(path);
76 }
77 }
78 #[cfg(target_os = "linux")]
79 {
80 let path = PathBuf::from("/etc/claude-code");
81 if path.exists() {
82 self.enterprise.path = Some(path);
83 }
84 }
85 self
86 }
87
88 pub fn with_content(mut self, level: MemoryLevel, content: impl Into<String>) -> Self {
89 let target = self.level_mut(level);
90 target.content.push(content.into());
91 self
92 }
93
94 pub fn with_rule(mut self, level: MemoryLevel, rule: RuleIndex) -> Self {
95 let target = self.level_mut(level);
96 target.rules.push(rule);
97 self
98 }
99
100 pub fn add_content(&mut self, level: MemoryLevel, content: impl Into<String>) {
101 self.level_mut(level).content.push(content.into());
102 }
103
104 pub fn add_rule(&mut self, level: MemoryLevel, rule: RuleIndex) {
105 self.level_mut(level).rules.push(rule);
106 }
107
108 fn level_mut(&mut self, level: MemoryLevel) -> &mut LevelContent {
109 match level {
110 MemoryLevel::Enterprise => &mut self.enterprise,
111 MemoryLevel::User => &mut self.user,
112 MemoryLevel::Project => &mut self.project,
113 MemoryLevel::Local => &mut self.local,
114 }
115 }
116
117 async fn load_level(
118 &self,
119 level_content: &LevelContent,
120 is_local: bool,
121 ) -> ContextResult<MemoryContent> {
122 let mut content = MemoryContent::default();
123
124 if let Some(ref path) = level_content.path
125 && path.exists()
126 {
127 let mut loader = MemoryLoader::new();
128 if is_local {
129 if let Ok(loaded) = loader.load_local_only(path).await {
130 content = loaded;
131 }
132 } else if let Ok(loaded) = loader.load_all(path).await {
133 content = loaded;
134 }
135 }
136
137 for c in &level_content.content {
138 content.claude_md.push(c.clone());
139 }
140
141 content
142 .rule_indices
143 .extend(level_content.rules.iter().cloned());
144
145 Ok(content)
146 }
147}
148
149#[async_trait]
150impl MemoryProvider for LeveledMemoryProvider {
151 fn name(&self) -> &str {
152 "leveled"
153 }
154
155 async fn load(&self) -> ContextResult<MemoryContent> {
156 let mut combined = MemoryContent::default();
157
158 let enterprise = self.load_level(&self.enterprise, false).await?;
160 let user = self.load_level(&self.user, false).await?;
161 let project = self.load_level(&self.project, false).await?;
162 let local = self.load_level(&self.local, true).await?;
163
164 combined.claude_md.extend(enterprise.claude_md);
165 combined.claude_md.extend(user.claude_md);
166 combined.claude_md.extend(project.claude_md);
167 combined.local_md.extend(local.local_md);
168 combined.claude_md.extend(local.claude_md);
169
170 combined.rule_indices.extend(enterprise.rule_indices);
171 combined.rule_indices.extend(user.rule_indices);
172 combined.rule_indices.extend(project.rule_indices);
173 combined.rule_indices.extend(local.rule_indices);
174
175 Ok(combined)
176 }
177
178 fn priority(&self) -> i32 {
179 100
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[tokio::test]
188 async fn test_leveled_memory_provider() {
189 let provider = LeveledMemoryProvider::new()
190 .with_content(MemoryLevel::Enterprise, "# Enterprise Rules")
191 .with_content(MemoryLevel::User, "# User Preferences")
192 .with_content(MemoryLevel::Project, "# Project Guidelines");
193
194 let content = provider.load().await.unwrap();
195 assert_eq!(content.claude_md.len(), 3);
196 assert_eq!(content.claude_md[0], "# Enterprise Rules");
197 assert_eq!(content.claude_md[1], "# User Preferences");
198 assert_eq!(content.claude_md[2], "# Project Guidelines");
199 }
200
201 #[test]
202 fn test_memory_level_order() {
203 assert!(MemoryLevel::Enterprise < MemoryLevel::User);
204 assert!(MemoryLevel::User < MemoryLevel::Project);
205 assert!(MemoryLevel::Project < MemoryLevel::Local);
206 }
207}