Skip to main content

oxidized_state/
handle.rs

1//! SurrealDB Handle - Connection and Operations
2//!
3//! Manages connection and provides methods for:
4//! - save_snapshot / load_snapshot
5//! - save_commit_graph_edge
6//! - get_branch_head
7//! - CRUD for commits, branches, agents, memories, and CI records
8//!
9//! Supports both local (in-memory) and cloud (WebSocket) connections.
10
11use crate::ci::{CiPipelineSpec, CiRunRecord, CiSnapshot};
12use crate::error::StateError;
13use crate::schema::{
14    AgentRecord, BranchRecord, CommitId, CommitRecord, DecisionRecord, GraphEdge,
15    MemoryProvenanceRecord, MemoryRecord, SnapshotRecord,
16};
17use crate::storage_traits::{ContentDigest, ReleaseMetadata, ReleaseRecord, StorageResult};
18use crate::Result;
19use crate::StorageError;
20use chrono::{DateTime, Utc};
21use serde::{Deserialize, Serialize};
22use surrealdb::engine::any::Any;
23use surrealdb::opt::auth::{Database, Root};
24use surrealdb::sql::Datetime as SurrealDatetime;
25use surrealdb::Surreal;
26use tracing::{debug, info, instrument};
27
28/// Configuration for SurrealDB Cloud connection
29#[derive(Debug, Clone)]
30pub struct CloudConfig {
31    /// WebSocket endpoint URL (e.g., "wss://xxx.aws-use1.surrealdb.cloud")
32    pub endpoint: String,
33    /// Database username
34    pub username: String,
35    /// Database password
36    pub password: String,
37    /// Namespace (default: "aivcs")
38    pub namespace: String,
39    /// Database name (default: "main")
40    pub database: String,
41    /// Whether this is a root user (true) or database user (false)
42    pub is_root: bool,
43}
44
45impl CloudConfig {
46    /// Create a new cloud configuration for a database user
47    pub fn new(
48        endpoint: impl Into<String>,
49        username: impl Into<String>,
50        password: impl Into<String>,
51    ) -> Self {
52        Self {
53            endpoint: endpoint.into(),
54            username: username.into(),
55            password: password.into(),
56            namespace: "aivcs".to_string(),
57            database: "main".to_string(),
58            is_root: false,
59        }
60    }
61
62    /// Create a new cloud configuration for a root user
63    pub fn new_root(
64        endpoint: impl Into<String>,
65        username: impl Into<String>,
66        password: impl Into<String>,
67    ) -> Self {
68        Self {
69            endpoint: endpoint.into(),
70            username: username.into(),
71            password: password.into(),
72            namespace: "aivcs".to_string(),
73            database: "main".to_string(),
74            is_root: true,
75        }
76    }
77
78    /// Set custom namespace
79    pub fn with_namespace(mut self, ns: impl Into<String>) -> Self {
80        self.namespace = ns.into();
81        self
82    }
83
84    /// Set custom database
85    pub fn with_database(mut self, db: impl Into<String>) -> Self {
86        self.database = db.into();
87        self
88    }
89
90    /// Set whether this is a root user
91    pub fn with_root(mut self, is_root: bool) -> Self {
92        self.is_root = is_root;
93        self
94    }
95
96    /// Create from environment variables
97    ///
98    /// Reads:
99    /// - SURREALDB_ENDPOINT (required)
100    /// - SURREALDB_USERNAME (required)
101    /// - SURREALDB_PASSWORD (required)
102    /// - SURREALDB_NAMESPACE (optional, default: "aivcs")
103    /// - SURREALDB_DATABASE (optional, default: "main")
104    /// - SURREALDB_ROOT (optional, default: "false") - set to "true" for root users
105    pub fn from_env() -> std::result::Result<Self, String> {
106        let endpoint =
107            std::env::var("SURREALDB_ENDPOINT").map_err(|_| "SURREALDB_ENDPOINT not set")?;
108        let username =
109            std::env::var("SURREALDB_USERNAME").map_err(|_| "SURREALDB_USERNAME not set")?;
110        let password =
111            std::env::var("SURREALDB_PASSWORD").map_err(|_| "SURREALDB_PASSWORD not set")?;
112        let namespace =
113            std::env::var("SURREALDB_NAMESPACE").unwrap_or_else(|_| "aivcs".to_string());
114        let database = std::env::var("SURREALDB_DATABASE").unwrap_or_else(|_| "main".to_string());
115        let is_root = std::env::var("SURREALDB_ROOT")
116            .map(|v| v.to_lowercase() == "true")
117            .unwrap_or(false);
118
119        Ok(Self {
120            endpoint,
121            username,
122            password,
123            namespace,
124            database,
125            is_root,
126        })
127    }
128}
129
130/// SurrealDB connection handle for AIVCS
131#[derive(Clone)]
132pub struct SurrealHandle {
133    db: Surreal<Any>,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
137struct DbReleaseRecord {
138    name: String,
139    spec_digest: ContentDigest,
140    metadata: ReleaseMetadata,
141    version_label: Option<String>,
142    promoted_by: String,
143    notes: Option<String>,
144    created_at: SurrealDatetime,
145}
146
147impl DbReleaseRecord {
148    fn into_release_record(self) -> ReleaseRecord {
149        ReleaseRecord {
150            name: self.name,
151            spec_digest: self.spec_digest,
152            metadata: self.metadata,
153            created_at: DateTime::<Utc>::from(self.created_at),
154        }
155    }
156}
157
158impl SurrealHandle {
159    /// Connect to SurrealDB in-memory and set up schema
160    ///
161    /// # TDD: test_surreal_connection_and_schema_creation
162    #[instrument(skip_all)]
163    pub async fn setup_db() -> Result<Self> {
164        info!("Connecting to SurrealDB (in-memory)");
165
166        let db = surrealdb::engine::any::connect("mem://")
167            .await
168            .map_err(|e| StateError::Connection(e.to_string()))?;
169
170        // Select namespace and database
171        db.use_ns("aivcs")
172            .use_db("main")
173            .await
174            .map_err(|e| StateError::Connection(e.to_string()))?;
175
176        let handle = SurrealHandle { db };
177        handle.init_schema().await?;
178
179        info!("SurrealDB connected and schema initialized");
180        Ok(handle)
181    }
182
183    /// Connect to SurrealDB Cloud
184    ///
185    /// # Example
186    /// ```ignore
187    /// let config = CloudConfig::new(
188    ///     "wss://xxx.aws-use1.surrealdb.cloud",
189    ///     "your_username",
190    ///     "your_password",
191    /// );
192    /// let handle = SurrealHandle::setup_cloud(config).await?;
193    /// ```
194    #[instrument(skip(config), fields(endpoint = %config.endpoint, namespace = %config.namespace, database = %config.database))]
195    pub async fn setup_cloud(config: CloudConfig) -> Result<Self> {
196        info!("Connecting to SurrealDB Cloud (root={})", config.is_root);
197
198        let db = surrealdb::engine::any::connect(&config.endpoint)
199            .await
200            .map_err(|e| {
201                StateError::Connection(format!("Failed to connect to {}: {}", config.endpoint, e))
202            })?;
203
204        // Authenticate based on user type
205        if config.is_root {
206            // Root user authentication
207            db.signin(Root {
208                username: &config.username,
209                password: &config.password,
210            })
211            .await
212            .map_err(|e| StateError::Connection(format!("Root authentication failed: {}", e)))?;
213        } else {
214            // Database user authentication - requires namespace and database
215            db.signin(Database {
216                namespace: &config.namespace,
217                database: &config.database,
218                username: &config.username,
219                password: &config.password,
220            })
221            .await
222            .map_err(|e| {
223                StateError::Connection(format!("Database authentication failed: {}", e))
224            })?;
225        }
226
227        // Select namespace and database
228        db.use_ns(&config.namespace)
229            .use_db(&config.database)
230            .await
231            .map_err(|e| {
232                StateError::Connection(format!("Failed to select namespace/database: {}", e))
233            })?;
234
235        let handle = SurrealHandle { db };
236        handle.init_schema().await?;
237
238        info!("SurrealDB Cloud connected and schema initialized");
239        Ok(handle)
240    }
241
242    /// Connect using environment variables
243    ///
244    /// If SURREALDB_ENDPOINT is set, connects to cloud.
245    /// If SURREALDB_URL is set, connects to that URL.
246    /// Otherwise, falls back to local persistence in `.aivcs/db` using SurrealKV.
247    #[instrument(skip_all)]
248    pub async fn setup_from_env() -> Result<Self> {
249        if let Ok(config) = CloudConfig::from_env() {
250            info!("Cloud config found, connecting to SurrealDB Cloud");
251            return Self::setup_cloud(config).await;
252        }
253
254        let url = if let Ok(url) = std::env::var("SURREALDB_URL") {
255            info!("SURREALDB_URL found, connecting to {}", url);
256            url
257        } else {
258            // Default to local persistence in .aivcs/db
259            let path = ".aivcs/db";
260            std::fs::create_dir_all(path).map_err(|e| {
261                StateError::Connection(format!(
262                    "Failed to create database directory {}: {}",
263                    path, e
264                ))
265            })?;
266            let url = format!("surrealkv://{}", path);
267            info!(
268                "No cloud config or SURREALDB_URL found, using local persistence: {}",
269                url
270            );
271            url
272        };
273
274        let db = surrealdb::engine::any::connect(&url)
275            .await
276            .map_err(|e| StateError::Connection(format!("Failed to connect to {}: {}", url, e)))?;
277
278        db.use_ns("aivcs")
279            .use_db("main")
280            .await
281            .map_err(|e| StateError::Connection(e.to_string()))?;
282
283        let handle = SurrealHandle { db };
284        handle.init_schema().await?;
285        Ok(handle)
286    }
287
288    /// Initialize the database schema
289    async fn init_schema(&self) -> Result<()> {
290        crate::migrations::init_schema(&self.db).await?;
291        Ok(())
292    }
293
294    // ========== Commit Operations ==========
295
296    /// Save a new commit record
297    #[instrument(skip(self, record), fields(commit_id = %record.commit_id))]
298    pub async fn save_commit(&self, record: &CommitRecord) -> Result<CommitRecord> {
299        debug!("Saving commit");
300
301        // Clone to owned value to satisfy SurrealDB lifetime requirements
302        let record_owned = record.clone();
303
304        let created: Option<CommitRecord> = self.db.create("commits").content(record_owned).await?;
305
306        created.ok_or_else(|| StateError::Transaction("Failed to create commit".to_string()))
307    }
308
309    /// Get a commit by its hash
310    #[instrument(skip(self))]
311    pub async fn get_commit(&self, commit_hash: &str) -> Result<Option<CommitRecord>> {
312        debug!("Getting commit");
313
314        let hash_owned = commit_hash.to_string();
315
316        let mut result = self
317            .db
318            .query("SELECT * FROM commits WHERE commit_id.hash = $hash")
319            .bind(("hash", hash_owned))
320            .await?;
321
322        let commits: Vec<CommitRecord> = result.take(0)?;
323        Ok(commits.into_iter().next())
324    }
325
326    // ========== Snapshot Operations ==========
327
328    /// Save a snapshot (agent state)
329    ///
330    /// # TDD: test_snapshot_is_atomic_and_retrievable
331    #[instrument(skip(self, commit_id, state))]
332    pub async fn save_snapshot(
333        &self,
334        commit_id: &CommitId,
335        state: serde_json::Value,
336    ) -> Result<()> {
337        debug!("Saving snapshot for commit {}", commit_id.short());
338
339        let record = SnapshotRecord::new(&commit_id.hash, state);
340
341        let _created: Option<SnapshotRecord> =
342            self.db.create("snapshots").content(record.clone()).await?;
343
344        info!(
345            "Snapshot saved: {} ({} bytes)",
346            commit_id.short(),
347            record.size_bytes
348        );
349        Ok(())
350    }
351
352    /// Load a snapshot by commit ID
353    #[instrument(skip(self))]
354    pub async fn load_snapshot(&self, commit_id: &str) -> Result<SnapshotRecord> {
355        debug!("Loading snapshot");
356
357        let id_owned = commit_id.to_string();
358
359        let mut result = self
360            .db
361            .query("SELECT * FROM snapshots WHERE commit_id = $id")
362            .bind(("id", id_owned))
363            .await?;
364
365        let snapshots: Vec<SnapshotRecord> = result.take(0)?;
366        snapshots
367            .into_iter()
368            .next()
369            .ok_or_else(|| StateError::CommitNotFound(commit_id.to_string()))
370    }
371
372    // ========== Graph Edge Operations ==========
373
374    /// Save a commit graph edge (parent -> child relationship)
375    ///
376    /// # TDD: test_parent_child_edge_is_created
377    #[instrument(skip(self))]
378    pub async fn save_commit_graph_edge(&self, child_id: &str, parent_id: &str) -> Result<()> {
379        debug!("Saving graph edge: {} -> {}", parent_id, child_id);
380
381        let edge = GraphEdge::new(child_id, parent_id);
382
383        let _created: Option<GraphEdge> = self.db.create("graph_edges").content(edge).await?;
384
385        info!("Graph edge saved: {} -> {}", parent_id, child_id);
386        Ok(())
387    }
388
389    /// Get parent commit ID for a given commit
390    #[instrument(skip(self))]
391    pub async fn get_parent(&self, child_id: &str) -> Result<Option<String>> {
392        let id_owned = child_id.to_string();
393
394        let mut result = self
395            .db
396            .query("SELECT parent_id FROM graph_edges WHERE child_id = $id")
397            .bind(("id", id_owned))
398            .await?;
399
400        #[derive(serde::Deserialize)]
401        struct ParentResult {
402            parent_id: String,
403        }
404
405        let parents: Vec<ParentResult> = result.take(0)?;
406        Ok(parents.into_iter().next().map(|p| p.parent_id))
407    }
408
409    /// Get all children of a commit (for branch visualization)
410    #[instrument(skip(self))]
411    pub async fn get_children(&self, parent_id: &str) -> Result<Vec<String>> {
412        let id_owned = parent_id.to_string();
413
414        let mut result = self
415            .db
416            .query("SELECT child_id FROM graph_edges WHERE parent_id = $id")
417            .bind(("id", id_owned))
418            .await?;
419
420        #[derive(serde::Deserialize)]
421        struct ChildResult {
422            child_id: String,
423        }
424
425        let children: Vec<ChildResult> = result.take(0)?;
426        Ok(children.into_iter().map(|c| c.child_id).collect())
427    }
428
429    // ========== Branch Operations ==========
430
431    /// Create or update a branch
432    #[instrument(skip(self))]
433    pub async fn save_branch(&self, record: &BranchRecord) -> Result<BranchRecord> {
434        debug!("Saving branch: {}", record.name);
435
436        // Check if branch exists
437        let existing = self.get_branch(&record.name).await?;
438
439        if existing.is_some() {
440            // Update existing branch
441            let head = record.head_commit_id.clone();
442            let now = SurrealDatetime::from(chrono::Utc::now());
443            let name = record.name.clone();
444
445            let mut result = self
446                .db
447                .query("UPDATE branches SET head_commit_id = $head, updated_at = $now WHERE name = $name")
448                .bind(("head", head))
449                .bind(("now", now))
450                .bind(("name", name))
451                .await?;
452
453            let updated: Vec<BranchRecord> = result.take(0)?;
454            updated
455                .into_iter()
456                .next()
457                .ok_or_else(|| StateError::Transaction("Failed to update branch".to_string()))
458        } else {
459            // Create new branch - clone to owned
460            let record_owned = record.clone();
461
462            let created: Option<BranchRecord> =
463                self.db.create("branches").content(record_owned).await?;
464
465            created.ok_or_else(|| StateError::Transaction("Failed to create branch".to_string()))
466        }
467    }
468
469    /// Get a branch by name
470    #[instrument(skip(self))]
471    pub async fn get_branch(&self, name: &str) -> Result<Option<BranchRecord>> {
472        let name_owned = name.to_string();
473
474        let mut result = self
475            .db
476            .query("SELECT * FROM branches WHERE name = $name")
477            .bind(("name", name_owned))
478            .await?;
479
480        let branches: Vec<BranchRecord> = result.take(0)?;
481        Ok(branches.into_iter().next())
482    }
483
484    /// Get branch head commit ID
485    #[instrument(skip(self))]
486    pub async fn get_branch_head(&self, branch_name: &str) -> Result<String> {
487        let branch = self
488            .get_branch(branch_name)
489            .await?
490            .ok_or_else(|| StateError::BranchNotFound(branch_name.to_string()))?;
491
492        Ok(branch.head_commit_id)
493    }
494
495    /// List all branches
496    #[instrument(skip(self))]
497    pub async fn list_branches(&self) -> Result<Vec<BranchRecord>> {
498        let mut result = self
499            .db
500            .query("SELECT * FROM branches ORDER BY name")
501            .await?;
502
503        let branches: Vec<BranchRecord> = result.take(0)?;
504        Ok(branches)
505    }
506
507    /// Delete a branch
508    #[instrument(skip(self))]
509    pub async fn delete_branch(&self, name: &str) -> Result<()> {
510        debug!("Deleting branch: {}", name);
511
512        let branch = self
513            .get_branch(name)
514            .await?
515            .ok_or_else(|| StateError::BranchNotFound(name.to_string()))?;
516
517        if branch.is_default {
518            return Err(StateError::Transaction(
519                "Cannot delete the default branch".to_string(),
520            ));
521        }
522
523        let name_owned = name.to_string();
524
525        let _result = self
526            .db
527            .query("DELETE FROM branches WHERE name = $name")
528            .bind(("name", name_owned))
529            .await?;
530
531        Ok(())
532    }
533
534    // ========== Agent Operations ==========
535
536    /// Register an agent
537    #[instrument(skip(self, record), fields(agent_name = %record.name))]
538    pub async fn register_agent(&self, record: &AgentRecord) -> Result<AgentRecord> {
539        debug!("Registering agent");
540
541        let record_owned = record.clone();
542
543        let created: Option<AgentRecord> = self.db.create("agents").content(record_owned).await?;
544
545        created.ok_or_else(|| StateError::Transaction("Failed to register agent".to_string()))
546    }
547
548    /// Get agent by ID
549    #[instrument(skip(self))]
550    pub async fn get_agent(&self, agent_id: &str) -> Result<Option<AgentRecord>> {
551        let id_owned = agent_id.to_string();
552
553        let mut result = self
554            .db
555            .query("SELECT * FROM agents WHERE agent_id = $id")
556            .bind(("id", id_owned))
557            .await?;
558
559        let agents: Vec<AgentRecord> = result.take(0)?;
560        Ok(agents.into_iter().next())
561    }
562
563    // ========== Memory Operations ==========
564
565    /// Save a memory record
566    #[instrument(skip(self, record), fields(key = %record.key))]
567    pub async fn save_memory(&self, record: &MemoryRecord) -> Result<MemoryRecord> {
568        debug!("Saving memory");
569
570        let record_owned = record.clone();
571
572        let created: Option<MemoryRecord> =
573            self.db.create("memories").content(record_owned).await?;
574
575        created.ok_or_else(|| StateError::Transaction("Failed to save memory".to_string()))
576    }
577
578    /// Get all memories for a commit
579    #[instrument(skip(self))]
580    pub async fn get_memories(&self, commit_id: &str) -> Result<Vec<MemoryRecord>> {
581        let id_owned = commit_id.to_string();
582
583        let mut result = self
584            .db
585            .query("SELECT * FROM memories WHERE commit_id = $id ORDER BY created_at")
586            .bind(("id", id_owned))
587            .await?;
588
589        let memories: Vec<MemoryRecord> = result.take(0)?;
590        Ok(memories)
591    }
592
593    // ========== Release Registry Operations ==========
594
595    /// Promote a new release for an agent.
596    #[instrument(skip(self, spec_digest, metadata), fields(name = %name, digest = %spec_digest))]
597    pub async fn release_promote(
598        &self,
599        name: &str,
600        spec_digest: &ContentDigest,
601        metadata: ReleaseMetadata,
602    ) -> StorageResult<ReleaseRecord> {
603        let record = DbReleaseRecord {
604            name: name.to_string(),
605            spec_digest: spec_digest.clone(),
606            version_label: metadata.version_label.clone(),
607            promoted_by: metadata.promoted_by.clone(),
608            notes: metadata.notes.clone(),
609            metadata,
610            created_at: SurrealDatetime::from(Utc::now()),
611        };
612
613        let created: Option<DbReleaseRecord> = self
614            .db
615            .create("releases")
616            .content(record.clone())
617            .await
618            .map_err(|e| StorageError::Backend(e.to_string()))?;
619
620        created
621            .map(DbReleaseRecord::into_release_record)
622            .ok_or_else(|| StorageError::Backend("failed to create release record".to_string()))
623    }
624
625    /// Roll back to the previous release for an agent by re-appending it.
626    #[instrument(skip(self), fields(name = %name))]
627    pub async fn release_rollback(&self, name: &str) -> StorageResult<ReleaseRecord> {
628        let history = self.release_history(name).await?;
629        if history.is_empty() {
630            return Err(StorageError::ReleaseNotFound {
631                name: name.to_string(),
632            });
633        }
634        if history.len() < 2 {
635            return Err(StorageError::NoPreviousRelease {
636                name: name.to_string(),
637            });
638        }
639
640        let previous = &history[1];
641        self.release_promote(name, &previous.spec_digest, previous.metadata.clone())
642            .await
643    }
644
645    /// Get the current release (most recent) for an agent.
646    #[instrument(skip(self), fields(name = %name))]
647    pub async fn release_current(&self, name: &str) -> StorageResult<Option<ReleaseRecord>> {
648        let name_owned = name.to_string();
649
650        let mut result = self
651            .db
652            .query("SELECT * FROM releases WHERE name = $name ORDER BY created_at DESC LIMIT 1")
653            .bind(("name", name_owned))
654            .await
655            .map_err(|e| StorageError::Backend(e.to_string()))?;
656
657        let releases: Vec<DbReleaseRecord> = result
658            .take(0)
659            .map_err(|e| StorageError::Backend(e.to_string()))?;
660        Ok(releases
661            .into_iter()
662            .next()
663            .map(DbReleaseRecord::into_release_record))
664    }
665
666    /// Get full release history (newest first) for an agent.
667    #[instrument(skip(self), fields(name = %name))]
668    pub async fn release_history(&self, name: &str) -> StorageResult<Vec<ReleaseRecord>> {
669        let name_owned = name.to_string();
670
671        let mut result = self
672            .db
673            .query("SELECT * FROM releases WHERE name = $name ORDER BY created_at DESC")
674            .bind(("name", name_owned))
675            .await
676            .map_err(|e| StorageError::Backend(e.to_string()))?;
677
678        let releases: Vec<DbReleaseRecord> = result
679            .take(0)
680            .map_err(|e| StorageError::Backend(e.to_string()))?;
681        Ok(releases
682            .into_iter()
683            .map(DbReleaseRecord::into_release_record)
684            .collect())
685    }
686
687    // ========== CI Operations ==========
688
689    /// Save a CI snapshot as a content-addressed object.
690    #[instrument(skip(self, snapshot))]
691    pub async fn save_ci_snapshot(&self, snapshot: &CiSnapshot) -> Result<String> {
692        #[derive(serde::Serialize, serde::Deserialize)]
693        struct CiSnapshotStore {
694            digest: String,
695            snapshot_json: String,
696        }
697
698        let digest = snapshot.digest();
699        let snapshot_json = serde_json::to_string(snapshot)?;
700        let payload = CiSnapshotStore {
701            digest: digest.clone(),
702            snapshot_json,
703        };
704
705        let _created: Option<CiSnapshotStore> =
706            self.db.create("ci_snapshots").content(payload).await?;
707        Ok(digest)
708    }
709
710    /// Load a CI snapshot by digest.
711    #[instrument(skip(self))]
712    pub async fn load_ci_snapshot(&self, digest: &str) -> Result<Option<CiSnapshot>> {
713        #[derive(serde::Deserialize)]
714        struct CiSnapshotStore {
715            snapshot_json: String,
716        }
717
718        let digest_owned = digest.to_string();
719        let mut result = self
720            .db
721            .query("SELECT snapshot_json FROM ci_snapshots WHERE digest = $digest")
722            .bind(("digest", digest_owned))
723            .await?;
724
725        let rows: Vec<CiSnapshotStore> = result.take(0)?;
726        rows.into_iter()
727            .next()
728            .map(|r| serde_json::from_str(&r.snapshot_json))
729            .transpose()
730            .map_err(Into::into)
731    }
732
733    /// Save a CI pipeline as a content-addressed object.
734    #[instrument(skip(self, pipeline))]
735    pub async fn save_ci_pipeline(&self, pipeline: &CiPipelineSpec) -> Result<String> {
736        #[derive(serde::Serialize, serde::Deserialize)]
737        struct CiPipelineStore {
738            digest: String,
739            pipeline_json: String,
740        }
741
742        let digest = pipeline.digest();
743        let pipeline_json = serde_json::to_string(pipeline)?;
744        let payload = CiPipelineStore {
745            digest: digest.clone(),
746            pipeline_json,
747        };
748
749        let _created: Option<CiPipelineStore> =
750            self.db.create("ci_pipelines").content(payload).await?;
751        Ok(digest)
752    }
753
754    /// Load a CI pipeline by digest.
755    #[instrument(skip(self))]
756    pub async fn load_ci_pipeline(&self, digest: &str) -> Result<Option<CiPipelineSpec>> {
757        #[derive(serde::Deserialize)]
758        struct CiPipelineStore {
759            pipeline_json: String,
760        }
761
762        let digest_owned = digest.to_string();
763        let mut result = self
764            .db
765            .query("SELECT pipeline_json FROM ci_pipelines WHERE digest = $digest")
766            .bind(("digest", digest_owned))
767            .await?;
768
769        let rows: Vec<CiPipelineStore> = result.take(0)?;
770        rows.into_iter()
771            .next()
772            .map(|r| serde_json::from_str(&r.pipeline_json))
773            .transpose()
774            .map_err(Into::into)
775    }
776
777    /// Save a CI run record.
778    #[instrument(skip(self, run), fields(run_id = %run.run_id))]
779    pub async fn save_ci_run(&self, run: &CiRunRecord) -> Result<CiRunRecord> {
780        #[derive(serde::Serialize, serde::Deserialize)]
781        struct CiRunStore {
782            run_id: String,
783            snapshot_digest: String,
784            pipeline_digest: String,
785            status: String,
786            run_json: String,
787            started_at: Option<String>,
788            finished_at: Option<String>,
789        }
790
791        let payload = CiRunStore {
792            run_id: run.run_id.clone(),
793            snapshot_digest: run.snapshot_digest.clone(),
794            pipeline_digest: run.pipeline_digest.clone(),
795            status: serde_json::to_string(&run.status)?
796                .trim_matches('"')
797                .to_string(),
798            run_json: serde_json::to_string(run)?,
799            started_at: run.started_at.clone(),
800            finished_at: run.finished_at.clone(),
801        };
802
803        let created: Option<CiRunStore> = self.db.create("ci_runs").content(payload).await?;
804        if created.is_some() {
805            Ok(run.clone())
806        } else {
807            Err(StateError::Transaction(
808                "Failed to create CI run".to_string(),
809            ))
810        }
811    }
812
813    /// Get a CI run by run ID.
814    #[instrument(skip(self))]
815    pub async fn get_ci_run(&self, run_id: &str) -> Result<Option<CiRunRecord>> {
816        #[derive(serde::Deserialize)]
817        struct CiRunStore {
818            run_json: String,
819        }
820
821        let run_id_owned = run_id.to_string();
822        let mut result = self
823            .db
824            .query("SELECT run_json FROM ci_runs WHERE run_id = $run_id")
825            .bind(("run_id", run_id_owned))
826            .await?;
827        let runs: Vec<CiRunStore> = result.take(0)?;
828        runs.into_iter()
829            .next()
830            .map(|r| serde_json::from_str(&r.run_json))
831            .transpose()
832            .map_err(Into::into)
833    }
834
835    /// List CI runs for a given snapshot digest.
836    #[instrument(skip(self))]
837    pub async fn list_ci_runs_by_snapshot(
838        &self,
839        snapshot_digest: &str,
840    ) -> Result<Vec<CiRunRecord>> {
841        #[derive(serde::Deserialize)]
842        struct CiRunStore {
843            run_json: String,
844        }
845
846        let snapshot_digest_owned = snapshot_digest.to_string();
847        let mut result = self
848            .db
849            .query("SELECT run_json FROM ci_runs WHERE snapshot_digest = $snapshot_digest")
850            .bind(("snapshot_digest", snapshot_digest_owned))
851            .await?;
852        let runs: Vec<CiRunStore> = result.take(0)?;
853        runs.into_iter()
854            .map(|r| serde_json::from_str::<CiRunRecord>(&r.run_json))
855            .collect::<std::result::Result<Vec<_>, _>>()
856            .map_err(Into::into)
857    }
858
859    // ========== Decision and Provenance Operations (EPIC5) ==========
860
861    /// Save a decision record
862    #[instrument(skip(self, record))]
863    pub async fn save_decision(&self, record: &DecisionRecord) -> Result<DecisionRecord> {
864        debug!("Saving decision");
865
866        let record_owned = record.clone();
867
868        let created: Option<DecisionRecord> =
869            self.db.create("decisions").content(record_owned).await?;
870
871        created.ok_or_else(|| StateError::Transaction("Failed to save decision".to_string()))
872    }
873
874    /// Get decision by decision ID
875    #[instrument(skip(self))]
876    pub async fn get_decision(&self, decision_id: &str) -> Result<Option<DecisionRecord>> {
877        let id_owned = decision_id.to_string();
878
879        let mut result = self
880            .db
881            .query("SELECT * FROM decisions WHERE decision_id = $id")
882            .bind(("id", id_owned))
883            .await?;
884
885        let decisions: Vec<DecisionRecord> = result.take(0)?;
886        Ok(decisions.into_iter().next())
887    }
888
889    /// Update decision outcome by decision ID
890    #[instrument(skip(self))]
891    pub async fn update_decision_outcome(
892        &self,
893        decision_id: &str,
894        outcome_json: String,
895    ) -> Result<DecisionRecord> {
896        let id_owned = decision_id.to_string();
897        let now = SurrealDatetime::from(Utc::now());
898
899        let mut result = self
900            .db
901            .query(
902                "UPDATE decisions SET outcome = $outcome, outcome_at = $outcome_at WHERE decision_id = $id RETURN AFTER",
903            )
904            .bind(("id", id_owned))
905            .bind(("outcome", outcome_json))
906            .bind(("outcome_at", now))
907            .await?;
908
909        let decisions: Vec<DecisionRecord> = result.take(0)?;
910        decisions
911            .into_iter()
912            .next()
913            .ok_or_else(|| StateError::Transaction("Decision not found for update".to_string()))
914    }
915
916    /// Get decision history for a task
917    #[instrument(skip(self))]
918    pub async fn get_decision_history(
919        &self,
920        task: &str,
921        limit: usize,
922    ) -> Result<Vec<DecisionRecord>> {
923        let task_owned = task.to_string();
924
925        let mut result = self
926            .db
927            .query(
928                "SELECT * FROM decisions WHERE task = $task ORDER BY timestamp DESC LIMIT $limit",
929            )
930            .bind(("task", task_owned))
931            .bind(("limit", limit as i64))
932            .await?;
933
934        let decisions: Vec<DecisionRecord> = result.take(0)?;
935        Ok(decisions)
936    }
937
938    /// Save a memory provenance record
939    #[instrument(skip(self, record))]
940    pub async fn save_provenance(
941        &self,
942        record: &MemoryProvenanceRecord,
943    ) -> Result<MemoryProvenanceRecord> {
944        debug!("Saving memory provenance");
945
946        let record_owned = record.clone();
947
948        let created: Option<MemoryProvenanceRecord> = self
949            .db
950            .create("memory_provenances")
951            .content(record_owned)
952            .await?;
953
954        created.ok_or_else(|| StateError::Transaction("Failed to save provenance".to_string()))
955    }
956
957    /// Get provenance records for a memory
958    #[instrument(skip(self))]
959    pub async fn get_provenance(&self, memory_id: &str) -> Result<Vec<MemoryProvenanceRecord>> {
960        let memory_id_owned = memory_id.to_string();
961
962        let mut result = self
963            .db
964            .query("SELECT * FROM memory_provenances WHERE memory_id = $memory_id")
965            .bind(("memory_id", memory_id_owned))
966            .await?;
967
968        let provenances: Vec<MemoryProvenanceRecord> = result.take(0)?;
969        Ok(provenances)
970    }
971
972    // ========== History Operations ==========
973
974    /// Get commit history (walk back from a commit)
975    #[instrument(skip(self))]
976    pub async fn get_commit_history(
977        &self,
978        start_commit: &str,
979        limit: usize,
980    ) -> Result<Vec<CommitRecord>> {
981        let mut history = Vec::new();
982        let mut current = Some(start_commit.to_string());
983
984        while let Some(commit_hash) = current {
985            if history.len() >= limit {
986                break;
987            }
988
989            if let Some(commit) = self.get_commit(&commit_hash).await? {
990                // For linear history, we follow the first parent
991                current = commit.parent_ids.first().cloned();
992                history.push(commit);
993            } else {
994                break;
995            }
996        }
997
998        Ok(history)
999    }
1000
1001    /// Get the reasoning trace (CoT) for time-travel debugging
1002    ///
1003    /// # TDD: test_get_trace_for_commit_id_returns_correct_CoT
1004    #[instrument(skip(self))]
1005    pub async fn get_reasoning_trace(&self, commit_id: &str) -> Result<Vec<SnapshotRecord>> {
1006        // Get commit history
1007        let history = self.get_commit_history(commit_id, 100).await?;
1008
1009        // Load snapshots for each commit
1010        let mut trace = Vec::new();
1011        for commit in history {
1012            if let Ok(snapshot) = self.load_snapshot(&commit.commit_id.hash).await {
1013                trace.push(snapshot);
1014            }
1015        }
1016
1017        Ok(trace)
1018    }
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023    use super::*;
1024    use std::collections::BTreeMap;
1025
1026    #[tokio::test]
1027    async fn test_surreal_connection_and_schema_creation() {
1028        let handle = SurrealHandle::setup_db().await;
1029        assert!(handle.is_ok(), "Failed to connect: {:?}", handle.err());
1030    }
1031
1032    #[tokio::test]
1033    async fn test_branch_deletion() {
1034        let handle = SurrealHandle::setup_db().await.unwrap();
1035
1036        // Create a branch
1037        let branch = BranchRecord::new("feature/test", "commit-123", false);
1038        handle.save_branch(&branch).await.unwrap();
1039
1040        // Verify it exists
1041        let loaded = handle.get_branch("feature/test").await.unwrap();
1042        assert!(loaded.is_some());
1043
1044        // Delete it
1045        handle.delete_branch("feature/test").await.unwrap();
1046
1047        // Verify it's gone
1048        let deleted = handle.get_branch("feature/test").await.unwrap();
1049        assert!(deleted.is_none());
1050    }
1051
1052    #[tokio::test]
1053    async fn test_delete_nonexistent_branch() {
1054        let handle = SurrealHandle::setup_db().await.unwrap();
1055        let result = handle.delete_branch("nonexistent").await;
1056        assert!(result.is_err());
1057        assert!(result.unwrap_err().to_string().contains("Branch not found"));
1058    }
1059
1060    #[tokio::test]
1061    async fn test_delete_default_branch() {
1062        let handle = SurrealHandle::setup_db().await.unwrap();
1063
1064        // Create default branch
1065        let branch = BranchRecord::new("main", "commit-123", true);
1066        handle.save_branch(&branch).await.unwrap();
1067
1068        let result = handle.delete_branch("main").await;
1069        assert!(result.is_err());
1070        assert!(result
1071            .unwrap_err()
1072            .to_string()
1073            .contains("Cannot delete the default branch"));
1074    }
1075
1076    #[tokio::test]
1077    async fn test_update_decision_outcome_persists_fields() {
1078        let handle = SurrealHandle::setup_db().await.unwrap();
1079
1080        let decision = DecisionRecord::new(
1081            "dec-outcome-1".to_string(),
1082            "commit-123".to_string(),
1083            "task-123".to_string(),
1084            "action-123".to_string(),
1085            "because".to_string(),
1086            0.9,
1087        );
1088        handle.save_decision(&decision).await.unwrap();
1089
1090        let updated = handle
1091            .update_decision_outcome(
1092                "dec-outcome-1",
1093                r#"{"status":"success","duration_ms":123}"#.to_string(),
1094            )
1095            .await
1096            .unwrap();
1097
1098        assert_eq!(
1099            updated.outcome,
1100            Some(r#"{"status":"success","duration_ms":123}"#.to_string())
1101        );
1102        assert!(updated.outcome_at.is_some());
1103    }
1104
1105    #[tokio::test]
1106    async fn test_snapshot_is_atomic_and_retrievable() {
1107        let handle = SurrealHandle::setup_db().await.unwrap();
1108
1109        let state = serde_json::json!({
1110            "agent_name": "test-agent",
1111            "step": 1,
1112            "variables": {"x": 42, "y": "hello"}
1113        });
1114
1115        let commit_id = CommitId::from_state(serde_json::to_vec(&state).unwrap().as_slice());
1116
1117        // Save snapshot
1118        handle
1119            .save_snapshot(&commit_id, state.clone())
1120            .await
1121            .unwrap();
1122
1123        // Retrieve snapshot
1124        let loaded = handle.load_snapshot(&commit_id.hash).await.unwrap();
1125
1126        assert_eq!(loaded.commit_id, commit_id.hash);
1127        assert_eq!(loaded.state, state);
1128    }
1129
1130    #[tokio::test]
1131    async fn test_parent_child_edge_is_created() {
1132        let handle = SurrealHandle::setup_db().await.unwrap();
1133
1134        let parent_id = "parent-commit-hash";
1135        let child_id = "child-commit-hash";
1136
1137        // Save edge
1138        handle
1139            .save_commit_graph_edge(child_id, parent_id)
1140            .await
1141            .unwrap();
1142
1143        // Verify parent can be retrieved
1144        let parent = handle.get_parent(child_id).await.unwrap();
1145        assert_eq!(parent, Some(parent_id.to_string()));
1146
1147        // Verify children can be retrieved
1148        let children = handle.get_children(parent_id).await.unwrap();
1149        assert!(children.contains(&child_id.to_string()));
1150    }
1151
1152    #[tokio::test]
1153    async fn test_branch_operations() {
1154        let handle = SurrealHandle::setup_db().await.unwrap();
1155
1156        let branch = BranchRecord::new("main", "commit-abc123", true);
1157        handle.save_branch(&branch).await.unwrap();
1158
1159        // Get branch
1160        let loaded = handle.get_branch("main").await.unwrap();
1161        assert!(loaded.is_some());
1162        assert_eq!(loaded.unwrap().head_commit_id, "commit-abc123");
1163
1164        // Get branch head
1165        let head = handle.get_branch_head("main").await.unwrap();
1166        assert_eq!(head, "commit-abc123");
1167
1168        // Update branch head
1169        let updated_branch = BranchRecord::new("main", "commit-def456", true);
1170        handle.save_branch(&updated_branch).await.unwrap();
1171
1172        let new_head = handle.get_branch_head("main").await.unwrap();
1173        assert_eq!(new_head, "commit-def456");
1174    }
1175
1176    #[tokio::test]
1177    async fn test_commit_record_operations() {
1178        let handle = SurrealHandle::setup_db().await.unwrap();
1179
1180        let commit_id = CommitId::from_state(b"test state");
1181        let commit = CommitRecord::new(commit_id.clone(), vec![], "Initial commit", "test-agent");
1182
1183        // Save commit
1184        let saved = handle.save_commit(&commit).await.unwrap();
1185        assert_eq!(saved.commit_id.hash, commit_id.hash);
1186
1187        // Get commit
1188        let loaded = handle.get_commit(&commit_id.hash).await.unwrap();
1189        assert!(loaded.is_some());
1190        assert_eq!(loaded.unwrap().message, "Initial commit");
1191    }
1192
1193    #[tokio::test]
1194    async fn test_get_trace_for_commit_id_returns_correct_cot() {
1195        let handle = SurrealHandle::setup_db().await.unwrap();
1196
1197        // Create a chain of commits: initial -> step1 -> step2 -> step3
1198        let state_0 = serde_json::json!({"step": 0, "thought": "Starting exploration"});
1199        let state_1 = serde_json::json!({"step": 1, "thought": "Trying strategy A"});
1200        let state_2 = serde_json::json!({"step": 2, "thought": "Strategy A failed, pivoting"});
1201        let state_3 = serde_json::json!({"step": 3, "thought": "Strategy B succeeded"});
1202
1203        let id_0 = CommitId::from_state(b"state-0");
1204        let id_1 = CommitId::from_state(b"state-1");
1205        let id_2 = CommitId::from_state(b"state-2");
1206        let id_3 = CommitId::from_state(b"state-3");
1207
1208        // Save snapshots
1209        handle.save_snapshot(&id_0, state_0.clone()).await.unwrap();
1210        handle.save_snapshot(&id_1, state_1.clone()).await.unwrap();
1211        handle.save_snapshot(&id_2, state_2.clone()).await.unwrap();
1212        handle.save_snapshot(&id_3, state_3.clone()).await.unwrap();
1213
1214        // Save commits with parent chain
1215        let commit_0 = CommitRecord::new(id_0.clone(), vec![], "Step 0", "agent");
1216        let commit_1 = CommitRecord::new(id_1.clone(), vec![id_0.hash.clone()], "Step 1", "agent");
1217        let commit_2 = CommitRecord::new(id_2.clone(), vec![id_1.hash.clone()], "Step 2", "agent");
1218        let commit_3 = CommitRecord::new(id_3.clone(), vec![id_2.hash.clone()], "Step 3", "agent");
1219
1220        handle.save_commit(&commit_0).await.unwrap();
1221        handle.save_commit(&commit_1).await.unwrap();
1222        handle.save_commit(&commit_2).await.unwrap();
1223        handle.save_commit(&commit_3).await.unwrap();
1224
1225        // Get reasoning trace from step 3
1226        let trace = handle.get_reasoning_trace(&id_3.hash).await.unwrap();
1227
1228        // Should have 4 snapshots in reverse order (newest first)
1229        assert_eq!(trace.len(), 4, "Trace should contain all 4 commits");
1230
1231        // Verify order (most recent first)
1232        assert_eq!(trace[0].state["step"], 3);
1233        assert_eq!(trace[1].state["step"], 2);
1234        assert_eq!(trace[2].state["step"], 1);
1235        assert_eq!(trace[3].state["step"], 0);
1236
1237        // Verify Chain-of-Thought is preserved
1238        assert_eq!(trace[0].state["thought"], "Strategy B succeeded");
1239        assert_eq!(trace[1].state["thought"], "Strategy A failed, pivoting");
1240        assert_eq!(trace[2].state["thought"], "Trying strategy A");
1241        assert_eq!(trace[3].state["thought"], "Starting exploration");
1242    }
1243
1244    #[tokio::test]
1245    async fn test_ci_records_roundtrip() {
1246        let handle = SurrealHandle::setup_db().await.unwrap();
1247
1248        let snapshot = CiSnapshot {
1249            repo_sha: "abc123".to_string(),
1250            workspace_hash: "work-1".to_string(),
1251            local_ci_config_hash: "cfg-1".to_string(),
1252            env_hash: "env-1".to_string(),
1253        };
1254        let snapshot_digest = handle.save_ci_snapshot(&snapshot).await.unwrap();
1255        let loaded_snapshot = handle.load_ci_snapshot(&snapshot_digest).await.unwrap();
1256        assert_eq!(loaded_snapshot, Some(snapshot.clone()));
1257
1258        let pipeline = CiPipelineSpec {
1259            name: "default".to_string(),
1260            steps: vec![crate::ci::CiStepSpec {
1261                name: "test".to_string(),
1262                command: crate::ci::CiCommand {
1263                    program: "cargo".to_string(),
1264                    args: vec!["test".to_string()],
1265                    env: BTreeMap::new(),
1266                    cwd: None,
1267                },
1268                timeout_secs: Some(300),
1269                allow_failure: false,
1270            }],
1271        };
1272        let pipeline_digest = handle.save_ci_pipeline(&pipeline).await.unwrap();
1273        let loaded_pipeline = handle.load_ci_pipeline(&pipeline_digest).await.unwrap();
1274        assert_eq!(loaded_pipeline, Some(pipeline.clone()));
1275
1276        let run = CiRunRecord::queued(&snapshot_digest, &pipeline_digest);
1277        let saved_run = handle.save_ci_run(&run).await.unwrap();
1278        let loaded_run = handle.get_ci_run(&saved_run.run_id).await.unwrap();
1279        assert_eq!(loaded_run, Some(saved_run.clone()));
1280
1281        let runs = handle
1282            .list_ci_runs_by_snapshot(&snapshot_digest)
1283            .await
1284            .unwrap();
1285        assert_eq!(runs.len(), 1);
1286        assert_eq!(runs[0].run_id, saved_run.run_id);
1287    }
1288
1289    #[tokio::test]
1290    async fn test_release_fields_roundtrip_in_surreal() {
1291        let handle = SurrealHandle::setup_db().await.unwrap();
1292
1293        let metadata = ReleaseMetadata {
1294            version_label: Some("v1.2.3".to_string()),
1295            promoted_by: "test-user".to_string(),
1296            notes: Some("Release notes here".to_string()),
1297        };
1298        let digest = ContentDigest::from_bytes(b"spec-data");
1299
1300        // Promote release
1301        let release = handle
1302            .release_promote("test-agent", &digest, metadata.clone())
1303            .await
1304            .unwrap();
1305
1306        assert_eq!(release.name, "test-agent");
1307        assert_eq!(release.metadata.version_label, Some("v1.2.3".to_string()));
1308
1309        // Check raw DB record to ensure top-level fields are set (since table is SCHEMAFULL)
1310        let mut result = handle
1311            .db
1312            .query("SELECT name, version_label, promoted_by, notes FROM releases WHERE name = 'test-agent'")
1313            .await
1314            .unwrap();
1315
1316        #[derive(serde::Deserialize)]
1317        struct RawRelease {
1318            name: String,
1319            version_label: Option<String>,
1320            promoted_by: String,
1321            notes: Option<String>,
1322        }
1323
1324        let rows: Vec<RawRelease> = result.take(0).unwrap();
1325        assert_eq!(rows.len(), 1);
1326        let row = &rows[0];
1327
1328        assert_eq!(row.version_label, Some("v1.2.3".to_string()));
1329        assert_eq!(row.promoted_by, "test-user");
1330        assert_eq!(row.notes, Some("Release notes here".to_string()));
1331        assert_eq!(row.name, "test-agent");
1332
1333        // Verify history roundtrip
1334        let history = handle.release_history("test-agent").await.unwrap();
1335        assert_eq!(history.len(), 1);
1336        assert_eq!(
1337            history[0].metadata.version_label,
1338            Some("v1.2.3".to_string())
1339        );
1340        assert_eq!(history[0].metadata.promoted_by, "test-user");
1341    }
1342}