Skip to main content

claw_spawn/infrastructure/
postgres_config_repo.rs

1use crate::domain::{EncryptedBotSecrets, RiskConfig, StoredBotConfig, TradingConfig};
2use crate::infrastructure::{ConfigRepository, RepositoryError};
3use async_trait::async_trait;
4use sqlx::{PgPool, Row};
5use uuid::Uuid;
6
7pub struct PostgresConfigRepository {
8    pool: PgPool,
9}
10
11impl PostgresConfigRepository {
12    pub fn new(pool: PgPool) -> Self {
13        Self { pool }
14    }
15}
16
17#[async_trait]
18impl ConfigRepository for PostgresConfigRepository {
19    async fn create(&self, config: &StoredBotConfig) -> Result<(), RepositoryError> {
20        let trading_json = serde_json::to_value(&config.trading_config).map_err(|e| {
21            RepositoryError::InvalidData(format!("Failed to serialize trading config: {}", e))
22        })?;
23        let risk_json = serde_json::to_value(&config.risk_config).map_err(|e| {
24            RepositoryError::InvalidData(format!("Failed to serialize risk config: {}", e))
25        })?;
26
27        sqlx::query(
28            r#"
29            INSERT INTO bot_configs (id, bot_id, version, trading_config, risk_config, secrets_encrypted, llm_provider, created_at)
30            VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
31            "#,
32        )
33        .bind(config.id)
34        .bind(config.bot_id)
35        .bind(config.version)
36        .bind(trading_json)
37        .bind(risk_json)
38        .bind(&config.secrets.llm_api_key_encrypted)
39        .bind(&config.secrets.llm_provider)
40        .bind(config.created_at)
41        .execute(&self.pool)
42        .await?;
43
44        Ok(())
45    }
46
47    async fn get_by_id(&self, id: Uuid) -> Result<StoredBotConfig, RepositoryError> {
48        let row = sqlx::query(
49            r#"
50            SELECT id, bot_id, version, trading_config, risk_config, secrets_encrypted, llm_provider, created_at
51            FROM bot_configs
52            WHERE id = $1
53            "#,
54        )
55        .bind(id)
56        .fetch_one(&self.pool)
57        .await
58        .map_err(|e| match e {
59            sqlx::Error::RowNotFound => RepositoryError::NotFound(format!("Config {}", id)),
60            _ => RepositoryError::DatabaseError(e),
61        })?;
62
63        Ok(row_to_config(&row)?)
64    }
65
66    async fn get_latest_for_bot(
67        &self,
68        bot_id: Uuid,
69    ) -> Result<Option<StoredBotConfig>, RepositoryError> {
70        let row = sqlx::query(
71            r#"
72            SELECT id, bot_id, version, trading_config, risk_config, secrets_encrypted, llm_provider, created_at
73            FROM bot_configs
74            WHERE bot_id = $1
75            ORDER BY version DESC
76            LIMIT 1
77            "#,
78        )
79        .bind(bot_id)
80        .fetch_optional(&self.pool)
81        .await?;
82
83        match row {
84            Some(r) => Ok(Some(row_to_config(&r)?)),
85            None => Ok(None),
86        }
87    }
88
89    async fn list_by_bot(&self, bot_id: Uuid) -> Result<Vec<StoredBotConfig>, RepositoryError> {
90        let rows = sqlx::query(
91            r#"
92            SELECT id, bot_id, version, trading_config, risk_config, secrets_encrypted, llm_provider, created_at
93            FROM bot_configs
94            WHERE bot_id = $1
95            ORDER BY version ASC
96            "#,
97        )
98        .bind(bot_id)
99        .fetch_all(&self.pool)
100        .await?;
101
102        rows.iter().map(row_to_config).collect()
103    }
104
105    async fn get_next_version_atomic(&self, bot_id: Uuid) -> Result<i32, RepositoryError> {
106        let row = sqlx::query(
107            r#"
108            SELECT get_next_config_version_atomic($1) as version
109            "#,
110        )
111        .bind(bot_id)
112        .fetch_one(&self.pool)
113        .await?;
114
115        let version: i32 = row.try_get("version")?;
116        Ok(version)
117    }
118}
119
120fn row_to_config(row: &sqlx::postgres::PgRow) -> Result<StoredBotConfig, RepositoryError> {
121    let trading_json: serde_json::Value = row.try_get("trading_config")?;
122    let risk_json: serde_json::Value = row.try_get("risk_config")?;
123    let encrypted_secrets: Vec<u8> = row.try_get("secrets_encrypted")?;
124
125    let trading_config: TradingConfig = serde_json::from_value(trading_json).map_err(|e| {
126        RepositoryError::InvalidData(format!("Failed to deserialize trading config: {}", e))
127    })?;
128    let risk_config: RiskConfig = serde_json::from_value(risk_json).map_err(|e| {
129        RepositoryError::InvalidData(format!("Failed to deserialize risk config: {}", e))
130    })?;
131
132    let llm_provider: String = row.try_get("llm_provider")?;
133
134    Ok(StoredBotConfig {
135        id: row.try_get("id")?,
136        bot_id: row.try_get("bot_id")?,
137        version: row.try_get("version")?,
138        trading_config,
139        risk_config,
140        secrets: EncryptedBotSecrets {
141            llm_provider,
142            llm_api_key_encrypted: encrypted_secrets,
143        },
144        created_at: row.try_get("created_at")?,
145    })
146}