1use crate::cli::CliOutput;
10use crate::{db, identity, models, tls, validate};
11use anyhow::Result;
12use clap::Args;
13use std::path::{Path, PathBuf};
14use std::sync::Arc;
15use tracing_subscriber::EnvFilter;
16
17#[derive(Args)]
18pub struct SyncArgs {
19 pub remote_db: PathBuf,
21 #[arg(long, short, default_value = "merge")]
23 pub direction: String,
24 #[arg(long, default_value_t = false)]
27 pub trust_source: bool,
28 #[arg(long, default_value_t = false)]
33 pub dry_run: bool,
34}
35
36#[derive(Args)]
37pub struct SyncDaemonArgs {
38 #[arg(long, value_delimiter = ',')]
40 pub peers: Vec<String>,
41 #[arg(long, default_value_t = 2)]
43 pub interval: u64,
44 #[arg(long)]
46 pub api_key: Option<String>,
47 #[arg(long, default_value_t = 500)]
49 pub batch_size: usize,
50 #[arg(long, requires = "client_key")]
52 pub client_cert: Option<PathBuf>,
53 #[arg(long, requires = "client_cert")]
55 pub client_key: Option<PathBuf>,
56 #[arg(long, default_value_t = false)]
59 pub insecure_skip_server_verify: bool,
60}
61
62fn restamp_agent_id(mem: &mut models::Memory, caller_id: &str) {
66 let original = mem
67 .metadata
68 .get("agent_id")
69 .and_then(serde_json::Value::as_str)
70 .map(ToString::to_string);
71 if let Some(obj) = mem.metadata.as_object_mut() {
72 obj.insert(
73 "agent_id".to_string(),
74 serde_json::Value::String(caller_id.to_string()),
75 );
76 if let Some(orig) = original
77 && orig != caller_id
78 {
79 obj.insert(
80 "imported_from_agent_id".to_string(),
81 serde_json::Value::String(orig),
82 );
83 }
84 }
85}
86
87#[derive(Default)]
88struct SyncPreview {
89 would_pull_new: usize,
90 would_pull_update: usize,
91 would_pull_noop: usize,
92 would_push_new: usize,
93 would_push_update: usize,
94 would_push_noop: usize,
95 would_pull_links: usize,
96 would_push_links: usize,
97}
98
99impl SyncPreview {
100 fn classify(local: Option<&models::Memory>, remote: &models::Memory) -> MergeOutcome {
101 match local {
102 None => MergeOutcome::New,
103 Some(existing) => {
104 if remote.updated_at > existing.updated_at {
105 MergeOutcome::Update
106 } else {
107 MergeOutcome::Noop
108 }
109 }
110 }
111 }
112}
113
114enum MergeOutcome {
115 New,
116 Update,
117 Noop,
118}
119
120#[allow(clippy::too_many_lines)]
122pub fn run(
123 db_path: &Path,
124 args: &SyncArgs,
125 json_out: bool,
126 cli_agent_id: Option<&str>,
127 out: &mut CliOutput<'_>,
128) -> Result<()> {
129 let local_conn = db::open(db_path)?;
130 let remote_conn = db::open(&args.remote_db)?;
131 let caller_id = identity::resolve_agent_id(cli_agent_id, None)?;
132
133 if args.dry_run {
134 return cmd_sync_dry_run(&local_conn, &remote_conn, &args.direction, json_out, out);
135 }
136
137 match args.direction.as_str() {
138 "pull" => {
139 let mems = db::export_all(&remote_conn)?;
140 let links = db::export_links(&remote_conn)?;
141 let mut n = 0;
142 for mem in &mems {
143 let mut owned = mem.clone();
144 if !args.trust_source {
145 restamp_agent_id(&mut owned, &caller_id);
146 }
147 if let Err(e) = validate::validate_memory(&owned) {
148 tracing::warn!("sync: skipping invalid memory {}: {}", owned.id, e);
149 continue;
150 }
151 if db::insert(&local_conn, &owned).is_ok() {
152 n += 1;
153 }
154 }
155 for link in &links {
156 if validate::validate_link(&link.source_id, &link.target_id, &link.relation)
157 .is_err()
158 {
159 continue;
160 }
161 let _ = db::create_link(
162 &local_conn,
163 &link.source_id,
164 &link.target_id,
165 &link.relation,
166 );
167 }
168 if json_out {
169 writeln!(
170 out.stdout,
171 "{}",
172 serde_json::json!({"direction": "pull", "imported": n})
173 )?;
174 } else {
175 writeln!(out.stdout, "pulled {n} memories from remote")?;
176 }
177 }
178 "push" => {
179 let mems = db::export_all(&local_conn)?;
180 let links = db::export_links(&local_conn)?;
181 let mut n = 0;
182 for mem in &mems {
183 if let Err(e) = validate::validate_memory(mem) {
184 tracing::warn!("sync: skipping invalid memory {}: {}", mem.id, e);
185 continue;
186 }
187 if db::insert(&remote_conn, mem).is_ok() {
188 n += 1;
189 }
190 }
191 for link in &links {
192 if validate::validate_link(&link.source_id, &link.target_id, &link.relation)
193 .is_err()
194 {
195 continue;
196 }
197 let _ = db::create_link(
198 &remote_conn,
199 &link.source_id,
200 &link.target_id,
201 &link.relation,
202 );
203 }
204 if json_out {
205 writeln!(
206 out.stdout,
207 "{}",
208 serde_json::json!({"direction": "push", "exported": n})
209 )?;
210 } else {
211 writeln!(out.stdout, "pushed {n} memories to remote")?;
212 }
213 }
214 "merge" => {
215 let r_mems = db::export_all(&remote_conn)?;
216 let r_links = db::export_links(&remote_conn)?;
217 let l_mems = db::export_all(&local_conn)?;
218 let l_links = db::export_links(&local_conn)?;
219 let (mut pulled, mut pushed) = (0, 0);
220 for mem in &r_mems {
221 let mut owned = mem.clone();
222 if !args.trust_source {
223 restamp_agent_id(&mut owned, &caller_id);
224 }
225 if validate::validate_memory(&owned).is_err() {
226 continue;
227 }
228 if db::insert_if_newer(&local_conn, &owned).is_ok() {
229 pulled += 1;
230 }
231 }
232 for link in &r_links {
233 if validate::validate_link(&link.source_id, &link.target_id, &link.relation)
234 .is_err()
235 {
236 continue;
237 }
238 let _ = db::create_link(
239 &local_conn,
240 &link.source_id,
241 &link.target_id,
242 &link.relation,
243 );
244 }
245 for mem in &l_mems {
246 if validate::validate_memory(mem).is_err() {
247 continue;
248 }
249 if db::insert_if_newer(&remote_conn, mem).is_ok() {
250 pushed += 1;
251 }
252 }
253 for link in &l_links {
254 if validate::validate_link(&link.source_id, &link.target_id, &link.relation)
255 .is_err()
256 {
257 continue;
258 }
259 let _ = db::create_link(
260 &remote_conn,
261 &link.source_id,
262 &link.target_id,
263 &link.relation,
264 );
265 }
266 if json_out {
267 writeln!(
268 out.stdout,
269 "{}",
270 serde_json::json!({"direction": "merge", "pulled": pulled, "pushed": pushed})
271 )?;
272 } else {
273 writeln!(out.stdout, "merged: pulled {pulled}, pushed {pushed}")?;
274 }
275 }
276 _ => anyhow::bail!(
277 "invalid direction: {} (use pull, push, merge)",
278 args.direction
279 ),
280 }
281 Ok(())
282}
283
284fn cmd_sync_dry_run(
285 local_conn: &rusqlite::Connection,
286 remote_conn: &rusqlite::Connection,
287 direction: &str,
288 json_out: bool,
289 out: &mut CliOutput<'_>,
290) -> Result<()> {
291 let l_mems = db::export_all(local_conn)?;
292 let r_mems = db::export_all(remote_conn)?;
293 let l_links = db::export_links(local_conn)?;
294 let r_links = db::export_links(remote_conn)?;
295
296 let local_by_id: std::collections::HashMap<&str, &models::Memory> =
297 l_mems.iter().map(|m| (m.id.as_str(), m)).collect();
298 let remote_by_id: std::collections::HashMap<&str, &models::Memory> =
299 r_mems.iter().map(|m| (m.id.as_str(), m)).collect();
300
301 let mut preview = SyncPreview::default();
302
303 let classify_pull = direction != "push";
304 let classify_push = direction != "pull";
305
306 if classify_pull {
307 for mem in &r_mems {
308 match SyncPreview::classify(local_by_id.get(mem.id.as_str()).copied(), mem) {
309 MergeOutcome::New => preview.would_pull_new += 1,
310 MergeOutcome::Update => preview.would_pull_update += 1,
311 MergeOutcome::Noop => preview.would_pull_noop += 1,
312 }
313 }
314 preview.would_pull_links = r_links.len();
315 }
316
317 if classify_push {
318 for mem in &l_mems {
319 match SyncPreview::classify(remote_by_id.get(mem.id.as_str()).copied(), mem) {
320 MergeOutcome::New => preview.would_push_new += 1,
321 MergeOutcome::Update => preview.would_push_update += 1,
322 MergeOutcome::Noop => preview.would_push_noop += 1,
323 }
324 }
325 preview.would_push_links = l_links.len();
326 }
327
328 if json_out {
329 writeln!(
330 out.stdout,
331 "{}",
332 serde_json::json!({
333 "dry_run": true,
334 "direction": direction,
335 "pull": {
336 "new": preview.would_pull_new,
337 "update": preview.would_pull_update,
338 "noop": preview.would_pull_noop,
339 "links": preview.would_pull_links,
340 },
341 "push": {
342 "new": preview.would_push_new,
343 "update": preview.would_push_update,
344 "noop": preview.would_push_noop,
345 "links": preview.would_push_links,
346 }
347 })
348 )?;
349 } else {
350 writeln!(
351 out.stdout,
352 "DRY RUN — no changes written. Direction: {direction}"
353 )?;
354 if classify_pull {
355 writeln!(
356 out.stdout,
357 " pull: {} new, {} update, {} noop, {} links",
358 preview.would_pull_new,
359 preview.would_pull_update,
360 preview.would_pull_noop,
361 preview.would_pull_links
362 )?;
363 }
364 if classify_push {
365 writeln!(
366 out.stdout,
367 " push: {} new, {} update, {} noop, {} links",
368 preview.would_push_new,
369 preview.would_push_update,
370 preview.would_push_noop,
371 preview.would_push_links
372 )?;
373 }
374 }
375 Ok(())
376}
377
378pub async fn run_daemon(
380 db_path: &Path,
381 args: SyncDaemonArgs,
382 cli_agent_id: Option<&str>,
383) -> Result<()> {
384 if args.peers.is_empty() {
385 anyhow::bail!("at least one --peers URL is required");
386 }
387 let interval = args.interval.max(1);
388 let batch_size = args.batch_size.max(1);
389 let local_agent_id = identity::resolve_agent_id(cli_agent_id, None)?;
390
391 let _ = tracing_subscriber::fmt()
392 .with_env_filter(
393 EnvFilter::from_default_env()
394 .add_directive("ai_memory=info".parse()?)
395 .add_directive("tower_http=info".parse()?),
396 )
397 .try_init();
398
399 let _ = rustls::crypto::ring::default_provider().install_default();
400 if args.insecure_skip_server_verify && (args.client_cert.is_none() || args.client_key.is_none())
401 {
402 anyhow::bail!(
403 "sync-daemon: --insecure-skip-server-verify requires both --client-cert \
404 and --client-key as a compensating mTLS control. Running with neither side \
405 of the TLS handshake verified is an open MITM surface and is refused."
406 );
407 }
408
409 let client = if let (Some(cert_path), Some(key_path)) = (&args.client_cert, &args.client_key) {
410 let rustls_config = tls::build_rustls_client_config(cert_path, key_path).await?;
411 let mut builder = reqwest::Client::builder()
412 .timeout(std::time::Duration::from_secs(30))
413 .use_preconfigured_tls(rustls_config);
414 if args.insecure_skip_server_verify {
415 tracing::warn!(
416 "sync-daemon: --insecure-skip-server-verify set with --client-cert — \
417 peer server certificates will NOT be validated; peer authenticates us \
418 via mTLS allowlist (compensating control). Do NOT use in production."
419 );
420 builder = builder.danger_accept_invalid_certs(true);
421 }
422 builder.build()?
423 } else {
424 reqwest::Client::builder()
425 .timeout(std::time::Duration::from_secs(30))
426 .build()?
427 };
428
429 tracing::info!(
430 "sync-daemon: local_agent_id={local_agent_id} peers={peers:?} interval={interval}s",
431 peers = args.peers
432 );
433
434 let shutdown = Arc::new(tokio::sync::Notify::new());
435 let shutdown_for_signal = shutdown.clone();
436 tokio::spawn(async move {
437 let _ = tokio::signal::ctrl_c().await;
438 shutdown_for_signal.notify_one();
439 });
440
441 crate::daemon_runtime::run_sync_daemon_with_shutdown_using_client(
442 client,
443 db_path.to_path_buf(),
444 local_agent_id,
445 args.peers,
446 args.api_key,
447 interval,
448 batch_size,
449 shutdown,
450 )
451 .await
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use crate::cli::test_utils::{TestEnv, seed_memory};
458
459 fn args_for(remote_db: PathBuf, direction: &str) -> SyncArgs {
460 SyncArgs {
461 remote_db,
462 direction: direction.to_string(),
463 trust_source: false,
464 dry_run: false,
465 }
466 }
467
468 #[test]
469 fn test_sync_dry_run_merge() {
470 let mut env = TestEnv::fresh();
471 let local = env.db_path.clone();
472 let remote_env = TestEnv::fresh();
473 let remote = remote_env.db_path.clone();
474 seed_memory(&local, "ns", "local-only", "L");
475 seed_memory(&remote, "ns", "remote-only", "R");
476 let mut args = args_for(remote, "merge");
477 args.dry_run = true;
478 {
479 let mut out = env.output();
480 run(&local, &args, true, Some("test-agent"), &mut out).unwrap();
481 }
482 let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
483 assert_eq!(v["dry_run"].as_bool().unwrap(), true);
484 assert_eq!(v["direction"].as_str().unwrap(), "merge");
485 }
486
487 #[test]
488 fn test_sync_pull_direction() {
489 let mut env = TestEnv::fresh();
490 let local = env.db_path.clone();
491 let remote_env = TestEnv::fresh();
492 let remote = remote_env.db_path.clone();
493 seed_memory(&remote, "ns", "from-remote", "data");
494 let args = args_for(remote, "pull");
495 {
496 let mut out = env.output();
497 run(&local, &args, false, Some("test-agent"), &mut out).unwrap();
498 }
499 assert!(env.stdout_str().contains("pulled"));
500 }
501
502 #[test]
503 fn test_sync_push_direction() {
504 let mut env = TestEnv::fresh();
505 let local = env.db_path.clone();
506 let remote_env = TestEnv::fresh();
507 let remote = remote_env.db_path.clone();
508 seed_memory(&local, "ns", "to-remote", "data");
509 let args = args_for(remote, "push");
510 {
511 let mut out = env.output();
512 run(&local, &args, false, Some("test-agent"), &mut out).unwrap();
513 }
514 assert!(env.stdout_str().contains("pushed"));
515 }
516
517 #[test]
518 fn test_sync_merge_direction() {
519 let mut env = TestEnv::fresh();
520 let local = env.db_path.clone();
521 let remote_env = TestEnv::fresh();
522 let remote = remote_env.db_path.clone();
523 seed_memory(&local, "ns", "L", "L");
524 seed_memory(&remote, "ns", "R", "R");
525 let args = args_for(remote, "merge");
526 {
527 let mut out = env.output();
528 run(&local, &args, false, Some("test-agent"), &mut out).unwrap();
529 }
530 assert!(env.stdout_str().contains("merged:"));
531 }
532
533 #[test]
534 fn test_sync_invalid_direction_errors() {
535 let mut env = TestEnv::fresh();
536 let local = env.db_path.clone();
537 let remote_env = TestEnv::fresh();
538 let remote = remote_env.db_path.clone();
539 let args = args_for(remote, "sideways");
540 let mut out = env.output();
541 let res = run(&local, &args, false, Some("test-agent"), &mut out);
542 assert!(res.is_err());
543 }
544
545 #[test]
546 fn test_sync_dry_run_pull_only() {
547 let mut env = TestEnv::fresh();
548 let local = env.db_path.clone();
549 let remote_env = TestEnv::fresh();
550 let remote = remote_env.db_path.clone();
551 seed_memory(&remote, "ns", "remote", "x");
552 let mut args = args_for(remote, "pull");
553 args.dry_run = true;
554 {
555 let mut out = env.output();
556 run(&local, &args, true, Some("test-agent"), &mut out).unwrap();
557 }
558 let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
559 assert!(v["pull"]["new"].as_u64().unwrap() >= 1);
560 }
561
562 #[test]
563 fn test_restamp_agent_id_preserves_original() {
564 let mut mem = models::Memory {
565 id: "m1".to_string(),
566 tier: models::Tier::Mid,
567 namespace: "ns".to_string(),
568 title: "t".to_string(),
569 content: "c".to_string(),
570 tags: vec![],
571 priority: 5,
572 confidence: 1.0,
573 source: "test".to_string(),
574 access_count: 0,
575 created_at: "2026-01-01T00:00:00Z".to_string(),
576 updated_at: "2026-01-01T00:00:00Z".to_string(),
577 last_accessed_at: None,
578 expires_at: None,
579 metadata: serde_json::json!({"agent_id": "remote-agent"}),
580 };
581 restamp_agent_id(&mut mem, "local-agent");
582 assert_eq!(mem.metadata["agent_id"].as_str().unwrap(), "local-agent");
583 assert_eq!(
584 mem.metadata["imported_from_agent_id"].as_str().unwrap(),
585 "remote-agent"
586 );
587 }
588
589 #[test]
590 fn test_restamp_same_agent_no_imported_from() {
591 let mut mem = models::Memory {
592 id: "m1".to_string(),
593 tier: models::Tier::Mid,
594 namespace: "ns".to_string(),
595 title: "t".to_string(),
596 content: "c".to_string(),
597 tags: vec![],
598 priority: 5,
599 confidence: 1.0,
600 source: "test".to_string(),
601 access_count: 0,
602 created_at: "2026-01-01T00:00:00Z".to_string(),
603 updated_at: "2026-01-01T00:00:00Z".to_string(),
604 last_accessed_at: None,
605 expires_at: None,
606 metadata: serde_json::json!({"agent_id": "same-agent"}),
607 };
608 restamp_agent_id(&mut mem, "same-agent");
609 assert_eq!(mem.metadata["agent_id"].as_str().unwrap(), "same-agent");
610 assert!(mem.metadata.get("imported_from_agent_id").is_none());
611 }
612
613 #[tokio::test]
614 async fn test_sync_daemon_empty_peers_errors() {
615 let env = TestEnv::fresh();
616 let db = env.db_path.clone();
617 let args = SyncDaemonArgs {
618 peers: Vec::new(),
619 interval: 2,
620 api_key: None,
621 batch_size: 500,
622 client_cert: None,
623 client_key: None,
624 insecure_skip_server_verify: false,
625 };
626 let res = run_daemon(&db, args, Some("test-agent")).await;
627 assert!(res.is_err());
628 assert!(res.unwrap_err().to_string().contains("--peers"));
629 }
630
631 #[tokio::test]
632 async fn test_sync_daemon_insecure_without_mtls_errors() {
633 let env = TestEnv::fresh();
634 let db = env.db_path.clone();
635 let args = SyncDaemonArgs {
636 peers: vec!["http://example.com:9077".to_string()],
637 interval: 2,
638 api_key: None,
639 batch_size: 500,
640 client_cert: None,
641 client_key: None,
642 insecure_skip_server_verify: true,
643 };
644 let res = run_daemon(&db, args, Some("test-agent")).await;
645 assert!(res.is_err());
646 assert!(
647 res.unwrap_err()
648 .to_string()
649 .contains("insecure-skip-server-verify")
650 );
651 }
652}