Skip to main content

roboticus_db/
agents.rs

1use roboticus_core::{RoboticusError, Result};
2use std::collections::HashMap;
3
4use crate::{Database, DbResultExt};
5
6#[derive(Debug, Clone)]
7pub struct SubAgentRow {
8    pub id: String,
9    pub name: String,
10    pub display_name: Option<String>,
11    pub model: String,
12    pub fallback_models_json: Option<String>,
13    pub role: String,
14    pub description: Option<String>,
15    pub skills_json: Option<String>,
16    pub enabled: bool,
17    pub session_count: i64,
18}
19
20fn normalized_fallback_models_json(raw: Option<&str>) -> String {
21    match raw.map(str::trim) {
22        Some(v) if !v.is_empty() => v.to_string(),
23        _ => "[]".to_string(),
24    }
25}
26
27pub fn upsert_sub_agent(db: &Database, agent: &SubAgentRow) -> Result<()> {
28    let conn = db.conn();
29    let fallback_models_json =
30        normalized_fallback_models_json(agent.fallback_models_json.as_deref());
31    conn.execute(
32        "INSERT INTO sub_agents (id, name, display_name, model, fallback_models_json, role, description, skills_json, enabled, session_count)
33         VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
34         ON CONFLICT(name) DO UPDATE SET
35           display_name = excluded.display_name,
36           model = excluded.model,
37           fallback_models_json = excluded.fallback_models_json,
38           role = excluded.role,
39           description = excluded.description,
40           skills_json = excluded.skills_json,
41           enabled = excluded.enabled,
42           session_count = excluded.session_count",
43        rusqlite::params![
44            agent.id,
45            agent.name,
46            agent.display_name,
47            agent.model,
48            fallback_models_json,
49            agent.role,
50            agent.description,
51            agent.skills_json,
52            agent.enabled as i32,
53            agent.session_count,
54        ],
55    )
56    .map_err(|e| RoboticusError::Database(format!("upsert sub_agent: {e}")))?;
57    Ok(())
58}
59
60pub fn list_sub_agents(db: &Database) -> Result<Vec<SubAgentRow>> {
61    let conn = db.conn();
62    let mut stmt = conn
63        .prepare(
64            "SELECT id, name, display_name, model, fallback_models_json, role, description, skills_json, enabled, session_count
65             FROM sub_agents ORDER BY name",
66        )
67        .db_err()?;
68
69    let rows = stmt
70        .query_map([], |row| {
71            Ok(SubAgentRow {
72                id: row.get(0)?,
73                name: row.get(1)?,
74                display_name: row.get(2)?,
75                model: row.get(3)?,
76                fallback_models_json: Some(normalized_fallback_models_json(
77                    row.get::<_, Option<String>>(4)?.as_deref(),
78                )),
79                role: row.get(5)?,
80                description: row.get(6)?,
81                skills_json: row.get(7)?,
82                enabled: row.get::<_, i32>(8)? != 0,
83                session_count: row.get(9)?,
84            })
85        })
86        .db_err()?
87        .collect::<std::result::Result<Vec<_>, _>>()
88        .db_err()?;
89
90    Ok(rows)
91}
92
93pub fn list_enabled_sub_agents(db: &Database) -> Result<Vec<SubAgentRow>> {
94    let all = list_sub_agents(db)?;
95    Ok(all.into_iter().filter(|a| a.enabled).collect())
96}
97
98pub fn list_session_counts_by_agent(db: &Database) -> Result<HashMap<String, i64>> {
99    let conn = db.conn();
100    let mut stmt = conn
101        .prepare("SELECT agent_id, COUNT(*) FROM sessions GROUP BY agent_id")
102        .db_err()?;
103
104    let rows = stmt
105        .query_map([], |row| {
106            let agent_id: String = row.get(0)?;
107            let count: i64 = row.get(1)?;
108            Ok((agent_id, count))
109        })
110        .db_err()?
111        .collect::<std::result::Result<Vec<_>, _>>()
112        .db_err()?;
113
114    Ok(rows.into_iter().collect())
115}
116
117pub fn delete_sub_agent(db: &Database, name: &str) -> Result<bool> {
118    let conn = db.conn();
119    let deleted = conn
120        .execute(
121            "DELETE FROM sub_agents WHERE name = ?1",
122            rusqlite::params![name],
123        )
124        .db_err()?;
125    Ok(deleted > 0)
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    fn test_db() -> Database {
133        Database::new(":memory:").unwrap()
134    }
135
136    fn sample_agent(name: &str) -> SubAgentRow {
137        SubAgentRow {
138            id: uuid::Uuid::new_v4().to_string(),
139            name: name.to_string(),
140            display_name: Some(name.replace('-', " ")),
141            model: "test-model".into(),
142            fallback_models_json: Some("[]".into()),
143            role: "specialist".into(),
144            description: Some("Test agent".into()),
145            skills_json: None,
146            enabled: true,
147            session_count: 0,
148        }
149    }
150
151    #[test]
152    fn upsert_and_list() {
153        let db = test_db();
154        upsert_sub_agent(&db, &sample_agent("alpha")).unwrap();
155        upsert_sub_agent(&db, &sample_agent("bravo")).unwrap();
156        let agents = list_sub_agents(&db).unwrap();
157        assert_eq!(agents.len(), 2);
158        assert_eq!(agents[0].name, "alpha");
159        assert_eq!(agents[1].name, "bravo");
160    }
161
162    #[test]
163    fn upsert_updates_existing() {
164        let db = test_db();
165        let mut agent = sample_agent("alpha");
166        upsert_sub_agent(&db, &agent).unwrap();
167        agent.model = "updated-model".into();
168        agent.session_count = 42;
169        upsert_sub_agent(&db, &agent).unwrap();
170        let agents = list_sub_agents(&db).unwrap();
171        assert_eq!(agents.len(), 1);
172        assert_eq!(agents[0].model, "updated-model");
173        assert_eq!(agents[0].session_count, 42);
174    }
175
176    #[test]
177    fn list_enabled_filters() {
178        let db = test_db();
179        let mut a = sample_agent("enabled-one");
180        upsert_sub_agent(&db, &a).unwrap();
181        a = sample_agent("disabled-one");
182        a.enabled = false;
183        upsert_sub_agent(&db, &a).unwrap();
184        let enabled = list_enabled_sub_agents(&db).unwrap();
185        assert_eq!(enabled.len(), 1);
186        assert_eq!(enabled[0].name, "enabled-one");
187    }
188
189    #[test]
190    fn delete_works() {
191        let db = test_db();
192        upsert_sub_agent(&db, &sample_agent("doomed")).unwrap();
193        assert!(delete_sub_agent(&db, "doomed").unwrap());
194        assert!(!delete_sub_agent(&db, "doomed").unwrap());
195        assert!(list_sub_agents(&db).unwrap().is_empty());
196    }
197
198    #[test]
199    fn session_counts_by_agent_reads_sessions_table() {
200        let db = test_db();
201        {
202            let conn = db.conn();
203            conn.execute(
204                "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES (?1, ?2, 'agent', 'active')",
205                rusqlite::params!["s1", "alpha"],
206            )
207            .unwrap();
208            conn.execute(
209                "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES (?1, ?2, 'agent', 'archived')",
210                rusqlite::params!["s2", "alpha"],
211            )
212            .unwrap();
213            conn.execute(
214                "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES (?1, ?2, 'agent', 'active')",
215                rusqlite::params!["s3", "bravo"],
216            )
217            .unwrap();
218        }
219
220        let counts = list_session_counts_by_agent(&db).unwrap();
221        assert_eq!(counts.get("alpha"), Some(&2));
222        assert_eq!(counts.get("bravo"), Some(&1));
223    }
224
225    #[test]
226    fn upsert_normalizes_missing_fallback_models() {
227        let db = test_db();
228        let mut agent = sample_agent("fallback-default");
229        agent.fallback_models_json = None;
230        upsert_sub_agent(&db, &agent).unwrap();
231        let stored = list_sub_agents(&db).unwrap();
232        assert_eq!(
233            stored[0].fallback_models_json.as_deref(),
234            Some("[]"),
235            "missing fallback models should normalize to JSON empty array"
236        );
237    }
238}