1use anyhow::Context;
2use entertainarr_domain::task::entity::{
3 TargetKind, Task, TaskMethod, TaskParams, TaskPayload, TaskStatus,
4};
5use entertainarr_domain::task::prelude::TaskRepository;
6use sqlx::types::{Json, chrono};
7
8use crate::Wrapper;
9
10impl crate::Pool {
11 #[tracing::instrument(
12 skip(self, tasks),
13 fields(
14 otel.kind = "client",
15 db.system = "sqlite",
16 db.name = "tasks",
17 db.operation = "SELECT",
18 db.sql.table = "tasks",
19 db.query.text = tracing::field::Empty,
20 error.type = tracing::field::Empty,
21 error.message = tracing::field::Empty,
22 error.stacktrace = tracing::field::Empty,
23 ),
24 err(Debug),
25 )]
26 pub async fn insert_tasks<I>(&self, tasks: I, params: &TaskParams) -> anyhow::Result<u64>
27 where
28 I: Iterator<Item = TaskPayload>,
29 {
30 let mut qb = sqlx::QueryBuilder::new(
31 "insert into tasks (user_id, after, target_kind, target_id, method, parameters, status, retries, max_retries) ",
32 );
33 let mut count = 0;
34 qb.push_values(tasks, |mut q, payload| {
35 count += 1;
36 q.push_bind(params.user_id.map(|v| v as i64))
37 .push_bind(params.after)
38 .push_bind(Wrapper(payload.target_kind()))
39 .push_bind(payload.target_id().map(|value| value as i64))
40 .push_bind(payload.method().method())
41 .push_bind(Json(payload))
42 .push_bind(TaskStatus::Pending.as_str())
43 .push_bind(params.retry)
44 .push_bind(5);
45 });
46 if count == 0 {
47 return Ok(0);
48 }
49
50 qb.push(" on conflict (target_kind, target_id, method, parameters) where status = 'pending' do update set updated_at = CURRENT_TIMESTAMP");
51 tracing::Span::current().record("qb.query.text", qb.sql());
52 qb.build()
53 .execute(self.as_ref())
54 .await
55 .inspect_err(super::record_error)
56 .map(|res| res.rows_affected())
57 .context("unable to insert task")
58 }
59}
60
61const START_TASK_QUERY: &str = r#"with next_task as (
62 select id from tasks
63 where status = 'pending'
64 and (after is null or after <= current_timestamp)
65 order by created_at
66 limit 1
67)
68update tasks
69set status = 'running', updated_at = current_timestamp
70where id in (select id from next_task)
71returning id, user_id, after, target_kind, target_id, method, parameters, status, retries, max_retries, error_message, created_at, updated_at"#;
72
73const COMPLETE_TASK_QUERY: &str = "update tasks set status = 'completed' where id = ?";
74const FAIL_TASK_QUERY: &str = "update tasks
75set retries = retries + 1,
76 status = case when retries + 1 > max_retries then 'failed' else 'pending' end,
77 updated_at = current_timestamp
78where id = ? and status = 'running'";
79
80impl crate::Pool {
81 #[tracing::instrument(
82 skip(self),
83 fields(
84 otel.kind = "client",
85 db.system = "sqlite",
86 db.name = "tasks",
87 db.operation = "SELECT",
88 db.sql.table = "tasks",
89 db.query.text = tracing::field::Empty,
90 db.response.returned_rows = tracing::field::Empty,
91 error.type = tracing::field::Empty,
92 error.message = tracing::field::Empty,
93 error.stacktrace = tracing::field::Empty,
94 ),
95 err(Debug),
96 )]
97 pub(crate) async fn last_tasks_by_target_ids(
98 &self,
99 task_method: TaskMethod,
100 target_ids: &[u64],
101 ) -> anyhow::Result<Vec<entertainarr_domain::task::entity::Task>> {
102 if target_ids.is_empty() {
103 return Ok(Vec::new());
104 }
105
106 let (target_kind, method) = Wrapper(task_method).tuple();
107
108 let mut qb = sqlx::QueryBuilder::new(
109 "with subset_tasks as (select *, row_number() over (partition by target_id order by updated_at desc) task_index",
110 );
111 qb.push(" from tasks where target_kind = ")
112 .push_bind(Wrapper(target_kind));
113 qb.push(" and method = ").push_bind(method);
114 qb.push(" and (");
115 for (index, task_id) in target_ids.iter().enumerate() {
116 if index > 0 {
117 qb.push(" or");
118 }
119 qb.push("target_id = ").push_bind(*task_id as i64);
120 }
121 qb.push("))");
122
123 qb.push(" select id, user_id, after, target_kind, target_id, method, parameters, status, retries, max_retries, error_message, created_at, updated_at");
124 qb.push(" from subset_tasks");
125 qb.push(" where task_index = 1");
126
127 let span = tracing::Span::current();
128 span.record("db.query.text", qb.sql());
129
130 qb.build_query_as()
131 .fetch_all(self.as_ref())
132 .await
133 .inspect(super::record_all)
134 .inspect_err(super::record_error)
135 .map(super::Wrapper::list)
136 .context("unable to fetch podcast tasks")
137 }
138}
139
140impl TaskRepository for crate::Pool {
141 #[tracing::instrument(
142 skip(self),
143 fields(
144 otel.kind = "client",
145 db.system = "sqlite",
146 db.name = "tasks",
147 db.operation = "UPDATE",
148 db.sql.table = "tasks",
149 db.query.text = START_TASK_QUERY,
150 db.response.returned_rows = tracing::field::Empty,
151 error.type = tracing::field::Empty,
152 error.message = tracing::field::Empty,
153 error.stacktrace = tracing::field::Empty,
154 ),
155 err(Debug),
156 )]
157 async fn start_task(&self) -> anyhow::Result<Option<Task>> {
158 sqlx::query_as(START_TASK_QUERY)
159 .fetch_optional(self.as_ref())
160 .await
161 .inspect(crate::record_optional)
162 .inspect_err(crate::record_error)
163 .map(crate::Wrapper::maybe_inner)
164 .context("unable to start task")
165 }
166
167 #[tracing::instrument(
168 skip(self),
169 fields(
170 otel.kind = "client",
171 db.system = "sqlite",
172 db.name = "tasks",
173 db.operation = "UPDATE",
174 db.sql.table = "tasks",
175 db.query.text = COMPLETE_TASK_QUERY,
176 db.response.returned_rows = 0,
177 error.type = tracing::field::Empty,
178 error.message = tracing::field::Empty,
179 error.stacktrace = tracing::field::Empty,
180 ),
181 err(Debug),
182 )]
183 async fn complete_task(&self, task_id: u64) -> anyhow::Result<()> {
184 sqlx::query(COMPLETE_TASK_QUERY)
185 .bind(task_id as i64)
186 .execute(self.as_ref())
187 .await
188 .inspect_err(crate::record_error)
189 .map(|_| ())
190 .context("unable to complete task")
191 }
192
193 #[tracing::instrument(
194 skip(self),
195 fields(
196 otel.kind = "client",
197 db.system = "sqlite",
198 db.name = "tasks",
199 db.operation = "UPDATE",
200 db.sql.table = "tasks",
201 db.query.text = FAIL_TASK_QUERY,
202 db.response.returned_rows = 0,
203 error.type = tracing::field::Empty,
204 error.message = tracing::field::Empty,
205 error.stacktrace = tracing::field::Empty,
206 ),
207 err(Debug),
208 )]
209 async fn fail_task(&self, task_id: u64, error: String) -> anyhow::Result<()> {
210 sqlx::query(FAIL_TASK_QUERY)
211 .bind(task_id as i64)
212 .bind(error.as_str())
213 .execute(self.as_ref())
214 .await
215 .inspect_err(crate::record_error)
216 .map(|_| ())
217 .context("unable to complete task")
218 }
219
220 async fn find_by_target_ids(
221 &self,
222 task_method: TaskMethod,
223 target_ids: &[u64],
224 ) -> anyhow::Result<Vec<entertainarr_domain::task::entity::Task>> {
225 self.last_tasks_by_target_ids(task_method, target_ids).await
226 }
227}
228
229impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<Task> {
230 fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
231 use sqlx::Row;
232
233 let payload: TaskPayload = row.try_get(6).map(|Json(inner)| inner)?;
234
235 let status: String = row.try_get(7)?;
236 let retries: u8 = row.try_get(8)?;
237 let max_retries: u8 = row.try_get(9)?;
238 let error_message: Option<String> = row.try_get(10)?;
239 let created_at: chrono::DateTime<chrono::Utc> = row.try_get(11)?;
240 let updated_at: chrono::DateTime<chrono::Utc> = row.try_get(12)?;
241
242 Ok(Self(Task {
243 id: row.try_get(0)?,
244 user_id: row.try_get(1)?,
245 after: row.try_get(2)?,
246 payload,
247 status: build_task_status(status, error_message),
248 retries,
249 max_retries,
250 created_at,
251 updated_at,
252 }))
253 }
254}
255
256fn build_task_status(status: String, message: Option<String>) -> TaskStatus {
257 match status.as_str() {
258 "pending" => TaskStatus::Pending,
259 "running" => TaskStatus::Running,
260 "completed" => TaskStatus::Completed,
261 "failed" => TaskStatus::Failed { message },
262 _ => {
263 tracing::error!(%status, "invalid status");
264 TaskStatus::Pending
265 }
266 }
267}
268
269#[derive(Debug, thiserror::Error)]
270#[error("invalid target kind {_0:?}")]
271pub struct TargetKindParseError(String);
272
273impl std::str::FromStr for Wrapper<TargetKind> {
274 type Err = TargetKindParseError;
275
276 fn from_str(input: &str) -> Result<Self, Self::Err> {
277 match input {
278 "noop" => Ok(Wrapper(TargetKind::Noop)),
279 "podcast" => Ok(Wrapper(TargetKind::Podcast)),
280 "tvshow" => Ok(Wrapper(TargetKind::TvShow)),
281 other => Err(TargetKindParseError(other.to_string())),
282 }
283 }
284}
285
286impl std::fmt::Display for Wrapper<TargetKind> {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 f.write_str(self.as_str())
289 }
290}
291
292impl Wrapper<TargetKind> {
293 const fn as_str(&self) -> &'static str {
294 match self.0 {
295 TargetKind::Media => "media",
296 TargetKind::Noop => "noop",
297 TargetKind::Podcast => "podcast",
298 TargetKind::TvShow => "tvshow",
299 }
300 }
301}
302
303impl sqlx::Type<sqlx::Sqlite> for Wrapper<TargetKind> {
304 fn type_info() -> <sqlx::Sqlite as sqlx::Database>::TypeInfo {
305 <String as sqlx::Type<sqlx::Sqlite>>::type_info()
306 }
307
308 fn compatible(ty: &<sqlx::Sqlite as sqlx::Database>::TypeInfo) -> bool {
309 <String as sqlx::Type<sqlx::Sqlite>>::compatible(ty)
310 }
311}
312
313impl<'q> sqlx::Encode<'q, sqlx::Sqlite> for Wrapper<TargetKind> {
314 fn encode_by_ref(
315 &self,
316 buf: &mut <sqlx::Sqlite as sqlx::Database>::ArgumentBuffer<'q>,
317 ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
318 <String as sqlx::Encode<'q, sqlx::Sqlite>>::encode(self.to_string(), buf)
319 }
320}
321
322impl<'r> sqlx::Decode<'r, sqlx::Sqlite> for Wrapper<TargetKind> {
323 fn decode(
324 value: <sqlx::Sqlite as sqlx::Database>::ValueRef<'r>,
325 ) -> Result<Self, sqlx::error::BoxDynError> {
326 use std::str::FromStr;
327
328 <String as sqlx::Decode<'r, sqlx::Sqlite>>::decode(value).and_then(|value| {
329 Wrapper::<TargetKind>::from_str(value.as_str())
330 .map_err(|err| Box::new(err) as Box<dyn std::error::Error + Send + Sync>)
331 })
332 }
333}
334
335impl sqlx::Type<sqlx::Sqlite> for Wrapper<TaskStatus> {
336 fn type_info() -> <sqlx::Sqlite as sqlx::Database>::TypeInfo {
337 <String as sqlx::Type<sqlx::Sqlite>>::type_info()
338 }
339
340 fn compatible(ty: &<sqlx::Sqlite as sqlx::Database>::TypeInfo) -> bool {
341 <String as sqlx::Type<sqlx::Sqlite>>::compatible(ty)
342 }
343}
344
345impl Wrapper<TaskMethod> {
346 fn tuple(&self) -> (TargetKind, &'static str) {
347 match self.0 {
348 TaskMethod::MediaSynchronize => (TargetKind::Media, "synchronize"),
349 TaskMethod::MediaSynchronizeAll => (TargetKind::Media, "synchronize-all"),
350 TaskMethod::Noop => (TargetKind::Noop, "noop"),
351 TaskMethod::PodcastSynchronize => (TargetKind::Podcast, "synchronize"),
352 TaskMethod::PodcastSynchronizeAll => (TargetKind::Podcast, "synchronize-all"),
353 TaskMethod::TvShowSynchronize => (TargetKind::TvShow, "synchronize"),
354 TaskMethod::TvShowSynchronizeAll => (TargetKind::TvShow, "synchronize-all"),
355 TaskMethod::TvShowSynchronizeLocated => (TargetKind::TvShow, "synchronize-located"),
356 TaskMethod::TvShowSynchronizeSeasonLocated => {
357 (TargetKind::TvShow, "synchronize-season-located")
358 }
359 }
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use entertainarr_domain::podcast::entity::PodcastTask;
366 use entertainarr_domain::task::entity::TaskPayload;
367 use entertainarr_domain::task::prelude::TaskRepository;
368 use sqlx::types::Json;
369
370 async fn seed(pool: &crate::Pool) {
371 let _: Vec<u64> = sqlx::query_scalar("insert into users (id, email, password) values (1, 'user1@example.com', 'password'), (2, 'user2@example.com', 'password') returning id").fetch_all(pool.as_ref()).await.unwrap();
372 let _: Vec<u64> = sqlx::query_scalar("insert into podcasts (id, feed_url, title) values (1, 'first', 'first'), (2, 'second', 'second'), (3, 'third', 'third') returning id").fetch_all(pool.as_ref()).await.unwrap();
373 }
374
375 #[tokio::test]
376 async fn should_find_by_target_ids() {
377 let tmpdir = tempfile::tempdir().unwrap();
378 let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
379
380 seed(&pool).await;
381
382 let _first: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 1, 'synchronize', ?, 'complete', 0) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 1 }))).fetch_one(pool.as_ref()).await.unwrap();
383 let second: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 1, 'synchronize', ?, 'pending', 1) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 1 }))).fetch_one(pool.as_ref()).await.unwrap();
384
385 let _third: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 2, 'synchronize', ?, 'complete', 2) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 2 }))).fetch_one(pool.as_ref()).await.unwrap();
386 let _fourth: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 2, 'synchronize', ?, 'complete', 3) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 2 }))).fetch_one(pool.as_ref()).await.unwrap();
387
388 let mut list = pool
389 .find_by_target_ids(
390 entertainarr_domain::task::entity::TaskMethod::PodcastSynchronize,
391 &[1],
392 )
393 .await
394 .unwrap();
395 assert_eq!(list.len(), 1);
396 let value = list.pop().unwrap();
397 assert_eq!(value.id, second);
398 }
399}