Skip to main content

engram/storage/
agent_registry.rs

1//! Agent registry storage queries
2//!
3//! Provides CRUD operations for the `agents` table introduced in schema v17.
4//!
5//! Agents represent registered AI agents with capabilities, namespaces,
6//! heartbeat tracking, and lifecycle status.
7
8use chrono::Utc;
9use rusqlite::{params, Connection, OptionalExtension};
10use serde::{Deserialize, Serialize};
11
12use crate::error::{EngramError, Result};
13
14/// A registered AI agent
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Agent {
17    pub agent_id: String,
18    pub display_name: String,
19    pub capabilities: Vec<String>,
20    pub namespaces: Vec<String>,
21    pub last_heartbeat: Option<String>,
22    pub status: String,
23    pub metadata: serde_json::Value,
24    pub registered_at: String,
25    pub updated_at: String,
26}
27
28/// Input for registering a new agent or updating an existing one (upsert)
29#[derive(Debug, Clone)]
30pub struct RegisterAgentInput {
31    pub agent_id: String,
32    pub display_name: String,
33    pub capabilities: Vec<String>,
34    pub namespaces: Vec<String>,
35    pub metadata: serde_json::Value,
36}
37
38impl Default for RegisterAgentInput {
39    fn default() -> Self {
40        Self {
41            agent_id: String::new(),
42            display_name: String::new(),
43            capabilities: vec![],
44            namespaces: vec!["default".to_string()],
45            metadata: serde_json::Value::Object(serde_json::Map::new()),
46        }
47    }
48}
49
50/// Parse an Agent from a rusqlite row.
51///
52/// Columns expected in order: agent_id, display_name, capabilities, namespaces,
53/// last_heartbeat, status, metadata, registered_at, updated_at
54fn agent_from_row(row: &rusqlite::Row) -> rusqlite::Result<Agent> {
55    let capabilities_str: String = row.get(2)?;
56    let namespaces_str: String = row.get(3)?;
57    let metadata_str: String = row.get(6)?;
58
59    let capabilities: Vec<String> = serde_json::from_str(&capabilities_str).unwrap_or_default();
60    let namespaces: Vec<String> =
61        serde_json::from_str(&namespaces_str).unwrap_or_else(|_| vec!["default".to_string()]);
62    let metadata: serde_json::Value = serde_json::from_str(&metadata_str)
63        .unwrap_or(serde_json::Value::Object(Default::default()));
64
65    Ok(Agent {
66        agent_id: row.get(0)?,
67        display_name: row.get(1)?,
68        capabilities,
69        namespaces,
70        last_heartbeat: row.get(4)?,
71        status: row.get(5)?,
72        metadata,
73        registered_at: row.get(7)?,
74        updated_at: row.get(8)?,
75    })
76}
77
78/// Register a new agent, or update an existing one if the `agent_id` already exists.
79///
80/// On conflict, updates: display_name, capabilities, namespaces, metadata, updated_at.
81/// The `registered_at` timestamp is preserved from the original registration.
82pub fn register_agent(conn: &Connection, input: &RegisterAgentInput) -> Result<Agent> {
83    if input.agent_id.trim().is_empty() {
84        return Err(EngramError::InvalidInput(
85            "agent_id must not be empty".to_string(),
86        ));
87    }
88    if input.display_name.trim().is_empty() {
89        return Err(EngramError::InvalidInput(
90            "display_name must not be empty".to_string(),
91        ));
92    }
93
94    let now = Utc::now().to_rfc3339();
95    let capabilities_json = serde_json::to_string(&input.capabilities)?;
96    let namespaces_json = serde_json::to_string(&input.namespaces)?;
97    let metadata_json = serde_json::to_string(&input.metadata)?;
98
99    conn.execute(
100        r#"
101        INSERT INTO agents
102            (agent_id, display_name, capabilities, namespaces, status, metadata, registered_at, updated_at)
103        VALUES (?, ?, ?, ?, 'active', ?, ?, ?)
104        ON CONFLICT(agent_id) DO UPDATE SET
105            display_name = excluded.display_name,
106            capabilities = excluded.capabilities,
107            namespaces   = excluded.namespaces,
108            metadata     = excluded.metadata,
109            status       = 'active',
110            updated_at   = excluded.updated_at
111        "#,
112        params![
113            input.agent_id,
114            input.display_name,
115            capabilities_json,
116            namespaces_json,
117            metadata_json,
118            now,
119            now,
120        ],
121    )?;
122
123    get_agent(conn, &input.agent_id)?
124        .ok_or_else(|| EngramError::Storage("Agent not found after insert".to_string()))
125}
126
127/// Deregister an agent by setting its status to 'inactive'.
128///
129/// Returns `true` if the agent was found and deregistered, `false` if not found.
130pub fn deregister_agent(conn: &Connection, agent_id: &str) -> Result<bool> {
131    let now = Utc::now().to_rfc3339();
132
133    let affected = conn.execute(
134        "UPDATE agents SET status = 'inactive', updated_at = ? WHERE agent_id = ?",
135        params![now, agent_id],
136    )?;
137
138    Ok(affected > 0)
139}
140
141/// Update the heartbeat timestamp for an agent.
142///
143/// Returns the updated `Agent` if found, or `None` if the agent does not exist.
144pub fn heartbeat_agent(conn: &Connection, agent_id: &str) -> Result<Option<Agent>> {
145    let now = Utc::now().to_rfc3339();
146
147    let affected = conn.execute(
148        "UPDATE agents SET last_heartbeat = ?, updated_at = ? WHERE agent_id = ?",
149        params![now, now, agent_id],
150    )?;
151
152    if affected == 0 {
153        return Ok(None);
154    }
155
156    get_agent(conn, agent_id)
157}
158
159/// Retrieve a single agent by its ID.
160pub fn get_agent(conn: &Connection, agent_id: &str) -> Result<Option<Agent>> {
161    conn.query_row(
162        r#"
163        SELECT agent_id, display_name, capabilities, namespaces,
164               last_heartbeat, status, metadata, registered_at, updated_at
165        FROM agents WHERE agent_id = ?
166        "#,
167        params![agent_id],
168        agent_from_row,
169    )
170    .optional()
171    .map_err(EngramError::from)
172}
173
174/// List all agents, optionally filtered by status.
175///
176/// `status_filter` accepts values like `"active"` or `"inactive"`.
177/// Pass `None` to return all agents regardless of status.
178pub fn list_agents(conn: &Connection, status_filter: Option<&str>) -> Result<Vec<Agent>> {
179    let (sql, param_str): (&str, Option<String>) = match status_filter {
180        Some(s) => (
181            r#"
182            SELECT agent_id, display_name, capabilities, namespaces,
183                   last_heartbeat, status, metadata, registered_at, updated_at
184            FROM agents WHERE status = ?
185            ORDER BY registered_at DESC
186            "#,
187            Some(s.to_string()),
188        ),
189        None => (
190            r#"
191            SELECT agent_id, display_name, capabilities, namespaces,
192                   last_heartbeat, status, metadata, registered_at, updated_at
193            FROM agents
194            ORDER BY registered_at DESC
195            "#,
196            None,
197        ),
198    };
199
200    let mut stmt = conn.prepare(sql)?;
201
202    let agents = if let Some(ref status) = param_str {
203        stmt.query_map(params![status], agent_from_row)?
204            .filter_map(|r| r.ok())
205            .collect()
206    } else {
207        stmt.query_map([], agent_from_row)?
208            .filter_map(|r| r.ok())
209            .collect()
210    };
211
212    Ok(agents)
213}
214
215/// Update the capabilities list for an agent.
216///
217/// Returns the updated `Agent` if found, or `None` if the agent does not exist.
218pub fn update_agent_capabilities(
219    conn: &Connection,
220    agent_id: &str,
221    capabilities: &[String],
222) -> Result<Option<Agent>> {
223    let now = Utc::now().to_rfc3339();
224    let capabilities_json = serde_json::to_string(capabilities)?;
225
226    let affected = conn.execute(
227        "UPDATE agents SET capabilities = ?, updated_at = ? WHERE agent_id = ?",
228        params![capabilities_json, now, agent_id],
229    )?;
230
231    if affected == 0 {
232        return Ok(None);
233    }
234
235    get_agent(conn, agent_id)
236}
237
238/// List all active agents that belong to the given namespace.
239pub fn get_agents_in_namespace(conn: &Connection, namespace: &str) -> Result<Vec<Agent>> {
240    // SQLite JSON array membership: json_each returns rows for each element.
241    let mut stmt = conn.prepare(
242        r#"
243        SELECT a.agent_id, a.display_name, a.capabilities, a.namespaces,
244               a.last_heartbeat, a.status, a.metadata, a.registered_at, a.updated_at
245        FROM agents a
246        WHERE a.status = 'active'
247          AND EXISTS (
248              SELECT 1 FROM json_each(a.namespaces)
249              WHERE value = ?
250          )
251        ORDER BY a.registered_at DESC
252        "#,
253    )?;
254
255    let agents = stmt
256        .query_map(params![namespace], agent_from_row)?
257        .filter_map(|r| r.ok())
258        .collect();
259
260    Ok(agents)
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::storage::migrations::run_migrations;
267
268    fn in_memory_conn() -> Connection {
269        let conn = Connection::open_in_memory().expect("open in-memory db");
270        run_migrations(&conn).expect("run migrations");
271        conn
272    }
273
274    fn basic_input(agent_id: &str) -> RegisterAgentInput {
275        RegisterAgentInput {
276            agent_id: agent_id.to_string(),
277            display_name: "Test Agent".to_string(),
278            capabilities: vec!["read".to_string(), "write".to_string()],
279            namespaces: vec!["default".to_string()],
280            metadata: serde_json::json!({"version": "1.0"}),
281        }
282    }
283
284    #[test]
285    fn test_register_and_get_agent() {
286        let conn = in_memory_conn();
287        let input = basic_input("agent-001");
288
289        let agent = register_agent(&conn, &input).expect("register agent");
290        assert_eq!(agent.agent_id, "agent-001");
291        assert_eq!(agent.display_name, "Test Agent");
292        assert_eq!(agent.capabilities, vec!["read", "write"]);
293        assert_eq!(agent.namespaces, vec!["default"]);
294        assert_eq!(agent.status, "active");
295        assert!(agent.last_heartbeat.is_none());
296
297        let fetched = get_agent(&conn, "agent-001")
298            .expect("get agent")
299            .expect("agent exists");
300        assert_eq!(fetched.agent_id, agent.agent_id);
301        assert_eq!(fetched.display_name, agent.display_name);
302    }
303
304    #[test]
305    fn test_deregister_agent() {
306        let conn = in_memory_conn();
307        register_agent(&conn, &basic_input("agent-deregister")).expect("register");
308
309        let found = deregister_agent(&conn, "agent-deregister").expect("deregister");
310        assert!(found, "should return true for existing agent");
311
312        let agent = get_agent(&conn, "agent-deregister")
313            .expect("get")
314            .expect("exists");
315        assert_eq!(agent.status, "inactive");
316    }
317
318    #[test]
319    fn test_heartbeat_updates_timestamp() {
320        let conn = in_memory_conn();
321        register_agent(&conn, &basic_input("agent-hb")).expect("register");
322
323        let before = get_agent(&conn, "agent-hb").expect("get").expect("exists");
324        assert!(before.last_heartbeat.is_none());
325
326        let updated = heartbeat_agent(&conn, "agent-hb")
327            .expect("heartbeat")
328            .expect("agent found");
329        assert!(
330            updated.last_heartbeat.is_some(),
331            "last_heartbeat should be set after heartbeat"
332        );
333    }
334
335    #[test]
336    fn test_list_agents_with_filter() {
337        let conn = in_memory_conn();
338        register_agent(&conn, &basic_input("agent-a1")).expect("register");
339        register_agent(&conn, &basic_input("agent-a2")).expect("register");
340        deregister_agent(&conn, "agent-a2").expect("deregister");
341
342        let active = list_agents(&conn, Some("active")).expect("list active");
343        assert_eq!(active.len(), 1);
344        assert_eq!(active[0].agent_id, "agent-a1");
345
346        let inactive = list_agents(&conn, Some("inactive")).expect("list inactive");
347        assert_eq!(inactive.len(), 1);
348        assert_eq!(inactive[0].agent_id, "agent-a2");
349
350        let all = list_agents(&conn, None).expect("list all");
351        assert_eq!(all.len(), 2);
352    }
353
354    #[test]
355    fn test_update_capabilities() {
356        let conn = in_memory_conn();
357        register_agent(&conn, &basic_input("agent-caps")).expect("register");
358
359        let updated = update_agent_capabilities(
360            &conn,
361            "agent-caps",
362            &[
363                "search".to_string(),
364                "create".to_string(),
365                "delete".to_string(),
366            ],
367        )
368        .expect("update")
369        .expect("found");
370
371        assert_eq!(updated.capabilities, vec!["search", "create", "delete"]);
372    }
373
374    #[test]
375    fn test_get_agents_in_namespace() {
376        let conn = in_memory_conn();
377
378        let mut input_a = basic_input("agent-ns1");
379        input_a.namespaces = vec!["default".to_string(), "project-x".to_string()];
380        register_agent(&conn, &input_a).expect("register a");
381
382        let mut input_b = basic_input("agent-ns2");
383        input_b.namespaces = vec!["project-x".to_string()];
384        register_agent(&conn, &input_b).expect("register b");
385
386        let mut input_c = basic_input("agent-ns3");
387        input_c.namespaces = vec!["other".to_string()];
388        register_agent(&conn, &input_c).expect("register c");
389
390        let in_project_x = get_agents_in_namespace(&conn, "project-x").expect("query");
391        let ids: Vec<&str> = in_project_x.iter().map(|a| a.agent_id.as_str()).collect();
392        assert!(
393            ids.contains(&"agent-ns1"),
394            "agent-ns1 should be in project-x"
395        );
396        assert!(
397            ids.contains(&"agent-ns2"),
398            "agent-ns2 should be in project-x"
399        );
400        assert!(
401            !ids.contains(&"agent-ns3"),
402            "agent-ns3 should not be in project-x"
403        );
404
405        let in_default = get_agents_in_namespace(&conn, "default").expect("query default");
406        assert_eq!(in_default.len(), 1);
407        assert_eq!(in_default[0].agent_id, "agent-ns1");
408    }
409
410    #[test]
411    fn test_register_duplicate_updates() {
412        let conn = in_memory_conn();
413        register_agent(&conn, &basic_input("agent-dup")).expect("register first");
414
415        let mut updated_input = basic_input("agent-dup");
416        updated_input.display_name = "Updated Agent".to_string();
417        updated_input.capabilities = vec!["admin".to_string()];
418        let agent = register_agent(&conn, &updated_input).expect("register second (upsert)");
419
420        assert_eq!(agent.display_name, "Updated Agent");
421        assert_eq!(agent.capabilities, vec!["admin"]);
422        assert_eq!(agent.status, "active");
423
424        // Only one row should exist
425        let all = list_agents(&conn, None).expect("list");
426        assert_eq!(all.len(), 1);
427    }
428
429    #[test]
430    fn test_deregister_nonexistent() {
431        let conn = in_memory_conn();
432
433        let found = deregister_agent(&conn, "does-not-exist").expect("no db error");
434        assert!(!found, "should return false for nonexistent agent");
435    }
436
437    #[test]
438    fn test_heartbeat_nonexistent_returns_none() {
439        let conn = in_memory_conn();
440
441        let result = heartbeat_agent(&conn, "ghost-agent").expect("no db error");
442        assert!(
443            result.is_none(),
444            "heartbeat on missing agent should return None"
445        );
446    }
447
448    #[test]
449    fn test_register_empty_agent_id_fails() {
450        let conn = in_memory_conn();
451        let mut input = basic_input("");
452        input.agent_id = "   ".to_string(); // blank
453
454        let err = register_agent(&conn, &input);
455        assert!(err.is_err(), "empty agent_id should fail");
456    }
457}