Skip to main content

agentics_persistence/db/
pioneer_codes.rs

1//! Persistence for pioneer-code gated agent registration.
2
3use chrono::{DateTime, Utc};
4use sqlx::{PgPool, Postgres, Row, Transaction};
5use uuid::Uuid;
6
7use agentics_domain::models::ids::{AdminServiceTokenId, AgentId, HumanId, PioneerCodeId};
8use agentics_domain::models::pioneer_codes::{
9    INVALID_OR_UNAVAILABLE_PIONEER_CODE, PioneerCodeSubjectKind,
10};
11use agentics_error::{Result, ServiceError};
12
13/// Registration flow that consumed a pioneer code.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum PioneerCodeRegistrationKind {
16    /// Human account creation during GitHub sign-in callback.
17    HumanGithubSignIn,
18    /// Direct agent registration through `/api/agents/register`.
19    AgentApi,
20}
21
22impl PioneerCodeRegistrationKind {
23    /// Return the storage value persisted with a pioneer-code use.
24    pub fn as_str(self) -> &'static str {
25        match self {
26            Self::HumanGithubSignIn => "human_github_sign_in",
27            Self::AgentApi => "agent_api",
28        }
29    }
30}
31
32/// Input used by admins to create a pioneer code.
33#[derive(Debug, Clone)]
34pub struct CreatePioneerCodeInput {
35    pub id: PioneerCodeId,
36    pub code_display: String,
37    pub code_hash: String,
38    pub label: Option<String>,
39    pub note: String,
40    pub max_uses: i64,
41    pub expires_at: Option<DateTime<Utc>>,
42    pub created_by_human_id: Option<HumanId>,
43    pub created_by_admin_service_token_id: Option<AdminServiceTokenId>,
44    pub created_by_display: String,
45}
46
47/// Persisted pioneer-code row returned to admins.
48#[derive(Debug, Clone)]
49pub struct PioneerCodeRecord {
50    pub id: PioneerCodeId,
51    pub code_display: String,
52    pub label: Option<String>,
53    pub note: String,
54    pub max_uses: i64,
55    pub use_count: i64,
56    pub status: String,
57    pub expires_at: Option<DateTime<Utc>>,
58    pub created_by_display: String,
59    pub created_at: DateTime<Utc>,
60    pub revoked_at: Option<DateTime<Utc>>,
61}
62
63/// One account that was created with a pioneer code.
64#[derive(Debug, Clone)]
65pub struct PioneerCodeUseRecord {
66    pub subject_kind: PioneerCodeSubjectKind,
67    pub human_id: Option<HumanId>,
68    pub human_github_login: Option<String>,
69    pub agent_id: Option<AgentId>,
70    pub agent_display_name: Option<String>,
71    pub registration_kind: String,
72    pub used_at: DateTime<Utc>,
73}
74
75/// Result returned after revoking a pioneer code and disabling derived agents.
76#[derive(Debug, Clone)]
77pub struct RevokePioneerCodeOutcome {
78    pub revoked_human_count: i64,
79    pub revoked_human_session_count: i64,
80    pub revoked_admin_service_token_count: i64,
81    pub revoked_creator_api_token_count: i64,
82    pub revoked_agent_count: i64,
83    pub revoked_token_count: i64,
84}
85
86/// Insert a newly generated or admin-supplied pioneer code.
87pub async fn create_pioneer_code(
88    pool: &PgPool,
89    input: &CreatePioneerCodeInput,
90) -> Result<PioneerCodeRecord> {
91    let row = sqlx::query(
92        r#"
93        INSERT INTO pioneer_codes (
94            id,
95            code_display,
96            code_hash,
97            label,
98            note,
99            max_uses,
100            expires_at,
101            created_by_human_id,
102            created_by_admin_service_token_id,
103            created_by_display
104        )
105        VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8::uuid, $9::uuid, $10)
106        RETURNING
107            id::text AS id,
108            code_display,
109            label,
110            note,
111            max_uses,
112            use_count,
113            status,
114            expires_at,
115            created_by_display,
116            created_at,
117            revoked_at
118        "#,
119    )
120    .bind(input.id.as_str())
121    .bind(&input.code_display)
122    .bind(&input.code_hash)
123    .bind(&input.label)
124    .bind(&input.note)
125    .bind(input.max_uses)
126    .bind(input.expires_at)
127    .bind(input.created_by_human_id.as_ref().map(HumanId::as_str))
128    .bind(
129        input
130            .created_by_admin_service_token_id
131            .as_ref()
132            .map(AdminServiceTokenId::as_str),
133    )
134    .bind(&input.created_by_display)
135    .fetch_one(pool)
136    .await?;
137
138    pioneer_code_record_from_row(&row)
139}
140
141/// List pioneer codes for the admin console.
142pub async fn list_pioneer_codes(pool: &PgPool) -> Result<Vec<PioneerCodeRecord>> {
143    let rows = sqlx::query(
144        r#"
145        SELECT
146            id::text AS id,
147            code_display,
148            label,
149            note,
150            max_uses,
151            use_count,
152            status,
153            expires_at,
154            created_by_display,
155            created_at,
156            revoked_at
157        FROM pioneer_codes
158        ORDER BY created_at DESC
159        "#,
160    )
161    .fetch_all(pool)
162    .await?;
163
164    rows.iter().map(pioneer_code_record_from_row).collect()
165}
166
167/// Fetch a pioneer code and all agents created through it.
168pub async fn get_pioneer_code_detail(
169    pool: &PgPool,
170    id: &PioneerCodeId,
171) -> Result<(PioneerCodeRecord, Vec<PioneerCodeUseRecord>)> {
172    let code_row = sqlx::query(
173        r#"
174        SELECT
175            id::text AS id,
176            code_display,
177            label,
178            note,
179            max_uses,
180            use_count,
181            status,
182            expires_at,
183            created_by_display,
184            created_at,
185            revoked_at
186        FROM pioneer_codes
187        WHERE id = $1::uuid
188        "#,
189    )
190    .bind(id.as_str())
191    .fetch_optional(pool)
192    .await?
193    .ok_or(ServiceError::NotFound)?;
194
195    let use_rows = sqlx::query(
196        r#"
197        SELECT
198            u.subject_kind,
199            u.human_id::text AS human_id,
200            h_e.provider_login AS human_github_login,
201            u.agent_id::text AS agent_id,
202            a.display_name AS agent_display_name,
203            u.registration_kind,
204            u.used_at
205        FROM pioneer_code_uses u
206        LEFT JOIN agents a ON a.id = u.agent_id
207        LEFT JOIN human_external_identities h_e ON h_e.human_id = u.human_id AND h_e.provider = 'github'
208        WHERE u.pioneer_code_id = $1::uuid
209        ORDER BY u.used_at DESC
210        "#,
211    )
212    .bind(id.as_str())
213    .fetch_all(pool)
214    .await?;
215
216    let uses = use_rows
217        .iter()
218        .map(pioneer_code_use_from_row)
219        .collect::<Result<_>>()?;
220    Ok((pioneer_code_record_from_row(&code_row)?, uses))
221}
222
223/// Verify that a pioneer code can currently start a registration flow.
224pub async fn ensure_pioneer_code_available(pool: &PgPool, code_hash: &str) -> Result<()> {
225    let row = sqlx::query(
226        r#"
227        SELECT max_uses, use_count, status, expires_at
228        FROM pioneer_codes
229        WHERE code_hash = $1
230        "#,
231    )
232    .bind(code_hash)
233    .fetch_optional(pool)
234    .await?;
235
236    let Some(row) = row else {
237        return Err(unavailable_pioneer_code());
238    };
239
240    let status: String = row.try_get("status")?;
241    let expires_at: Option<DateTime<Utc>> = row.try_get("expires_at")?;
242    let max_uses: i64 = row.try_get("max_uses")?;
243    let use_count: i64 = row.try_get("use_count")?;
244    if status != "active"
245        || expires_at.is_some_and(|expires_at| Utc::now() >= expires_at)
246        || (max_uses != -1 && use_count >= max_uses)
247    {
248        return Err(unavailable_pioneer_code());
249    }
250
251    Ok(())
252}
253
254/// Consume a pioneer code inside the transaction that creates the agent.
255pub async fn consume_pioneer_code_for_agent_tx(
256    tx: &mut Transaction<'_, Postgres>,
257    code_hash: &str,
258    agent_id: &str,
259    registration_kind: PioneerCodeRegistrationKind,
260) -> Result<()> {
261    let row = sqlx::query(
262        r#"
263        SELECT id::text AS id, max_uses, use_count, status, expires_at
264        FROM pioneer_codes
265        WHERE code_hash = $1
266        FOR UPDATE
267        "#,
268    )
269    .bind(code_hash)
270    .fetch_optional(&mut **tx)
271    .await?;
272
273    let Some(row) = row else {
274        return Err(unavailable_pioneer_code());
275    };
276
277    let status: String = row.try_get("status")?;
278    let expires_at: Option<DateTime<Utc>> = row.try_get("expires_at")?;
279    let max_uses: i64 = row.try_get("max_uses")?;
280    let use_count: i64 = row.try_get("use_count")?;
281    if status != "active"
282        || expires_at.is_some_and(|expires_at| Utc::now() >= expires_at)
283        || (max_uses != -1 && use_count >= max_uses)
284    {
285        return Err(unavailable_pioneer_code());
286    }
287
288    let pioneer_code_id: String = row.try_get("id")?;
289    sqlx::query(
290        r#"
291        INSERT INTO pioneer_code_uses (
292            id,
293            pioneer_code_id,
294            subject_kind,
295            agent_id,
296            registration_kind
297        )
298        VALUES ($1::uuid, $2::uuid, 'agent', $3::uuid, $4)
299        "#,
300    )
301    .bind(Uuid::new_v4().to_string())
302    .bind(&pioneer_code_id)
303    .bind(agent_id)
304    .bind(registration_kind.as_str())
305    .execute(&mut **tx)
306    .await?;
307
308    sqlx::query("UPDATE pioneer_codes SET use_count = use_count + 1 WHERE id = $1::uuid")
309        .bind(&pioneer_code_id)
310        .execute(&mut **tx)
311        .await?;
312
313    Ok(())
314}
315
316/// Consume a pioneer code inside the transaction that creates the human.
317pub async fn consume_pioneer_code_for_human_tx(
318    tx: &mut Transaction<'_, Postgres>,
319    code_hash: &str,
320    human_id: &str,
321) -> Result<()> {
322    let row = sqlx::query(
323        r#"
324        SELECT id::text AS id, max_uses, use_count, status, expires_at
325        FROM pioneer_codes
326        WHERE code_hash = $1
327        FOR UPDATE
328        "#,
329    )
330    .bind(code_hash)
331    .fetch_optional(&mut **tx)
332    .await?;
333
334    let Some(row) = row else {
335        return Err(unavailable_pioneer_code());
336    };
337
338    let status: String = row.try_get("status")?;
339    let expires_at: Option<DateTime<Utc>> = row.try_get("expires_at")?;
340    let max_uses: i64 = row.try_get("max_uses")?;
341    let use_count: i64 = row.try_get("use_count")?;
342    if status != "active"
343        || expires_at.is_some_and(|expires_at| Utc::now() >= expires_at)
344        || (max_uses != -1 && use_count >= max_uses)
345    {
346        return Err(unavailable_pioneer_code());
347    }
348
349    let pioneer_code_id: String = row.try_get("id")?;
350    sqlx::query(
351        r#"
352        INSERT INTO pioneer_code_uses (
353            id,
354            pioneer_code_id,
355            subject_kind,
356            human_id,
357            registration_kind
358        )
359        VALUES ($1::uuid, $2::uuid, 'human', $3::uuid, $4)
360        "#,
361    )
362    .bind(Uuid::new_v4().to_string())
363    .bind(&pioneer_code_id)
364    .bind(human_id)
365    .bind(PioneerCodeRegistrationKind::HumanGithubSignIn.as_str())
366    .execute(&mut **tx)
367    .await?;
368
369    sqlx::query("UPDATE pioneer_codes SET use_count = use_count + 1 WHERE id = $1::uuid")
370        .bind(&pioneer_code_id)
371        .execute(&mut **tx)
372        .await?;
373
374    Ok(())
375}
376
377/// Revoke a pioneer code, rescind human setup, and disable agents created through it.
378pub async fn revoke_pioneer_code(
379    pool: &PgPool,
380    id: &PioneerCodeId,
381) -> Result<RevokePioneerCodeOutcome> {
382    let mut tx = pool.begin().await?;
383
384    let row = sqlx::query(
385        r#"
386        UPDATE pioneer_codes
387        SET status = 'revoked',
388            revoked_at = COALESCE(revoked_at, NOW())
389        WHERE id = $1::uuid
390        RETURNING id
391        "#,
392    )
393    .bind(id.as_str())
394    .fetch_optional(&mut *tx)
395    .await?;
396    if row.is_none() {
397        return Err(ServiceError::NotFound);
398    }
399
400    let agent_id_rows = sqlx::query(
401        r#"
402        SELECT agent_id
403        FROM pioneer_code_uses
404        WHERE pioneer_code_id = $1::uuid
405          AND agent_id IS NOT NULL
406        "#,
407    )
408    .bind(id.as_str())
409    .fetch_all(&mut *tx)
410    .await?;
411    let agent_ids = agent_id_rows
412        .iter()
413        .map(|row| row.try_get::<Uuid, _>("agent_id"))
414        .collect::<std::result::Result<Vec<_>, _>>()?;
415
416    let human_id_rows = sqlx::query(
417        r#"
418        SELECT human_id
419        FROM pioneer_code_uses
420        WHERE pioneer_code_id = $1::uuid
421          AND human_id IS NOT NULL
422        "#,
423    )
424    .bind(id.as_str())
425    .fetch_all(&mut *tx)
426    .await?;
427    let human_ids = human_id_rows
428        .iter()
429        .map(|row| row.try_get::<Uuid, _>("human_id"))
430        .collect::<std::result::Result<Vec<_>, _>>()?;
431
432    let revoked_human_count = if human_ids.is_empty() {
433        0
434    } else {
435        let result = sqlx::query(
436            r#"
437            UPDATE humans
438            SET status = 'setup_required',
439                disabled_at = NULL,
440                deleted_at = NULL
441            WHERE id = ANY($1::uuid[])
442              AND status NOT IN ('disabled', 'deleted')
443            "#,
444        )
445        .bind(&human_ids)
446        .execute(&mut *tx)
447        .await?;
448        i64::try_from(result.rows_affected())
449            .map_err(|_| ServiceError::Internal("revoked human count overflow".to_string()))?
450    };
451
452    if !human_ids.is_empty() {
453        sqlx::query(
454            r#"
455            UPDATE human_roles
456            SET revoked_at = COALESCE(revoked_at, NOW())
457            WHERE human_id = ANY($1::uuid[])
458              AND role = 'creator'
459              AND revoked_at IS NULL
460            "#,
461        )
462        .bind(&human_ids)
463        .execute(&mut *tx)
464        .await?;
465    }
466
467    let revoked_human_session_count = if human_ids.is_empty() {
468        0
469    } else {
470        let result = sqlx::query(
471            r#"
472            DELETE FROM human_sessions
473            WHERE human_id = ANY($1::uuid[])
474            "#,
475        )
476        .bind(&human_ids)
477        .execute(&mut *tx)
478        .await?;
479        i64::try_from(result.rows_affected()).map_err(|_| {
480            ServiceError::Internal("revoked human session count overflow".to_string())
481        })?
482    };
483
484    let revoked_admin_service_token_count = if human_ids.is_empty() {
485        0
486    } else {
487        let result = sqlx::query(
488            r#"
489            UPDATE admin_service_tokens
490            SET status = 'revoked',
491                revoked_at = COALESCE(revoked_at, NOW())
492            WHERE created_by_human_id = ANY($1::uuid[])
493              AND status = 'active'
494            "#,
495        )
496        .bind(&human_ids)
497        .execute(&mut *tx)
498        .await?;
499        i64::try_from(result.rows_affected()).map_err(|_| {
500            ServiceError::Internal("revoked admin service token count overflow".to_string())
501        })?
502    };
503
504    let revoked_creator_api_token_count = if human_ids.is_empty() {
505        0
506    } else {
507        let result = sqlx::query(
508            r#"
509            UPDATE creator_api_tokens
510            SET status = 'revoked',
511                revoked_at = COALESCE(revoked_at, NOW())
512            WHERE created_by_human_id = ANY($1::uuid[])
513              AND status = 'active'
514            "#,
515        )
516        .bind(&human_ids)
517        .execute(&mut *tx)
518        .await?;
519        i64::try_from(result.rows_affected()).map_err(|_| {
520            ServiceError::Internal("revoked creator API token count overflow".to_string())
521        })?
522    };
523
524    let revoked_agent_count = if agent_ids.is_empty() {
525        0
526    } else {
527        let result = sqlx::query(
528            r#"
529            UPDATE agents
530            SET status = 'disabled'
531            WHERE id = ANY($1::uuid[])
532              AND status = 'active'
533            "#,
534        )
535        .bind(&agent_ids)
536        .execute(&mut *tx)
537        .await?;
538        i64::try_from(result.rows_affected())
539            .map_err(|_| ServiceError::Internal("revoked agent count overflow".to_string()))?
540    };
541
542    let revoked_token_count = if agent_ids.is_empty() {
543        0
544    } else {
545        let result = sqlx::query(
546            r#"
547            UPDATE agent_tokens
548            SET revoked_at = COALESCE(revoked_at, NOW())
549            WHERE agent_id = ANY($1::uuid[])
550              AND revoked_at IS NULL
551            "#,
552        )
553        .bind(&agent_ids)
554        .execute(&mut *tx)
555        .await?;
556        i64::try_from(result.rows_affected())
557            .map_err(|_| ServiceError::Internal("revoked token count overflow".to_string()))?
558    };
559
560    tx.commit().await?;
561
562    Ok(RevokePioneerCodeOutcome {
563        revoked_human_count,
564        revoked_human_session_count,
565        revoked_admin_service_token_count,
566        revoked_creator_api_token_count,
567        revoked_agent_count,
568        revoked_token_count,
569    })
570}
571
572/// Convert a unavailable-code condition into the public generic error.
573fn unavailable_pioneer_code() -> ServiceError {
574    ServiceError::Forbidden(INVALID_OR_UNAVAILABLE_PIONEER_CODE.to_string())
575}
576
577/// Parse a pioneer-code row into the typed DB record.
578fn pioneer_code_record_from_row(row: &sqlx::postgres::PgRow) -> Result<PioneerCodeRecord> {
579    let id: String = row.try_get("id")?;
580    Ok(PioneerCodeRecord {
581        id: PioneerCodeId::try_new(id)
582            .map_err(|e| ServiceError::Internal(format!("stored invalid pioneer code id: {e}")))?,
583        code_display: row.try_get("code_display")?,
584        label: row.try_get("label")?,
585        note: row.try_get("note")?,
586        max_uses: row.try_get("max_uses")?,
587        use_count: row.try_get("use_count")?,
588        status: row.try_get("status")?,
589        expires_at: row.try_get("expires_at")?,
590        created_by_display: row.try_get("created_by_display")?,
591        created_at: row.try_get("created_at")?,
592        revoked_at: row.try_get("revoked_at")?,
593    })
594}
595
596/// Parse a pioneer-code use row into the typed DB record.
597fn pioneer_code_use_from_row(row: &sqlx::postgres::PgRow) -> Result<PioneerCodeUseRecord> {
598    let subject_kind: String = row.try_get("subject_kind")?;
599    let subject_kind =
600        PioneerCodeSubjectKind::from_storage_value(&subject_kind).ok_or_else(|| {
601            ServiceError::Internal(format!(
602                "stored invalid pioneer-code subject `{subject_kind}`"
603            ))
604        })?;
605    let human_id = row
606        .try_get::<Option<String>, _>("human_id")?
607        .map(HumanId::try_new)
608        .transpose()
609        .map_err(|e| {
610            ServiceError::Internal(format!("stored invalid pioneer-code human id: {e}"))
611        })?;
612    let agent_id = row
613        .try_get::<Option<String>, _>("agent_id")?
614        .map(AgentId::try_new)
615        .transpose()
616        .map_err(|e| {
617            ServiceError::Internal(format!("stored invalid pioneer-code agent id: {e}"))
618        })?;
619    Ok(PioneerCodeUseRecord {
620        subject_kind,
621        human_id,
622        human_github_login: row.try_get("human_github_login")?,
623        agent_id,
624        agent_display_name: row.try_get("agent_display_name")?,
625        registration_kind: row.try_get("registration_kind")?,
626        used_at: row.try_get("used_at")?,
627    })
628}