1use roboticus_core::{RoboticusError, Result};
2use std::collections::HashMap;
3
4use crate::{Database, DbResultExt};
5
6#[derive(Debug, Clone)]
7pub struct SubAgentRow {
8 pub id: String,
9 pub name: String,
10 pub display_name: Option<String>,
11 pub model: String,
12 pub fallback_models_json: Option<String>,
13 pub role: String,
14 pub description: Option<String>,
15 pub skills_json: Option<String>,
16 pub enabled: bool,
17 pub session_count: i64,
18}
19
20fn normalized_fallback_models_json(raw: Option<&str>) -> String {
21 match raw.map(str::trim) {
22 Some(v) if !v.is_empty() => v.to_string(),
23 _ => "[]".to_string(),
24 }
25}
26
27pub fn upsert_sub_agent(db: &Database, agent: &SubAgentRow) -> Result<()> {
28 let conn = db.conn();
29 let fallback_models_json =
30 normalized_fallback_models_json(agent.fallback_models_json.as_deref());
31 conn.execute(
32 "INSERT INTO sub_agents (id, name, display_name, model, fallback_models_json, role, description, skills_json, enabled, session_count)
33 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
34 ON CONFLICT(name) DO UPDATE SET
35 display_name = excluded.display_name,
36 model = excluded.model,
37 fallback_models_json = excluded.fallback_models_json,
38 role = excluded.role,
39 description = excluded.description,
40 skills_json = excluded.skills_json,
41 enabled = excluded.enabled,
42 session_count = excluded.session_count",
43 rusqlite::params![
44 agent.id,
45 agent.name,
46 agent.display_name,
47 agent.model,
48 fallback_models_json,
49 agent.role,
50 agent.description,
51 agent.skills_json,
52 agent.enabled as i32,
53 agent.session_count,
54 ],
55 )
56 .map_err(|e| RoboticusError::Database(format!("upsert sub_agent: {e}")))?;
57 Ok(())
58}
59
60pub fn list_sub_agents(db: &Database) -> Result<Vec<SubAgentRow>> {
61 let conn = db.conn();
62 let mut stmt = conn
63 .prepare(
64 "SELECT id, name, display_name, model, fallback_models_json, role, description, skills_json, enabled, session_count
65 FROM sub_agents ORDER BY name",
66 )
67 .db_err()?;
68
69 let rows = stmt
70 .query_map([], |row| {
71 Ok(SubAgentRow {
72 id: row.get(0)?,
73 name: row.get(1)?,
74 display_name: row.get(2)?,
75 model: row.get(3)?,
76 fallback_models_json: Some(normalized_fallback_models_json(
77 row.get::<_, Option<String>>(4)?.as_deref(),
78 )),
79 role: row.get(5)?,
80 description: row.get(6)?,
81 skills_json: row.get(7)?,
82 enabled: row.get::<_, i32>(8)? != 0,
83 session_count: row.get(9)?,
84 })
85 })
86 .db_err()?
87 .collect::<std::result::Result<Vec<_>, _>>()
88 .db_err()?;
89
90 Ok(rows)
91}
92
93pub fn list_enabled_sub_agents(db: &Database) -> Result<Vec<SubAgentRow>> {
94 let all = list_sub_agents(db)?;
95 Ok(all.into_iter().filter(|a| a.enabled).collect())
96}
97
98pub fn list_session_counts_by_agent(db: &Database) -> Result<HashMap<String, i64>> {
99 let conn = db.conn();
100 let mut stmt = conn
101 .prepare("SELECT agent_id, COUNT(*) FROM sessions GROUP BY agent_id")
102 .db_err()?;
103
104 let rows = stmt
105 .query_map([], |row| {
106 let agent_id: String = row.get(0)?;
107 let count: i64 = row.get(1)?;
108 Ok((agent_id, count))
109 })
110 .db_err()?
111 .collect::<std::result::Result<Vec<_>, _>>()
112 .db_err()?;
113
114 Ok(rows.into_iter().collect())
115}
116
117pub fn delete_sub_agent(db: &Database, name: &str) -> Result<bool> {
118 let conn = db.conn();
119 let deleted = conn
120 .execute(
121 "DELETE FROM sub_agents WHERE name = ?1",
122 rusqlite::params![name],
123 )
124 .db_err()?;
125 Ok(deleted > 0)
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 fn test_db() -> Database {
133 Database::new(":memory:").unwrap()
134 }
135
136 fn sample_agent(name: &str) -> SubAgentRow {
137 SubAgentRow {
138 id: uuid::Uuid::new_v4().to_string(),
139 name: name.to_string(),
140 display_name: Some(name.replace('-', " ")),
141 model: "test-model".into(),
142 fallback_models_json: Some("[]".into()),
143 role: "specialist".into(),
144 description: Some("Test agent".into()),
145 skills_json: None,
146 enabled: true,
147 session_count: 0,
148 }
149 }
150
151 #[test]
152 fn upsert_and_list() {
153 let db = test_db();
154 upsert_sub_agent(&db, &sample_agent("alpha")).unwrap();
155 upsert_sub_agent(&db, &sample_agent("bravo")).unwrap();
156 let agents = list_sub_agents(&db).unwrap();
157 assert_eq!(agents.len(), 2);
158 assert_eq!(agents[0].name, "alpha");
159 assert_eq!(agents[1].name, "bravo");
160 }
161
162 #[test]
163 fn upsert_updates_existing() {
164 let db = test_db();
165 let mut agent = sample_agent("alpha");
166 upsert_sub_agent(&db, &agent).unwrap();
167 agent.model = "updated-model".into();
168 agent.session_count = 42;
169 upsert_sub_agent(&db, &agent).unwrap();
170 let agents = list_sub_agents(&db).unwrap();
171 assert_eq!(agents.len(), 1);
172 assert_eq!(agents[0].model, "updated-model");
173 assert_eq!(agents[0].session_count, 42);
174 }
175
176 #[test]
177 fn list_enabled_filters() {
178 let db = test_db();
179 let mut a = sample_agent("enabled-one");
180 upsert_sub_agent(&db, &a).unwrap();
181 a = sample_agent("disabled-one");
182 a.enabled = false;
183 upsert_sub_agent(&db, &a).unwrap();
184 let enabled = list_enabled_sub_agents(&db).unwrap();
185 assert_eq!(enabled.len(), 1);
186 assert_eq!(enabled[0].name, "enabled-one");
187 }
188
189 #[test]
190 fn delete_works() {
191 let db = test_db();
192 upsert_sub_agent(&db, &sample_agent("doomed")).unwrap();
193 assert!(delete_sub_agent(&db, "doomed").unwrap());
194 assert!(!delete_sub_agent(&db, "doomed").unwrap());
195 assert!(list_sub_agents(&db).unwrap().is_empty());
196 }
197
198 #[test]
199 fn session_counts_by_agent_reads_sessions_table() {
200 let db = test_db();
201 {
202 let conn = db.conn();
203 conn.execute(
204 "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES (?1, ?2, 'agent', 'active')",
205 rusqlite::params!["s1", "alpha"],
206 )
207 .unwrap();
208 conn.execute(
209 "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES (?1, ?2, 'agent', 'archived')",
210 rusqlite::params!["s2", "alpha"],
211 )
212 .unwrap();
213 conn.execute(
214 "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES (?1, ?2, 'agent', 'active')",
215 rusqlite::params!["s3", "bravo"],
216 )
217 .unwrap();
218 }
219
220 let counts = list_session_counts_by_agent(&db).unwrap();
221 assert_eq!(counts.get("alpha"), Some(&2));
222 assert_eq!(counts.get("bravo"), Some(&1));
223 }
224
225 #[test]
226 fn upsert_normalizes_missing_fallback_models() {
227 let db = test_db();
228 let mut agent = sample_agent("fallback-default");
229 agent.fallback_models_json = None;
230 upsert_sub_agent(&db, &agent).unwrap();
231 let stored = list_sub_agents(&db).unwrap();
232 assert_eq!(
233 stored[0].fallback_models_json.as_deref(),
234 Some("[]"),
235 "missing fallback models should normalize to JSON empty array"
236 );
237 }
238}