Skip to main content

jamjet_agents/
sqlite_registry.rs

1//! SQLite-backed agent registry.
2//!
3//! Uses the same database as the `jamjet-state` SQLite backend.
4//! The `agents` table is created by the `jamjet-state` migration.
5
6use crate::card::AgentCard;
7use crate::lifecycle::AgentStatus;
8use crate::registry::{Agent, AgentFilter, AgentId, AgentRegistry};
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use sqlx::{Row, SqlitePool};
12use tracing::instrument;
13use uuid::Uuid;
14
15pub struct SqliteAgentRegistry {
16    pool: SqlitePool,
17}
18
19impl SqliteAgentRegistry {
20    pub fn new(pool: SqlitePool) -> Self {
21        Self { pool }
22    }
23
24    /// Connect using an existing database URL (shared with state backend).
25    pub async fn connect(database_url: &str) -> Result<Self, sqlx::Error> {
26        use sqlx::sqlite::SqliteConnectOptions;
27        use std::str::FromStr;
28        let opts = SqliteConnectOptions::from_str(database_url)?.create_if_missing(true);
29        let pool = SqlitePool::connect_with(opts).await?;
30        Ok(Self { pool })
31    }
32}
33
34// ── helpers ───────────────────────────────────────────────────────────────────
35
36fn status_to_str(s: &AgentStatus) -> &'static str {
37    match s {
38        AgentStatus::Registered => "registered",
39        AgentStatus::Active => "active",
40        AgentStatus::Paused => "paused",
41        AgentStatus::Deactivated => "deactivated",
42        AgentStatus::Archived => "archived",
43    }
44}
45
46fn str_to_status(s: &str) -> Result<AgentStatus, String> {
47    match s {
48        "registered" => Ok(AgentStatus::Registered),
49        "active" => Ok(AgentStatus::Active),
50        "paused" => Ok(AgentStatus::Paused),
51        "deactivated" => Ok(AgentStatus::Deactivated),
52        "archived" => Ok(AgentStatus::Archived),
53        other => Err(format!("unknown agent status: {other}")),
54    }
55}
56
57fn parse_dt(s: &str) -> Result<DateTime<Utc>, String> {
58    DateTime::parse_from_rfc3339(s)
59        .map(|dt| dt.with_timezone(&Utc))
60        .map_err(|e| format!("bad datetime: {e}"))
61}
62
63fn row_to_agent(row: &sqlx::sqlite::SqliteRow) -> Result<Agent, String> {
64    let id = Uuid::parse_str(row.try_get::<&str, _>("id").map_err(|e| e.to_string())?)
65        .map_err(|e| e.to_string())?;
66    let card: AgentCard = serde_json::from_str(
67        row.try_get::<&str, _>("card_json")
68            .map_err(|e| e.to_string())?,
69    )
70    .map_err(|e| e.to_string())?;
71    let status = str_to_status(
72        row.try_get::<&str, _>("status")
73            .map_err(|e| e.to_string())?,
74    )?;
75    let registered_at = parse_dt(
76        row.try_get::<&str, _>("registered_at")
77            .map_err(|e| e.to_string())?,
78    )?;
79    let updated_at = parse_dt(
80        row.try_get::<&str, _>("updated_at")
81            .map_err(|e| e.to_string())?,
82    )?;
83    let last_heartbeat: Option<DateTime<Utc>> = row
84        .try_get::<Option<&str>, _>("last_heartbeat")
85        .map_err(|e| e.to_string())?
86        .map(parse_dt)
87        .transpose()?;
88
89    Ok(Agent {
90        id,
91        card,
92        status,
93        registered_at,
94        updated_at,
95        last_heartbeat,
96    })
97}
98
99// ── AgentRegistry impl ────────────────────────────────────────────────────────
100
101#[async_trait]
102impl AgentRegistry for SqliteAgentRegistry {
103    #[instrument(skip(self, card), fields(uri = %card.uri))]
104    async fn register(&self, card: AgentCard) -> Result<AgentId, String> {
105        let id = Uuid::new_v4();
106        let card_json = serde_json::to_string(&card).map_err(|e| e.to_string())?;
107        let now = Utc::now().to_rfc3339();
108
109        sqlx::query(
110            "INSERT INTO agents (id, uri, card_json, status, registered_at, updated_at) VALUES (?, ?, ?, 'registered', ?, ?)",
111        )
112        .bind(id.to_string())
113        .bind(&card.uri)
114        .bind(&card_json)
115        .bind(&now)
116        .bind(&now)
117        .execute(&self.pool)
118        .await
119        .map_err(|e| e.to_string())?;
120
121        Ok(id)
122    }
123
124    async fn get(&self, id: AgentId) -> Result<Option<Agent>, String> {
125        let id_str = id.to_string();
126        let row = sqlx::query("SELECT * FROM agents WHERE id = ?")
127            .bind(&id_str)
128            .fetch_optional(&self.pool)
129            .await
130            .map_err(|e| e.to_string())?;
131
132        row.map(|r| row_to_agent(&r)).transpose()
133    }
134
135    async fn get_by_uri(&self, uri: &str) -> Result<Option<Agent>, String> {
136        let row = sqlx::query("SELECT * FROM agents WHERE uri = ?")
137            .bind(uri)
138            .fetch_optional(&self.pool)
139            .await
140            .map_err(|e| e.to_string())?;
141
142        row.map(|r| row_to_agent(&r)).transpose()
143    }
144
145    async fn find(&self, filter: AgentFilter) -> Result<Vec<Agent>, String> {
146        // Build query dynamically based on filter fields.
147        // For Phase 1, status filter is enough; skill/protocol matching is in-memory.
148        let rows = match &filter.status {
149            Some(s) => {
150                sqlx::query("SELECT * FROM agents WHERE status = ? ORDER BY registered_at DESC")
151                    .bind(status_to_str(s))
152                    .fetch_all(&self.pool)
153                    .await
154                    .map_err(|e| e.to_string())?
155            }
156            None => sqlx::query("SELECT * FROM agents ORDER BY registered_at DESC")
157                .fetch_all(&self.pool)
158                .await
159                .map_err(|e| e.to_string())?,
160        };
161
162        let mut agents: Vec<Agent> = rows.iter().map(row_to_agent).collect::<Result<_, _>>()?;
163
164        // Post-filter by skill and protocol.
165        if let Some(skill) = &filter.skill {
166            agents.retain(|a| a.card.capabilities.skills.iter().any(|s| &s.name == skill));
167        }
168        if let Some(protocol) = &filter.protocol {
169            agents.retain(|a| a.card.capabilities.protocols.contains(protocol));
170        }
171
172        Ok(agents)
173    }
174
175    #[instrument(skip(self), fields(agent_id = %id))]
176    async fn update_status(&self, id: AgentId, status: AgentStatus) -> Result<(), String> {
177        let id_str = id.to_string();
178        let status_str = status_to_str(&status);
179        let now = Utc::now().to_rfc3339();
180
181        let rows = sqlx::query("UPDATE agents SET status = ?, updated_at = ? WHERE id = ?")
182            .bind(status_str)
183            .bind(&now)
184            .bind(&id_str)
185            .execute(&self.pool)
186            .await
187            .map_err(|e| e.to_string())?
188            .rows_affected();
189
190        if rows == 0 {
191            return Err(format!("agent {id} not found"));
192        }
193        Ok(())
194    }
195
196    #[instrument(skip(self), fields(agent_id = %id))]
197    async fn heartbeat(&self, id: AgentId) -> Result<(), String> {
198        let id_str = id.to_string();
199        let now = Utc::now().to_rfc3339();
200
201        sqlx::query("UPDATE agents SET last_heartbeat = ?, updated_at = ? WHERE id = ?")
202            .bind(&now)
203            .bind(&now)
204            .bind(&id_str)
205            .execute(&self.pool)
206            .await
207            .map_err(|e| e.to_string())?;
208
209        Ok(())
210    }
211
212    #[instrument(skip(self), fields(url = url))]
213    async fn discover_remote(&self, url: &str) -> Result<Agent, String> {
214        let agent_card_url = format!("{url}/.well-known/agent.json");
215        let card: AgentCard = reqwest::Client::new()
216            .get(&agent_card_url)
217            .send()
218            .await
219            .map_err(|e| format!("fetch Agent Card: {e}"))?
220            .json()
221            .await
222            .map_err(|e| format!("parse Agent Card: {e}"))?;
223
224        let id = self.register(card).await?;
225        self.get(id)
226            .await?
227            .ok_or_else(|| "agent not found after registration".into())
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::card::{AgentCapabilities, AuthSpec, AutonomyLevel};
235
236    async fn open_test_registry() -> SqliteAgentRegistry {
237        // Use the state migration to create the agents table.
238        let backend = jamjet_state::SqliteBackend::open("sqlite::memory:")
239            .await
240            .expect("failed to open in-memory db");
241        // Re-use the pool from SqliteBackend by connecting to the same in-memory URL.
242        // In-memory SQLite is connection-scoped, so we need a fresh connection here
243        // that shares the schema from migrations.
244        // For tests, connect to a named in-memory DB.
245        let pool = SqlitePool::connect("sqlite::memory:").await.expect("pool");
246        // Run migrations manually.
247        sqlx::migrate!("../state/migrations")
248            .run(&pool)
249            .await
250            .expect("migrations");
251        SqliteAgentRegistry { pool }
252    }
253
254    fn sample_card(uri: &str) -> AgentCard {
255        AgentCard {
256            id: uuid::Uuid::new_v4().to_string(),
257            uri: uri.to_string(),
258            name: "Test Agent".into(),
259            description: "A test agent".into(),
260            version: "1.0.0".into(),
261            capabilities: AgentCapabilities {
262                skills: vec![],
263                protocols: vec!["mcp_client".into()],
264                tools_provided: vec![],
265                tools_consumed: vec![],
266            },
267            autonomy: AutonomyLevel::Guided,
268            constraints: None,
269            auth: AuthSpec::None,
270            latency_class: None,
271            cost_class: None,
272            reasoning_modes: vec![],
273            labels: Default::default(),
274        }
275    }
276
277    #[tokio::test]
278    async fn test_register_and_get() {
279        let reg = open_test_registry().await;
280        let card = sample_card("jamjet://test/agent1");
281        let id = reg.register(card.clone()).await.unwrap();
282        let agent = reg.get(id).await.unwrap().unwrap();
283        assert_eq!(agent.card.uri, "jamjet://test/agent1");
284        assert_eq!(agent.status, AgentStatus::Registered);
285    }
286
287    #[tokio::test]
288    async fn test_status_transition() {
289        let reg = open_test_registry().await;
290        let id = reg
291            .register(sample_card("jamjet://test/agent2"))
292            .await
293            .unwrap();
294        reg.update_status(id, AgentStatus::Active).await.unwrap();
295        let agent = reg.get(id).await.unwrap().unwrap();
296        assert_eq!(agent.status, AgentStatus::Active);
297    }
298
299    #[tokio::test]
300    async fn test_find_by_status() {
301        let reg = open_test_registry().await;
302        let id1 = reg.register(sample_card("jamjet://test/a3")).await.unwrap();
303        let _id2 = reg.register(sample_card("jamjet://test/a4")).await.unwrap();
304        reg.update_status(id1, AgentStatus::Active).await.unwrap();
305
306        let active = reg
307            .find(AgentFilter {
308                status: Some(AgentStatus::Active),
309                ..Default::default()
310            })
311            .await
312            .unwrap();
313        assert_eq!(active.len(), 1);
314        assert_eq!(active[0].status, AgentStatus::Active);
315    }
316}