1use super::types::{SessionId, SessionInfo, SessionState};
2use crate::activations::storage::init_sqlite_pool;
3use crate::activation_db_path_from_module;
4use sqlx::{sqlite::SqlitePool, Row};
5use std::collections::HashMap;
6use std::path::PathBuf;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10#[derive(Debug, Clone)]
12pub struct OrchaStorageConfig {
13 pub db_path: PathBuf,
14}
15
16impl Default for OrchaStorageConfig {
17 fn default() -> Self {
18 Self {
19 db_path: activation_db_path_from_module!("orcha.db"),
20 }
21 }
22}
23
24pub struct OrchaStorage {
26 pool: SqlitePool,
27 sessions: Arc<RwLock<HashMap<SessionId, SessionInfo>>>,
29}
30
31impl OrchaStorage {
32 pub async fn new(config: OrchaStorageConfig) -> Result<Self, String> {
34 let pool = init_sqlite_pool(config.db_path).await?;
35
36 let storage = Self {
37 pool,
38 sessions: Arc::new(RwLock::new(HashMap::new())),
39 };
40
41 storage.init_schema().await?;
42 storage.load_sessions().await?;
43
44 Ok(storage)
45 }
46
47 async fn init_schema(&self) -> Result<(), String> {
49 sqlx::query(
51 r#"
52 CREATE TABLE IF NOT EXISTS orcha_sessions (
53 session_id TEXT PRIMARY KEY,
54 model TEXT NOT NULL,
55 working_directory TEXT NOT NULL,
56 rules TEXT,
57 max_retries INTEGER NOT NULL DEFAULT 3,
58 retry_count INTEGER NOT NULL DEFAULT 0,
59 created_at INTEGER NOT NULL,
60 last_activity INTEGER NOT NULL,
61 state_type TEXT NOT NULL,
62 state_data TEXT,
63 UNIQUE(session_id)
64 )
65 "#,
66 )
67 .execute(&self.pool)
68 .await
69 .map_err(|e| format!("Failed to create orcha_sessions table: {}", e))?;
70
71 let rows = sqlx::query("PRAGMA table_info(orcha_sessions)")
75 .fetch_all(&self.pool)
76 .await
77 .map_err(|e| format!("Failed to get table info: {}", e))?;
78
79 let column_names: Vec<String> = rows.iter()
80 .filter_map(|row| match row.try_get::<String, _>("name") {
81 Ok(name) => Some(name),
82 Err(e) => {
83 tracing::warn!("Failed to read column name from PRAGMA table_info: {}", e);
84 None
85 }
86 })
87 .collect();
88
89 let has_agent_mode = column_names.iter().any(|name| name == "agent_mode");
90 if !has_agent_mode {
91 sqlx::query("ALTER TABLE orcha_sessions ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'")
92 .execute(&self.pool)
93 .await
94 .map_err(|e| format!("Failed to add agent_mode column: {}", e))?;
95 }
96
97 let has_primary_agent_id = column_names.iter().any(|name| name == "primary_agent_id");
98 if !has_primary_agent_id {
99 sqlx::query("ALTER TABLE orcha_sessions ADD COLUMN primary_agent_id TEXT")
100 .execute(&self.pool)
101 .await
102 .map_err(|e| format!("Failed to add primary_agent_id column: {}", e))?;
103 }
104
105 let has_tree_id = column_names.iter().any(|name| name == "tree_id");
106 if !has_tree_id {
107 sqlx::query("ALTER TABLE orcha_sessions ADD COLUMN tree_id TEXT")
108 .execute(&self.pool)
109 .await
110 .map_err(|e| format!("Failed to add tree_id column: {}", e))?;
111 }
112
113 sqlx::query(
115 r#"
116 CREATE TABLE IF NOT EXISTS orcha_agents (
117 agent_id TEXT PRIMARY KEY,
118 session_id TEXT NOT NULL,
119 claudecode_session_id TEXT NOT NULL,
120 subtask TEXT NOT NULL,
121 state_type TEXT NOT NULL,
122 state_data TEXT,
123 is_primary INTEGER NOT NULL DEFAULT 0,
124 parent_agent_id TEXT,
125 created_at INTEGER NOT NULL,
126 last_activity INTEGER NOT NULL,
127 completed_at INTEGER,
128 error_message TEXT,
129 FOREIGN KEY (session_id) REFERENCES orcha_sessions(session_id) ON DELETE CASCADE,
130 FOREIGN KEY (parent_agent_id) REFERENCES orcha_agents(agent_id) ON DELETE SET NULL
131 )
132 "#,
133 )
134 .execute(&self.pool)
135 .await
136 .map_err(|e| format!("Failed to create orcha_agents table: {}", e))?;
137
138 sqlx::query("CREATE INDEX IF NOT EXISTS idx_agents_session ON orcha_agents(session_id)")
140 .execute(&self.pool)
141 .await
142 .map_err(|e| format!("Failed to create session index: {}", e))?;
143
144 sqlx::query("CREATE INDEX IF NOT EXISTS idx_agents_state ON orcha_agents(state_type)")
145 .execute(&self.pool)
146 .await
147 .map_err(|e| format!("Failed to create state index: {}", e))?;
148
149 Ok(())
150 }
151
152 async fn load_sessions(&self) -> Result<(), String> {
154 let rows = sqlx::query("SELECT * FROM orcha_sessions")
155 .fetch_all(&self.pool)
156 .await
157 .map_err(|e| format!("Failed to load sessions: {}", e))?;
158
159 let mut sessions = self.sessions.write().await;
160
161 for row in rows {
162 let session_id: String = row.get("session_id");
163 let model: String = row.get("model");
164 let created_at: i64 = row.get("created_at");
165 let last_activity: i64 = row.get("last_activity");
166 let retry_count: i64 = row.get("retry_count");
167 let max_retries: i64 = row.get("max_retries");
168 let state_type: String = row.get("state_type");
169 let state_data: Option<String> = row.get("state_data");
170
171 let agent_mode_str: Option<String> = row.try_get("agent_mode").ok();
173 let agent_mode = agent_mode_str
174 .and_then(|s| serde_json::from_str(&format!("\"{}\"", s)).ok())
175 .unwrap_or(super::types::AgentMode::Single);
176
177 let primary_agent_id: Option<String> = row.try_get("primary_agent_id").ok().flatten();
178 let tree_id: Option<String> = row.try_get("tree_id").ok().flatten();
179
180 let state = self.deserialize_state(&state_type, state_data.as_deref())?;
181
182 let info = SessionInfo {
183 session_id: session_id.clone(),
184 model,
185 created_at,
186 last_activity,
187 state,
188 retry_count: retry_count as u32,
189 max_retries: max_retries as u32,
190 agent_mode,
191 primary_agent_id,
192 tree_id,
193 };
194
195 sessions.insert(session_id, info);
196 }
197
198 Ok(())
199 }
200
201 pub async fn create_session(
203 &self,
204 session_id: SessionId,
205 model: String,
206 working_directory: String,
207 rules: Option<String>,
208 max_retries: u32,
209 agent_mode: super::types::AgentMode,
210 tree_id: Option<String>,
211 ) -> Result<SessionInfo, String> {
212 let now = chrono::Utc::now().timestamp();
213
214 let agent_mode_str = match agent_mode {
215 super::types::AgentMode::Single => "single",
216 super::types::AgentMode::Multi => "multi",
217 };
218
219 let info = SessionInfo {
220 session_id: session_id.clone(),
221 model: model.clone(),
222 created_at: now,
223 last_activity: now,
224 state: SessionState::Idle,
225 retry_count: 0,
226 max_retries,
227 agent_mode,
228 primary_agent_id: None,
229 tree_id: tree_id.clone(),
230 };
231
232 sqlx::query(
234 r#"
235 INSERT INTO orcha_sessions (
236 session_id, model, working_directory, rules, max_retries,
237 retry_count, created_at, last_activity, state_type, state_data,
238 agent_mode, primary_agent_id, tree_id
239 ) VALUES (?, ?, ?, ?, ?, 0, ?, ?, 'idle', NULL, ?, NULL, ?)
240 "#,
241 )
242 .bind(&session_id)
243 .bind(&model)
244 .bind(&working_directory)
245 .bind(&rules)
246 .bind(max_retries as i64)
247 .bind(now)
248 .bind(now)
249 .bind(agent_mode_str)
250 .bind(&tree_id)
251 .execute(&self.pool)
252 .await
253 .map_err(|e| format!("Failed to create session: {}", e))?;
254
255 self.sessions.write().await.insert(session_id.clone(), info.clone());
257
258 Ok(info)
259 }
260
261 pub async fn get_session(&self, session_id: &SessionId) -> Result<SessionInfo, String> {
263 let sessions = self.sessions.read().await;
264 sessions
265 .get(session_id)
266 .cloned()
267 .ok_or_else(|| format!("Session not found: {}", session_id))
268 }
269
270 pub async fn update_state(
272 &self,
273 session_id: &SessionId,
274 state: SessionState,
275 ) -> Result<(), String> {
276 let now = chrono::Utc::now().timestamp();
277 let (state_type, state_data) = self.serialize_state(&state);
278
279 sqlx::query(
280 r#"
281 UPDATE orcha_sessions
282 SET state_type = ?, state_data = ?, last_activity = ?
283 WHERE session_id = ?
284 "#,
285 )
286 .bind(&state_type)
287 .bind(&state_data)
288 .bind(now)
289 .bind(session_id)
290 .execute(&self.pool)
291 .await
292 .map_err(|e| format!("Failed to update state: {}", e))?;
293
294 if let Some(info) = self.sessions.write().await.get_mut(session_id) {
296 info.state = state;
297 info.last_activity = now;
298 }
299
300 Ok(())
301 }
302
303 pub async fn increment_retry(&self, session_id: &SessionId) -> Result<u32, String> {
305 sqlx::query(
306 r#"
307 UPDATE orcha_sessions
308 SET retry_count = retry_count + 1
309 WHERE session_id = ?
310 "#,
311 )
312 .bind(session_id)
313 .execute(&self.pool)
314 .await
315 .map_err(|e| format!("Failed to increment retry count: {}", e))?;
316
317 if let Some(info) = self.sessions.write().await.get_mut(session_id) {
319 info.retry_count += 1;
320 Ok(info.retry_count)
321 } else {
322 Err(format!("Session not found: {}", session_id))
323 }
324 }
325
326 pub async fn list_sessions(&self) -> Vec<SessionInfo> {
328 self.sessions.read().await.values().cloned().collect()
329 }
330
331 pub async fn delete_session(&self, session_id: &SessionId) -> Result<(), String> {
333 sqlx::query("DELETE FROM orcha_sessions WHERE session_id = ?")
334 .bind(session_id)
335 .execute(&self.pool)
336 .await
337 .map_err(|e| format!("Failed to delete session: {}", e))?;
338
339 self.sessions.write().await.remove(session_id);
340
341 Ok(())
342 }
343
344 pub async fn create_agent(
350 &self,
351 session_id: &SessionId,
352 claudecode_session_id: String,
353 subtask: String,
354 is_primary: bool,
355 parent_agent_id: Option<super::types::AgentId>,
356 ) -> Result<super::types::AgentInfo, String> {
357 let agent_id = format!("agent-{}", uuid::Uuid::new_v4());
358 let now = chrono::Utc::now().timestamp();
359
360 sqlx::query(
361 r#"
362 INSERT INTO orcha_agents (
363 agent_id, session_id, claudecode_session_id, subtask,
364 state_type, state_data, is_primary, parent_agent_id,
365 created_at, last_activity
366 ) VALUES (?, ?, ?, ?, 'idle', NULL, ?, ?, ?, ?)
367 "#,
368 )
369 .bind(&agent_id)
370 .bind(session_id)
371 .bind(&claudecode_session_id)
372 .bind(&subtask)
373 .bind(if is_primary { 1 } else { 0 })
374 .bind(&parent_agent_id)
375 .bind(now)
376 .bind(now)
377 .execute(&self.pool)
378 .await
379 .map_err(|e| format!("Failed to create agent: {}", e))?;
380
381 Ok(super::types::AgentInfo {
382 agent_id,
383 session_id: session_id.clone(),
384 claudecode_session_id,
385 subtask,
386 state: super::types::AgentState::Idle,
387 is_primary,
388 parent_agent_id,
389 created_at: now,
390 last_activity: now,
391 completed_at: None,
392 error_message: None,
393 })
394 }
395
396 pub async fn get_agent(&self, agent_id: &super::types::AgentId) -> Result<super::types::AgentInfo, String> {
398 let row = sqlx::query("SELECT * FROM orcha_agents WHERE agent_id = ?")
399 .bind(agent_id)
400 .fetch_optional(&self.pool)
401 .await
402 .map_err(|e| format!("Failed to fetch agent: {}", e))?
403 .ok_or_else(|| format!("Agent not found: {}", agent_id))?;
404
405 self.row_to_agent(row)
406 }
407
408 pub async fn list_agents(&self, session_id: &SessionId) -> Result<Vec<super::types::AgentInfo>, String> {
410 let rows = sqlx::query(
411 "SELECT * FROM orcha_agents WHERE session_id = ? ORDER BY created_at ASC"
412 )
413 .bind(session_id)
414 .fetch_all(&self.pool)
415 .await
416 .map_err(|e| format!("Failed to list agents: {}", e))?;
417
418 rows.into_iter()
419 .map(|row| self.row_to_agent(row))
420 .collect()
421 }
422
423 pub async fn update_agent_state(
425 &self,
426 agent_id: &super::types::AgentId,
427 state: super::types::AgentState,
428 ) -> Result<(), String> {
429 let now = chrono::Utc::now().timestamp();
430 let (state_type, state_data) = self.serialize_agent_state(&state);
431
432 let completed_at = match state {
434 super::types::AgentState::Complete | super::types::AgentState::Failed { .. } => Some(now),
435 _ => None,
436 };
437
438 let error_message = match &state {
440 super::types::AgentState::Failed { error } => Some(error.clone()),
441 _ => None,
442 };
443
444 if completed_at.is_some() {
445 sqlx::query(
446 r#"
447 UPDATE orcha_agents
448 SET state_type = ?, state_data = ?, last_activity = ?, completed_at = ?, error_message = ?
449 WHERE agent_id = ?
450 "#,
451 )
452 .bind(&state_type)
453 .bind(&state_data)
454 .bind(now)
455 .bind(completed_at)
456 .bind(&error_message)
457 .bind(agent_id)
458 .execute(&self.pool)
459 .await
460 .map_err(|e| format!("Failed to update agent state: {}", e))?;
461 } else {
462 sqlx::query(
463 r#"
464 UPDATE orcha_agents
465 SET state_type = ?, state_data = ?, last_activity = ?
466 WHERE agent_id = ?
467 "#,
468 )
469 .bind(&state_type)
470 .bind(&state_data)
471 .bind(now)
472 .bind(agent_id)
473 .execute(&self.pool)
474 .await
475 .map_err(|e| format!("Failed to update agent state: {}", e))?;
476 }
477
478 Ok(())
479 }
480
481 pub async fn get_agent_counts(&self, session_id: &SessionId) -> Result<(u32, u32, u32), String> {
483 let row = sqlx::query(
484 r#"
485 SELECT
486 COUNT(CASE WHEN state_type IN ('idle', 'running', 'waiting_approval', 'validating') THEN 1 END) as active,
487 COUNT(CASE WHEN state_type = 'complete' THEN 1 END) as completed,
488 COUNT(CASE WHEN state_type = 'failed' THEN 1 END) as failed
489 FROM orcha_agents WHERE session_id = ?
490 "#
491 )
492 .bind(session_id)
493 .fetch_one(&self.pool)
494 .await
495 .map_err(|e| format!("Failed to get agent counts: {}", e))?;
496
497 let active: i64 = row.get("active");
498 let completed: i64 = row.get("completed");
499 let failed: i64 = row.get("failed");
500
501 Ok((active as u32, completed as u32, failed as u32))
502 }
503
504 fn row_to_agent(&self, row: sqlx::sqlite::SqliteRow) -> Result<super::types::AgentInfo, String> {
506 let state_type: String = row.get("state_type");
507 let state_data: Option<String> = row.get("state_data");
508 let state = self.deserialize_agent_state(&state_type, state_data.as_deref())?;
509
510 Ok(super::types::AgentInfo {
511 agent_id: row.get("agent_id"),
512 session_id: row.get("session_id"),
513 claudecode_session_id: row.get("claudecode_session_id"),
514 subtask: row.get("subtask"),
515 state,
516 is_primary: row.get::<i64, _>("is_primary") == 1,
517 parent_agent_id: row.get("parent_agent_id"),
518 created_at: row.get("created_at"),
519 last_activity: row.get("last_activity"),
520 completed_at: row.get("completed_at"),
521 error_message: row.get("error_message"),
522 })
523 }
524
525 fn serialize_agent_state(&self, state: &super::types::AgentState) -> (String, Option<String>) {
527 match state {
528 super::types::AgentState::Idle => ("idle".to_string(), None),
529 super::types::AgentState::Running { sequence } => (
530 "running".to_string(),
531 Some(serde_json::json!({ "sequence": sequence }).to_string()),
532 ),
533 super::types::AgentState::WaitingApproval { approval_id } => (
534 "waiting_approval".to_string(),
535 Some(serde_json::json!({ "approval_id": approval_id }).to_string()),
536 ),
537 super::types::AgentState::Validating { test_command } => (
538 "validating".to_string(),
539 Some(serde_json::json!({ "test_command": test_command }).to_string()),
540 ),
541 super::types::AgentState::Complete => ("complete".to_string(), None),
542 super::types::AgentState::Failed { error } => (
543 "failed".to_string(),
544 Some(serde_json::json!({ "error": error }).to_string()),
545 ),
546 }
547 }
548
549 fn deserialize_agent_state(&self, state_type: &str, state_data: Option<&str>) -> Result<super::types::AgentState, String> {
551 match state_type {
552 "idle" => Ok(super::types::AgentState::Idle),
553 "running" => {
554 let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
555 .map_err(|e| format!("Failed to parse running state: {}", e))?;
556 Ok(super::types::AgentState::Running {
557 sequence: data["sequence"].as_u64().unwrap_or(0),
558 })
559 }
560 "waiting_approval" => {
561 let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
562 .map_err(|e| format!("Failed to parse waiting_approval state: {}", e))?;
563 Ok(super::types::AgentState::WaitingApproval {
564 approval_id: data["approval_id"].as_str().unwrap_or("").to_string(),
565 })
566 }
567 "validating" => {
568 let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
569 .map_err(|e| format!("Failed to parse validating state: {}", e))?;
570 Ok(super::types::AgentState::Validating {
571 test_command: data["test_command"].as_str().unwrap_or("").to_string(),
572 })
573 }
574 "complete" => Ok(super::types::AgentState::Complete),
575 "failed" => {
576 let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
577 .map_err(|e| format!("Failed to parse failed state: {}", e))?;
578 Ok(super::types::AgentState::Failed {
579 error: data["error"].as_str().unwrap_or("Unknown error").to_string(),
580 })
581 }
582 _ => Err(format!("Unknown agent state type: {}", state_type)),
583 }
584 }
585
586 fn serialize_state(&self, state: &SessionState) -> (String, Option<String>) {
591 match state {
592 SessionState::Idle => ("idle".to_string(), None),
593 SessionState::Running { stream_id, sequence, active_agents, completed_agents, failed_agents } => (
594 "running".to_string(),
595 Some(serde_json::json!({
596 "stream_id": stream_id,
597 "sequence": sequence,
598 "active_agents": active_agents,
599 "completed_agents": completed_agents,
600 "failed_agents": failed_agents,
601 }).to_string()),
602 ),
603 SessionState::WaitingApproval { approval_id } => (
604 "waiting_approval".to_string(),
605 Some(serde_json::json!({
606 "approval_id": approval_id,
607 }).to_string()),
608 ),
609 SessionState::Validating { test_command } => (
610 "validating".to_string(),
611 Some(serde_json::json!({
612 "test_command": test_command,
613 }).to_string()),
614 ),
615 SessionState::Complete => ("complete".to_string(), None),
616 SessionState::Failed { error } => (
617 "failed".to_string(),
618 Some(serde_json::json!({
619 "error": error,
620 }).to_string()),
621 ),
622 }
623 }
624
625 fn deserialize_state(&self, state_type: &str, state_data: Option<&str>) -> Result<SessionState, String> {
626 match state_type {
627 "idle" => Ok(SessionState::Idle),
628 "running" => {
629 let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
630 .map_err(|e| format!("Failed to parse running state: {}", e))?;
631 Ok(SessionState::Running {
632 stream_id: data["stream_id"].as_str().unwrap_or("").to_string(),
633 sequence: data["sequence"].as_u64().unwrap_or(0),
634 active_agents: data["active_agents"].as_u64().unwrap_or(0) as u32,
635 completed_agents: data["completed_agents"].as_u64().unwrap_or(0) as u32,
636 failed_agents: data["failed_agents"].as_u64().unwrap_or(0) as u32,
637 })
638 }
639 "waiting_approval" => {
640 let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
641 .map_err(|e| format!("Failed to parse waiting_approval state: {}", e))?;
642 Ok(SessionState::WaitingApproval {
643 approval_id: data["approval_id"].as_str().unwrap_or("").to_string(),
644 })
645 }
646 "validating" => {
647 let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
648 .map_err(|e| format!("Failed to parse validating state: {}", e))?;
649 Ok(SessionState::Validating {
650 test_command: data["test_command"].as_str().unwrap_or("").to_string(),
651 })
652 }
653 "complete" => Ok(SessionState::Complete),
654 "failed" => {
655 let data: serde_json::Value = serde_json::from_str(state_data.unwrap_or("{}"))
656 .map_err(|e| format!("Failed to parse failed state: {}", e))?;
657 Ok(SessionState::Failed {
658 error: data["error"].as_str().unwrap_or("Unknown error").to_string(),
659 })
660 }
661 _ => Err(format!("Unknown state type: {}", state_type)),
662 }
663 }
664}