use crate::constants::env::system;
use crate::AgentError;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct TeamMemoryContent {
pub entries: HashMap<String, String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub entry_checksums: HashMap<String, String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TeamMemoryData {
pub organization_id: String,
pub repo: String,
pub version: u32,
pub last_modified: String,
pub checksum: String,
pub content: TeamMemoryContent,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TeamMemoryTooManyEntries {
pub error: TeamMemoryTooManyEntriesError,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TeamMemoryTooManyEntriesError {
pub details: TeamMemoryTooManyEntriesDetails,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TeamMemoryTooManyEntriesDetails {
#[serde(rename = "error_code")]
pub error_code: String,
#[serde(rename = "max_entries")]
pub max_entries: u32,
#[serde(rename = "received_entries")]
pub received_entries: u32,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SkippedSecretFile {
pub path: String,
pub rule_id: String,
pub label: String,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TeamMemorySyncFetchResult {
pub success: bool,
pub data: Option<TeamMemoryData>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub is_empty: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub not_modified: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub checksum: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub skip_retry: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error_type: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub http_status: Option<u16>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TeamMemoryHashesResult {
pub success: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub version: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub checksum: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub entry_checksums: Option<HashMap<String, String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error_type: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub http_status: Option<u16>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TeamMemorySyncPushResult {
pub success: bool,
pub files_uploaded: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub checksum: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub conflict: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub skipped_secrets: Vec<SkippedSecretFile>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error_type: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub http_status: Option<u16>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TeamMemorySyncUploadResult {
pub success: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub checksum: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub last_modified: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub conflict: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub server_error_code: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub server_max_entries: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub server_received_entries: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error_type: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub http_status: Option<u16>,
}
#[derive(Debug, Clone)]
pub struct SyncState {
pub last_known_checksum: Option<String>,
pub server_checksums: HashMap<String, String>,
pub server_max_entries: Option<u32>,
}
impl SyncState {
pub fn new() -> Self {
Self {
last_known_checksum: None,
server_checksums: HashMap::new(),
server_max_entries: None,
}
}
}
impl Default for SyncState {
fn default() -> Self {
Self::new()
}
}
pub fn create_sync_state() -> SyncState {
SyncState::new()
}
pub fn hash_content(content: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
let hash = hasher.finish();
format!("sha256:{:016x}", hash)
}
pub const TEAM_MEMORY_SYNC_TIMEOUT_MS: u64 = 30_000;
pub const MAX_FILE_SIZE_BYTES: usize = 250_000;
pub const MAX_PUT_BODY_BYTES: usize = 200_000;
pub const MAX_RETRIES: u32 = 3;
pub const MAX_CONFLICT_RETRIES: u32 = 2;
pub fn get_team_memory_dir() -> PathBuf {
let home = std::env::var(system::HOME)
.or_else(|_| std::env::var(system::USERPROFILE))
.unwrap_or_else(|_| "/tmp".to_string());
PathBuf::from(home)
.join(".open-agent-sdk")
.join("team_memory")
}
pub fn get_team_memory_path(key: &str) -> PathBuf {
if key.contains("..") || key.starts_with('/') {
return get_team_memory_dir().join("INVALID");
}
get_team_memory_dir().join(key)
}
pub fn validate_team_memory_key(key: &str) -> Result<(), String> {
if key.is_empty() {
return Err("Key cannot be empty".to_string());
}
if key.contains("..") {
return Err("Key cannot contain '..'".to_string());
}
if key.starts_with('/') {
return Err("Key cannot start with '/'".to_string());
}
Ok(())
}
pub async fn read_local_team_memory() -> Result<HashMap<String, String>, AgentError> {
let dir = get_team_memory_dir();
if !dir.exists() {
return Ok(HashMap::new());
}
let mut entries = HashMap::new();
let mut dirs_to_process: Vec<PathBuf> = vec![dir.clone()];
while let Some(current_dir) = dirs_to_process.pop() {
let mut read_dir = tokio::fs::read_dir(¤t_dir)
.await
.map_err(AgentError::Io)?;
while let Some(entry) = read_dir.next_entry().await.map_err(AgentError::Io)? {
let path = entry.path();
let relative = path
.strip_prefix(&dir)
.map_err(|_| AgentError::Internal("Failed to get relative path".to_string()))?
.to_string_lossy()
.to_string();
if path.is_dir() {
dirs_to_process.push(path);
} else if path.is_file() {
if relative.starts_with('.') {
continue;
}
let content = tokio::fs::read_to_string(&path)
.await
.map_err(AgentError::Io)?;
entries.insert(relative, content);
}
}
}
Ok(entries)
}
pub async fn write_local_team_memory(entries: &HashMap<String, String>) -> Result<(), AgentError> {
let dir = get_team_memory_dir();
tokio::fs::create_dir_all(&dir)
.await
.map_err(AgentError::Io)?;
for (key, content) in entries {
let path = get_team_memory_path(key);
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent)
.await
.map_err(AgentError::Io)?;
}
tokio::fs::write(&path, content)
.await
.map_err(AgentError::Io)?;
}
Ok(())
}
pub async fn delete_local_team_memory_entry(key: &str) -> Result<(), AgentError> {
let path = get_team_memory_path(key);
if path.exists() {
tokio::fs::remove_file(path).await.map_err(AgentError::Io)?;
}
Ok(())
}
pub fn compute_delta(
local_entries: &HashMap<String, String>,
server_checksums: &HashMap<String, String>,
) -> HashMap<String, String> {
let mut delta = HashMap::new();
for (key, content) in local_entries {
let local_hash = hash_content(content);
let server_hash = server_checksums.get(key);
if server_hash.is_none() || server_hash != Some(&local_hash) {
delta.insert(key.clone(), content.clone());
}
}
delta
}
pub fn batch_delta_by_bytes(
delta: &HashMap<String, String>,
max_bytes: usize,
) -> Vec<HashMap<String, String>> {
let mut batches: Vec<HashMap<String, String>> = Vec::new();
let mut current_batch: HashMap<String, String> = HashMap::new();
let mut current_bytes: usize = 0;
let mut keys: Vec<&String> = delta.keys().collect();
keys.sort();
for key in keys {
let content = delta.get(key).unwrap();
let entry_bytes = key.len() + content.len();
if entry_bytes > max_bytes {
if !current_batch.is_empty() {
batches.push(current_batch);
current_batch = HashMap::new();
current_bytes = 0;
}
let mut single = HashMap::new();
single.insert(key.clone(), content.clone());
batches.push(single);
continue;
}
if current_bytes + entry_bytes > max_bytes && !current_batch.is_empty() {
batches.push(current_batch);
current_batch = HashMap::new();
current_bytes = 0;
}
current_batch.insert(key.clone(), content.clone());
current_bytes += entry_bytes;
}
if !current_batch.is_empty() {
batches.push(current_batch);
}
batches
}
pub fn is_team_memory_sync_available() -> bool {
false
}
pub async fn pull_team_memory(
_state: &mut SyncState,
_repo_slug: &str,
) -> Result<TeamMemorySyncFetchResult, AgentError> {
Ok(TeamMemorySyncFetchResult {
success: false,
data: None,
is_empty: None,
not_modified: None,
checksum: None,
error: Some("Team memory sync requires OAuth authentication".to_string()),
skip_retry: Some(true),
error_type: Some("auth".to_string()),
http_status: None,
})
}
pub async fn push_team_memory(
_state: &mut SyncState,
_repo_slug: &str,
_entries: &HashMap<String, String>,
) -> Result<TeamMemorySyncPushResult, AgentError> {
Ok(TeamMemorySyncPushResult {
success: false,
files_uploaded: 0,
checksum: None,
conflict: None,
error: Some("Team memory sync requires OAuth authentication".to_string()),
skipped_secrets: Vec::new(),
error_type: Some("auth".to_string()),
http_status: None,
})
}
pub async fn sync_team_memory(
state: &mut SyncState,
repo_slug: &str,
) -> Result<TeamMemorySyncPushResult, AgentError> {
let pull_result = pull_team_memory(state, repo_slug).await?;
if !pull_result.success {
return Ok(TeamMemorySyncPushResult {
success: false,
files_uploaded: 0,
checksum: None,
conflict: None,
error: pull_result.error,
skipped_secrets: Vec::new(),
error_type: pull_result.error_type,
http_status: pull_result.http_status,
});
}
let local_entries = read_local_team_memory().await?;
let delta = compute_delta(&local_entries, &state.server_checksums);
if delta.is_empty() {
return Ok(TeamMemorySyncPushResult {
success: true,
files_uploaded: 0,
checksum: state.last_known_checksum.clone(),
conflict: None,
error: None,
skipped_secrets: Vec::new(),
error_type: None,
http_status: None,
});
}
push_team_memory(state, repo_slug, &delta).await
}
pub fn scan_for_secrets(_content: &str, _path: &str) -> Option<SkippedSecretFile> {
None
}
pub fn scan_entries_for_secrets(entries: &HashMap<String, String>) -> Vec<SkippedSecretFile> {
let mut skipped = Vec::new();
for (path, content) in entries {
if let Some(secret) = scan_for_secrets(content, path) {
skipped.push(secret);
}
}
skipped
}
static TEAM_MEMORY_ENABLED: AtomicBool = AtomicBool::new(false);
pub fn is_team_memory_enabled() -> bool {
TEAM_MEMORY_ENABLED.load(Ordering::SeqCst)
}
pub fn enable_team_memory() {
TEAM_MEMORY_ENABLED.store(true, Ordering::SeqCst);
}
pub fn disable_team_memory() {
TEAM_MEMORY_ENABLED.store(false, Ordering::SeqCst);
}
static LAST_SYNC_ERROR: Mutex<Option<String>> = Mutex::new(None);
pub fn set_last_sync_error(error: Option<String>) {
*LAST_SYNC_ERROR.lock().unwrap() = error;
}
pub fn get_last_sync_error() -> Option<String> {
LAST_SYNC_ERROR.lock().unwrap().clone()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_sync_state() {
let state = create_sync_state();
assert!(state.last_known_checksum.is_none());
assert!(state.server_checksums.is_empty());
assert!(state.server_max_entries.is_none());
}
#[test]
fn test_hash_content() {
let hash1 = hash_content("hello");
let hash2 = hash_content("hello");
let hash3 = hash_content("world");
assert!(hash1.starts_with("sha256:"));
assert_eq!(hash1, hash2);
assert_ne!(hash1, hash3);
}
#[test]
fn test_validate_team_memory_key() {
assert!(validate_team_memory_key("MEMORY.md").is_ok());
assert!(validate_team_memory_key("subdir/notes.md").is_ok());
assert!(validate_team_memory_key("").is_err());
assert!(validate_team_memory_key("../etc/passwd").is_err());
assert!(validate_team_memory_key("/absolute/path").is_err());
}
#[test]
fn test_compute_delta() {
let local = HashMap::from([
("a.txt".to_string(), "content1".to_string()),
("b.txt".to_string(), "content2".to_string()),
("c.txt".to_string(), "content3".to_string()),
]);
let server = HashMap::from([
("a.txt".to_string(), hash_content("content1")), ("b.txt".to_string(), hash_content("different")), ]);
let delta = compute_delta(&local, &server);
assert!(delta.contains_key("b.txt")); assert!(delta.contains_key("c.txt")); assert!(!delta.contains_key("a.txt")); }
#[test]
fn test_batch_delta_by_bytes() {
let delta = HashMap::from([
("a.txt".to_string(), "x".repeat(100)),
("b.txt".to_string(), "y".repeat(100)),
("c.txt".to_string(), "z".repeat(250)), ]);
let batches = batch_delta_by_bytes(&delta, 150);
assert!(batches.len() >= 2);
}
#[test]
fn test_team_memory_enabled() {
disable_team_memory();
assert!(!is_team_memory_enabled());
enable_team_memory();
assert!(is_team_memory_enabled());
disable_team_memory();
assert!(!is_team_memory_enabled());
}
#[test]
fn test_last_sync_error() {
set_last_sync_error(None);
assert!(get_last_sync_error().is_none());
set_last_sync_error(Some("test error".to_string()));
assert_eq!(get_last_sync_error(), Some("test error".to_string()));
}
}