1use crate::ci::{CiPipelineSpec, CiRunRecord, CiSnapshot};
12use crate::error::StateError;
13use crate::schema::{
14 AgentRecord, BranchRecord, CommitId, CommitRecord, DecisionRecord, GraphEdge,
15 MemoryProvenanceRecord, MemoryRecord, SnapshotRecord,
16};
17use crate::storage_traits::{ContentDigest, ReleaseMetadata, ReleaseRecord, StorageResult};
18use crate::Result;
19use crate::StorageError;
20use chrono::{DateTime, Utc};
21use serde::{Deserialize, Serialize};
22use surrealdb::engine::any::Any;
23use surrealdb::opt::auth::{Database, Root};
24use surrealdb::sql::Datetime as SurrealDatetime;
25use surrealdb::Surreal;
26use tracing::{debug, info, instrument};
27
28#[derive(Debug, Clone)]
30pub struct CloudConfig {
31 pub endpoint: String,
33 pub username: String,
35 pub password: String,
37 pub namespace: String,
39 pub database: String,
41 pub is_root: bool,
43}
44
45impl CloudConfig {
46 pub fn new(
48 endpoint: impl Into<String>,
49 username: impl Into<String>,
50 password: impl Into<String>,
51 ) -> Self {
52 Self {
53 endpoint: endpoint.into(),
54 username: username.into(),
55 password: password.into(),
56 namespace: "aivcs".to_string(),
57 database: "main".to_string(),
58 is_root: false,
59 }
60 }
61
62 pub fn new_root(
64 endpoint: impl Into<String>,
65 username: impl Into<String>,
66 password: impl Into<String>,
67 ) -> Self {
68 Self {
69 endpoint: endpoint.into(),
70 username: username.into(),
71 password: password.into(),
72 namespace: "aivcs".to_string(),
73 database: "main".to_string(),
74 is_root: true,
75 }
76 }
77
78 pub fn with_namespace(mut self, ns: impl Into<String>) -> Self {
80 self.namespace = ns.into();
81 self
82 }
83
84 pub fn with_database(mut self, db: impl Into<String>) -> Self {
86 self.database = db.into();
87 self
88 }
89
90 pub fn with_root(mut self, is_root: bool) -> Self {
92 self.is_root = is_root;
93 self
94 }
95
96 pub fn from_env() -> std::result::Result<Self, String> {
106 let endpoint =
107 std::env::var("SURREALDB_ENDPOINT").map_err(|_| "SURREALDB_ENDPOINT not set")?;
108 let username =
109 std::env::var("SURREALDB_USERNAME").map_err(|_| "SURREALDB_USERNAME not set")?;
110 let password =
111 std::env::var("SURREALDB_PASSWORD").map_err(|_| "SURREALDB_PASSWORD not set")?;
112 let namespace =
113 std::env::var("SURREALDB_NAMESPACE").unwrap_or_else(|_| "aivcs".to_string());
114 let database = std::env::var("SURREALDB_DATABASE").unwrap_or_else(|_| "main".to_string());
115 let is_root = std::env::var("SURREALDB_ROOT")
116 .map(|v| v.to_lowercase() == "true")
117 .unwrap_or(false);
118
119 Ok(Self {
120 endpoint,
121 username,
122 password,
123 namespace,
124 database,
125 is_root,
126 })
127 }
128}
129
130#[derive(Clone)]
132pub struct SurrealHandle {
133 db: Surreal<Any>,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
137struct DbReleaseRecord {
138 name: String,
139 spec_digest: ContentDigest,
140 metadata: ReleaseMetadata,
141 version_label: Option<String>,
142 promoted_by: String,
143 notes: Option<String>,
144 created_at: SurrealDatetime,
145}
146
147impl DbReleaseRecord {
148 fn into_release_record(self) -> ReleaseRecord {
149 ReleaseRecord {
150 name: self.name,
151 spec_digest: self.spec_digest,
152 metadata: self.metadata,
153 created_at: DateTime::<Utc>::from(self.created_at),
154 }
155 }
156}
157
158impl SurrealHandle {
159 #[instrument(skip_all)]
163 pub async fn setup_db() -> Result<Self> {
164 info!("Connecting to SurrealDB (in-memory)");
165
166 let db = surrealdb::engine::any::connect("mem://")
167 .await
168 .map_err(|e| StateError::Connection(e.to_string()))?;
169
170 db.use_ns("aivcs")
172 .use_db("main")
173 .await
174 .map_err(|e| StateError::Connection(e.to_string()))?;
175
176 let handle = SurrealHandle { db };
177 handle.init_schema().await?;
178
179 info!("SurrealDB connected and schema initialized");
180 Ok(handle)
181 }
182
183 #[instrument(skip(config), fields(endpoint = %config.endpoint, namespace = %config.namespace, database = %config.database))]
195 pub async fn setup_cloud(config: CloudConfig) -> Result<Self> {
196 info!("Connecting to SurrealDB Cloud (root={})", config.is_root);
197
198 let db = surrealdb::engine::any::connect(&config.endpoint)
199 .await
200 .map_err(|e| {
201 StateError::Connection(format!("Failed to connect to {}: {}", config.endpoint, e))
202 })?;
203
204 if config.is_root {
206 db.signin(Root {
208 username: &config.username,
209 password: &config.password,
210 })
211 .await
212 .map_err(|e| StateError::Connection(format!("Root authentication failed: {}", e)))?;
213 } else {
214 db.signin(Database {
216 namespace: &config.namespace,
217 database: &config.database,
218 username: &config.username,
219 password: &config.password,
220 })
221 .await
222 .map_err(|e| {
223 StateError::Connection(format!("Database authentication failed: {}", e))
224 })?;
225 }
226
227 db.use_ns(&config.namespace)
229 .use_db(&config.database)
230 .await
231 .map_err(|e| {
232 StateError::Connection(format!("Failed to select namespace/database: {}", e))
233 })?;
234
235 let handle = SurrealHandle { db };
236 handle.init_schema().await?;
237
238 info!("SurrealDB Cloud connected and schema initialized");
239 Ok(handle)
240 }
241
242 #[instrument(skip_all)]
248 pub async fn setup_from_env() -> Result<Self> {
249 if let Ok(config) = CloudConfig::from_env() {
250 info!("Cloud config found, connecting to SurrealDB Cloud");
251 return Self::setup_cloud(config).await;
252 }
253
254 let url = if let Ok(url) = std::env::var("SURREALDB_URL") {
255 info!("SURREALDB_URL found, connecting to {}", url);
256 url
257 } else {
258 let path = ".aivcs/db";
260 std::fs::create_dir_all(path).map_err(|e| {
261 StateError::Connection(format!(
262 "Failed to create database directory {}: {}",
263 path, e
264 ))
265 })?;
266 let url = format!("surrealkv://{}", path);
267 info!(
268 "No cloud config or SURREALDB_URL found, using local persistence: {}",
269 url
270 );
271 url
272 };
273
274 let db = surrealdb::engine::any::connect(&url)
275 .await
276 .map_err(|e| StateError::Connection(format!("Failed to connect to {}: {}", url, e)))?;
277
278 db.use_ns("aivcs")
279 .use_db("main")
280 .await
281 .map_err(|e| StateError::Connection(e.to_string()))?;
282
283 let handle = SurrealHandle { db };
284 handle.init_schema().await?;
285 Ok(handle)
286 }
287
288 async fn init_schema(&self) -> Result<()> {
290 crate::migrations::init_schema(&self.db).await?;
291 Ok(())
292 }
293
294 #[instrument(skip(self, record), fields(commit_id = %record.commit_id))]
298 pub async fn save_commit(&self, record: &CommitRecord) -> Result<CommitRecord> {
299 debug!("Saving commit");
300
301 let record_owned = record.clone();
303
304 let created: Option<CommitRecord> = self.db.create("commits").content(record_owned).await?;
305
306 created.ok_or_else(|| StateError::Transaction("Failed to create commit".to_string()))
307 }
308
309 #[instrument(skip(self))]
311 pub async fn get_commit(&self, commit_hash: &str) -> Result<Option<CommitRecord>> {
312 debug!("Getting commit");
313
314 let hash_owned = commit_hash.to_string();
315
316 let mut result = self
317 .db
318 .query("SELECT * FROM commits WHERE commit_id.hash = $hash")
319 .bind(("hash", hash_owned))
320 .await?;
321
322 let commits: Vec<CommitRecord> = result.take(0)?;
323 Ok(commits.into_iter().next())
324 }
325
326 #[instrument(skip(self, commit_id, state))]
332 pub async fn save_snapshot(
333 &self,
334 commit_id: &CommitId,
335 state: serde_json::Value,
336 ) -> Result<()> {
337 debug!("Saving snapshot for commit {}", commit_id.short());
338
339 let record = SnapshotRecord::new(&commit_id.hash, state);
340
341 let _created: Option<SnapshotRecord> =
342 self.db.create("snapshots").content(record.clone()).await?;
343
344 info!(
345 "Snapshot saved: {} ({} bytes)",
346 commit_id.short(),
347 record.size_bytes
348 );
349 Ok(())
350 }
351
352 #[instrument(skip(self))]
354 pub async fn load_snapshot(&self, commit_id: &str) -> Result<SnapshotRecord> {
355 debug!("Loading snapshot");
356
357 let id_owned = commit_id.to_string();
358
359 let mut result = self
360 .db
361 .query("SELECT * FROM snapshots WHERE commit_id = $id")
362 .bind(("id", id_owned))
363 .await?;
364
365 let snapshots: Vec<SnapshotRecord> = result.take(0)?;
366 snapshots
367 .into_iter()
368 .next()
369 .ok_or_else(|| StateError::CommitNotFound(commit_id.to_string()))
370 }
371
372 #[instrument(skip(self))]
378 pub async fn save_commit_graph_edge(&self, child_id: &str, parent_id: &str) -> Result<()> {
379 debug!("Saving graph edge: {} -> {}", parent_id, child_id);
380
381 let edge = GraphEdge::new(child_id, parent_id);
382
383 let _created: Option<GraphEdge> = self.db.create("graph_edges").content(edge).await?;
384
385 info!("Graph edge saved: {} -> {}", parent_id, child_id);
386 Ok(())
387 }
388
389 #[instrument(skip(self))]
391 pub async fn get_parent(&self, child_id: &str) -> Result<Option<String>> {
392 let id_owned = child_id.to_string();
393
394 let mut result = self
395 .db
396 .query("SELECT parent_id FROM graph_edges WHERE child_id = $id")
397 .bind(("id", id_owned))
398 .await?;
399
400 #[derive(serde::Deserialize)]
401 struct ParentResult {
402 parent_id: String,
403 }
404
405 let parents: Vec<ParentResult> = result.take(0)?;
406 Ok(parents.into_iter().next().map(|p| p.parent_id))
407 }
408
409 #[instrument(skip(self))]
411 pub async fn get_children(&self, parent_id: &str) -> Result<Vec<String>> {
412 let id_owned = parent_id.to_string();
413
414 let mut result = self
415 .db
416 .query("SELECT child_id FROM graph_edges WHERE parent_id = $id")
417 .bind(("id", id_owned))
418 .await?;
419
420 #[derive(serde::Deserialize)]
421 struct ChildResult {
422 child_id: String,
423 }
424
425 let children: Vec<ChildResult> = result.take(0)?;
426 Ok(children.into_iter().map(|c| c.child_id).collect())
427 }
428
429 #[instrument(skip(self))]
433 pub async fn save_branch(&self, record: &BranchRecord) -> Result<BranchRecord> {
434 debug!("Saving branch: {}", record.name);
435
436 let existing = self.get_branch(&record.name).await?;
438
439 if existing.is_some() {
440 let head = record.head_commit_id.clone();
442 let now = SurrealDatetime::from(chrono::Utc::now());
443 let name = record.name.clone();
444
445 let mut result = self
446 .db
447 .query("UPDATE branches SET head_commit_id = $head, updated_at = $now WHERE name = $name")
448 .bind(("head", head))
449 .bind(("now", now))
450 .bind(("name", name))
451 .await?;
452
453 let updated: Vec<BranchRecord> = result.take(0)?;
454 updated
455 .into_iter()
456 .next()
457 .ok_or_else(|| StateError::Transaction("Failed to update branch".to_string()))
458 } else {
459 let record_owned = record.clone();
461
462 let created: Option<BranchRecord> =
463 self.db.create("branches").content(record_owned).await?;
464
465 created.ok_or_else(|| StateError::Transaction("Failed to create branch".to_string()))
466 }
467 }
468
469 #[instrument(skip(self))]
471 pub async fn get_branch(&self, name: &str) -> Result<Option<BranchRecord>> {
472 let name_owned = name.to_string();
473
474 let mut result = self
475 .db
476 .query("SELECT * FROM branches WHERE name = $name")
477 .bind(("name", name_owned))
478 .await?;
479
480 let branches: Vec<BranchRecord> = result.take(0)?;
481 Ok(branches.into_iter().next())
482 }
483
484 #[instrument(skip(self))]
486 pub async fn get_branch_head(&self, branch_name: &str) -> Result<String> {
487 let branch = self
488 .get_branch(branch_name)
489 .await?
490 .ok_or_else(|| StateError::BranchNotFound(branch_name.to_string()))?;
491
492 Ok(branch.head_commit_id)
493 }
494
495 #[instrument(skip(self))]
497 pub async fn list_branches(&self) -> Result<Vec<BranchRecord>> {
498 let mut result = self
499 .db
500 .query("SELECT * FROM branches ORDER BY name")
501 .await?;
502
503 let branches: Vec<BranchRecord> = result.take(0)?;
504 Ok(branches)
505 }
506
507 #[instrument(skip(self))]
509 pub async fn delete_branch(&self, name: &str) -> Result<()> {
510 debug!("Deleting branch: {}", name);
511
512 let branch = self
513 .get_branch(name)
514 .await?
515 .ok_or_else(|| StateError::BranchNotFound(name.to_string()))?;
516
517 if branch.is_default {
518 return Err(StateError::Transaction(
519 "Cannot delete the default branch".to_string(),
520 ));
521 }
522
523 let name_owned = name.to_string();
524
525 let _result = self
526 .db
527 .query("DELETE FROM branches WHERE name = $name")
528 .bind(("name", name_owned))
529 .await?;
530
531 Ok(())
532 }
533
534 #[instrument(skip(self, record), fields(agent_name = %record.name))]
538 pub async fn register_agent(&self, record: &AgentRecord) -> Result<AgentRecord> {
539 debug!("Registering agent");
540
541 let record_owned = record.clone();
542
543 let created: Option<AgentRecord> = self.db.create("agents").content(record_owned).await?;
544
545 created.ok_or_else(|| StateError::Transaction("Failed to register agent".to_string()))
546 }
547
548 #[instrument(skip(self))]
550 pub async fn get_agent(&self, agent_id: &str) -> Result<Option<AgentRecord>> {
551 let id_owned = agent_id.to_string();
552
553 let mut result = self
554 .db
555 .query("SELECT * FROM agents WHERE agent_id = $id")
556 .bind(("id", id_owned))
557 .await?;
558
559 let agents: Vec<AgentRecord> = result.take(0)?;
560 Ok(agents.into_iter().next())
561 }
562
563 #[instrument(skip(self, record), fields(key = %record.key))]
567 pub async fn save_memory(&self, record: &MemoryRecord) -> Result<MemoryRecord> {
568 debug!("Saving memory");
569
570 let record_owned = record.clone();
571
572 let created: Option<MemoryRecord> =
573 self.db.create("memories").content(record_owned).await?;
574
575 created.ok_or_else(|| StateError::Transaction("Failed to save memory".to_string()))
576 }
577
578 #[instrument(skip(self))]
580 pub async fn get_memories(&self, commit_id: &str) -> Result<Vec<MemoryRecord>> {
581 let id_owned = commit_id.to_string();
582
583 let mut result = self
584 .db
585 .query("SELECT * FROM memories WHERE commit_id = $id ORDER BY created_at")
586 .bind(("id", id_owned))
587 .await?;
588
589 let memories: Vec<MemoryRecord> = result.take(0)?;
590 Ok(memories)
591 }
592
593 #[instrument(skip(self, spec_digest, metadata), fields(name = %name, digest = %spec_digest))]
597 pub async fn release_promote(
598 &self,
599 name: &str,
600 spec_digest: &ContentDigest,
601 metadata: ReleaseMetadata,
602 ) -> StorageResult<ReleaseRecord> {
603 let record = DbReleaseRecord {
604 name: name.to_string(),
605 spec_digest: spec_digest.clone(),
606 version_label: metadata.version_label.clone(),
607 promoted_by: metadata.promoted_by.clone(),
608 notes: metadata.notes.clone(),
609 metadata,
610 created_at: SurrealDatetime::from(Utc::now()),
611 };
612
613 let created: Option<DbReleaseRecord> = self
614 .db
615 .create("releases")
616 .content(record.clone())
617 .await
618 .map_err(|e| StorageError::Backend(e.to_string()))?;
619
620 created
621 .map(DbReleaseRecord::into_release_record)
622 .ok_or_else(|| StorageError::Backend("failed to create release record".to_string()))
623 }
624
625 #[instrument(skip(self), fields(name = %name))]
627 pub async fn release_rollback(&self, name: &str) -> StorageResult<ReleaseRecord> {
628 let history = self.release_history(name).await?;
629 if history.is_empty() {
630 return Err(StorageError::ReleaseNotFound {
631 name: name.to_string(),
632 });
633 }
634 if history.len() < 2 {
635 return Err(StorageError::NoPreviousRelease {
636 name: name.to_string(),
637 });
638 }
639
640 let previous = &history[1];
641 self.release_promote(name, &previous.spec_digest, previous.metadata.clone())
642 .await
643 }
644
645 #[instrument(skip(self), fields(name = %name))]
647 pub async fn release_current(&self, name: &str) -> StorageResult<Option<ReleaseRecord>> {
648 let name_owned = name.to_string();
649
650 let mut result = self
651 .db
652 .query("SELECT * FROM releases WHERE name = $name ORDER BY created_at DESC LIMIT 1")
653 .bind(("name", name_owned))
654 .await
655 .map_err(|e| StorageError::Backend(e.to_string()))?;
656
657 let releases: Vec<DbReleaseRecord> = result
658 .take(0)
659 .map_err(|e| StorageError::Backend(e.to_string()))?;
660 Ok(releases
661 .into_iter()
662 .next()
663 .map(DbReleaseRecord::into_release_record))
664 }
665
666 #[instrument(skip(self), fields(name = %name))]
668 pub async fn release_history(&self, name: &str) -> StorageResult<Vec<ReleaseRecord>> {
669 let name_owned = name.to_string();
670
671 let mut result = self
672 .db
673 .query("SELECT * FROM releases WHERE name = $name ORDER BY created_at DESC")
674 .bind(("name", name_owned))
675 .await
676 .map_err(|e| StorageError::Backend(e.to_string()))?;
677
678 let releases: Vec<DbReleaseRecord> = result
679 .take(0)
680 .map_err(|e| StorageError::Backend(e.to_string()))?;
681 Ok(releases
682 .into_iter()
683 .map(DbReleaseRecord::into_release_record)
684 .collect())
685 }
686
687 #[instrument(skip(self, snapshot))]
691 pub async fn save_ci_snapshot(&self, snapshot: &CiSnapshot) -> Result<String> {
692 #[derive(serde::Serialize, serde::Deserialize)]
693 struct CiSnapshotStore {
694 digest: String,
695 snapshot_json: String,
696 }
697
698 let digest = snapshot.digest();
699 let snapshot_json = serde_json::to_string(snapshot)?;
700 let payload = CiSnapshotStore {
701 digest: digest.clone(),
702 snapshot_json,
703 };
704
705 let _created: Option<CiSnapshotStore> =
706 self.db.create("ci_snapshots").content(payload).await?;
707 Ok(digest)
708 }
709
710 #[instrument(skip(self))]
712 pub async fn load_ci_snapshot(&self, digest: &str) -> Result<Option<CiSnapshot>> {
713 #[derive(serde::Deserialize)]
714 struct CiSnapshotStore {
715 snapshot_json: String,
716 }
717
718 let digest_owned = digest.to_string();
719 let mut result = self
720 .db
721 .query("SELECT snapshot_json FROM ci_snapshots WHERE digest = $digest")
722 .bind(("digest", digest_owned))
723 .await?;
724
725 let rows: Vec<CiSnapshotStore> = result.take(0)?;
726 rows.into_iter()
727 .next()
728 .map(|r| serde_json::from_str(&r.snapshot_json))
729 .transpose()
730 .map_err(Into::into)
731 }
732
733 #[instrument(skip(self, pipeline))]
735 pub async fn save_ci_pipeline(&self, pipeline: &CiPipelineSpec) -> Result<String> {
736 #[derive(serde::Serialize, serde::Deserialize)]
737 struct CiPipelineStore {
738 digest: String,
739 pipeline_json: String,
740 }
741
742 let digest = pipeline.digest();
743 let pipeline_json = serde_json::to_string(pipeline)?;
744 let payload = CiPipelineStore {
745 digest: digest.clone(),
746 pipeline_json,
747 };
748
749 let _created: Option<CiPipelineStore> =
750 self.db.create("ci_pipelines").content(payload).await?;
751 Ok(digest)
752 }
753
754 #[instrument(skip(self))]
756 pub async fn load_ci_pipeline(&self, digest: &str) -> Result<Option<CiPipelineSpec>> {
757 #[derive(serde::Deserialize)]
758 struct CiPipelineStore {
759 pipeline_json: String,
760 }
761
762 let digest_owned = digest.to_string();
763 let mut result = self
764 .db
765 .query("SELECT pipeline_json FROM ci_pipelines WHERE digest = $digest")
766 .bind(("digest", digest_owned))
767 .await?;
768
769 let rows: Vec<CiPipelineStore> = result.take(0)?;
770 rows.into_iter()
771 .next()
772 .map(|r| serde_json::from_str(&r.pipeline_json))
773 .transpose()
774 .map_err(Into::into)
775 }
776
777 #[instrument(skip(self, run), fields(run_id = %run.run_id))]
779 pub async fn save_ci_run(&self, run: &CiRunRecord) -> Result<CiRunRecord> {
780 #[derive(serde::Serialize, serde::Deserialize)]
781 struct CiRunStore {
782 run_id: String,
783 snapshot_digest: String,
784 pipeline_digest: String,
785 status: String,
786 run_json: String,
787 started_at: Option<String>,
788 finished_at: Option<String>,
789 }
790
791 let payload = CiRunStore {
792 run_id: run.run_id.clone(),
793 snapshot_digest: run.snapshot_digest.clone(),
794 pipeline_digest: run.pipeline_digest.clone(),
795 status: serde_json::to_string(&run.status)?
796 .trim_matches('"')
797 .to_string(),
798 run_json: serde_json::to_string(run)?,
799 started_at: run.started_at.clone(),
800 finished_at: run.finished_at.clone(),
801 };
802
803 let created: Option<CiRunStore> = self.db.create("ci_runs").content(payload).await?;
804 if created.is_some() {
805 Ok(run.clone())
806 } else {
807 Err(StateError::Transaction(
808 "Failed to create CI run".to_string(),
809 ))
810 }
811 }
812
813 #[instrument(skip(self))]
815 pub async fn get_ci_run(&self, run_id: &str) -> Result<Option<CiRunRecord>> {
816 #[derive(serde::Deserialize)]
817 struct CiRunStore {
818 run_json: String,
819 }
820
821 let run_id_owned = run_id.to_string();
822 let mut result = self
823 .db
824 .query("SELECT run_json FROM ci_runs WHERE run_id = $run_id")
825 .bind(("run_id", run_id_owned))
826 .await?;
827 let runs: Vec<CiRunStore> = result.take(0)?;
828 runs.into_iter()
829 .next()
830 .map(|r| serde_json::from_str(&r.run_json))
831 .transpose()
832 .map_err(Into::into)
833 }
834
835 #[instrument(skip(self))]
837 pub async fn list_ci_runs_by_snapshot(
838 &self,
839 snapshot_digest: &str,
840 ) -> Result<Vec<CiRunRecord>> {
841 #[derive(serde::Deserialize)]
842 struct CiRunStore {
843 run_json: String,
844 }
845
846 let snapshot_digest_owned = snapshot_digest.to_string();
847 let mut result = self
848 .db
849 .query("SELECT run_json FROM ci_runs WHERE snapshot_digest = $snapshot_digest")
850 .bind(("snapshot_digest", snapshot_digest_owned))
851 .await?;
852 let runs: Vec<CiRunStore> = result.take(0)?;
853 runs.into_iter()
854 .map(|r| serde_json::from_str::<CiRunRecord>(&r.run_json))
855 .collect::<std::result::Result<Vec<_>, _>>()
856 .map_err(Into::into)
857 }
858
859 #[instrument(skip(self, record))]
863 pub async fn save_decision(&self, record: &DecisionRecord) -> Result<DecisionRecord> {
864 debug!("Saving decision");
865
866 let record_owned = record.clone();
867
868 let created: Option<DecisionRecord> =
869 self.db.create("decisions").content(record_owned).await?;
870
871 created.ok_or_else(|| StateError::Transaction("Failed to save decision".to_string()))
872 }
873
874 #[instrument(skip(self))]
876 pub async fn get_decision(&self, decision_id: &str) -> Result<Option<DecisionRecord>> {
877 let id_owned = decision_id.to_string();
878
879 let mut result = self
880 .db
881 .query("SELECT * FROM decisions WHERE decision_id = $id")
882 .bind(("id", id_owned))
883 .await?;
884
885 let decisions: Vec<DecisionRecord> = result.take(0)?;
886 Ok(decisions.into_iter().next())
887 }
888
889 #[instrument(skip(self))]
891 pub async fn update_decision_outcome(
892 &self,
893 decision_id: &str,
894 outcome_json: String,
895 ) -> Result<DecisionRecord> {
896 let id_owned = decision_id.to_string();
897 let now = SurrealDatetime::from(Utc::now());
898
899 let mut result = self
900 .db
901 .query(
902 "UPDATE decisions SET outcome = $outcome, outcome_at = $outcome_at WHERE decision_id = $id RETURN AFTER",
903 )
904 .bind(("id", id_owned))
905 .bind(("outcome", outcome_json))
906 .bind(("outcome_at", now))
907 .await?;
908
909 let decisions: Vec<DecisionRecord> = result.take(0)?;
910 decisions
911 .into_iter()
912 .next()
913 .ok_or_else(|| StateError::Transaction("Decision not found for update".to_string()))
914 }
915
916 #[instrument(skip(self))]
918 pub async fn get_decision_history(
919 &self,
920 task: &str,
921 limit: usize,
922 ) -> Result<Vec<DecisionRecord>> {
923 let task_owned = task.to_string();
924
925 let mut result = self
926 .db
927 .query(
928 "SELECT * FROM decisions WHERE task = $task ORDER BY timestamp DESC LIMIT $limit",
929 )
930 .bind(("task", task_owned))
931 .bind(("limit", limit as i64))
932 .await?;
933
934 let decisions: Vec<DecisionRecord> = result.take(0)?;
935 Ok(decisions)
936 }
937
938 #[instrument(skip(self, record))]
940 pub async fn save_provenance(
941 &self,
942 record: &MemoryProvenanceRecord,
943 ) -> Result<MemoryProvenanceRecord> {
944 debug!("Saving memory provenance");
945
946 let record_owned = record.clone();
947
948 let created: Option<MemoryProvenanceRecord> = self
949 .db
950 .create("memory_provenances")
951 .content(record_owned)
952 .await?;
953
954 created.ok_or_else(|| StateError::Transaction("Failed to save provenance".to_string()))
955 }
956
957 #[instrument(skip(self))]
959 pub async fn get_provenance(&self, memory_id: &str) -> Result<Vec<MemoryProvenanceRecord>> {
960 let memory_id_owned = memory_id.to_string();
961
962 let mut result = self
963 .db
964 .query("SELECT * FROM memory_provenances WHERE memory_id = $memory_id")
965 .bind(("memory_id", memory_id_owned))
966 .await?;
967
968 let provenances: Vec<MemoryProvenanceRecord> = result.take(0)?;
969 Ok(provenances)
970 }
971
972 #[instrument(skip(self))]
976 pub async fn get_commit_history(
977 &self,
978 start_commit: &str,
979 limit: usize,
980 ) -> Result<Vec<CommitRecord>> {
981 let mut history = Vec::new();
982 let mut current = Some(start_commit.to_string());
983
984 while let Some(commit_hash) = current {
985 if history.len() >= limit {
986 break;
987 }
988
989 if let Some(commit) = self.get_commit(&commit_hash).await? {
990 current = commit.parent_ids.first().cloned();
992 history.push(commit);
993 } else {
994 break;
995 }
996 }
997
998 Ok(history)
999 }
1000
1001 #[instrument(skip(self))]
1005 pub async fn get_reasoning_trace(&self, commit_id: &str) -> Result<Vec<SnapshotRecord>> {
1006 let history = self.get_commit_history(commit_id, 100).await?;
1008
1009 let mut trace = Vec::new();
1011 for commit in history {
1012 if let Ok(snapshot) = self.load_snapshot(&commit.commit_id.hash).await {
1013 trace.push(snapshot);
1014 }
1015 }
1016
1017 Ok(trace)
1018 }
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023 use super::*;
1024 use std::collections::BTreeMap;
1025
1026 #[tokio::test]
1027 async fn test_surreal_connection_and_schema_creation() {
1028 let handle = SurrealHandle::setup_db().await;
1029 assert!(handle.is_ok(), "Failed to connect: {:?}", handle.err());
1030 }
1031
1032 #[tokio::test]
1033 async fn test_branch_deletion() {
1034 let handle = SurrealHandle::setup_db().await.unwrap();
1035
1036 let branch = BranchRecord::new("feature/test", "commit-123", false);
1038 handle.save_branch(&branch).await.unwrap();
1039
1040 let loaded = handle.get_branch("feature/test").await.unwrap();
1042 assert!(loaded.is_some());
1043
1044 handle.delete_branch("feature/test").await.unwrap();
1046
1047 let deleted = handle.get_branch("feature/test").await.unwrap();
1049 assert!(deleted.is_none());
1050 }
1051
1052 #[tokio::test]
1053 async fn test_delete_nonexistent_branch() {
1054 let handle = SurrealHandle::setup_db().await.unwrap();
1055 let result = handle.delete_branch("nonexistent").await;
1056 assert!(result.is_err());
1057 assert!(result.unwrap_err().to_string().contains("Branch not found"));
1058 }
1059
1060 #[tokio::test]
1061 async fn test_delete_default_branch() {
1062 let handle = SurrealHandle::setup_db().await.unwrap();
1063
1064 let branch = BranchRecord::new("main", "commit-123", true);
1066 handle.save_branch(&branch).await.unwrap();
1067
1068 let result = handle.delete_branch("main").await;
1069 assert!(result.is_err());
1070 assert!(result
1071 .unwrap_err()
1072 .to_string()
1073 .contains("Cannot delete the default branch"));
1074 }
1075
1076 #[tokio::test]
1077 async fn test_update_decision_outcome_persists_fields() {
1078 let handle = SurrealHandle::setup_db().await.unwrap();
1079
1080 let decision = DecisionRecord::new(
1081 "dec-outcome-1".to_string(),
1082 "commit-123".to_string(),
1083 "task-123".to_string(),
1084 "action-123".to_string(),
1085 "because".to_string(),
1086 0.9,
1087 );
1088 handle.save_decision(&decision).await.unwrap();
1089
1090 let updated = handle
1091 .update_decision_outcome(
1092 "dec-outcome-1",
1093 r#"{"status":"success","duration_ms":123}"#.to_string(),
1094 )
1095 .await
1096 .unwrap();
1097
1098 assert_eq!(
1099 updated.outcome,
1100 Some(r#"{"status":"success","duration_ms":123}"#.to_string())
1101 );
1102 assert!(updated.outcome_at.is_some());
1103 }
1104
1105 #[tokio::test]
1106 async fn test_snapshot_is_atomic_and_retrievable() {
1107 let handle = SurrealHandle::setup_db().await.unwrap();
1108
1109 let state = serde_json::json!({
1110 "agent_name": "test-agent",
1111 "step": 1,
1112 "variables": {"x": 42, "y": "hello"}
1113 });
1114
1115 let commit_id = CommitId::from_state(serde_json::to_vec(&state).unwrap().as_slice());
1116
1117 handle
1119 .save_snapshot(&commit_id, state.clone())
1120 .await
1121 .unwrap();
1122
1123 let loaded = handle.load_snapshot(&commit_id.hash).await.unwrap();
1125
1126 assert_eq!(loaded.commit_id, commit_id.hash);
1127 assert_eq!(loaded.state, state);
1128 }
1129
1130 #[tokio::test]
1131 async fn test_parent_child_edge_is_created() {
1132 let handle = SurrealHandle::setup_db().await.unwrap();
1133
1134 let parent_id = "parent-commit-hash";
1135 let child_id = "child-commit-hash";
1136
1137 handle
1139 .save_commit_graph_edge(child_id, parent_id)
1140 .await
1141 .unwrap();
1142
1143 let parent = handle.get_parent(child_id).await.unwrap();
1145 assert_eq!(parent, Some(parent_id.to_string()));
1146
1147 let children = handle.get_children(parent_id).await.unwrap();
1149 assert!(children.contains(&child_id.to_string()));
1150 }
1151
1152 #[tokio::test]
1153 async fn test_branch_operations() {
1154 let handle = SurrealHandle::setup_db().await.unwrap();
1155
1156 let branch = BranchRecord::new("main", "commit-abc123", true);
1157 handle.save_branch(&branch).await.unwrap();
1158
1159 let loaded = handle.get_branch("main").await.unwrap();
1161 assert!(loaded.is_some());
1162 assert_eq!(loaded.unwrap().head_commit_id, "commit-abc123");
1163
1164 let head = handle.get_branch_head("main").await.unwrap();
1166 assert_eq!(head, "commit-abc123");
1167
1168 let updated_branch = BranchRecord::new("main", "commit-def456", true);
1170 handle.save_branch(&updated_branch).await.unwrap();
1171
1172 let new_head = handle.get_branch_head("main").await.unwrap();
1173 assert_eq!(new_head, "commit-def456");
1174 }
1175
1176 #[tokio::test]
1177 async fn test_commit_record_operations() {
1178 let handle = SurrealHandle::setup_db().await.unwrap();
1179
1180 let commit_id = CommitId::from_state(b"test state");
1181 let commit = CommitRecord::new(commit_id.clone(), vec![], "Initial commit", "test-agent");
1182
1183 let saved = handle.save_commit(&commit).await.unwrap();
1185 assert_eq!(saved.commit_id.hash, commit_id.hash);
1186
1187 let loaded = handle.get_commit(&commit_id.hash).await.unwrap();
1189 assert!(loaded.is_some());
1190 assert_eq!(loaded.unwrap().message, "Initial commit");
1191 }
1192
1193 #[tokio::test]
1194 async fn test_get_trace_for_commit_id_returns_correct_cot() {
1195 let handle = SurrealHandle::setup_db().await.unwrap();
1196
1197 let state_0 = serde_json::json!({"step": 0, "thought": "Starting exploration"});
1199 let state_1 = serde_json::json!({"step": 1, "thought": "Trying strategy A"});
1200 let state_2 = serde_json::json!({"step": 2, "thought": "Strategy A failed, pivoting"});
1201 let state_3 = serde_json::json!({"step": 3, "thought": "Strategy B succeeded"});
1202
1203 let id_0 = CommitId::from_state(b"state-0");
1204 let id_1 = CommitId::from_state(b"state-1");
1205 let id_2 = CommitId::from_state(b"state-2");
1206 let id_3 = CommitId::from_state(b"state-3");
1207
1208 handle.save_snapshot(&id_0, state_0.clone()).await.unwrap();
1210 handle.save_snapshot(&id_1, state_1.clone()).await.unwrap();
1211 handle.save_snapshot(&id_2, state_2.clone()).await.unwrap();
1212 handle.save_snapshot(&id_3, state_3.clone()).await.unwrap();
1213
1214 let commit_0 = CommitRecord::new(id_0.clone(), vec![], "Step 0", "agent");
1216 let commit_1 = CommitRecord::new(id_1.clone(), vec![id_0.hash.clone()], "Step 1", "agent");
1217 let commit_2 = CommitRecord::new(id_2.clone(), vec![id_1.hash.clone()], "Step 2", "agent");
1218 let commit_3 = CommitRecord::new(id_3.clone(), vec![id_2.hash.clone()], "Step 3", "agent");
1219
1220 handle.save_commit(&commit_0).await.unwrap();
1221 handle.save_commit(&commit_1).await.unwrap();
1222 handle.save_commit(&commit_2).await.unwrap();
1223 handle.save_commit(&commit_3).await.unwrap();
1224
1225 let trace = handle.get_reasoning_trace(&id_3.hash).await.unwrap();
1227
1228 assert_eq!(trace.len(), 4, "Trace should contain all 4 commits");
1230
1231 assert_eq!(trace[0].state["step"], 3);
1233 assert_eq!(trace[1].state["step"], 2);
1234 assert_eq!(trace[2].state["step"], 1);
1235 assert_eq!(trace[3].state["step"], 0);
1236
1237 assert_eq!(trace[0].state["thought"], "Strategy B succeeded");
1239 assert_eq!(trace[1].state["thought"], "Strategy A failed, pivoting");
1240 assert_eq!(trace[2].state["thought"], "Trying strategy A");
1241 assert_eq!(trace[3].state["thought"], "Starting exploration");
1242 }
1243
1244 #[tokio::test]
1245 async fn test_ci_records_roundtrip() {
1246 let handle = SurrealHandle::setup_db().await.unwrap();
1247
1248 let snapshot = CiSnapshot {
1249 repo_sha: "abc123".to_string(),
1250 workspace_hash: "work-1".to_string(),
1251 local_ci_config_hash: "cfg-1".to_string(),
1252 env_hash: "env-1".to_string(),
1253 };
1254 let snapshot_digest = handle.save_ci_snapshot(&snapshot).await.unwrap();
1255 let loaded_snapshot = handle.load_ci_snapshot(&snapshot_digest).await.unwrap();
1256 assert_eq!(loaded_snapshot, Some(snapshot.clone()));
1257
1258 let pipeline = CiPipelineSpec {
1259 name: "default".to_string(),
1260 steps: vec![crate::ci::CiStepSpec {
1261 name: "test".to_string(),
1262 command: crate::ci::CiCommand {
1263 program: "cargo".to_string(),
1264 args: vec!["test".to_string()],
1265 env: BTreeMap::new(),
1266 cwd: None,
1267 },
1268 timeout_secs: Some(300),
1269 allow_failure: false,
1270 }],
1271 };
1272 let pipeline_digest = handle.save_ci_pipeline(&pipeline).await.unwrap();
1273 let loaded_pipeline = handle.load_ci_pipeline(&pipeline_digest).await.unwrap();
1274 assert_eq!(loaded_pipeline, Some(pipeline.clone()));
1275
1276 let run = CiRunRecord::queued(&snapshot_digest, &pipeline_digest);
1277 let saved_run = handle.save_ci_run(&run).await.unwrap();
1278 let loaded_run = handle.get_ci_run(&saved_run.run_id).await.unwrap();
1279 assert_eq!(loaded_run, Some(saved_run.clone()));
1280
1281 let runs = handle
1282 .list_ci_runs_by_snapshot(&snapshot_digest)
1283 .await
1284 .unwrap();
1285 assert_eq!(runs.len(), 1);
1286 assert_eq!(runs[0].run_id, saved_run.run_id);
1287 }
1288
1289 #[tokio::test]
1290 async fn test_release_fields_roundtrip_in_surreal() {
1291 let handle = SurrealHandle::setup_db().await.unwrap();
1292
1293 let metadata = ReleaseMetadata {
1294 version_label: Some("v1.2.3".to_string()),
1295 promoted_by: "test-user".to_string(),
1296 notes: Some("Release notes here".to_string()),
1297 };
1298 let digest = ContentDigest::from_bytes(b"spec-data");
1299
1300 let release = handle
1302 .release_promote("test-agent", &digest, metadata.clone())
1303 .await
1304 .unwrap();
1305
1306 assert_eq!(release.name, "test-agent");
1307 assert_eq!(release.metadata.version_label, Some("v1.2.3".to_string()));
1308
1309 let mut result = handle
1311 .db
1312 .query("SELECT name, version_label, promoted_by, notes FROM releases WHERE name = 'test-agent'")
1313 .await
1314 .unwrap();
1315
1316 #[derive(serde::Deserialize)]
1317 struct RawRelease {
1318 name: String,
1319 version_label: Option<String>,
1320 promoted_by: String,
1321 notes: Option<String>,
1322 }
1323
1324 let rows: Vec<RawRelease> = result.take(0).unwrap();
1325 assert_eq!(rows.len(), 1);
1326 let row = &rows[0];
1327
1328 assert_eq!(row.version_label, Some("v1.2.3".to_string()));
1329 assert_eq!(row.promoted_by, "test-user");
1330 assert_eq!(row.notes, Some("Release notes here".to_string()));
1331 assert_eq!(row.name, "test-agent");
1332
1333 let history = handle.release_history("test-agent").await.unwrap();
1335 assert_eq!(history.len(), 1);
1336 assert_eq!(
1337 history[0].metadata.version_label,
1338 Some("v1.2.3".to_string())
1339 );
1340 assert_eq!(history[0].metadata.promoted_by, "test-user");
1341 }
1342}