1use crate::constants::env::system;
8use crate::AgentError;
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Mutex;
13
14#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
16pub struct TeamMemoryContent {
17 pub entries: HashMap<String, String>,
19 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
21 pub entry_checksums: HashMap<String, String>,
22}
23
24#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
26pub struct TeamMemoryData {
27 pub organization_id: String,
28 pub repo: String,
29 pub version: u32,
30 pub last_modified: String,
31 pub checksum: String,
32 pub content: TeamMemoryContent,
33}
34
35#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
37pub struct TeamMemoryTooManyEntries {
38 pub error: TeamMemoryTooManyEntriesError,
39}
40
41#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
42pub struct TeamMemoryTooManyEntriesError {
43 pub details: TeamMemoryTooManyEntriesDetails,
44}
45
46#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
47pub struct TeamMemoryTooManyEntriesDetails {
48 #[serde(rename = "error_code")]
49 pub error_code: String,
50 #[serde(rename = "max_entries")]
51 pub max_entries: u32,
52 #[serde(rename = "received_entries")]
53 pub received_entries: u32,
54}
55
56#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
58pub struct SkippedSecretFile {
59 pub path: String,
61 pub rule_id: String,
63 pub label: String,
65}
66
67#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
69pub struct TeamMemorySyncFetchResult {
70 pub success: bool,
71 pub data: Option<TeamMemoryData>,
72 #[serde(default, skip_serializing_if = "Option::is_none")]
74 pub is_empty: Option<bool>,
75 #[serde(default, skip_serializing_if = "Option::is_none")]
77 pub not_modified: Option<bool>,
78 #[serde(default, skip_serializing_if = "Option::is_none")]
80 pub checksum: Option<String>,
81 #[serde(default, skip_serializing_if = "Option::is_none")]
82 pub error: Option<String>,
83 #[serde(default, skip_serializing_if = "Option::is_none")]
84 pub skip_retry: Option<bool>,
85 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub error_type: Option<String>,
87 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub http_status: Option<u16>,
89}
90
91#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93pub struct TeamMemoryHashesResult {
94 pub success: bool,
95 #[serde(default, skip_serializing_if = "Option::is_none")]
96 pub version: Option<u32>,
97 #[serde(default, skip_serializing_if = "Option::is_none")]
98 pub checksum: Option<String>,
99 #[serde(default, skip_serializing_if = "Option::is_none")]
100 pub entry_checksums: Option<HashMap<String, String>>,
101 #[serde(default, skip_serializing_if = "Option::is_none")]
102 pub error: Option<String>,
103 #[serde(default, skip_serializing_if = "Option::is_none")]
104 pub error_type: Option<String>,
105 #[serde(default, skip_serializing_if = "Option::is_none")]
106 pub http_status: Option<u16>,
107}
108
109#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
111pub struct TeamMemorySyncPushResult {
112 pub success: bool,
113 pub files_uploaded: u32,
114 #[serde(default, skip_serializing_if = "Option::is_none")]
115 pub checksum: Option<String>,
116 #[serde(default, skip_serializing_if = "Option::is_none")]
118 pub conflict: Option<bool>,
119 #[serde(default, skip_serializing_if = "Option::is_none")]
120 pub error: Option<String>,
121 #[serde(default, skip_serializing_if = "Vec::is_empty")]
123 pub skipped_secrets: Vec<SkippedSecretFile>,
124 #[serde(default, skip_serializing_if = "Option::is_none")]
125 pub error_type: Option<String>,
126 #[serde(default, skip_serializing_if = "Option::is_none")]
127 pub http_status: Option<u16>,
128}
129
130#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
132pub struct TeamMemorySyncUploadResult {
133 pub success: bool,
134 #[serde(default, skip_serializing_if = "Option::is_none")]
135 pub checksum: Option<String>,
136 #[serde(default, skip_serializing_if = "Option::is_none")]
137 pub last_modified: Option<String>,
138 #[serde(default, skip_serializing_if = "Option::is_none")]
140 pub conflict: Option<bool>,
141 #[serde(default, skip_serializing_if = "Option::is_none")]
142 pub error: Option<String>,
143 #[serde(default, skip_serializing_if = "Option::is_none")]
145 pub server_error_code: Option<String>,
146 #[serde(default, skip_serializing_if = "Option::is_none")]
148 pub server_max_entries: Option<u32>,
149 #[serde(default, skip_serializing_if = "Option::is_none")]
151 pub server_received_entries: Option<u32>,
152 #[serde(default, skip_serializing_if = "Option::is_none")]
153 pub error_type: Option<String>,
154 #[serde(default, skip_serializing_if = "Option::is_none")]
155 pub http_status: Option<u16>,
156}
157
158#[derive(Debug, Clone)]
162pub struct SyncState {
163 pub last_known_checksum: Option<String>,
165 pub server_checksums: HashMap<String, String>,
167 pub server_max_entries: Option<u32>,
169}
170
171impl SyncState {
172 pub fn new() -> Self {
173 Self {
174 last_known_checksum: None,
175 server_checksums: HashMap::new(),
176 server_max_entries: None,
177 }
178 }
179}
180
181impl Default for SyncState {
182 fn default() -> Self {
183 Self::new()
184 }
185}
186
187pub fn create_sync_state() -> SyncState {
189 SyncState::new()
190}
191
192pub fn hash_content(content: &str) -> String {
196 use std::collections::hash_map::DefaultHasher;
197 use std::hash::{Hash, Hasher};
198
199 let mut hasher = DefaultHasher::new();
200 content.hash(&mut hasher);
201 let hash = hasher.finish();
202
203 format!("sha256:{:016x}", hash)
205}
206
207pub const TEAM_MEMORY_SYNC_TIMEOUT_MS: u64 = 30_000;
211pub const MAX_FILE_SIZE_BYTES: usize = 250_000;
213pub const MAX_PUT_BODY_BYTES: usize = 200_000;
215pub const MAX_RETRIES: u32 = 3;
217pub const MAX_CONFLICT_RETRIES: u32 = 2;
219
220pub fn get_team_memory_dir() -> PathBuf {
224 let home = std::env::var(system::HOME)
225 .or_else(|_| std::env::var(system::USERPROFILE))
226 .unwrap_or_else(|_| "/tmp".to_string());
227 PathBuf::from(home)
228 .join(".open-agent-sdk")
229 .join("team_memory")
230}
231
232pub fn get_team_memory_path(key: &str) -> PathBuf {
234 if key.contains("..") || key.starts_with('/') {
236 return get_team_memory_dir().join("INVALID");
237 }
238 get_team_memory_dir().join(key)
239}
240
241pub fn validate_team_memory_key(key: &str) -> Result<(), String> {
243 if key.is_empty() {
244 return Err("Key cannot be empty".to_string());
245 }
246 if key.contains("..") {
247 return Err("Key cannot contain '..'".to_string());
248 }
249 if key.starts_with('/') {
250 return Err("Key cannot start with '/'".to_string());
251 }
252 Ok(())
253}
254
255pub async fn read_local_team_memory() -> Result<HashMap<String, String>, AgentError> {
257 let dir = get_team_memory_dir();
258
259 if !dir.exists() {
260 return Ok(HashMap::new());
261 }
262
263 let mut entries = HashMap::new();
264 let mut dirs_to_process: Vec<PathBuf> = vec![dir.clone()];
265
266 while let Some(current_dir) = dirs_to_process.pop() {
267 let mut read_dir = tokio::fs::read_dir(¤t_dir)
268 .await
269 .map_err(AgentError::Io)?;
270
271 while let Some(entry) = read_dir.next_entry().await.map_err(AgentError::Io)? {
272 let path = entry.path();
273 let relative = path
274 .strip_prefix(&dir)
275 .map_err(|_| AgentError::Internal("Failed to get relative path".to_string()))?
276 .to_string_lossy()
277 .to_string();
278
279 if path.is_dir() {
280 dirs_to_process.push(path);
281 } else if path.is_file() {
282 if relative.starts_with('.') {
284 continue;
285 }
286 let content = tokio::fs::read_to_string(&path)
287 .await
288 .map_err(AgentError::Io)?;
289 entries.insert(relative, content);
290 }
291 }
292 }
293
294 Ok(entries)
295}
296
297pub async fn write_local_team_memory(entries: &HashMap<String, String>) -> Result<(), AgentError> {
299 let dir = get_team_memory_dir();
300 tokio::fs::create_dir_all(&dir)
301 .await
302 .map_err(AgentError::Io)?;
303
304 for (key, content) in entries {
305 let path = get_team_memory_path(key);
306 if let Some(parent) = path.parent() {
307 tokio::fs::create_dir_all(parent)
308 .await
309 .map_err(AgentError::Io)?;
310 }
311 tokio::fs::write(&path, content)
312 .await
313 .map_err(AgentError::Io)?;
314 }
315
316 Ok(())
317}
318
319pub async fn delete_local_team_memory_entry(key: &str) -> Result<(), AgentError> {
321 let path = get_team_memory_path(key);
322 if path.exists() {
323 tokio::fs::remove_file(path).await.map_err(AgentError::Io)?;
324 }
325 Ok(())
326}
327
328pub fn compute_delta(
332 local_entries: &HashMap<String, String>,
333 server_checksums: &HashMap<String, String>,
334) -> HashMap<String, String> {
335 let mut delta = HashMap::new();
336
337 for (key, content) in local_entries {
338 let local_hash = hash_content(content);
339 let server_hash = server_checksums.get(key);
340
341 if server_hash.is_none() || server_hash != Some(&local_hash) {
343 delta.insert(key.clone(), content.clone());
344 }
345 }
346
347 delta
348}
349
350pub fn batch_delta_by_bytes(
352 delta: &HashMap<String, String>,
353 max_bytes: usize,
354) -> Vec<HashMap<String, String>> {
355 let mut batches: Vec<HashMap<String, String>> = Vec::new();
356 let mut current_batch: HashMap<String, String> = HashMap::new();
357 let mut current_bytes: usize = 0;
358
359 let mut keys: Vec<&String> = delta.keys().collect();
361 keys.sort();
362
363 for key in keys {
364 let content = delta.get(key).unwrap();
365 let entry_bytes = key.len() + content.len();
366
367 if entry_bytes > max_bytes {
369 if !current_batch.is_empty() {
371 batches.push(current_batch);
372 current_batch = HashMap::new();
373 current_bytes = 0;
374 }
375 let mut single = HashMap::new();
377 single.insert(key.clone(), content.clone());
378 batches.push(single);
379 continue;
380 }
381
382 if current_bytes + entry_bytes > max_bytes && !current_batch.is_empty() {
384 batches.push(current_batch);
385 current_batch = HashMap::new();
386 current_bytes = 0;
387 }
388
389 current_batch.insert(key.clone(), content.clone());
390 current_bytes += entry_bytes;
391 }
392
393 if !current_batch.is_empty() {
395 batches.push(current_batch);
396 }
397
398 batches
399}
400
401pub fn is_team_memory_sync_available() -> bool {
405 false
409}
410
411pub async fn pull_team_memory(
413 _state: &mut SyncState,
414 _repo_slug: &str,
415) -> Result<TeamMemorySyncFetchResult, AgentError> {
416 Ok(TeamMemorySyncFetchResult {
417 success: false,
418 data: None,
419 is_empty: None,
420 not_modified: None,
421 checksum: None,
422 error: Some("Team memory sync requires OAuth authentication".to_string()),
423 skip_retry: Some(true),
424 error_type: Some("auth".to_string()),
425 http_status: None,
426 })
427}
428
429pub async fn push_team_memory(
431 _state: &mut SyncState,
432 _repo_slug: &str,
433 _entries: &HashMap<String, String>,
434) -> Result<TeamMemorySyncPushResult, AgentError> {
435 Ok(TeamMemorySyncPushResult {
436 success: false,
437 files_uploaded: 0,
438 checksum: None,
439 conflict: None,
440 error: Some("Team memory sync requires OAuth authentication".to_string()),
441 skipped_secrets: Vec::new(),
442 error_type: Some("auth".to_string()),
443 http_status: None,
444 })
445}
446
447pub async fn sync_team_memory(
449 state: &mut SyncState,
450 repo_slug: &str,
451) -> Result<TeamMemorySyncPushResult, AgentError> {
452 let pull_result = pull_team_memory(state, repo_slug).await?;
454
455 if !pull_result.success {
456 return Ok(TeamMemorySyncPushResult {
457 success: false,
458 files_uploaded: 0,
459 checksum: None,
460 conflict: None,
461 error: pull_result.error,
462 skipped_secrets: Vec::new(),
463 error_type: pull_result.error_type,
464 http_status: pull_result.http_status,
465 });
466 }
467
468 let local_entries = read_local_team_memory().await?;
470
471 let delta = compute_delta(&local_entries, &state.server_checksums);
473
474 if delta.is_empty() {
475 return Ok(TeamMemorySyncPushResult {
476 success: true,
477 files_uploaded: 0,
478 checksum: state.last_known_checksum.clone(),
479 conflict: None,
480 error: None,
481 skipped_secrets: Vec::new(),
482 error_type: None,
483 http_status: None,
484 });
485 }
486
487 push_team_memory(state, repo_slug, &delta).await
489}
490
491pub fn scan_for_secrets(_content: &str, _path: &str) -> Option<SkippedSecretFile> {
495 None
498}
499
500pub fn scan_entries_for_secrets(entries: &HashMap<String, String>) -> Vec<SkippedSecretFile> {
502 let mut skipped = Vec::new();
503
504 for (path, content) in entries {
505 if let Some(secret) = scan_for_secrets(content, path) {
506 skipped.push(secret);
507 }
508 }
509
510 skipped
511}
512
513static TEAM_MEMORY_ENABLED: AtomicBool = AtomicBool::new(false);
517
518pub fn is_team_memory_enabled() -> bool {
520 TEAM_MEMORY_ENABLED.load(Ordering::SeqCst)
521}
522
523pub fn enable_team_memory() {
525 TEAM_MEMORY_ENABLED.store(true, Ordering::SeqCst);
526}
527
528pub fn disable_team_memory() {
530 TEAM_MEMORY_ENABLED.store(false, Ordering::SeqCst);
531}
532
533static LAST_SYNC_ERROR: Mutex<Option<String>> = Mutex::new(None);
535
536pub fn set_last_sync_error(error: Option<String>) {
538 *LAST_SYNC_ERROR.lock().unwrap() = error;
539}
540
541pub fn get_last_sync_error() -> Option<String> {
543 LAST_SYNC_ERROR.lock().unwrap().clone()
544}
545
546#[cfg(test)]
549mod tests {
550 use super::*;
551
552 #[test]
553 fn test_create_sync_state() {
554 let state = create_sync_state();
555 assert!(state.last_known_checksum.is_none());
556 assert!(state.server_checksums.is_empty());
557 assert!(state.server_max_entries.is_none());
558 }
559
560 #[test]
561 fn test_hash_content() {
562 let hash1 = hash_content("hello");
563 let hash2 = hash_content("hello");
564 let hash3 = hash_content("world");
565
566 assert!(hash1.starts_with("sha256:"));
567 assert_eq!(hash1, hash2);
568 assert_ne!(hash1, hash3);
569 }
570
571 #[test]
572 fn test_validate_team_memory_key() {
573 assert!(validate_team_memory_key("MEMORY.md").is_ok());
574 assert!(validate_team_memory_key("subdir/notes.md").is_ok());
575 assert!(validate_team_memory_key("").is_err());
576 assert!(validate_team_memory_key("../etc/passwd").is_err());
577 assert!(validate_team_memory_key("/absolute/path").is_err());
578 }
579
580 #[test]
581 fn test_compute_delta() {
582 let local = HashMap::from([
583 ("a.txt".to_string(), "content1".to_string()),
584 ("b.txt".to_string(), "content2".to_string()),
585 ("c.txt".to_string(), "content3".to_string()),
586 ]);
587
588 let server = HashMap::from([
589 ("a.txt".to_string(), hash_content("content1")), ("b.txt".to_string(), hash_content("different")), ]);
592
593 let delta = compute_delta(&local, &server);
594
595 assert!(delta.contains_key("b.txt")); assert!(delta.contains_key("c.txt")); assert!(!delta.contains_key("a.txt")); }
599
600 #[test]
601 fn test_batch_delta_by_bytes() {
602 let delta = HashMap::from([
603 ("a.txt".to_string(), "x".repeat(100)),
604 ("b.txt".to_string(), "y".repeat(100)),
605 ("c.txt".to_string(), "z".repeat(250)), ]);
607
608 let batches = batch_delta_by_bytes(&delta, 150);
609
610 assert!(batches.len() >= 2);
613 }
614
615 #[test]
616 fn test_team_memory_enabled() {
617 disable_team_memory();
618 assert!(!is_team_memory_enabled());
619
620 enable_team_memory();
621 assert!(is_team_memory_enabled());
622
623 disable_team_memory();
624 assert!(!is_team_memory_enabled());
625 }
626
627 #[test]
628 fn test_last_sync_error() {
629 set_last_sync_error(None);
630 assert!(get_last_sync_error().is_none());
631
632 set_last_sync_error(Some("test error".to_string()));
633 assert_eq!(get_last_sync_error(), Some("test error".to_string()));
634 }
635}