Skip to main content

claw_spawn/infrastructure/
repository.rs

1use crate::domain::{Account, Bot, BotStatus, Droplet, Persona, StoredBotConfig, SubscriptionTier};
2use async_trait::async_trait;
3use chrono::Utc;
4use sqlx::{PgPool, Row};
5use std::str::FromStr;
6use thiserror::Error;
7use uuid::Uuid;
8
9#[derive(Error, Debug)]
10pub enum RepositoryError {
11    #[error("Database error: {0}")]
12    DatabaseError(#[from] sqlx::Error),
13    #[error("Not found: {0}")]
14    NotFound(String),
15    #[error("Invalid data: {0}")]
16    InvalidData(String),
17}
18
19#[async_trait]
20pub trait AccountRepository: Send + Sync {
21    #[must_use]
22    async fn create(&self, account: &Account) -> Result<(), RepositoryError>;
23    #[must_use]
24    async fn get_by_id(&self, id: Uuid) -> Result<Account, RepositoryError>;
25    #[must_use]
26    async fn get_by_external_id(&self, external_id: &str) -> Result<Account, RepositoryError>;
27    #[must_use]
28    async fn update_subscription(
29        &self,
30        id: Uuid,
31        tier: SubscriptionTier,
32    ) -> Result<(), RepositoryError>;
33}
34
35#[async_trait]
36pub trait BotRepository: Send + Sync {
37    #[must_use]
38    async fn create(&self, bot: &Bot) -> Result<(), RepositoryError>;
39    #[must_use]
40    async fn get_by_id(&self, id: Uuid) -> Result<Bot, RepositoryError>;
41    #[must_use]
42    async fn get_by_id_with_token(&self, id: Uuid, token: &str) -> Result<Bot, RepositoryError>;
43    #[must_use]
44    async fn list_by_account(&self, account_id: Uuid) -> Result<Vec<Bot>, RepositoryError>;
45    /// PERF-002: Paginated list of bots for account
46    /// Use limit/offset for pagination instead of loading all bots
47    #[must_use]
48    async fn list_by_account_paginated(
49        &self,
50        account_id: Uuid,
51        limit: i64,
52        offset: i64,
53    ) -> Result<Vec<Bot>, RepositoryError>;
54    /// PERF-001: Count bots for account without fetching all rows
55    /// Use SQL COUNT(*) instead of list_by_account().len()
56    #[must_use]
57    async fn count_by_account(&self, account_id: Uuid) -> Result<i64, RepositoryError>;
58    #[must_use]
59    async fn update_status(&self, id: Uuid, status: BotStatus) -> Result<(), RepositoryError>;
60    #[must_use]
61    async fn update_droplet(
62        &self,
63        bot_id: Uuid,
64        droplet_id: Option<i64>,
65    ) -> Result<(), RepositoryError>;
66    #[must_use]
67    async fn update_config_version(
68        &self,
69        bot_id: Uuid,
70        desired: Option<Uuid>,
71        applied: Option<Uuid>,
72    ) -> Result<(), RepositoryError>;
73    #[must_use]
74    async fn update_heartbeat(&self, bot_id: Uuid) -> Result<(), RepositoryError>;
75    #[must_use]
76    async fn update_registration_token(
77        &self,
78        bot_id: Uuid,
79        token: &str,
80    ) -> Result<(), RepositoryError>;
81    #[must_use]
82    async fn delete(&self, id: Uuid) -> Result<(), RepositoryError>;
83    /// Atomically increment bot counter for account, returning (success, current_count, max_count)
84    /// CRIT-002: Prevents race conditions in account limit checking
85    #[must_use]
86    async fn increment_bot_counter(
87        &self,
88        account_id: Uuid,
89    ) -> Result<(bool, i32, i32), RepositoryError>;
90    /// Decrement bot counter when bot is destroyed
91    #[must_use]
92    async fn decrement_bot_counter(&self, account_id: Uuid) -> Result<(), RepositoryError>;
93    /// List bots with stale heartbeats (HIGH-001)
94    #[must_use]
95    async fn list_stale_bots(
96        &self,
97        threshold: chrono::DateTime<chrono::Utc>,
98    ) -> Result<Vec<Bot>, RepositoryError>;
99}
100
101#[async_trait]
102pub trait ConfigRepository: Send + Sync {
103    #[must_use]
104    async fn create(&self, config: &StoredBotConfig) -> Result<(), RepositoryError>;
105    #[must_use]
106    async fn get_by_id(&self, id: Uuid) -> Result<StoredBotConfig, RepositoryError>;
107    #[must_use]
108    async fn get_latest_for_bot(
109        &self,
110        bot_id: Uuid,
111    ) -> Result<Option<StoredBotConfig>, RepositoryError>;
112    #[must_use]
113    async fn list_by_bot(&self, bot_id: Uuid) -> Result<Vec<StoredBotConfig>, RepositoryError>;
114    /// Get next config version atomically using advisory locks
115    /// CRIT-007: Prevents duplicate version numbers under concurrent updates
116    #[must_use]
117    async fn get_next_version_atomic(&self, bot_id: Uuid) -> Result<i32, RepositoryError>;
118}
119
120#[async_trait]
121pub trait DropletRepository: Send + Sync {
122    #[must_use]
123    async fn create(&self, droplet: &Droplet) -> Result<(), RepositoryError>;
124    #[must_use]
125    async fn get_by_id(&self, id: i64) -> Result<Droplet, RepositoryError>;
126    #[must_use]
127    async fn update_bot_assignment(
128        &self,
129        droplet_id: i64,
130        bot_id: Option<Uuid>,
131    ) -> Result<(), RepositoryError>;
132    #[must_use]
133    async fn update_status(&self, droplet_id: i64, status: &str) -> Result<(), RepositoryError>;
134    #[must_use]
135    async fn update_ip(&self, droplet_id: i64, ip: Option<String>) -> Result<(), RepositoryError>;
136    #[must_use]
137    async fn mark_destroyed(&self, droplet_id: i64) -> Result<(), RepositoryError>;
138}
139
140pub struct PostgresAccountRepository {
141    pool: PgPool,
142}
143
144impl PostgresAccountRepository {
145    pub fn new(pool: PgPool) -> Self {
146        Self { pool }
147    }
148}
149
150#[async_trait]
151impl AccountRepository for PostgresAccountRepository {
152    async fn create(&self, account: &Account) -> Result<(), RepositoryError> {
153        let tier_str = match account.subscription_tier {
154            SubscriptionTier::Free => "free",
155            SubscriptionTier::Basic => "basic",
156            SubscriptionTier::Pro => "pro",
157        };
158
159        sqlx::query(
160            r#"
161            INSERT INTO accounts (id, external_id, subscription_tier, max_bots, created_at, updated_at)
162            VALUES ($1, $2, $3, $4, $5, $6)
163            "#,
164        )
165        .bind(account.id)
166        .bind(&account.external_id)
167        .bind(tier_str)
168        .bind(account.max_bots)
169        .bind(account.created_at)
170        .bind(account.updated_at)
171        .execute(&self.pool)
172        .await?;
173
174        Ok(())
175    }
176
177    async fn get_by_id(&self, id: Uuid) -> Result<Account, RepositoryError> {
178        let row = sqlx::query(
179            r#"
180            SELECT id, external_id, subscription_tier, max_bots, created_at, updated_at
181            FROM accounts
182            WHERE id = $1
183            "#,
184        )
185        .bind(id)
186        .fetch_one(&self.pool)
187        .await
188        .map_err(|e| match e {
189            sqlx::Error::RowNotFound => RepositoryError::NotFound(format!("Account {}", id)),
190            _ => RepositoryError::DatabaseError(e),
191        })?;
192
193        Ok(row_to_account(&row)?)
194    }
195
196    async fn get_by_external_id(&self, external_id: &str) -> Result<Account, RepositoryError> {
197        let row = sqlx::query(
198            r#"
199            SELECT id, external_id, subscription_tier, max_bots, created_at, updated_at
200            FROM accounts
201            WHERE external_id = $1
202            "#,
203        )
204        .bind(external_id)
205        .fetch_one(&self.pool)
206        .await
207        .map_err(|e| match e {
208            sqlx::Error::RowNotFound => {
209                RepositoryError::NotFound(format!("Account {}", external_id))
210            }
211            _ => RepositoryError::DatabaseError(e),
212        })?;
213
214        Ok(row_to_account(&row)?)
215    }
216
217    async fn update_subscription(
218        &self,
219        id: Uuid,
220        tier: SubscriptionTier,
221    ) -> Result<(), RepositoryError> {
222        let tier_str = match tier {
223            SubscriptionTier::Free => "free",
224            SubscriptionTier::Basic => "basic",
225            SubscriptionTier::Pro => "pro",
226        };
227
228        let max_bots = match tier {
229            SubscriptionTier::Free => 0,
230            SubscriptionTier::Basic => 2,
231            SubscriptionTier::Pro => 4,
232        };
233
234        sqlx::query(
235            r#"
236            UPDATE accounts
237            SET subscription_tier = $1, max_bots = $2, updated_at = $3
238            WHERE id = $4
239            "#,
240        )
241        .bind(tier_str)
242        .bind(max_bots)
243        .bind(Utc::now())
244        .bind(id)
245        .execute(&self.pool)
246        .await?;
247
248        Ok(())
249    }
250}
251
252fn row_to_account(row: &sqlx::postgres::PgRow) -> Result<Account, RepositoryError> {
253    let tier_str: String = row.try_get("subscription_tier")?;
254    let tier = match tier_str.as_str() {
255        "free" => SubscriptionTier::Free,
256        "basic" => SubscriptionTier::Basic,
257        "pro" => SubscriptionTier::Pro,
258        _ => {
259            return Err(RepositoryError::InvalidData(format!(
260                "Unknown tier: {}",
261                tier_str
262            )))
263        }
264    };
265
266    Ok(Account {
267        id: row.try_get("id")?,
268        external_id: row.try_get("external_id")?,
269        subscription_tier: tier,
270        max_bots: row.try_get("max_bots")?,
271        created_at: row.try_get("created_at")?,
272        updated_at: row.try_get("updated_at")?,
273    })
274}
275
276pub struct PostgresBotRepository {
277    pool: PgPool,
278}
279
280impl PostgresBotRepository {
281    pub fn new(pool: PgPool) -> Self {
282        Self { pool }
283    }
284}
285
286#[async_trait]
287impl BotRepository for PostgresBotRepository {
288    async fn create(&self, bot: &Bot) -> Result<(), RepositoryError> {
289        let status_str = bot.status.to_string();
290        let persona_str = bot.persona.to_string();
291
292        sqlx::query(
293            r#"
294            INSERT INTO bots (id, account_id, name, persona, status, droplet_id, 
295                             desired_config_version_id, applied_config_version_id, 
296                             registration_token, created_at, updated_at, last_heartbeat_at)
297            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
298            "#,
299        )
300        .bind(bot.id)
301        .bind(bot.account_id)
302        .bind(&bot.name)
303        .bind(persona_str)
304        .bind(status_str)
305        .bind(bot.droplet_id)
306        .bind(bot.desired_config_version_id)
307        .bind(bot.applied_config_version_id)
308        .bind(&bot.registration_token)
309        .bind(bot.created_at)
310        .bind(bot.updated_at)
311        .bind(bot.last_heartbeat_at)
312        .execute(&self.pool)
313        .await?;
314
315        Ok(())
316    }
317
318    async fn get_by_id(&self, id: Uuid) -> Result<Bot, RepositoryError> {
319        let row = sqlx::query(
320            r#"
321            SELECT id, account_id, name, persona, status, droplet_id,
322                   desired_config_version_id, applied_config_version_id,
323                   registration_token, created_at, updated_at, last_heartbeat_at
324            FROM bots
325            WHERE id = $1
326            "#,
327        )
328        .bind(id)
329        .fetch_one(&self.pool)
330        .await
331        .map_err(|e| match e {
332            sqlx::Error::RowNotFound => RepositoryError::NotFound(format!("Bot {}", id)),
333            _ => RepositoryError::DatabaseError(e),
334        })?;
335
336        Ok(row_to_bot(&row)?)
337    }
338
339    async fn get_by_id_with_token(&self, id: Uuid, token: &str) -> Result<Bot, RepositoryError> {
340        let row = sqlx::query(
341            r#"
342            SELECT id, account_id, name, persona, status, droplet_id,
343                   desired_config_version_id, applied_config_version_id,
344                   registration_token, created_at, updated_at, last_heartbeat_at
345            FROM bots
346            WHERE id = $1 AND registration_token = $2
347            "#,
348        )
349        .bind(id)
350        .bind(token)
351        .fetch_one(&self.pool)
352        .await
353        .map_err(|e| match e {
354            sqlx::Error::RowNotFound => {
355                RepositoryError::NotFound(format!("Bot {} with invalid token", id))
356            }
357            _ => RepositoryError::DatabaseError(e),
358        })?;
359
360        Ok(row_to_bot(&row)?)
361    }
362
363    async fn list_by_account(&self, account_id: Uuid) -> Result<Vec<Bot>, RepositoryError> {
364        let rows = sqlx::query(
365            r#"
366            SELECT id, account_id, name, persona, status, droplet_id,
367                   desired_config_version_id, applied_config_version_id,
368                   registration_token, created_at, updated_at, last_heartbeat_at
369            FROM bots
370            WHERE account_id = $1
371            ORDER BY created_at DESC
372            "#,
373        )
374        .bind(account_id)
375        .fetch_all(&self.pool)
376        .await?;
377
378        rows.iter().map(row_to_bot).collect()
379    }
380
381    async fn count_by_account(&self, account_id: Uuid) -> Result<i64, RepositoryError> {
382        let count: i64 = sqlx::query_scalar(
383            r#"
384            SELECT COUNT(*) 
385            FROM bots 
386            WHERE account_id = $1
387            "#,
388        )
389        .bind(account_id)
390        .fetch_one(&self.pool)
391        .await?;
392
393        Ok(count)
394    }
395
396    async fn list_by_account_paginated(
397        &self,
398        account_id: Uuid,
399        limit: i64,
400        offset: i64,
401    ) -> Result<Vec<Bot>, RepositoryError> {
402        let rows = sqlx::query(
403            r#"
404            SELECT id, account_id, name, persona, status, droplet_id,
405                   desired_config_version_id, applied_config_version_id,
406                   registration_token, created_at, updated_at, last_heartbeat_at
407            FROM bots
408            WHERE account_id = $1
409            ORDER BY created_at DESC
410            LIMIT $2 OFFSET $3
411            "#,
412        )
413        .bind(account_id)
414        .bind(limit)
415        .bind(offset)
416        .fetch_all(&self.pool)
417        .await?;
418
419        rows.iter().map(row_to_bot).collect()
420    }
421
422    async fn update_status(&self, id: Uuid, status: BotStatus) -> Result<(), RepositoryError> {
423        let status_str = status.to_string();
424
425        sqlx::query(
426            r#"
427            UPDATE bots
428            SET status = $1, updated_at = $2
429            WHERE id = $3
430            "#,
431        )
432        .bind(status_str)
433        .bind(Utc::now())
434        .bind(id)
435        .execute(&self.pool)
436        .await?;
437
438        Ok(())
439    }
440
441    async fn update_droplet(
442        &self,
443        bot_id: Uuid,
444        droplet_id: Option<i64>,
445    ) -> Result<(), RepositoryError> {
446        sqlx::query(
447            r#"
448            UPDATE bots
449            SET droplet_id = $1, updated_at = $2
450            WHERE id = $3
451            "#,
452        )
453        .bind(droplet_id)
454        .bind(Utc::now())
455        .bind(bot_id)
456        .execute(&self.pool)
457        .await?;
458
459        Ok(())
460    }
461
462    async fn update_config_version(
463        &self,
464        bot_id: Uuid,
465        desired: Option<Uuid>,
466        applied: Option<Uuid>,
467    ) -> Result<(), RepositoryError> {
468        sqlx::query(
469            r#"
470            UPDATE bots
471            SET desired_config_version_id = $1, applied_config_version_id = $2, updated_at = $3
472            WHERE id = $4
473            "#,
474        )
475        .bind(desired)
476        .bind(applied)
477        .bind(Utc::now())
478        .bind(bot_id)
479        .execute(&self.pool)
480        .await?;
481
482        Ok(())
483    }
484
485    async fn update_heartbeat(&self, bot_id: Uuid) -> Result<(), RepositoryError> {
486        sqlx::query(
487            r#"
488            UPDATE bots
489            SET last_heartbeat_at = $1, updated_at = $2
490            WHERE id = $3
491            "#,
492        )
493        .bind(Utc::now())
494        .bind(Utc::now())
495        .bind(bot_id)
496        .execute(&self.pool)
497        .await?;
498
499        Ok(())
500    }
501
502    async fn update_registration_token(
503        &self,
504        bot_id: Uuid,
505        token: &str,
506    ) -> Result<(), RepositoryError> {
507        sqlx::query(
508            r#"
509            UPDATE bots
510            SET registration_token = $1, updated_at = $2
511            WHERE id = $3
512            "#,
513        )
514        .bind(token)
515        .bind(Utc::now())
516        .bind(bot_id)
517        .execute(&self.pool)
518        .await?;
519
520        Ok(())
521    }
522
523    async fn delete(&self, id: Uuid) -> Result<(), RepositoryError> {
524        sqlx::query(
525            r#"
526            UPDATE bots
527            SET status = 'destroyed', updated_at = $1
528            WHERE id = $2
529            "#,
530        )
531        .bind(Utc::now())
532        .bind(id)
533        .execute(&self.pool)
534        .await?;
535
536        Ok(())
537    }
538
539    async fn increment_bot_counter(
540        &self,
541        account_id: Uuid,
542    ) -> Result<(bool, i32, i32), RepositoryError> {
543        let row = sqlx::query(
544            r#"
545            SELECT success, current_count, max_count
546            FROM increment_bot_counter($1)
547            "#,
548        )
549        .bind(account_id)
550        .fetch_one(&self.pool)
551        .await
552        .map_err(|e| match e {
553            sqlx::Error::RowNotFound => {
554                // Counter doesn't exist yet - query current state
555                RepositoryError::NotFound(format!("Account counter for {}", account_id))
556            }
557            _ => RepositoryError::DatabaseError(e),
558        })?;
559
560        let success: bool = row.try_get("success")?;
561        let current_count: i32 = row.try_get("current_count")?;
562        let max_count: i32 = row.try_get("max_count")?;
563
564        Ok((success, current_count, max_count))
565    }
566
567    async fn decrement_bot_counter(&self, account_id: Uuid) -> Result<(), RepositoryError> {
568        sqlx::query("SELECT decrement_bot_counter($1)")
569            .bind(account_id)
570            .execute(&self.pool)
571            .await?;
572
573        Ok(())
574    }
575
576    async fn list_stale_bots(
577        &self,
578        threshold: chrono::DateTime<chrono::Utc>,
579    ) -> Result<Vec<Bot>, RepositoryError> {
580        let rows = sqlx::query(
581            r#"
582            SELECT id, account_id, name, persona, status, droplet_id,
583                   desired_config_version_id, applied_config_version_id,
584                   registration_token, created_at, updated_at, last_heartbeat_at
585            FROM bots
586            WHERE status = 'online'
587              AND (last_heartbeat_at < $1 OR last_heartbeat_at IS NULL)
588            "#,
589        )
590        .bind(threshold)
591        .fetch_all(&self.pool)
592        .await?;
593
594        rows.iter().map(row_to_bot).collect()
595    }
596}
597
598// MED-007: Status and persona mapping now handled by strum derive macros
599// BotStatus and Persona enums use #[derive(Display, EnumString)] for automatic
600// String <-> Enum conversion with snake_case serialization.
601
602fn row_to_bot(row: &sqlx::postgres::PgRow) -> Result<Bot, RepositoryError> {
603    let status_str: String = row.try_get("status")?;
604    let persona_str: String = row.try_get("persona")?;
605
606    Ok(Bot {
607        id: row.try_get("id")?,
608        account_id: row.try_get("account_id")?,
609        name: row.try_get("name")?,
610        persona: Persona::from_str(&persona_str).map_err(|_| {
611            RepositoryError::InvalidData(format!("Unknown persona: {}", persona_str))
612        })?,
613        status: BotStatus::from_str(&status_str)
614            .map_err(|_| RepositoryError::InvalidData(format!("Unknown status: {}", status_str)))?,
615        droplet_id: row.try_get("droplet_id")?,
616        desired_config_version_id: row.try_get("desired_config_version_id")?,
617        applied_config_version_id: row.try_get("applied_config_version_id")?,
618        registration_token: row.try_get("registration_token")?,
619        created_at: row.try_get("created_at")?,
620        updated_at: row.try_get("updated_at")?,
621        last_heartbeat_at: row.try_get("last_heartbeat_at")?,
622    })
623}