entertainarr_adapter_sqlite/
task.rs

1use anyhow::Context;
2use entertainarr_domain::task::entity::{
3    TargetKind, Task, TaskMethod, TaskParams, TaskPayload, TaskStatus,
4};
5use entertainarr_domain::task::prelude::TaskRepository;
6use sqlx::types::{Json, chrono};
7
8use crate::Wrapper;
9
10impl crate::Pool {
11    #[tracing::instrument(
12        skip(self, tasks),
13        fields(
14            otel.kind = "client",
15            db.system = "sqlite",
16            db.name = "tasks",
17            db.operation = "SELECT",
18            db.sql.table = "tasks",
19            db.query.text = tracing::field::Empty,
20            error.type = tracing::field::Empty,
21            error.message = tracing::field::Empty,
22            error.stacktrace = tracing::field::Empty,
23        ),
24        err(Debug),
25    )]
26    pub async fn insert_tasks<I>(&self, tasks: I, params: &TaskParams) -> anyhow::Result<u64>
27    where
28        I: Iterator<Item = TaskPayload>,
29    {
30        let mut qb = sqlx::QueryBuilder::new(
31            "insert into tasks (user_id, after, target_kind, target_id, method, parameters, status, retries, max_retries) ",
32        );
33        let mut count = 0;
34        qb.push_values(tasks, |mut q, payload| {
35            count += 1;
36            q.push_bind(params.user_id.map(|v| v as i64))
37                .push_bind(params.after)
38                .push_bind(Wrapper(payload.target_kind()))
39                .push_bind(payload.target_id().map(|value| value as i64))
40                .push_bind(payload.method().method())
41                .push_bind(Json(payload))
42                .push_bind(TaskStatus::Pending.as_str())
43                .push_bind(params.retry)
44                .push_bind(5);
45        });
46        if count == 0 {
47            return Ok(0);
48        }
49
50        qb.push(" on conflict (target_kind, target_id, method, parameters) where status = 'pending' do update set updated_at = CURRENT_TIMESTAMP");
51        tracing::Span::current().record("qb.query.text", qb.sql());
52        qb.build()
53            .execute(self.as_ref())
54            .await
55            .inspect_err(super::record_error)
56            .map(|res| res.rows_affected())
57            .context("unable to insert task")
58    }
59}
60
61const START_TASK_QUERY: &str = r#"with next_task as (
62    select id from tasks
63    where status = 'pending'
64    and (after is null or after <= current_timestamp)
65    order by created_at
66    limit 1
67)
68update tasks
69set status = 'running', updated_at = current_timestamp
70where id in (select id from next_task)
71returning id, user_id, after, target_kind, target_id, method, parameters, status, retries, max_retries, error_message, created_at, updated_at"#;
72
73const COMPLETE_TASK_QUERY: &str = "update tasks set status = 'completed' where id = ?";
74const FAIL_TASK_QUERY: &str = "update tasks
75set retries = retries + 1,
76    status = case when retries + 1 > max_retries then 'failed' else 'pending' end,
77    updated_at = current_timestamp
78where id = ? and status = 'running'";
79
80impl crate::Pool {
81    #[tracing::instrument(
82        skip(self),
83        fields(
84            otel.kind = "client",
85            db.system = "sqlite",
86            db.name = "tasks",
87            db.operation = "SELECT",
88            db.sql.table = "tasks",
89            db.query.text = tracing::field::Empty,
90            db.response.returned_rows = tracing::field::Empty,
91            error.type = tracing::field::Empty,
92            error.message = tracing::field::Empty,
93            error.stacktrace = tracing::field::Empty,
94        ),
95        err(Debug),
96    )]
97    pub(crate) async fn last_tasks_by_target_ids(
98        &self,
99        task_method: TaskMethod,
100        target_ids: &[u64],
101    ) -> anyhow::Result<Vec<entertainarr_domain::task::entity::Task>> {
102        if target_ids.is_empty() {
103            return Ok(Vec::new());
104        }
105
106        let (target_kind, method) = Wrapper(task_method).tuple();
107
108        let mut qb = sqlx::QueryBuilder::new(
109            "with subset_tasks as (select *, row_number() over (partition by target_id order by updated_at desc) task_index",
110        );
111        qb.push(" from tasks where target_kind = ")
112            .push_bind(Wrapper(target_kind));
113        qb.push(" and method = ").push_bind(method);
114        qb.push(" and (");
115        for (index, task_id) in target_ids.iter().enumerate() {
116            if index > 0 {
117                qb.push(" or");
118            }
119            qb.push("target_id = ").push_bind(*task_id as i64);
120        }
121        qb.push("))");
122
123        qb.push(" select id, user_id, after, target_kind, target_id, method, parameters, status, retries, max_retries, error_message, created_at, updated_at");
124        qb.push(" from subset_tasks");
125        qb.push(" where task_index = 1");
126
127        let span = tracing::Span::current();
128        span.record("db.query.text", qb.sql());
129
130        qb.build_query_as()
131            .fetch_all(self.as_ref())
132            .await
133            .inspect(super::record_all)
134            .inspect_err(super::record_error)
135            .map(super::Wrapper::list)
136            .context("unable to fetch podcast tasks")
137    }
138}
139
140impl TaskRepository for crate::Pool {
141    #[tracing::instrument(
142        skip(self),
143        fields(
144            otel.kind = "client",
145            db.system = "sqlite",
146            db.name = "tasks",
147            db.operation = "UPDATE",
148            db.sql.table = "tasks",
149            db.query.text = START_TASK_QUERY,
150            db.response.returned_rows = tracing::field::Empty,
151            error.type = tracing::field::Empty,
152            error.message = tracing::field::Empty,
153            error.stacktrace = tracing::field::Empty,
154        ),
155        err(Debug),
156    )]
157    async fn start_task(&self) -> anyhow::Result<Option<Task>> {
158        sqlx::query_as(START_TASK_QUERY)
159            .fetch_optional(self.as_ref())
160            .await
161            .inspect(crate::record_optional)
162            .inspect_err(crate::record_error)
163            .map(crate::Wrapper::maybe_inner)
164            .context("unable to start task")
165    }
166
167    #[tracing::instrument(
168        skip(self),
169        fields(
170            otel.kind = "client",
171            db.system = "sqlite",
172            db.name = "tasks",
173            db.operation = "UPDATE",
174            db.sql.table = "tasks",
175            db.query.text = COMPLETE_TASK_QUERY,
176            db.response.returned_rows = 0,
177            error.type = tracing::field::Empty,
178            error.message = tracing::field::Empty,
179            error.stacktrace = tracing::field::Empty,
180        ),
181        err(Debug),
182    )]
183    async fn complete_task(&self, task_id: u64) -> anyhow::Result<()> {
184        sqlx::query(COMPLETE_TASK_QUERY)
185            .bind(task_id as i64)
186            .execute(self.as_ref())
187            .await
188            .inspect_err(crate::record_error)
189            .map(|_| ())
190            .context("unable to complete task")
191    }
192
193    #[tracing::instrument(
194        skip(self),
195        fields(
196            otel.kind = "client",
197            db.system = "sqlite",
198            db.name = "tasks",
199            db.operation = "UPDATE",
200            db.sql.table = "tasks",
201            db.query.text = FAIL_TASK_QUERY,
202            db.response.returned_rows = 0,
203            error.type = tracing::field::Empty,
204            error.message = tracing::field::Empty,
205            error.stacktrace = tracing::field::Empty,
206        ),
207        err(Debug),
208    )]
209    async fn fail_task(&self, task_id: u64, error: String) -> anyhow::Result<()> {
210        sqlx::query(FAIL_TASK_QUERY)
211            .bind(task_id as i64)
212            .bind(error.as_str())
213            .execute(self.as_ref())
214            .await
215            .inspect_err(crate::record_error)
216            .map(|_| ())
217            .context("unable to complete task")
218    }
219
220    async fn find_by_target_ids(
221        &self,
222        task_method: TaskMethod,
223        target_ids: &[u64],
224    ) -> anyhow::Result<Vec<entertainarr_domain::task::entity::Task>> {
225        self.last_tasks_by_target_ids(task_method, target_ids).await
226    }
227}
228
229impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<Task> {
230    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
231        use sqlx::Row;
232
233        let payload: TaskPayload = row.try_get(6).map(|Json(inner)| inner)?;
234
235        let status: String = row.try_get(7)?;
236        let retries: u8 = row.try_get(8)?;
237        let max_retries: u8 = row.try_get(9)?;
238        let error_message: Option<String> = row.try_get(10)?;
239        let created_at: chrono::DateTime<chrono::Utc> = row.try_get(11)?;
240        let updated_at: chrono::DateTime<chrono::Utc> = row.try_get(12)?;
241
242        Ok(Self(Task {
243            id: row.try_get(0)?,
244            user_id: row.try_get(1)?,
245            after: row.try_get(2)?,
246            payload,
247            status: build_task_status(status, error_message),
248            retries,
249            max_retries,
250            created_at,
251            updated_at,
252        }))
253    }
254}
255
256fn build_task_status(status: String, message: Option<String>) -> TaskStatus {
257    match status.as_str() {
258        "pending" => TaskStatus::Pending,
259        "running" => TaskStatus::Running,
260        "completed" => TaskStatus::Completed,
261        "failed" => TaskStatus::Failed { message },
262        _ => {
263            tracing::error!(%status, "invalid status");
264            TaskStatus::Pending
265        }
266    }
267}
268
269#[derive(Debug, thiserror::Error)]
270#[error("invalid target kind {_0:?}")]
271pub struct TargetKindParseError(String);
272
273impl std::str::FromStr for Wrapper<TargetKind> {
274    type Err = TargetKindParseError;
275
276    fn from_str(input: &str) -> Result<Self, Self::Err> {
277        match input {
278            "noop" => Ok(Wrapper(TargetKind::Noop)),
279            "podcast" => Ok(Wrapper(TargetKind::Podcast)),
280            "tvshow" => Ok(Wrapper(TargetKind::TvShow)),
281            other => Err(TargetKindParseError(other.to_string())),
282        }
283    }
284}
285
286impl std::fmt::Display for Wrapper<TargetKind> {
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        f.write_str(self.as_str())
289    }
290}
291
292impl Wrapper<TargetKind> {
293    const fn as_str(&self) -> &'static str {
294        match self.0 {
295            TargetKind::Media => "media",
296            TargetKind::Noop => "noop",
297            TargetKind::Podcast => "podcast",
298            TargetKind::TvShow => "tvshow",
299        }
300    }
301}
302
303impl sqlx::Type<sqlx::Sqlite> for Wrapper<TargetKind> {
304    fn type_info() -> <sqlx::Sqlite as sqlx::Database>::TypeInfo {
305        <String as sqlx::Type<sqlx::Sqlite>>::type_info()
306    }
307
308    fn compatible(ty: &<sqlx::Sqlite as sqlx::Database>::TypeInfo) -> bool {
309        <String as sqlx::Type<sqlx::Sqlite>>::compatible(ty)
310    }
311}
312
313impl<'q> sqlx::Encode<'q, sqlx::Sqlite> for Wrapper<TargetKind> {
314    fn encode_by_ref(
315        &self,
316        buf: &mut <sqlx::Sqlite as sqlx::Database>::ArgumentBuffer<'q>,
317    ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
318        <String as sqlx::Encode<'q, sqlx::Sqlite>>::encode(self.to_string(), buf)
319    }
320}
321
322impl<'r> sqlx::Decode<'r, sqlx::Sqlite> for Wrapper<TargetKind> {
323    fn decode(
324        value: <sqlx::Sqlite as sqlx::Database>::ValueRef<'r>,
325    ) -> Result<Self, sqlx::error::BoxDynError> {
326        use std::str::FromStr;
327
328        <String as sqlx::Decode<'r, sqlx::Sqlite>>::decode(value).and_then(|value| {
329            Wrapper::<TargetKind>::from_str(value.as_str())
330                .map_err(|err| Box::new(err) as Box<dyn std::error::Error + Send + Sync>)
331        })
332    }
333}
334
335impl sqlx::Type<sqlx::Sqlite> for Wrapper<TaskStatus> {
336    fn type_info() -> <sqlx::Sqlite as sqlx::Database>::TypeInfo {
337        <String as sqlx::Type<sqlx::Sqlite>>::type_info()
338    }
339
340    fn compatible(ty: &<sqlx::Sqlite as sqlx::Database>::TypeInfo) -> bool {
341        <String as sqlx::Type<sqlx::Sqlite>>::compatible(ty)
342    }
343}
344
345impl Wrapper<TaskMethod> {
346    fn tuple(&self) -> (TargetKind, &'static str) {
347        match self.0 {
348            TaskMethod::MediaSynchronize => (TargetKind::Media, "synchronize"),
349            TaskMethod::MediaSynchronizeAll => (TargetKind::Media, "synchronize-all"),
350            TaskMethod::Noop => (TargetKind::Noop, "noop"),
351            TaskMethod::PodcastSynchronize => (TargetKind::Podcast, "synchronize"),
352            TaskMethod::PodcastSynchronizeAll => (TargetKind::Podcast, "synchronize-all"),
353            TaskMethod::TvShowSynchronize => (TargetKind::TvShow, "synchronize"),
354            TaskMethod::TvShowSynchronizeAll => (TargetKind::TvShow, "synchronize-all"),
355            TaskMethod::TvShowSynchronizeLocated => (TargetKind::TvShow, "synchronize-located"),
356            TaskMethod::TvShowSynchronizeSeasonLocated => {
357                (TargetKind::TvShow, "synchronize-season-located")
358            }
359        }
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use entertainarr_domain::podcast::entity::PodcastTask;
366    use entertainarr_domain::task::entity::TaskPayload;
367    use entertainarr_domain::task::prelude::TaskRepository;
368    use sqlx::types::Json;
369
370    async fn seed(pool: &crate::Pool) {
371        let _: Vec<u64> = sqlx::query_scalar("insert into users (id, email, password) values (1, 'user1@example.com', 'password'), (2, 'user2@example.com', 'password') returning id").fetch_all(pool.as_ref()).await.unwrap();
372        let _: Vec<u64> = sqlx::query_scalar("insert into podcasts (id, feed_url, title) values (1, 'first', 'first'), (2, 'second', 'second'), (3, 'third', 'third') returning id").fetch_all(pool.as_ref()).await.unwrap();
373    }
374
375    #[tokio::test]
376    async fn should_find_by_target_ids() {
377        let tmpdir = tempfile::tempdir().unwrap();
378        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
379
380        seed(&pool).await;
381
382        let _first: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 1, 'synchronize', ?, 'complete', 0) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 1 }))).fetch_one(pool.as_ref()).await.unwrap();
383        let second: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 1, 'synchronize', ?, 'pending', 1) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 1 }))).fetch_one(pool.as_ref()).await.unwrap();
384
385        let _third: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 2, 'synchronize', ?, 'complete', 2) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 2 }))).fetch_one(pool.as_ref()).await.unwrap();
386        let _fourth: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 2, 'synchronize', ?, 'complete', 3) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 2 }))).fetch_one(pool.as_ref()).await.unwrap();
387
388        let mut list = pool
389            .find_by_target_ids(
390                entertainarr_domain::task::entity::TaskMethod::PodcastSynchronize,
391                &[1],
392            )
393            .await
394            .unwrap();
395        assert_eq!(list.len(), 1);
396        let value = list.pop().unwrap();
397        assert_eq!(value.id, second);
398    }
399}