1use anyhow::Context;
2use entertainarr_domain::task::entity::{
3 TargetKind, Task, TaskMethod, TaskParams, TaskPayload, TaskStatus,
4};
5use entertainarr_domain::task::prelude::TaskRepository;
6use sqlx::types::{Json, chrono};
7
8use crate::Wrapper;
9
10impl crate::Pool {
11 #[tracing::instrument(
12 skip(self, tasks),
13 fields(
14 otel.kind = "client",
15 db.system = "sqlite",
16 db.name = "tasks",
17 db.operation = "SELECT",
18 db.sql.table = "tasks",
19 db.query.text = tracing::field::Empty,
20 error.type = tracing::field::Empty,
21 error.message = tracing::field::Empty,
22 error.stacktrace = tracing::field::Empty,
23 ),
24 err(Debug),
25 )]
26 pub(crate) async fn insert_tasks<I>(&self, tasks: I, params: &TaskParams) -> anyhow::Result<()>
27 where
28 I: Iterator<Item = TaskPayload>,
29 {
30 let mut qb = sqlx::QueryBuilder::new(
31 "insert into tasks (user_id, after, target_kind, target_id, method, parameters, status, retries, max_retries) ",
32 );
33 qb.push_values(tasks, |mut q, payload| {
34 q.push_bind(params.user_id.map(|v| v as i64))
35 .push_bind(params.after)
36 .push_bind(Wrapper(payload.target_kind()))
37 .push_bind(payload.target_id().map(|value| value as i64))
38 .push_bind(payload.method().method())
39 .push_bind(Json(payload))
40 .push_bind(TaskStatus::Pending.as_str())
41 .push_bind(params.retry)
42 .push_bind(5);
43 });
44 qb.push(" on conflict (target_kind, target_id, method, parameters) where status = 'pending' do update set updated_at = CURRENT_TIMESTAMP");
45 tracing::Span::current().record("qb.query.text", qb.sql());
46 qb.build()
47 .execute(self.as_ref())
48 .await
49 .inspect_err(super::record_error)
50 .map(|_| ())
51 .context("unable to insert task")
52 }
53}
54
55const START_TASK_QUERY: &str = r#"with next_task as (
56 select id from tasks
57 where status = 'pending'
58 and (after is null or after <= current_timestamp)
59 order by created_at
60 limit 1
61)
62update tasks
63set status = 'running', updated_at = current_timestamp
64where id in (select id from next_task)
65returning id, user_id, after, target_kind, target_id, method, parameters, status, retries, max_retries, error_message, created_at, updated_at"#;
66
67const COMPLETE_TASK_QUERY: &str = "update tasks set status = 'completed' where id = ?";
68const FAIL_TASK_QUERY: &str = "update tasks
69set retries = retries + 1,
70 status = case when retries + 1 > max_retries then 'failed' else 'pending' end,
71 updated_at = current_timestamp
72where id = ? and status = 'running'";
73
74impl crate::Pool {
75 #[tracing::instrument(
76 skip(self),
77 fields(
78 otel.kind = "client",
79 db.system = "sqlite",
80 db.name = "tasks",
81 db.operation = "SELECT",
82 db.sql.table = "tasks",
83 db.query.text = tracing::field::Empty,
84 db.response.returned_rows = tracing::field::Empty,
85 error.type = tracing::field::Empty,
86 error.message = tracing::field::Empty,
87 error.stacktrace = tracing::field::Empty,
88 ),
89 err(Debug),
90 )]
91 pub(crate) async fn last_tasks_by_target_ids(
92 &self,
93 task_method: TaskMethod,
94 target_ids: &[u64],
95 ) -> anyhow::Result<Vec<entertainarr_domain::task::entity::Task>> {
96 if target_ids.is_empty() {
97 return Ok(Vec::new());
98 }
99
100 let (target_kind, method) = Wrapper(task_method).tuple();
101
102 let mut qb = sqlx::QueryBuilder::new(
103 "with subset_tasks as (select *, row_number() over (partition by target_id order by updated_at desc) task_index",
104 );
105 qb.push(" from tasks where target_kind = ")
106 .push_bind(Wrapper(target_kind));
107 qb.push(" and method = ").push_bind(method);
108 qb.push(" and (");
109 for (index, task_id) in target_ids.iter().enumerate() {
110 if index > 0 {
111 qb.push(" or");
112 }
113 qb.push("target_id = ").push_bind(*task_id as i64);
114 }
115 qb.push("))");
116
117 qb.push(" select id, user_id, after, target_kind, target_id, method, parameters, status, retries, max_retries, error_message, created_at, updated_at");
118 qb.push(" from subset_tasks");
119 qb.push(" where task_index = 1");
120
121 let span = tracing::Span::current();
122 span.record("db.query.text", qb.sql());
123
124 qb.build_query_as()
125 .fetch_all(self.as_ref())
126 .await
127 .inspect(super::record_all)
128 .inspect_err(super::record_error)
129 .map(super::Wrapper::list)
130 .context("unable to fetch podcast tasks")
131 }
132}
133
134impl TaskRepository for crate::Pool {
135 #[tracing::instrument(
136 skip(self),
137 fields(
138 otel.kind = "client",
139 db.system = "sqlite",
140 db.name = "tasks",
141 db.operation = "UPDATE",
142 db.sql.table = "tasks",
143 db.query.text = START_TASK_QUERY,
144 db.response.returned_rows = tracing::field::Empty,
145 error.type = tracing::field::Empty,
146 error.message = tracing::field::Empty,
147 error.stacktrace = tracing::field::Empty,
148 ),
149 err(Debug),
150 )]
151 async fn start_task(&self) -> anyhow::Result<Option<Task>> {
152 sqlx::query_as(START_TASK_QUERY)
153 .fetch_optional(self.as_ref())
154 .await
155 .inspect(crate::record_optional)
156 .inspect_err(crate::record_error)
157 .map(crate::Wrapper::maybe_inner)
158 .context("unable to start task")
159 }
160
161 #[tracing::instrument(
162 skip(self),
163 fields(
164 otel.kind = "client",
165 db.system = "sqlite",
166 db.name = "tasks",
167 db.operation = "UPDATE",
168 db.sql.table = "tasks",
169 db.query.text = COMPLETE_TASK_QUERY,
170 db.response.returned_rows = 0,
171 error.type = tracing::field::Empty,
172 error.message = tracing::field::Empty,
173 error.stacktrace = tracing::field::Empty,
174 ),
175 err(Debug),
176 )]
177 async fn complete_task(&self, task_id: u64) -> anyhow::Result<()> {
178 sqlx::query(COMPLETE_TASK_QUERY)
179 .bind(task_id as i64)
180 .execute(self.as_ref())
181 .await
182 .inspect_err(crate::record_error)
183 .map(|_| ())
184 .context("unable to complete task")
185 }
186
187 #[tracing::instrument(
188 skip(self),
189 fields(
190 otel.kind = "client",
191 db.system = "sqlite",
192 db.name = "tasks",
193 db.operation = "UPDATE",
194 db.sql.table = "tasks",
195 db.query.text = FAIL_TASK_QUERY,
196 db.response.returned_rows = 0,
197 error.type = tracing::field::Empty,
198 error.message = tracing::field::Empty,
199 error.stacktrace = tracing::field::Empty,
200 ),
201 err(Debug),
202 )]
203 async fn fail_task(&self, task_id: u64, error: String) -> anyhow::Result<()> {
204 sqlx::query(FAIL_TASK_QUERY)
205 .bind(task_id as i64)
206 .bind(error.as_str())
207 .execute(self.as_ref())
208 .await
209 .inspect_err(crate::record_error)
210 .map(|_| ())
211 .context("unable to complete task")
212 }
213
214 async fn find_by_target_ids(
215 &self,
216 task_method: TaskMethod,
217 target_ids: &[u64],
218 ) -> anyhow::Result<Vec<entertainarr_domain::task::entity::Task>> {
219 self.last_tasks_by_target_ids(task_method, target_ids).await
220 }
221}
222
223impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<Task> {
224 fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
225 use sqlx::Row;
226
227 let payload: TaskPayload = row.try_get(6).map(|Json(inner)| inner)?;
228
229 let status: String = row.try_get(7)?;
230 let retries: u8 = row.try_get(8)?;
231 let max_retries: u8 = row.try_get(9)?;
232 let error_message: Option<String> = row.try_get(10)?;
233 let created_at: chrono::DateTime<chrono::Utc> = row.try_get(11)?;
234 let updated_at: chrono::DateTime<chrono::Utc> = row.try_get(12)?;
235
236 Ok(Self(Task {
237 id: row.try_get(0)?,
238 user_id: row.try_get(1)?,
239 after: row.try_get(2)?,
240 payload,
241 status: build_task_status(status, error_message),
242 retries,
243 max_retries,
244 created_at,
245 updated_at,
246 }))
247 }
248}
249
250fn build_task_status(status: String, message: Option<String>) -> TaskStatus {
251 match status.as_str() {
252 "pending" => TaskStatus::Pending,
253 "running" => TaskStatus::Running,
254 "completed" => TaskStatus::Completed,
255 "failed" => TaskStatus::Failed { message },
256 _ => {
257 tracing::error!(%status, "invalid status");
258 TaskStatus::Pending
259 }
260 }
261}
262
263#[derive(Debug, thiserror::Error)]
264#[error("invalid target kind {_0:?}")]
265pub struct TargetKindParseError(String);
266
267impl std::str::FromStr for Wrapper<TargetKind> {
268 type Err = TargetKindParseError;
269
270 fn from_str(input: &str) -> Result<Self, Self::Err> {
271 match input {
272 "noop" => Ok(Wrapper(TargetKind::Noop)),
273 "podcast" => Ok(Wrapper(TargetKind::Podcast)),
274 "tvshow" => Ok(Wrapper(TargetKind::TvShow)),
275 other => Err(TargetKindParseError(other.to_string())),
276 }
277 }
278}
279
280impl std::fmt::Display for Wrapper<TargetKind> {
281 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282 f.write_str(self.as_str())
283 }
284}
285
286impl Wrapper<TargetKind> {
287 const fn as_str(&self) -> &'static str {
288 match self.0 {
289 TargetKind::Media => "media",
290 TargetKind::Noop => "noop",
291 TargetKind::Podcast => "podcast",
292 TargetKind::TvShow => "tvshow",
293 }
294 }
295}
296
297impl sqlx::Type<sqlx::Sqlite> for Wrapper<TargetKind> {
298 fn type_info() -> <sqlx::Sqlite as sqlx::Database>::TypeInfo {
299 <String as sqlx::Type<sqlx::Sqlite>>::type_info()
300 }
301
302 fn compatible(ty: &<sqlx::Sqlite as sqlx::Database>::TypeInfo) -> bool {
303 <String as sqlx::Type<sqlx::Sqlite>>::compatible(ty)
304 }
305}
306
307impl<'q> sqlx::Encode<'q, sqlx::Sqlite> for Wrapper<TargetKind> {
308 fn encode_by_ref(
309 &self,
310 buf: &mut <sqlx::Sqlite as sqlx::Database>::ArgumentBuffer<'q>,
311 ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
312 <String as sqlx::Encode<'q, sqlx::Sqlite>>::encode(self.to_string(), buf)
313 }
314}
315
316impl<'r> sqlx::Decode<'r, sqlx::Sqlite> for Wrapper<TargetKind> {
317 fn decode(
318 value: <sqlx::Sqlite as sqlx::Database>::ValueRef<'r>,
319 ) -> Result<Self, sqlx::error::BoxDynError> {
320 use std::str::FromStr;
321
322 <String as sqlx::Decode<'r, sqlx::Sqlite>>::decode(value).and_then(|value| {
323 Wrapper::<TargetKind>::from_str(value.as_str())
324 .map_err(|err| Box::new(err) as Box<dyn std::error::Error + Send + Sync>)
325 })
326 }
327}
328
329impl sqlx::Type<sqlx::Sqlite> for Wrapper<TaskStatus> {
330 fn type_info() -> <sqlx::Sqlite as sqlx::Database>::TypeInfo {
331 <String as sqlx::Type<sqlx::Sqlite>>::type_info()
332 }
333
334 fn compatible(ty: &<sqlx::Sqlite as sqlx::Database>::TypeInfo) -> bool {
335 <String as sqlx::Type<sqlx::Sqlite>>::compatible(ty)
336 }
337}
338
339impl Wrapper<TaskMethod> {
340 fn tuple(&self) -> (TargetKind, &'static str) {
341 match self.0 {
342 TaskMethod::MediaSynchronize => (TargetKind::Media, "synchronize"),
343 TaskMethod::MediaSynchronizeAll => (TargetKind::Media, "synchronize-all"),
344 TaskMethod::Noop => (TargetKind::Noop, "noop"),
345 TaskMethod::PodcastSynchronize => (TargetKind::Podcast, "synchronize"),
346 TaskMethod::PodcastSynchronizeAll => (TargetKind::Podcast, "synchronize-all"),
347 TaskMethod::TvShowSynchronize => (TargetKind::TvShow, "synchronize"),
348 TaskMethod::TvShowSynchronizeAll => (TargetKind::TvShow, "synchronize-all"),
349 TaskMethod::TvShowSynchronizeLocated => (TargetKind::TvShow, "synchronize-located"),
350 TaskMethod::TvShowSynchronizeSeasonLocated => {
351 (TargetKind::TvShow, "synchronize-season-located")
352 }
353 }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use entertainarr_domain::podcast::entity::PodcastTask;
360 use entertainarr_domain::task::entity::TaskPayload;
361 use entertainarr_domain::task::prelude::TaskRepository;
362 use sqlx::types::Json;
363
364 async fn seed(pool: &crate::Pool) {
365 let _: Vec<u64> = sqlx::query_scalar("insert into users (id, email, password) values (1, 'user1@example.com', 'password'), (2, 'user2@example.com', 'password') returning id").fetch_all(pool.as_ref()).await.unwrap();
366 let _: Vec<u64> = sqlx::query_scalar("insert into podcasts (id, feed_url, title) values (1, 'first', 'first'), (2, 'second', 'second'), (3, 'third', 'third') returning id").fetch_all(pool.as_ref()).await.unwrap();
367 }
368
369 #[tokio::test]
370 async fn should_find_by_target_ids() {
371 let tmpdir = tempfile::tempdir().unwrap();
372 let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
373
374 seed(&pool).await;
375
376 let _first: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 1, 'synchronize', ?, 'complete', 0) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 1 }))).fetch_one(pool.as_ref()).await.unwrap();
377 let second: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 1, 'synchronize', ?, 'pending', 1) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 1 }))).fetch_one(pool.as_ref()).await.unwrap();
378
379 let _third: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 2, 'synchronize', ?, 'complete', 2) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 2 }))).fetch_one(pool.as_ref()).await.unwrap();
380 let _fourth: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 2, 'synchronize', ?, 'complete', 3) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 2 }))).fetch_one(pool.as_ref()).await.unwrap();
381
382 let mut list = pool
383 .find_by_target_ids(
384 entertainarr_domain::task::entity::TaskMethod::PodcastSynchronize,
385 &[1],
386 )
387 .await
388 .unwrap();
389 assert_eq!(list.len(), 1);
390 let value = list.pop().unwrap();
391 assert_eq!(value.id, second);
392 }
393}