1use crate::domain::{Account, Bot, BotStatus, Droplet, Persona, StoredBotConfig, SubscriptionTier};
2use async_trait::async_trait;
3use chrono::Utc;
4use sqlx::{PgPool, Row};
5use std::str::FromStr;
6use thiserror::Error;
7use uuid::Uuid;
8
9#[derive(Error, Debug)]
10pub enum RepositoryError {
11 #[error("Database error: {0}")]
12 DatabaseError(#[from] sqlx::Error),
13 #[error("Not found: {0}")]
14 NotFound(String),
15 #[error("Invalid data: {0}")]
16 InvalidData(String),
17}
18
19#[async_trait]
20pub trait AccountRepository: Send + Sync {
21 #[must_use]
22 async fn create(&self, account: &Account) -> Result<(), RepositoryError>;
23 #[must_use]
24 async fn get_by_id(&self, id: Uuid) -> Result<Account, RepositoryError>;
25 #[must_use]
26 async fn get_by_external_id(&self, external_id: &str) -> Result<Account, RepositoryError>;
27 #[must_use]
28 async fn update_subscription(
29 &self,
30 id: Uuid,
31 tier: SubscriptionTier,
32 ) -> Result<(), RepositoryError>;
33}
34
35#[async_trait]
36pub trait BotRepository: Send + Sync {
37 #[must_use]
38 async fn create(&self, bot: &Bot) -> Result<(), RepositoryError>;
39 #[must_use]
40 async fn get_by_id(&self, id: Uuid) -> Result<Bot, RepositoryError>;
41 #[must_use]
42 async fn get_by_id_with_token(&self, id: Uuid, token: &str) -> Result<Bot, RepositoryError>;
43 #[must_use]
44 async fn list_by_account(&self, account_id: Uuid) -> Result<Vec<Bot>, RepositoryError>;
45 #[must_use]
48 async fn list_by_account_paginated(
49 &self,
50 account_id: Uuid,
51 limit: i64,
52 offset: i64,
53 ) -> Result<Vec<Bot>, RepositoryError>;
54 #[must_use]
57 async fn count_by_account(&self, account_id: Uuid) -> Result<i64, RepositoryError>;
58 #[must_use]
59 async fn update_status(&self, id: Uuid, status: BotStatus) -> Result<(), RepositoryError>;
60 #[must_use]
61 async fn update_droplet(
62 &self,
63 bot_id: Uuid,
64 droplet_id: Option<i64>,
65 ) -> Result<(), RepositoryError>;
66 #[must_use]
67 async fn update_config_version(
68 &self,
69 bot_id: Uuid,
70 desired: Option<Uuid>,
71 applied: Option<Uuid>,
72 ) -> Result<(), RepositoryError>;
73 #[must_use]
74 async fn update_heartbeat(&self, bot_id: Uuid) -> Result<(), RepositoryError>;
75 #[must_use]
76 async fn update_registration_token(
77 &self,
78 bot_id: Uuid,
79 token: &str,
80 ) -> Result<(), RepositoryError>;
81 #[must_use]
82 async fn delete(&self, id: Uuid) -> Result<(), RepositoryError>;
83 #[must_use]
86 async fn increment_bot_counter(
87 &self,
88 account_id: Uuid,
89 ) -> Result<(bool, i32, i32), RepositoryError>;
90 #[must_use]
92 async fn decrement_bot_counter(&self, account_id: Uuid) -> Result<(), RepositoryError>;
93 #[must_use]
95 async fn list_stale_bots(
96 &self,
97 threshold: chrono::DateTime<chrono::Utc>,
98 ) -> Result<Vec<Bot>, RepositoryError>;
99}
100
101#[async_trait]
102pub trait ConfigRepository: Send + Sync {
103 #[must_use]
104 async fn create(&self, config: &StoredBotConfig) -> Result<(), RepositoryError>;
105 #[must_use]
106 async fn get_by_id(&self, id: Uuid) -> Result<StoredBotConfig, RepositoryError>;
107 #[must_use]
108 async fn get_latest_for_bot(
109 &self,
110 bot_id: Uuid,
111 ) -> Result<Option<StoredBotConfig>, RepositoryError>;
112 #[must_use]
113 async fn list_by_bot(&self, bot_id: Uuid) -> Result<Vec<StoredBotConfig>, RepositoryError>;
114 #[must_use]
117 async fn get_next_version_atomic(&self, bot_id: Uuid) -> Result<i32, RepositoryError>;
118}
119
120#[async_trait]
121pub trait DropletRepository: Send + Sync {
122 #[must_use]
123 async fn create(&self, droplet: &Droplet) -> Result<(), RepositoryError>;
124 #[must_use]
125 async fn get_by_id(&self, id: i64) -> Result<Droplet, RepositoryError>;
126 #[must_use]
127 async fn update_bot_assignment(
128 &self,
129 droplet_id: i64,
130 bot_id: Option<Uuid>,
131 ) -> Result<(), RepositoryError>;
132 #[must_use]
133 async fn update_status(&self, droplet_id: i64, status: &str) -> Result<(), RepositoryError>;
134 #[must_use]
135 async fn update_ip(&self, droplet_id: i64, ip: Option<String>) -> Result<(), RepositoryError>;
136 #[must_use]
137 async fn mark_destroyed(&self, droplet_id: i64) -> Result<(), RepositoryError>;
138}
139
140pub struct PostgresAccountRepository {
141 pool: PgPool,
142}
143
144impl PostgresAccountRepository {
145 pub fn new(pool: PgPool) -> Self {
146 Self { pool }
147 }
148}
149
150#[async_trait]
151impl AccountRepository for PostgresAccountRepository {
152 async fn create(&self, account: &Account) -> Result<(), RepositoryError> {
153 let tier_str = match account.subscription_tier {
154 SubscriptionTier::Free => "free",
155 SubscriptionTier::Basic => "basic",
156 SubscriptionTier::Pro => "pro",
157 };
158
159 sqlx::query(
160 r#"
161 INSERT INTO accounts (id, external_id, subscription_tier, max_bots, created_at, updated_at)
162 VALUES ($1, $2, $3, $4, $5, $6)
163 "#,
164 )
165 .bind(account.id)
166 .bind(&account.external_id)
167 .bind(tier_str)
168 .bind(account.max_bots)
169 .bind(account.created_at)
170 .bind(account.updated_at)
171 .execute(&self.pool)
172 .await?;
173
174 Ok(())
175 }
176
177 async fn get_by_id(&self, id: Uuid) -> Result<Account, RepositoryError> {
178 let row = sqlx::query(
179 r#"
180 SELECT id, external_id, subscription_tier, max_bots, created_at, updated_at
181 FROM accounts
182 WHERE id = $1
183 "#,
184 )
185 .bind(id)
186 .fetch_one(&self.pool)
187 .await
188 .map_err(|e| match e {
189 sqlx::Error::RowNotFound => RepositoryError::NotFound(format!("Account {}", id)),
190 _ => RepositoryError::DatabaseError(e),
191 })?;
192
193 Ok(row_to_account(&row)?)
194 }
195
196 async fn get_by_external_id(&self, external_id: &str) -> Result<Account, RepositoryError> {
197 let row = sqlx::query(
198 r#"
199 SELECT id, external_id, subscription_tier, max_bots, created_at, updated_at
200 FROM accounts
201 WHERE external_id = $1
202 "#,
203 )
204 .bind(external_id)
205 .fetch_one(&self.pool)
206 .await
207 .map_err(|e| match e {
208 sqlx::Error::RowNotFound => {
209 RepositoryError::NotFound(format!("Account {}", external_id))
210 }
211 _ => RepositoryError::DatabaseError(e),
212 })?;
213
214 Ok(row_to_account(&row)?)
215 }
216
217 async fn update_subscription(
218 &self,
219 id: Uuid,
220 tier: SubscriptionTier,
221 ) -> Result<(), RepositoryError> {
222 let tier_str = match tier {
223 SubscriptionTier::Free => "free",
224 SubscriptionTier::Basic => "basic",
225 SubscriptionTier::Pro => "pro",
226 };
227
228 let max_bots = match tier {
229 SubscriptionTier::Free => 0,
230 SubscriptionTier::Basic => 2,
231 SubscriptionTier::Pro => 4,
232 };
233
234 sqlx::query(
235 r#"
236 UPDATE accounts
237 SET subscription_tier = $1, max_bots = $2, updated_at = $3
238 WHERE id = $4
239 "#,
240 )
241 .bind(tier_str)
242 .bind(max_bots)
243 .bind(Utc::now())
244 .bind(id)
245 .execute(&self.pool)
246 .await?;
247
248 Ok(())
249 }
250}
251
252fn row_to_account(row: &sqlx::postgres::PgRow) -> Result<Account, RepositoryError> {
253 let tier_str: String = row.try_get("subscription_tier")?;
254 let tier = match tier_str.as_str() {
255 "free" => SubscriptionTier::Free,
256 "basic" => SubscriptionTier::Basic,
257 "pro" => SubscriptionTier::Pro,
258 _ => {
259 return Err(RepositoryError::InvalidData(format!(
260 "Unknown tier: {}",
261 tier_str
262 )))
263 }
264 };
265
266 Ok(Account {
267 id: row.try_get("id")?,
268 external_id: row.try_get("external_id")?,
269 subscription_tier: tier,
270 max_bots: row.try_get("max_bots")?,
271 created_at: row.try_get("created_at")?,
272 updated_at: row.try_get("updated_at")?,
273 })
274}
275
276pub struct PostgresBotRepository {
277 pool: PgPool,
278}
279
280impl PostgresBotRepository {
281 pub fn new(pool: PgPool) -> Self {
282 Self { pool }
283 }
284}
285
286#[async_trait]
287impl BotRepository for PostgresBotRepository {
288 async fn create(&self, bot: &Bot) -> Result<(), RepositoryError> {
289 let status_str = bot.status.to_string();
290 let persona_str = bot.persona.to_string();
291
292 sqlx::query(
293 r#"
294 INSERT INTO bots (id, account_id, name, persona, status, droplet_id,
295 desired_config_version_id, applied_config_version_id,
296 registration_token, created_at, updated_at, last_heartbeat_at)
297 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
298 "#,
299 )
300 .bind(bot.id)
301 .bind(bot.account_id)
302 .bind(&bot.name)
303 .bind(persona_str)
304 .bind(status_str)
305 .bind(bot.droplet_id)
306 .bind(bot.desired_config_version_id)
307 .bind(bot.applied_config_version_id)
308 .bind(&bot.registration_token)
309 .bind(bot.created_at)
310 .bind(bot.updated_at)
311 .bind(bot.last_heartbeat_at)
312 .execute(&self.pool)
313 .await?;
314
315 Ok(())
316 }
317
318 async fn get_by_id(&self, id: Uuid) -> Result<Bot, RepositoryError> {
319 let row = sqlx::query(
320 r#"
321 SELECT id, account_id, name, persona, status, droplet_id,
322 desired_config_version_id, applied_config_version_id,
323 registration_token, created_at, updated_at, last_heartbeat_at
324 FROM bots
325 WHERE id = $1
326 "#,
327 )
328 .bind(id)
329 .fetch_one(&self.pool)
330 .await
331 .map_err(|e| match e {
332 sqlx::Error::RowNotFound => RepositoryError::NotFound(format!("Bot {}", id)),
333 _ => RepositoryError::DatabaseError(e),
334 })?;
335
336 Ok(row_to_bot(&row)?)
337 }
338
339 async fn get_by_id_with_token(&self, id: Uuid, token: &str) -> Result<Bot, RepositoryError> {
340 let row = sqlx::query(
341 r#"
342 SELECT id, account_id, name, persona, status, droplet_id,
343 desired_config_version_id, applied_config_version_id,
344 registration_token, created_at, updated_at, last_heartbeat_at
345 FROM bots
346 WHERE id = $1 AND registration_token = $2
347 "#,
348 )
349 .bind(id)
350 .bind(token)
351 .fetch_one(&self.pool)
352 .await
353 .map_err(|e| match e {
354 sqlx::Error::RowNotFound => {
355 RepositoryError::NotFound(format!("Bot {} with invalid token", id))
356 }
357 _ => RepositoryError::DatabaseError(e),
358 })?;
359
360 Ok(row_to_bot(&row)?)
361 }
362
363 async fn list_by_account(&self, account_id: Uuid) -> Result<Vec<Bot>, RepositoryError> {
364 let rows = sqlx::query(
365 r#"
366 SELECT id, account_id, name, persona, status, droplet_id,
367 desired_config_version_id, applied_config_version_id,
368 registration_token, created_at, updated_at, last_heartbeat_at
369 FROM bots
370 WHERE account_id = $1
371 ORDER BY created_at DESC
372 "#,
373 )
374 .bind(account_id)
375 .fetch_all(&self.pool)
376 .await?;
377
378 rows.iter().map(row_to_bot).collect()
379 }
380
381 async fn count_by_account(&self, account_id: Uuid) -> Result<i64, RepositoryError> {
382 let count: i64 = sqlx::query_scalar(
383 r#"
384 SELECT COUNT(*)
385 FROM bots
386 WHERE account_id = $1
387 "#,
388 )
389 .bind(account_id)
390 .fetch_one(&self.pool)
391 .await?;
392
393 Ok(count)
394 }
395
396 async fn list_by_account_paginated(
397 &self,
398 account_id: Uuid,
399 limit: i64,
400 offset: i64,
401 ) -> Result<Vec<Bot>, RepositoryError> {
402 let rows = sqlx::query(
403 r#"
404 SELECT id, account_id, name, persona, status, droplet_id,
405 desired_config_version_id, applied_config_version_id,
406 registration_token, created_at, updated_at, last_heartbeat_at
407 FROM bots
408 WHERE account_id = $1
409 ORDER BY created_at DESC
410 LIMIT $2 OFFSET $3
411 "#,
412 )
413 .bind(account_id)
414 .bind(limit)
415 .bind(offset)
416 .fetch_all(&self.pool)
417 .await?;
418
419 rows.iter().map(row_to_bot).collect()
420 }
421
422 async fn update_status(&self, id: Uuid, status: BotStatus) -> Result<(), RepositoryError> {
423 let status_str = status.to_string();
424
425 sqlx::query(
426 r#"
427 UPDATE bots
428 SET status = $1, updated_at = $2
429 WHERE id = $3
430 "#,
431 )
432 .bind(status_str)
433 .bind(Utc::now())
434 .bind(id)
435 .execute(&self.pool)
436 .await?;
437
438 Ok(())
439 }
440
441 async fn update_droplet(
442 &self,
443 bot_id: Uuid,
444 droplet_id: Option<i64>,
445 ) -> Result<(), RepositoryError> {
446 sqlx::query(
447 r#"
448 UPDATE bots
449 SET droplet_id = $1, updated_at = $2
450 WHERE id = $3
451 "#,
452 )
453 .bind(droplet_id)
454 .bind(Utc::now())
455 .bind(bot_id)
456 .execute(&self.pool)
457 .await?;
458
459 Ok(())
460 }
461
462 async fn update_config_version(
463 &self,
464 bot_id: Uuid,
465 desired: Option<Uuid>,
466 applied: Option<Uuid>,
467 ) -> Result<(), RepositoryError> {
468 sqlx::query(
469 r#"
470 UPDATE bots
471 SET desired_config_version_id = $1, applied_config_version_id = $2, updated_at = $3
472 WHERE id = $4
473 "#,
474 )
475 .bind(desired)
476 .bind(applied)
477 .bind(Utc::now())
478 .bind(bot_id)
479 .execute(&self.pool)
480 .await?;
481
482 Ok(())
483 }
484
485 async fn update_heartbeat(&self, bot_id: Uuid) -> Result<(), RepositoryError> {
486 sqlx::query(
487 r#"
488 UPDATE bots
489 SET last_heartbeat_at = $1, updated_at = $2
490 WHERE id = $3
491 "#,
492 )
493 .bind(Utc::now())
494 .bind(Utc::now())
495 .bind(bot_id)
496 .execute(&self.pool)
497 .await?;
498
499 Ok(())
500 }
501
502 async fn update_registration_token(
503 &self,
504 bot_id: Uuid,
505 token: &str,
506 ) -> Result<(), RepositoryError> {
507 sqlx::query(
508 r#"
509 UPDATE bots
510 SET registration_token = $1, updated_at = $2
511 WHERE id = $3
512 "#,
513 )
514 .bind(token)
515 .bind(Utc::now())
516 .bind(bot_id)
517 .execute(&self.pool)
518 .await?;
519
520 Ok(())
521 }
522
523 async fn delete(&self, id: Uuid) -> Result<(), RepositoryError> {
524 sqlx::query(
525 r#"
526 UPDATE bots
527 SET status = 'destroyed', updated_at = $1
528 WHERE id = $2
529 "#,
530 )
531 .bind(Utc::now())
532 .bind(id)
533 .execute(&self.pool)
534 .await?;
535
536 Ok(())
537 }
538
539 async fn increment_bot_counter(
540 &self,
541 account_id: Uuid,
542 ) -> Result<(bool, i32, i32), RepositoryError> {
543 let row = sqlx::query(
544 r#"
545 SELECT success, current_count, max_count
546 FROM increment_bot_counter($1)
547 "#,
548 )
549 .bind(account_id)
550 .fetch_one(&self.pool)
551 .await
552 .map_err(|e| match e {
553 sqlx::Error::RowNotFound => {
554 RepositoryError::NotFound(format!("Account counter for {}", account_id))
556 }
557 _ => RepositoryError::DatabaseError(e),
558 })?;
559
560 let success: bool = row.try_get("success")?;
561 let current_count: i32 = row.try_get("current_count")?;
562 let max_count: i32 = row.try_get("max_count")?;
563
564 Ok((success, current_count, max_count))
565 }
566
567 async fn decrement_bot_counter(&self, account_id: Uuid) -> Result<(), RepositoryError> {
568 sqlx::query("SELECT decrement_bot_counter($1)")
569 .bind(account_id)
570 .execute(&self.pool)
571 .await?;
572
573 Ok(())
574 }
575
576 async fn list_stale_bots(
577 &self,
578 threshold: chrono::DateTime<chrono::Utc>,
579 ) -> Result<Vec<Bot>, RepositoryError> {
580 let rows = sqlx::query(
581 r#"
582 SELECT id, account_id, name, persona, status, droplet_id,
583 desired_config_version_id, applied_config_version_id,
584 registration_token, created_at, updated_at, last_heartbeat_at
585 FROM bots
586 WHERE status = 'online'
587 AND (last_heartbeat_at < $1 OR last_heartbeat_at IS NULL)
588 "#,
589 )
590 .bind(threshold)
591 .fetch_all(&self.pool)
592 .await?;
593
594 rows.iter().map(row_to_bot).collect()
595 }
596}
597
598fn row_to_bot(row: &sqlx::postgres::PgRow) -> Result<Bot, RepositoryError> {
603 let status_str: String = row.try_get("status")?;
604 let persona_str: String = row.try_get("persona")?;
605
606 Ok(Bot {
607 id: row.try_get("id")?,
608 account_id: row.try_get("account_id")?,
609 name: row.try_get("name")?,
610 persona: Persona::from_str(&persona_str).map_err(|_| {
611 RepositoryError::InvalidData(format!("Unknown persona: {}", persona_str))
612 })?,
613 status: BotStatus::from_str(&status_str)
614 .map_err(|_| RepositoryError::InvalidData(format!("Unknown status: {}", status_str)))?,
615 droplet_id: row.try_get("droplet_id")?,
616 desired_config_version_id: row.try_get("desired_config_version_id")?,
617 applied_config_version_id: row.try_get("applied_config_version_id")?,
618 registration_token: row.try_get("registration_token")?,
619 created_at: row.try_get("created_at")?,
620 updated_at: row.try_get("updated_at")?,
621 last_heartbeat_at: row.try_get("last_heartbeat_at")?,
622 })
623}