1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use sqlx::{FromRow, SqlitePool, Type};
4use ts_rs_forge::TS;
5use uuid::Uuid;
6
7#[derive(Debug, Clone, Serialize, Deserialize, TS, Type)]
8#[sqlx(type_name = "merge_status", rename_all = "snake_case")]
9#[serde(rename_all = "snake_case")]
10pub enum MergeStatus {
11 Open,
12 Merged,
13 Closed,
14 Unknown,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, TS)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum Merge {
20 Direct(DirectMerge),
21 Pr(PrMerge),
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, TS)]
25pub struct DirectMerge {
26 pub id: Uuid,
27 pub task_attempt_id: Uuid,
28 pub merge_commit: String,
29 pub target_branch_name: String,
30 pub created_at: DateTime<Utc>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, TS)]
35pub struct PrMerge {
36 pub id: Uuid,
37 pub task_attempt_id: Uuid,
38 pub created_at: DateTime<Utc>,
39 pub target_branch_name: String,
40 pub pr_info: PullRequestInfo,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, TS)]
44pub struct PullRequestInfo {
45 pub number: i64,
46 pub url: String,
47 pub status: MergeStatus,
48 pub merged_at: Option<chrono::DateTime<chrono::Utc>>,
49 pub merge_commit_sha: Option<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, Type)]
53#[sqlx(type_name = "TEXT", rename_all = "snake_case")]
54pub enum MergeType {
55 Direct,
56 Pr,
57}
58
59#[derive(FromRow)]
60struct MergeRow {
61 id: Uuid,
62 task_attempt_id: Uuid,
63 merge_type: MergeType,
64 merge_commit: Option<String>,
65 target_branch_name: String,
66 pr_number: Option<i64>,
67 pr_url: Option<String>,
68 pr_status: Option<MergeStatus>,
69 pr_merged_at: Option<DateTime<Utc>>,
70 pr_merge_commit_sha: Option<String>,
71 created_at: DateTime<Utc>,
72}
73
74impl Merge {
75 pub fn merge_commit(&self) -> Option<String> {
76 match self {
77 Merge::Direct(direct) => Some(direct.merge_commit.clone()),
78 Merge::Pr(pr) => pr.pr_info.merge_commit_sha.clone(),
79 }
80 }
81
82 pub async fn create_direct(
84 pool: &SqlitePool,
85 task_attempt_id: Uuid,
86 target_branch_name: &str,
87 merge_commit: &str,
88 ) -> Result<DirectMerge, sqlx::Error> {
89 let id = Uuid::new_v4();
90 let now = Utc::now();
91
92 sqlx::query_as!(
93 MergeRow,
94 r#"INSERT INTO merges (
95 id, task_attempt_id, merge_type, merge_commit, created_at, target_branch_name
96 ) VALUES ($1, $2, 'direct', $3, $4, $5)
97 RETURNING
98 id as "id!: Uuid",
99 task_attempt_id as "task_attempt_id!: Uuid",
100 merge_type as "merge_type!: MergeType",
101 merge_commit,
102 pr_number,
103 pr_url,
104 pr_status as "pr_status?: MergeStatus",
105 pr_merged_at as "pr_merged_at?: DateTime<Utc>",
106 pr_merge_commit_sha,
107 created_at as "created_at!: DateTime<Utc>",
108 target_branch_name as "target_branch_name!: String"
109 "#,
110 id,
111 task_attempt_id,
112 merge_commit,
113 now,
114 target_branch_name
115 )
116 .fetch_one(pool)
117 .await
118 .map(Into::into)
119 }
120 pub async fn create_pr(
122 pool: &SqlitePool,
123 task_attempt_id: Uuid,
124 target_branch_name: &str,
125 pr_number: i64,
126 pr_url: &str,
127 ) -> Result<PrMerge, sqlx::Error> {
128 let id = Uuid::new_v4();
129 let now = Utc::now();
130
131 sqlx::query_as!(
132 MergeRow,
133 r#"INSERT INTO merges (
134 id, task_attempt_id, merge_type, pr_number, pr_url, pr_status, created_at, target_branch_name
135 ) VALUES ($1, $2, 'pr', $3, $4, 'open', $5, $6)
136 RETURNING
137 id as "id!: Uuid",
138 task_attempt_id as "task_attempt_id!: Uuid",
139 merge_type as "merge_type!: MergeType",
140 merge_commit,
141 pr_number,
142 pr_url,
143 pr_status as "pr_status?: MergeStatus",
144 pr_merged_at as "pr_merged_at?: DateTime<Utc>",
145 pr_merge_commit_sha,
146 created_at as "created_at!: DateTime<Utc>",
147 target_branch_name as "target_branch_name!: String"
148 "#,
149 id,
150 task_attempt_id,
151 pr_number,
152 pr_url,
153 now,
154 target_branch_name
155 )
156 .fetch_one(pool)
157 .await
158 .map(Into::into)
159 }
160
161 pub async fn get_open_prs(pool: &SqlitePool) -> Result<Vec<PrMerge>, sqlx::Error> {
163 let rows = sqlx::query_as!(
164 MergeRow,
165 r#"SELECT
166 id as "id!: Uuid",
167 task_attempt_id as "task_attempt_id!: Uuid",
168 merge_type as "merge_type!: MergeType",
169 merge_commit,
170 pr_number,
171 pr_url,
172 pr_status as "pr_status?: MergeStatus",
173 pr_merged_at as "pr_merged_at?: DateTime<Utc>",
174 pr_merge_commit_sha,
175 created_at as "created_at!: DateTime<Utc>",
176 target_branch_name as "target_branch_name!: String"
177 FROM merges
178 WHERE merge_type = 'pr' AND pr_status = 'open'
179 ORDER BY created_at DESC"#,
180 )
181 .fetch_all(pool)
182 .await?;
183
184 Ok(rows.into_iter().map(Into::into).collect())
185 }
186
187 pub async fn update_status(
189 pool: &SqlitePool,
190 merge_id: Uuid,
191 pr_status: MergeStatus,
192 merge_commit_sha: Option<String>,
193 ) -> Result<(), sqlx::Error> {
194 let merged_at = if matches!(pr_status, MergeStatus::Merged) {
195 Some(Utc::now())
196 } else {
197 None
198 };
199
200 sqlx::query!(
201 r#"UPDATE merges
202 SET pr_status = $1,
203 pr_merge_commit_sha = $2,
204 pr_merged_at = $3
205 WHERE id = $4"#,
206 pr_status,
207 merge_commit_sha,
208 merged_at,
209 merge_id
210 )
211 .execute(pool)
212 .await?;
213
214 Ok(())
215 }
216 pub async fn find_by_task_attempt_id(
218 pool: &SqlitePool,
219 task_attempt_id: Uuid,
220 ) -> Result<Vec<Self>, sqlx::Error> {
221 let rows = sqlx::query_as!(
223 MergeRow,
224 r#"SELECT
225 id as "id!: Uuid",
226 task_attempt_id as "task_attempt_id!: Uuid",
227 merge_type as "merge_type!: MergeType",
228 merge_commit,
229 pr_number,
230 pr_url,
231 pr_status as "pr_status?: MergeStatus",
232 pr_merged_at as "pr_merged_at?: DateTime<Utc>",
233 pr_merge_commit_sha,
234 target_branch_name as "target_branch_name!: String",
235 created_at as "created_at!: DateTime<Utc>"
236 FROM merges
237 WHERE task_attempt_id = $1
238 ORDER BY created_at DESC"#,
239 task_attempt_id
240 )
241 .fetch_all(pool)
242 .await?;
243
244 Ok(rows.into_iter().map(Into::into).collect())
246 }
247
248 pub async fn find_latest_by_task_attempt_id(
250 pool: &SqlitePool,
251 task_attempt_id: Uuid,
252 ) -> Result<Option<Self>, sqlx::Error> {
253 Self::find_by_task_attempt_id(pool, task_attempt_id)
254 .await
255 .map(|mut merges| merges.pop())
256 }
257}
258
259impl From<MergeRow> for DirectMerge {
261 fn from(row: MergeRow) -> Self {
262 DirectMerge {
263 id: row.id,
264 task_attempt_id: row.task_attempt_id,
265 merge_commit: row
266 .merge_commit
267 .expect("direct merge must have merge_commit"),
268 target_branch_name: row.target_branch_name,
269 created_at: row.created_at,
270 }
271 }
272}
273
274impl From<MergeRow> for PrMerge {
275 fn from(row: MergeRow) -> Self {
276 PrMerge {
277 id: row.id,
278 task_attempt_id: row.task_attempt_id,
279 target_branch_name: row.target_branch_name,
280 pr_info: PullRequestInfo {
281 number: row.pr_number.expect("pr merge must have pr_number"),
282 url: row.pr_url.expect("pr merge must have pr_url"),
283 status: row.pr_status.expect("pr merge must have status"),
284 merged_at: row.pr_merged_at,
285 merge_commit_sha: row.pr_merge_commit_sha,
286 },
287 created_at: row.created_at,
288 }
289 }
290}
291
292impl From<MergeRow> for Merge {
293 fn from(row: MergeRow) -> Self {
294 match row.merge_type {
295 MergeType::Direct => Merge::Direct(DirectMerge::from(row)),
296 MergeType::Pr => Merge::Pr(PrMerge::from(row)),
297 }
298 }
299}