forge_core_db/models/
merge.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use sqlx::{FromRow, SqlitePool, Type};
4use ts_rs_forge::TS;
5use uuid::Uuid;
6
7#[derive(Debug, Clone, Serialize, Deserialize, TS, Type)]
8#[sqlx(type_name = "merge_status", rename_all = "snake_case")]
9#[serde(rename_all = "snake_case")]
10pub enum MergeStatus {
11    Open,
12    Merged,
13    Closed,
14    Unknown,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, TS)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum Merge {
20    Direct(DirectMerge),
21    Pr(PrMerge),
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, TS)]
25pub struct DirectMerge {
26    pub id: Uuid,
27    pub task_attempt_id: Uuid,
28    pub merge_commit: String,
29    pub target_branch_name: String,
30    pub created_at: DateTime<Utc>,
31}
32
33/// PR merge - represents a pull request merge
34#[derive(Debug, Clone, Serialize, Deserialize, TS)]
35pub struct PrMerge {
36    pub id: Uuid,
37    pub task_attempt_id: Uuid,
38    pub created_at: DateTime<Utc>,
39    pub target_branch_name: String,
40    pub pr_info: PullRequestInfo,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, TS)]
44pub struct PullRequestInfo {
45    pub number: i64,
46    pub url: String,
47    pub status: MergeStatus,
48    pub merged_at: Option<chrono::DateTime<chrono::Utc>>,
49    pub merge_commit_sha: Option<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, Type)]
53#[sqlx(type_name = "TEXT", rename_all = "snake_case")]
54pub enum MergeType {
55    Direct,
56    Pr,
57}
58
59#[derive(FromRow)]
60struct MergeRow {
61    id: Uuid,
62    task_attempt_id: Uuid,
63    merge_type: MergeType,
64    merge_commit: Option<String>,
65    target_branch_name: String,
66    pr_number: Option<i64>,
67    pr_url: Option<String>,
68    pr_status: Option<MergeStatus>,
69    pr_merged_at: Option<DateTime<Utc>>,
70    pr_merge_commit_sha: Option<String>,
71    created_at: DateTime<Utc>,
72}
73
74impl Merge {
75    pub fn merge_commit(&self) -> Option<String> {
76        match self {
77            Merge::Direct(direct) => Some(direct.merge_commit.clone()),
78            Merge::Pr(pr) => pr.pr_info.merge_commit_sha.clone(),
79        }
80    }
81
82    /// Create a direct merge record
83    pub async fn create_direct(
84        pool: &SqlitePool,
85        task_attempt_id: Uuid,
86        target_branch_name: &str,
87        merge_commit: &str,
88    ) -> Result<DirectMerge, sqlx::Error> {
89        let id = Uuid::new_v4();
90        let now = Utc::now();
91
92        sqlx::query_as!(
93            MergeRow,
94            r#"INSERT INTO merges (
95                id, task_attempt_id, merge_type, merge_commit, created_at, target_branch_name
96            ) VALUES ($1, $2, 'direct', $3, $4, $5)
97            RETURNING 
98                id as "id!: Uuid",
99                task_attempt_id as "task_attempt_id!: Uuid",
100                merge_type as "merge_type!: MergeType",
101                merge_commit,
102                pr_number,
103                pr_url,
104                pr_status as "pr_status?: MergeStatus",
105                pr_merged_at as "pr_merged_at?: DateTime<Utc>",
106                pr_merge_commit_sha,
107                created_at as "created_at!: DateTime<Utc>",
108                target_branch_name as "target_branch_name!: String"
109            "#,
110            id,
111            task_attempt_id,
112            merge_commit,
113            now,
114            target_branch_name
115        )
116        .fetch_one(pool)
117        .await
118        .map(Into::into)
119    }
120    /// Create a new PR record (when PR is opened)
121    pub async fn create_pr(
122        pool: &SqlitePool,
123        task_attempt_id: Uuid,
124        target_branch_name: &str,
125        pr_number: i64,
126        pr_url: &str,
127    ) -> Result<PrMerge, sqlx::Error> {
128        let id = Uuid::new_v4();
129        let now = Utc::now();
130
131        sqlx::query_as!(
132            MergeRow,
133            r#"INSERT INTO merges (
134                id, task_attempt_id, merge_type, pr_number, pr_url, pr_status, created_at, target_branch_name
135            ) VALUES ($1, $2, 'pr', $3, $4, 'open', $5, $6)
136            RETURNING 
137                id as "id!: Uuid",
138                task_attempt_id as "task_attempt_id!: Uuid",
139                merge_type as "merge_type!: MergeType",
140                merge_commit,
141                pr_number,
142                pr_url,
143                pr_status as "pr_status?: MergeStatus",
144                pr_merged_at as "pr_merged_at?: DateTime<Utc>",
145                pr_merge_commit_sha,
146                created_at as "created_at!: DateTime<Utc>",
147                target_branch_name as "target_branch_name!: String"
148            "#,
149            id,
150            task_attempt_id,
151            pr_number,
152            pr_url,
153            now,
154            target_branch_name
155        )
156        .fetch_one(pool)
157        .await
158        .map(Into::into)
159    }
160
161    /// Get all open PRs for monitoring
162    pub async fn get_open_prs(pool: &SqlitePool) -> Result<Vec<PrMerge>, sqlx::Error> {
163        let rows = sqlx::query_as!(
164            MergeRow,
165            r#"SELECT 
166                id as "id!: Uuid",
167                task_attempt_id as "task_attempt_id!: Uuid",
168                merge_type as "merge_type!: MergeType",
169                merge_commit,
170                pr_number,
171                pr_url,
172                pr_status as "pr_status?: MergeStatus",
173                pr_merged_at as "pr_merged_at?: DateTime<Utc>",
174                pr_merge_commit_sha,
175                created_at as "created_at!: DateTime<Utc>",
176                target_branch_name as "target_branch_name!: String"
177               FROM merges 
178               WHERE merge_type = 'pr' AND pr_status = 'open'
179               ORDER BY created_at DESC"#,
180        )
181        .fetch_all(pool)
182        .await?;
183
184        Ok(rows.into_iter().map(Into::into).collect())
185    }
186
187    /// Update PR status for a task attempt
188    pub async fn update_status(
189        pool: &SqlitePool,
190        merge_id: Uuid,
191        pr_status: MergeStatus,
192        merge_commit_sha: Option<String>,
193    ) -> Result<(), sqlx::Error> {
194        let merged_at = if matches!(pr_status, MergeStatus::Merged) {
195            Some(Utc::now())
196        } else {
197            None
198        };
199
200        sqlx::query!(
201            r#"UPDATE merges 
202            SET pr_status = $1, 
203                pr_merge_commit_sha = $2,
204                pr_merged_at = $3
205            WHERE id = $4"#,
206            pr_status,
207            merge_commit_sha,
208            merged_at,
209            merge_id
210        )
211        .execute(pool)
212        .await?;
213
214        Ok(())
215    }
216    /// Find all merges for a task attempt (returns both direct and PR merges)
217    pub async fn find_by_task_attempt_id(
218        pool: &SqlitePool,
219        task_attempt_id: Uuid,
220    ) -> Result<Vec<Self>, sqlx::Error> {
221        // Get raw data from database
222        let rows = sqlx::query_as!(
223            MergeRow,
224            r#"SELECT 
225                id as "id!: Uuid",
226                task_attempt_id as "task_attempt_id!: Uuid",
227                merge_type as "merge_type!: MergeType",
228                merge_commit,
229                pr_number,
230                pr_url,
231                pr_status as "pr_status?: MergeStatus",
232                pr_merged_at as "pr_merged_at?: DateTime<Utc>",
233                pr_merge_commit_sha,
234                target_branch_name as "target_branch_name!: String",
235                created_at as "created_at!: DateTime<Utc>"
236            FROM merges 
237            WHERE task_attempt_id = $1
238            ORDER BY created_at DESC"#,
239            task_attempt_id
240        )
241        .fetch_all(pool)
242        .await?;
243
244        // Convert to appropriate types based on merge_type
245        Ok(rows.into_iter().map(Into::into).collect())
246    }
247
248    /// Find the most recent merge for a task attempt
249    pub async fn find_latest_by_task_attempt_id(
250        pool: &SqlitePool,
251        task_attempt_id: Uuid,
252    ) -> Result<Option<Self>, sqlx::Error> {
253        Self::find_by_task_attempt_id(pool, task_attempt_id)
254            .await
255            .map(|mut merges| merges.pop())
256    }
257}
258
259// Conversion implementations
260impl From<MergeRow> for DirectMerge {
261    fn from(row: MergeRow) -> Self {
262        DirectMerge {
263            id: row.id,
264            task_attempt_id: row.task_attempt_id,
265            merge_commit: row
266                .merge_commit
267                .expect("direct merge must have merge_commit"),
268            target_branch_name: row.target_branch_name,
269            created_at: row.created_at,
270        }
271    }
272}
273
274impl From<MergeRow> for PrMerge {
275    fn from(row: MergeRow) -> Self {
276        PrMerge {
277            id: row.id,
278            task_attempt_id: row.task_attempt_id,
279            target_branch_name: row.target_branch_name,
280            pr_info: PullRequestInfo {
281                number: row.pr_number.expect("pr merge must have pr_number"),
282                url: row.pr_url.expect("pr merge must have pr_url"),
283                status: row.pr_status.expect("pr merge must have status"),
284                merged_at: row.pr_merged_at,
285                merge_commit_sha: row.pr_merge_commit_sha,
286            },
287            created_at: row.created_at,
288        }
289    }
290}
291
292impl From<MergeRow> for Merge {
293    fn from(row: MergeRow) -> Self {
294        match row.merge_type {
295            MergeType::Direct => Merge::Direct(DirectMerge::from(row)),
296            MergeType::Pr => Merge::Pr(PrMerge::from(row)),
297        }
298    }
299}