use std::path::{Path, PathBuf};
use anyhow::{anyhow, bail, Context, Result};
use super::types::{MemorySourceEntry, SourceKind};
pub fn memory_sync_defaults_for_toolkit(toolkit: &str) -> (Option<u32>, Option<u32>) {
match toolkit {
"gmail" => (Some(100), Some(30)),
"slack" => (Some(50), Some(14)),
"notion" => (Some(30), Some(30)),
"linear" => (Some(50), Some(30)),
"clickup" => (Some(50), Some(30)),
"github" => (Some(50), Some(30)),
_ => (Some(30), Some(14)),
}
}
#[derive(Debug, Clone)]
pub struct SourceRegistry {
path: PathBuf,
}
impl SourceRegistry {
pub fn new(config_path: impl Into<PathBuf>) -> Self {
Self {
path: config_path.into(),
}
}
pub fn path(&self) -> &Path {
&self.path
}
fn read_table(&self) -> Result<toml::Table> {
if !self.path.exists() {
return Ok(toml::Table::new());
}
let text = std::fs::read_to_string(&self.path)
.with_context(|| format!("failed to read {}", self.path.display()))?;
let table: toml::Table = toml::from_str(&text)
.with_context(|| format!("failed to parse {}", self.path.display()))?;
Ok(table)
}
pub fn list(&self) -> Result<Vec<MemorySourceEntry>> {
let table = self.read_table()?;
match table.get("memory_sources") {
Some(value) => value
.clone()
.try_into()
.context("failed to decode [[memory_sources]]"),
None => Ok(Vec::new()),
}
}
pub fn list_enabled_by_kind(&self, kind: SourceKind) -> Result<Vec<MemorySourceEntry>> {
Ok(self
.list()?
.into_iter()
.filter(|s| s.kind == kind && s.enabled)
.collect())
}
pub fn get(&self, id: &str) -> Result<Option<MemorySourceEntry>> {
Ok(self.list()?.into_iter().find(|s| s.id == id))
}
fn write_all(&self, entries: &[MemorySourceEntry]) -> Result<()> {
let mut table = self.read_table()?;
let value = toml::Value::try_from(entries).context("failed to encode memory_sources")?;
table.insert("memory_sources".to_string(), value);
let text = toml::to_string_pretty(&table).context("failed to serialize config")?;
if let Some(parent) = self.path.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent)
.with_context(|| format!("failed to create {}", parent.display()))?;
}
}
self.atomic_write(text.as_bytes())?;
Ok(())
}
fn atomic_write(&self, bytes: &[u8]) -> Result<()> {
let parent = self
.path
.parent()
.filter(|p| !p.as_os_str().is_empty())
.unwrap_or_else(|| Path::new("."));
let filename = self
.path
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| anyhow!("config path has no file name: {}", self.path.display()))?;
let tmp_path = parent.join(format!(
".{filename}.tmp-{}",
uuid::Uuid::new_v4().as_simple()
));
let write_result = (|| -> Result<()> {
{
let mut file = std::fs::File::create(&tmp_path)
.with_context(|| format!("failed to create {}", tmp_path.display()))?;
use std::io::Write;
file.write_all(bytes)
.with_context(|| format!("failed to write {}", tmp_path.display()))?;
file.sync_all()
.with_context(|| format!("failed to sync {}", tmp_path.display()))?;
}
std::fs::rename(&tmp_path, &self.path).with_context(|| {
format!(
"failed to atomically replace {} with {}",
self.path.display(),
tmp_path.display()
)
})?;
Ok(())
})();
if write_result.is_err() {
let _ = std::fs::remove_file(&tmp_path);
}
write_result
}
pub fn add(&self, entry: MemorySourceEntry) -> Result<MemorySourceEntry> {
entry.validate().map_err(|e| anyhow!(e))?;
let mut sources = self.list()?;
if sources.iter().any(|s| s.id == entry.id) {
bail!("source with id '{}' already exists", entry.id);
}
sources.push(entry.clone());
self.write_all(&sources)?;
Ok(entry)
}
pub fn update(&self, id: &str, patch: MemorySourcePatch) -> Result<MemorySourceEntry> {
let mut sources = self.list()?;
let entry = sources
.iter_mut()
.find(|s| s.id == id)
.ok_or_else(|| anyhow!("source '{id}' not found"))?;
patch.apply_to(entry);
entry.validate().map_err(|e| anyhow!(e))?;
let updated = entry.clone();
self.write_all(&sources)?;
Ok(updated)
}
pub fn remove(&self, id: &str) -> Result<bool> {
let mut sources = self.list()?;
let before = sources.len();
sources.retain(|s| s.id != id);
let removed = sources.len() < before;
if removed {
self.write_all(&sources)?;
}
Ok(removed)
}
pub fn remove_composio_source_by_connection_id(&self, connection_id: &str) -> Result<usize> {
let mut sources = self.list()?;
let before = sources.len();
sources.retain(|s| {
!(s.kind == SourceKind::Composio && s.connection_id.as_deref() == Some(connection_id))
});
let removed = before - sources.len();
if removed > 0 {
self.write_all(&sources)?;
}
Ok(removed)
}
pub fn upsert_composio_source(
&self,
toolkit: &str,
connection_id: &str,
label: &str,
) -> Result<MemorySourceEntry> {
let mut sources = self.list()?;
let (entry, _was_insert) =
upsert_composio_entry_in_place(&mut sources, toolkit, connection_id, label);
self.write_all(&sources)?;
Ok(entry)
}
pub fn apply_all_in(&self) -> Result<Vec<MemorySourceEntry>> {
let mut sources = self.list()?;
for source in &mut sources {
source.enabled = true;
source.max_items = None;
source.since_days = None;
source.sync_depth_days = None;
source.max_commits = None;
source.max_issues = None;
source.max_prs = None;
source.max_tokens_per_sync = None;
source.max_cost_per_sync_usd = None;
}
self.write_all(&sources)?;
Ok(sources)
}
}
pub(crate) fn upsert_composio_entry_in_place(
sources: &mut Vec<MemorySourceEntry>,
toolkit: &str,
connection_id: &str,
label: &str,
) -> (MemorySourceEntry, bool) {
if let Some(existing) = sources.iter_mut().find(|s| {
s.kind == SourceKind::Composio && s.connection_id.as_deref() == Some(connection_id)
}) {
existing.label = label.to_string();
return (existing.clone(), false);
}
let (default_max_items, default_sync_depth_days) = memory_sync_defaults_for_toolkit(toolkit);
let entry = MemorySourceEntry {
id: format!("src_{}", uuid::Uuid::new_v4().as_simple()),
kind: SourceKind::Composio,
label: label.to_string(),
enabled: true,
toolkit: Some(toolkit.to_string()),
connection_id: Some(connection_id.to_string()),
path: None,
glob: None,
url: None,
branch: None,
paths: Vec::new(),
max_commits: None,
max_issues: None,
max_prs: None,
query: None,
since_days: None,
max_items: default_max_items,
selector: None,
max_tokens_per_sync: None,
max_cost_per_sync_usd: None,
sync_depth_days: default_sync_depth_days,
};
sources.push(entry.clone());
(entry, true)
}
#[derive(Debug, Default, serde::Deserialize)]
pub struct MemorySourcePatch {
#[serde(default)]
pub label: Option<String>,
#[serde(default)]
pub enabled: Option<bool>,
#[serde(default)]
pub toolkit: Option<String>,
#[serde(default)]
pub connection_id: Option<String>,
#[serde(default)]
pub path: Option<String>,
#[serde(default)]
pub glob: Option<String>,
#[serde(default)]
pub url: Option<String>,
#[serde(default)]
pub branch: Option<String>,
#[serde(default)]
pub paths: Option<Vec<String>>,
#[serde(default)]
pub query: Option<String>,
#[serde(default)]
pub since_days: Option<u32>,
#[serde(default)]
pub max_items: Option<u32>,
#[serde(default)]
pub selector: Option<String>,
#[serde(default)]
pub max_tokens_per_sync: Option<u64>,
#[serde(default)]
pub max_cost_per_sync_usd: Option<f64>,
#[serde(default)]
pub sync_depth_days: Option<u32>,
#[serde(default)]
pub max_commits: Option<u32>,
#[serde(default)]
pub max_issues: Option<u32>,
#[serde(default)]
pub max_prs: Option<u32>,
}
impl MemorySourcePatch {
fn apply_to(self, entry: &mut MemorySourceEntry) {
if let Some(label) = self.label {
entry.label = label;
}
if let Some(enabled) = self.enabled {
entry.enabled = enabled;
}
if let Some(toolkit) = self.toolkit {
entry.toolkit = Some(toolkit);
}
if let Some(connection_id) = self.connection_id {
entry.connection_id = Some(connection_id);
}
if let Some(path) = self.path {
entry.path = Some(path);
}
if let Some(glob) = self.glob {
entry.glob = Some(glob);
}
if let Some(url) = self.url {
entry.url = Some(url);
}
if let Some(branch) = self.branch {
entry.branch = Some(branch);
}
if let Some(paths) = self.paths {
entry.paths = paths;
}
if let Some(query) = self.query {
entry.query = Some(query);
}
if let Some(since_days) = self.since_days {
entry.since_days = Some(since_days);
}
if let Some(max_items) = self.max_items {
entry.max_items = Some(max_items);
}
if let Some(selector) = self.selector {
entry.selector = Some(selector);
}
if let Some(v) = self.max_tokens_per_sync {
entry.max_tokens_per_sync = Some(v);
}
if let Some(v) = self.max_cost_per_sync_usd {
entry.max_cost_per_sync_usd = Some(v);
}
if let Some(v) = self.sync_depth_days {
entry.sync_depth_days = Some(v);
}
if let Some(v) = self.max_commits {
entry.max_commits = Some(v);
}
if let Some(v) = self.max_issues {
entry.max_issues = Some(v);
}
if let Some(v) = self.max_prs {
entry.max_prs = Some(v);
}
}
}
#[cfg(test)]
#[path = "registry_tests.rs"]
mod tests;