Skip to main content

plexus_substrate/activations/orcha/
storage.rs

1use super::types::{SessionId, SessionInfo, SessionState};
2use crate::activations::storage::init_sqlite_pool;
3use crate::activation_db_path_from_module;
4use sqlx::{sqlite::SqlitePool, Row};
5use std::collections::HashMap;
6use std::path::PathBuf;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10/// Configuration for Orcha storage
11#[derive(Debug, Clone)]
12pub struct OrchaStorageConfig {
13    pub db_path: PathBuf,
14}
15
16impl Default for OrchaStorageConfig {
17    fn default() -> Self {
18        Self {
19            db_path: activation_db_path_from_module!("orcha.db"),
20        }
21    }
22}
23
24/// Storage for orcha sessions backed by SQLite
25pub struct OrchaStorage {
26    pool: SqlitePool,
27    /// In-memory cache of active sessions
28    sessions: Arc<RwLock<HashMap<SessionId, SessionInfo>>>,
29}
30
31impl OrchaStorage {
32    /// Create new storage with the given configuration
33    pub async fn new(config: OrchaStorageConfig) -> Result<Self, String> {
34        let pool = init_sqlite_pool(config.db_path).await?;
35
36        let storage = Self {
37            pool,
38            sessions: Arc::new(RwLock::new(HashMap::new())),
39        };
40
41        storage.init_schema().await?;
42        storage.load_sessions().await?;
43
44        Ok(storage)
45    }
46
47    /// Initialize database schema
48    async fn init_schema(&self) -> Result<(), String> {
49        // Create orcha_sessions table
50        sqlx::query(
51            r#"
52            CREATE TABLE IF NOT EXISTS orcha_sessions (
53                session_id TEXT PRIMARY KEY,
54                model TEXT NOT NULL,
55                working_directory TEXT NOT NULL,
56                rules TEXT,
57                max_retries INTEGER NOT NULL DEFAULT 3,
58                retry_count INTEGER NOT NULL DEFAULT 0,
59                created_at INTEGER NOT NULL,
60                last_activity INTEGER NOT NULL,
61                state_type TEXT NOT NULL,
62                state_data TEXT,
63                UNIQUE(session_id)
64            )
65            "#,
66        )
67        .execute(&self.pool)
68        .await
69        .map_err(|e| format!("Failed to create orcha_sessions table: {}", e))?;
70
71        // Migrate orcha_sessions: add agent_mode column if not exists
72        // SQLite doesn't have a nice IF NOT EXISTS for ALTER TABLE, so we use PRAGMA
73        // PRAGMA table_info returns: (cid, name, type, notnull, dflt_value, pk)
74        let rows = sqlx::query("PRAGMA table_info(orcha_sessions)")
75            .fetch_all(&self.pool)
76            .await
77            .map_err(|e| format!("Failed to get table info: {}", e))?;
78
79        let column_names: Vec<String> = rows.iter()
80            .filter_map(|row| match row.try_get::<String, _>("name") {
81                Ok(name) => Some(name),
82                Err(e) => {
83                    tracing::warn!("Failed to read column name from PRAGMA table_info: {}", e);
84                    None
85                }
86            })
87            .collect();
88
89        let has_agent_mode = column_names.iter().any(|name| name == "agent_mode");
90        if !has_agent_mode {
91            sqlx::query("ALTER TABLE orcha_sessions ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'")
92                .execute(&self.pool)
93                .await
94                .map_err(|e| format!("Failed to add agent_mode column: {}", e))?;
95        }
96
97        let has_primary_agent_id = column_names.iter().any(|name| name == "primary_agent_id");
98        if !has_primary_agent_id {
99            sqlx::query("ALTER TABLE orcha_sessions ADD COLUMN primary_agent_id TEXT")
100                .execute(&self.pool)
101                .await
102                .map_err(|e| format!("Failed to add primary_agent_id column: {}", e))?;
103        }
104
105        let has_tree_id = column_names.iter().any(|name| name == "tree_id");
106        if !has_tree_id {
107            sqlx::query("ALTER TABLE orcha_sessions ADD COLUMN tree_id TEXT")
108                .execute(&self.pool)
109                .await
110                .map_err(|e| format!("Failed to add tree_id column: {}", e))?;
111        }
112
113        // Create orcha_agents table
114        sqlx::query(
115            r#"
116            CREATE TABLE IF NOT EXISTS orcha_agents (
117                agent_id TEXT PRIMARY KEY,
118                session_id TEXT NOT NULL,
119                claudecode_session_id TEXT NOT NULL,
120                subtask TEXT NOT NULL,
121                state_type TEXT NOT NULL,
122                state_data TEXT,
123                is_primary INTEGER NOT NULL DEFAULT 0,
124                parent_agent_id TEXT,
125                created_at INTEGER NOT NULL,
126                last_activity INTEGER NOT NULL,
127                completed_at INTEGER,
128                error_message TEXT,
129                FOREIGN KEY (session_id) REFERENCES orcha_sessions(session_id) ON DELETE CASCADE,
130                FOREIGN KEY (parent_agent_id) REFERENCES orcha_agents(agent_id) ON DELETE SET NULL
131            )
132            "#,
133        )
134        .execute(&self.pool)
135        .await
136        .map_err(|e| format!("Failed to create orcha_agents table: {}", e))?;
137
138        // Create indexes for orcha_agents
139        sqlx::query("CREATE INDEX IF NOT EXISTS idx_agents_session ON orcha_agents(session_id)")
140            .execute(&self.pool)
141            .await
142            .map_err(|e| format!("Failed to create session index: {}", e))?;
143
144        sqlx::query("CREATE INDEX IF NOT EXISTS idx_agents_state ON orcha_agents(state_type)")
145            .execute(&self.pool)
146            .await
147            .map_err(|e| format!("Failed to create state index: {}", e))?;
148
149        Ok(())
150    }
151
152    /// Load all sessions from database into memory cache
153    async fn load_sessions(&self) -> Result<(), String> {
154        let rows = sqlx::query("SELECT * FROM orcha_sessions")
155            .fetch_all(&self.pool)
156            .await
157            .map_err(|e| format!("Failed to load sessions: {}", e))?;
158
159        let mut sessions = self.sessions.write().await;
160
161        for row in rows {
162            let session_id: String = row.get("session_id");
163            let model: String = row.get("model");
164            let created_at: i64 = row.get("created_at");
165            let last_activity: i64 = row.get("last_activity");
166            let retry_count: i64 = row.get("retry_count");
167            let max_retries: i64 = row.get("max_retries");
168            let state_type: String = row.get("state_type");
169            let state_data: Option<String> = row.get("state_data");
170
171            // Try to get new fields, default if not present (for backward compat)
172            let agent_mode_str: Option<String> = row.try_get("agent_mode").ok();
173            let agent_mode = agent_mode_str
174                .and_then(|s| serde_json::from_str(&format!("\"{}\"", s)).ok())
175                .unwrap_or(super::types::AgentMode::Single);
176
177            let primary_agent_id: Option<String> = row.try_get("primary_agent_id").ok().flatten();
178            let tree_id: Option<String> = row.try_get("tree_id").ok().flatten();
179
180            let state = self.deserialize_state(&state_type, state_data.as_deref())?;
181
182            let info = SessionInfo {
183                session_id: session_id.clone(),
184                model,
185                created_at,
186                last_activity,
187                state,
188                retry_count: retry_count as u32,
189                max_retries: max_retries as u32,
190                agent_mode,
191                primary_agent_id,
192                tree_id,
193            };
194
195            sessions.insert(session_id, info);
196        }
197
198        Ok(())
199    }
200
201    /// Create a new session
202    pub async fn create_session(
203        &self,
204        session_id: SessionId,
205        model: String,
206        working_directory: String,
207        rules: Option<String>,
208        max_retries: u32,
209        agent_mode: super::types::AgentMode,
210        tree_id: Option<String>,
211    ) -> Result<SessionInfo, String> {
212        let now = chrono::Utc::now().timestamp();
213
214        let agent_mode_str = match agent_mode {
215            super::types::AgentMode::Single => "single",
216            super::types::AgentMode::Multi => "multi",
217        };
218
219        let info = SessionInfo {
220            session_id: session_id.clone(),
221            model: model.clone(),
222            created_at: now,
223            last_activity: now,
224            state: SessionState::Idle,
225            retry_count: 0,
226            max_retries,
227            agent_mode,
228            primary_agent_id: None,
229            tree_id: tree_id.clone(),
230        };
231
232        // Insert into database
233        sqlx::query(
234            r#"
235            INSERT INTO orcha_sessions (
236                session_id, model, working_directory, rules, max_retries,
237                retry_count, created_at, last_activity, state_type, state_data,
238                agent_mode, primary_agent_id, tree_id
239            ) VALUES (?, ?, ?, ?, ?, 0, ?, ?, 'idle', NULL, ?, NULL, ?)
240            "#,
241        )
242        .bind(&session_id)
243        .bind(&model)
244        .bind(&working_directory)
245        .bind(&rules)
246        .bind(max_retries as i64)
247        .bind(now)
248        .bind(now)
249        .bind(agent_mode_str)
250        .bind(&tree_id)
251        .execute(&self.pool)
252        .await
253        .map_err(|e| format!("Failed to create session: {}", e))?;
254
255        // Add to cache
256        self.sessions.write().await.insert(session_id.clone(), info.clone());
257
258        Ok(info)
259    }
260
261    /// Get session info
262    pub async fn get_session(&self, session_id: &SessionId) -> Result<SessionInfo, String> {
263        let sessions = self.sessions.read().await;
264        sessions
265            .get(session_id)
266            .cloned()
267            .ok_or_else(|| format!("Session not found: {}", session_id))
268    }
269
270    /// Update session state
271    pub async fn update_state(
272        &self,
273        session_id: &SessionId,
274        state: SessionState,
275    ) -> Result<(), String> {
276        let now = chrono::Utc::now().timestamp();
277        let (state_type, state_data) = self.serialize_state(&state);
278
279        sqlx::query(
280            r#"
281            UPDATE orcha_sessions
282            SET state_type = ?, state_data = ?, last_activity = ?
283            WHERE session_id = ?
284            "#,
285        )
286        .bind(&state_type)
287        .bind(&state_data)
288        .bind(now)
289        .bind(session_id)
290        .execute(&self.pool)
291        .await
292        .map_err(|e| format!("Failed to update state: {}", e))?;
293
294        // Update cache
295        if let Some(info) = self.sessions.write().await.get_mut(session_id) {
296            info.state = state;
297            info.last_activity = now;
298        }
299
300        Ok(())
301    }
302
303    /// Increment retry count
304    pub async fn increment_retry(&self, session_id: &SessionId) -> Result<u32, String> {
305        sqlx::query(
306            r#"
307            UPDATE orcha_sessions
308            SET retry_count = retry_count + 1
309            WHERE session_id = ?
310            "#,
311        )
312        .bind(session_id)
313        .execute(&self.pool)
314        .await
315        .map_err(|e| format!("Failed to increment retry count: {}", e))?;
316
317        // Update cache and return new count
318        if let Some(info) = self.sessions.write().await.get_mut(session_id) {
319            info.retry_count += 1;
320            Ok(info.retry_count)
321        } else {
322            Err(format!("Session not found: {}", session_id))
323        }
324    }
325
326    /// List all sessions
327    pub async fn list_sessions(&self) -> Vec<SessionInfo> {
328        self.sessions.read().await.values().cloned().collect()
329    }
330
331    /// Delete a session
332    pub async fn delete_session(&self, session_id: &SessionId) -> Result<(), String> {
333        sqlx::query("DELETE FROM orcha_sessions WHERE session_id = ?")
334            .bind(session_id)
335            .execute(&self.pool)
336            .await
337            .map_err(|e| format!("Failed to delete session: {}", e))?;
338
339        self.sessions.write().await.remove(session_id);
340
341        Ok(())
342    }
343
344    // ═══════════════════════════════════════════════════════════════════════
345    // Agent Management (Multi-Agent Orchestration)
346    // ═══════════════════════════════════════════════════════════════════════
347
348    /// Create a new agent for a session
349    pub async fn create_agent(
350        &self,
351        session_id: &SessionId,
352        claudecode_session_id: String,
353        subtask: String,
354        is_primary: bool,
355        parent_agent_id: Option<super::types::AgentId>,
356    ) -> Result<super::types::AgentInfo, String> {
357        let agent_id = format!("agent-{}", uuid::Uuid::new_v4());
358        let now = chrono::Utc::now().timestamp();
359
360        sqlx::query(
361            r#"
362            INSERT INTO orcha_agents (
363                agent_id, session_id, claudecode_session_id, subtask,
364                state_type, state_data, is_primary, parent_agent_id,
365                created_at, last_activity
366            ) VALUES (?, ?, ?, ?, 'idle', NULL, ?, ?, ?, ?)
367            "#,
368        )
369        .bind(&agent_id)
370        .bind(session_id)
371        .bind(&claudecode_session_id)
372        .bind(&subtask)
373        .bind(if is_primary { 1 } else { 0 })
374        .bind(&parent_agent_id)
375        .bind(now)
376        .bind(now)
377        .execute(&self.pool)
378        .await
379        .map_err(|e| format!("Failed to create agent: {}", e))?;
380
381        Ok(super::types::AgentInfo {
382            agent_id,
383            session_id: session_id.clone(),
384            claudecode_session_id,
385            subtask,
386            state: super::types::AgentState::Idle,
387            is_primary,
388            parent_agent_id,
389            created_at: now,
390            last_activity: now,
391            completed_at: None,
392            error_message: None,
393        })
394    }
395
396    /// Get agent by ID
397    pub async fn get_agent(&self, agent_id: &super::types::AgentId) -> Result<super::types::AgentInfo, String> {
398        let row = sqlx::query("SELECT * FROM orcha_agents WHERE agent_id = ?")
399            .bind(agent_id)
400            .fetch_optional(&self.pool)
401            .await
402            .map_err(|e| format!("Failed to fetch agent: {}", e))?
403            .ok_or_else(|| format!("Agent not found: {}", agent_id))?;
404
405        self.row_to_agent(row)
406    }
407
408    /// List all agents for a session
409    pub async fn list_agents(&self, session_id: &SessionId) -> Result<Vec<super::types::AgentInfo>, String> {
410        let rows = sqlx::query(
411            "SELECT * FROM orcha_agents WHERE session_id = ? ORDER BY created_at ASC"
412        )
413        .bind(session_id)
414        .fetch_all(&self.pool)
415        .await
416        .map_err(|e| format!("Failed to list agents: {}", e))?;
417
418        rows.into_iter()
419            .map(|row| self.row_to_agent(row))
420            .collect()
421    }
422
423    /// Update agent state
424    pub async fn update_agent_state(
425        &self,
426        agent_id: &super::types::AgentId,
427        state: super::types::AgentState,
428    ) -> Result<(), String> {
429        let now = chrono::Utc::now().timestamp();
430        let (state_type, state_data) = self.serialize_agent_state(&state);
431
432        // Also update completed_at if state is Complete or Failed
433        let completed_at = match state {
434            super::types::AgentState::Complete | super::types::AgentState::Failed { .. } => Some(now),
435            _ => None,
436        };
437
438        // Extract error message if failed
439        let error_message = match &state {
440            super::types::AgentState::Failed { error } => Some(error.clone()),
441            _ => None,
442        };
443
444        if completed_at.is_some() {
445            sqlx::query(
446                r#"
447                UPDATE orcha_agents
448                SET state_type = ?, state_data = ?, last_activity = ?, completed_at = ?, error_message = ?
449                WHERE agent_id = ?
450                "#,
451            )
452            .bind(&state_type)
453            .bind(&state_data)
454            .bind(now)
455            .bind(completed_at)
456            .bind(&error_message)
457            .bind(agent_id)
458            .execute(&self.pool)
459            .await
460            .map_err(|e| format!("Failed to update agent state: {}", e))?;
461        } else {
462            sqlx::query(
463                r#"
464                UPDATE orcha_agents
465                SET state_type = ?, state_data = ?, last_activity = ?
466                WHERE agent_id = ?
467                "#,
468            )
469            .bind(&state_type)
470            .bind(&state_data)
471            .bind(now)
472            .bind(agent_id)
473            .execute(&self.pool)
474            .await
475            .map_err(|e| format!("Failed to update agent state: {}", e))?;
476        }
477
478        Ok(())
479    }
480
481    /// Get session agent counts (active, completed, failed)
482    pub async fn get_agent_counts(&self, session_id: &SessionId) -> Result<(u32, u32, u32), String> {
483        let row = sqlx::query(
484            r#"
485            SELECT
486                COUNT(CASE WHEN state_type IN ('idle', 'running', 'waiting_approval', 'validating') THEN 1 END) as active,
487                COUNT(CASE WHEN state_type = 'complete' THEN 1 END) as completed,
488                COUNT(CASE WHEN state_type = 'failed' THEN 1 END) as failed
489            FROM orcha_agents WHERE session_id = ?
490            "#
491        )
492        .bind(session_id)
493        .fetch_one(&self.pool)
494        .await
495        .map_err(|e| format!("Failed to get agent counts: {}", e))?;
496
497        let active: i64 = row.get("active");
498        let completed: i64 = row.get("completed");
499        let failed: i64 = row.get("failed");
500
501        Ok((active as u32, completed as u32, failed as u32))
502    }
503
504    /// Helper: Convert row to AgentInfo
505    fn row_to_agent(&self, row: sqlx::sqlite::SqliteRow) -> Result<super::types::AgentInfo, String> {
506        let state_type: String = row.get("state_type");
507        let state_data: Option<String> = row.get("state_data");
508        let state = self.deserialize_agent_state(&state_type, state_data.as_deref())?;
509
510        Ok(super::types::AgentInfo {
511            agent_id: row.get("agent_id"),
512            session_id: row.get("session_id"),
513            claudecode_session_id: row.get("claudecode_session_id"),
514            subtask: row.get("subtask"),
515            state,
516            is_primary: row.get::<i64, _>("is_primary") == 1,
517            parent_agent_id: row.get("parent_agent_id"),
518            created_at: row.get("created_at"),
519            last_activity: row.get("last_activity"),
520            completed_at: row.get("completed_at"),
521            error_message: row.get("error_message"),
522        })
523    }
524
525    /// Helper: Serialize agent state
526    fn serialize_agent_state(&self, state: &super::types::AgentState) -> (String, Option<String>) {
527        match state {
528            super::types::AgentState::Idle => ("idle".to_string(), None),
529            super::types::AgentState::Running { sequence } => (
530                "running".to_string(),
531                Some(serde_json::json!({ "sequence": sequence }).to_string()),
532            ),
533            super::types::AgentState::WaitingApproval { approval_id } => (
534                "waiting_approval".to_string(),
535                Some(serde_json::json!({ "approval_id": approval_id }).to_string()),
536            ),
537            super::types::AgentState::Validating { test_command } => (
538                "validating".to_string(),
539                Some(serde_json::json!({ "test_command": test_command }).to_string()),
540            ),
541            super::types::AgentState::Complete => ("complete".to_string(), None),
542            super::types::AgentState::Failed { error } => (
543                "failed".to_string(),
544                Some(serde_json::json!({ "error": error }).to_string()),
545            ),
546        }
547    }
548
549    /// Helper: Deserialize agent state
550    fn deserialize_agent_state(&self, state_type: &str, state_data: Option<&str>) -> Result<super::types::AgentState, String> {
551        match state_type {
552            "idle" => Ok(super::types::AgentState::Idle),
553            "running" => {
554                let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
555                    .map_err(|e| format!("Failed to parse running state: {}", e))?;
556                Ok(super::types::AgentState::Running {
557                    sequence: data["sequence"].as_u64().unwrap_or(0),
558                })
559            }
560            "waiting_approval" => {
561                let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
562                    .map_err(|e| format!("Failed to parse waiting_approval state: {}", e))?;
563                Ok(super::types::AgentState::WaitingApproval {
564                    approval_id: data["approval_id"].as_str().unwrap_or("").to_string(),
565                })
566            }
567            "validating" => {
568                let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
569                    .map_err(|e| format!("Failed to parse validating state: {}", e))?;
570                Ok(super::types::AgentState::Validating {
571                    test_command: data["test_command"].as_str().unwrap_or("").to_string(),
572                })
573            }
574            "complete" => Ok(super::types::AgentState::Complete),
575            "failed" => {
576                let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
577                    .map_err(|e| format!("Failed to parse failed state: {}", e))?;
578                Ok(super::types::AgentState::Failed {
579                    error: data["error"].as_str().unwrap_or("Unknown error").to_string(),
580                })
581            }
582            _ => Err(format!("Unknown agent state type: {}", state_type)),
583        }
584    }
585
586    // ═══════════════════════════════════════════════════════════════════════
587    // State Serialization Helpers
588    // ═══════════════════════════════════════════════════════════════════════
589
590    fn serialize_state(&self, state: &SessionState) -> (String, Option<String>) {
591        match state {
592            SessionState::Idle => ("idle".to_string(), None),
593            SessionState::Running { stream_id, sequence, active_agents, completed_agents, failed_agents } => (
594                "running".to_string(),
595                Some(serde_json::json!({
596                    "stream_id": stream_id,
597                    "sequence": sequence,
598                    "active_agents": active_agents,
599                    "completed_agents": completed_agents,
600                    "failed_agents": failed_agents,
601                }).to_string()),
602            ),
603            SessionState::WaitingApproval { approval_id } => (
604                "waiting_approval".to_string(),
605                Some(serde_json::json!({
606                    "approval_id": approval_id,
607                }).to_string()),
608            ),
609            SessionState::Validating { test_command } => (
610                "validating".to_string(),
611                Some(serde_json::json!({
612                    "test_command": test_command,
613                }).to_string()),
614            ),
615            SessionState::Complete => ("complete".to_string(), None),
616            SessionState::Failed { error } => (
617                "failed".to_string(),
618                Some(serde_json::json!({
619                    "error": error,
620                }).to_string()),
621            ),
622        }
623    }
624
625    fn deserialize_state(&self, state_type: &str, state_data: Option<&str>) -> Result<SessionState, String> {
626        match state_type {
627            "idle" => Ok(SessionState::Idle),
628            "running" => {
629                let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
630                    .map_err(|e| format!("Failed to parse running state: {}", e))?;
631                Ok(SessionState::Running {
632                    stream_id: data["stream_id"].as_str().unwrap_or("").to_string(),
633                    sequence: data["sequence"].as_u64().unwrap_or(0),
634                    active_agents: data["active_agents"].as_u64().unwrap_or(0) as u32,
635                    completed_agents: data["completed_agents"].as_u64().unwrap_or(0) as u32,
636                    failed_agents: data["failed_agents"].as_u64().unwrap_or(0) as u32,
637                })
638            }
639            "waiting_approval" => {
640                let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
641                    .map_err(|e| format!("Failed to parse waiting_approval state: {}", e))?;
642                Ok(SessionState::WaitingApproval {
643                    approval_id: data["approval_id"].as_str().unwrap_or("").to_string(),
644                })
645            }
646            "validating" => {
647                let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
648                    .map_err(|e| format!("Failed to parse validating state: {}", e))?;
649                Ok(SessionState::Validating {
650                    test_command: data["test_command"].as_str().unwrap_or("").to_string(),
651                })
652            }
653            "complete" => Ok(SessionState::Complete),
654            "failed" => {
655                let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
656                    .map_err(|e| format!("Failed to parse failed state: {}", e))?;
657                Ok(SessionState::Failed {
658                    error: data["error"].as_str().unwrap_or("Unknown error").to_string(),
659                })
660            }
661            _ => Err(format!("Unknown state type: {}", state_type)),
662        }
663    }
664}