entertainarr_adapter_sqlite/
task.rs

1use anyhow::Context;
2use entertainarr_domain::task::entity::{
3    TargetKind, Task, TaskMethod, TaskParams, TaskPayload, TaskStatus,
4};
5use entertainarr_domain::task::prelude::TaskRepository;
6use sqlx::types::{Json, chrono};
7
8use crate::Wrapper;
9
10impl crate::Pool {
11    #[tracing::instrument(
12        skip(self, tasks),
13        fields(
14            otel.kind = "client",
15            db.system = "sqlite",
16            db.name = "tasks",
17            db.operation = "SELECT",
18            db.sql.table = "tasks",
19            db.query.text = tracing::field::Empty,
20            error.type = tracing::field::Empty,
21            error.message = tracing::field::Empty,
22            error.stacktrace = tracing::field::Empty,
23        ),
24        err(Debug),
25    )]
26    pub(crate) async fn insert_tasks<I>(&self, tasks: I, params: &TaskParams) -> anyhow::Result<()>
27    where
28        I: Iterator<Item = TaskPayload>,
29    {
30        let mut qb = sqlx::QueryBuilder::new(
31            "insert into tasks (user_id, after, target_kind, target_id, method, parameters, status, retries, max_retries) ",
32        );
33        qb.push_values(tasks, |mut q, payload| {
34            q.push_bind(params.user_id.map(|v| v as i64))
35                .push_bind(params.after)
36                .push_bind(Wrapper(payload.target_kind()))
37                .push_bind(payload.target_id().map(|value| value as i64))
38                .push_bind(payload.method().method())
39                .push_bind(Json(payload))
40                .push_bind(TaskStatus::Pending.as_str())
41                .push_bind(params.retry)
42                .push_bind(5);
43        });
44        qb.push(" on conflict (target_kind, target_id, method, parameters) where status = 'pending' do update set updated_at = CURRENT_TIMESTAMP");
45        tracing::Span::current().record("qb.query.text", qb.sql());
46        qb.build()
47            .execute(self.as_ref())
48            .await
49            .inspect_err(super::record_error)
50            .map(|_| ())
51            .context("unable to insert task")
52    }
53}
54
55const START_TASK_QUERY: &str = r#"with next_task as (
56    select id from tasks
57    where status = 'pending'
58    and (after is null or after <= current_timestamp)
59    order by created_at
60    limit 1
61)
62update tasks
63set status = 'running', updated_at = current_timestamp
64where id in (select id from next_task)
65returning id, user_id, after, target_kind, target_id, method, parameters, status, retries, max_retries, error_message, created_at, updated_at"#;
66
67const COMPLETE_TASK_QUERY: &str = "update tasks set status = 'completed' where id = ?";
68const FAIL_TASK_QUERY: &str = "update tasks
69set retries = retries + 1,
70    status = case when retries + 1 > max_retries then 'failed' else 'pending' end,
71    updated_at = current_timestamp
72where id = ? and status = 'running'";
73
74impl crate::Pool {
75    #[tracing::instrument(
76        skip(self),
77        fields(
78            otel.kind = "client",
79            db.system = "sqlite",
80            db.name = "tasks",
81            db.operation = "SELECT",
82            db.sql.table = "tasks",
83            db.query.text = tracing::field::Empty,
84            db.response.returned_rows = tracing::field::Empty,
85            error.type = tracing::field::Empty,
86            error.message = tracing::field::Empty,
87            error.stacktrace = tracing::field::Empty,
88        ),
89        err(Debug),
90    )]
91    pub(crate) async fn last_tasks_by_target_ids(
92        &self,
93        task_method: TaskMethod,
94        target_ids: &[u64],
95    ) -> anyhow::Result<Vec<entertainarr_domain::task::entity::Task>> {
96        if target_ids.is_empty() {
97            return Ok(Vec::new());
98        }
99
100        let (target_kind, method) = Wrapper(task_method).tuple();
101
102        let mut qb = sqlx::QueryBuilder::new(
103            "with subset_tasks as (select *, row_number() over (partition by target_id order by updated_at desc) task_index",
104        );
105        qb.push(" from tasks where target_kind = ")
106            .push_bind(Wrapper(target_kind));
107        qb.push(" and method = ").push_bind(method);
108        qb.push(" and (");
109        for (index, task_id) in target_ids.iter().enumerate() {
110            if index > 0 {
111                qb.push(" or");
112            }
113            qb.push("target_id = ").push_bind(*task_id as i64);
114        }
115        qb.push("))");
116
117        qb.push(" select id, user_id, after, target_kind, target_id, method, parameters, status, retries, max_retries, error_message, created_at, updated_at");
118        qb.push(" from subset_tasks");
119        qb.push(" where task_index = 1");
120
121        let span = tracing::Span::current();
122        span.record("db.query.text", qb.sql());
123
124        qb.build_query_as()
125            .fetch_all(self.as_ref())
126            .await
127            .inspect(super::record_all)
128            .inspect_err(super::record_error)
129            .map(super::Wrapper::list)
130            .context("unable to fetch podcast tasks")
131    }
132}
133
134impl TaskRepository for crate::Pool {
135    #[tracing::instrument(
136        skip(self),
137        fields(
138            otel.kind = "client",
139            db.system = "sqlite",
140            db.name = "tasks",
141            db.operation = "UPDATE",
142            db.sql.table = "tasks",
143            db.query.text = START_TASK_QUERY,
144            db.response.returned_rows = tracing::field::Empty,
145            error.type = tracing::field::Empty,
146            error.message = tracing::field::Empty,
147            error.stacktrace = tracing::field::Empty,
148        ),
149        err(Debug),
150    )]
151    async fn start_task(&self) -> anyhow::Result<Option<Task>> {
152        sqlx::query_as(START_TASK_QUERY)
153            .fetch_optional(self.as_ref())
154            .await
155            .inspect(crate::record_optional)
156            .inspect_err(crate::record_error)
157            .map(crate::Wrapper::maybe_inner)
158            .context("unable to start task")
159    }
160
161    #[tracing::instrument(
162        skip(self),
163        fields(
164            otel.kind = "client",
165            db.system = "sqlite",
166            db.name = "tasks",
167            db.operation = "UPDATE",
168            db.sql.table = "tasks",
169            db.query.text = COMPLETE_TASK_QUERY,
170            db.response.returned_rows = 0,
171            error.type = tracing::field::Empty,
172            error.message = tracing::field::Empty,
173            error.stacktrace = tracing::field::Empty,
174        ),
175        err(Debug),
176    )]
177    async fn complete_task(&self, task_id: u64) -> anyhow::Result<()> {
178        sqlx::query(COMPLETE_TASK_QUERY)
179            .bind(task_id as i64)
180            .execute(self.as_ref())
181            .await
182            .inspect_err(crate::record_error)
183            .map(|_| ())
184            .context("unable to complete task")
185    }
186
187    #[tracing::instrument(
188        skip(self),
189        fields(
190            otel.kind = "client",
191            db.system = "sqlite",
192            db.name = "tasks",
193            db.operation = "UPDATE",
194            db.sql.table = "tasks",
195            db.query.text = FAIL_TASK_QUERY,
196            db.response.returned_rows = 0,
197            error.type = tracing::field::Empty,
198            error.message = tracing::field::Empty,
199            error.stacktrace = tracing::field::Empty,
200        ),
201        err(Debug),
202    )]
203    async fn fail_task(&self, task_id: u64, error: String) -> anyhow::Result<()> {
204        sqlx::query(FAIL_TASK_QUERY)
205            .bind(task_id as i64)
206            .bind(error.as_str())
207            .execute(self.as_ref())
208            .await
209            .inspect_err(crate::record_error)
210            .map(|_| ())
211            .context("unable to complete task")
212    }
213
214    async fn find_by_target_ids(
215        &self,
216        task_method: TaskMethod,
217        target_ids: &[u64],
218    ) -> anyhow::Result<Vec<entertainarr_domain::task::entity::Task>> {
219        self.last_tasks_by_target_ids(task_method, target_ids).await
220    }
221}
222
223impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<Task> {
224    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
225        use sqlx::Row;
226
227        let payload: TaskPayload = row.try_get(6).map(|Json(inner)| inner)?;
228
229        let status: String = row.try_get(7)?;
230        let retries: u8 = row.try_get(8)?;
231        let max_retries: u8 = row.try_get(9)?;
232        let error_message: Option<String> = row.try_get(10)?;
233        let created_at: chrono::DateTime<chrono::Utc> = row.try_get(11)?;
234        let updated_at: chrono::DateTime<chrono::Utc> = row.try_get(12)?;
235
236        Ok(Self(Task {
237            id: row.try_get(0)?,
238            user_id: row.try_get(1)?,
239            after: row.try_get(2)?,
240            payload,
241            status: build_task_status(status, error_message),
242            retries,
243            max_retries,
244            created_at,
245            updated_at,
246        }))
247    }
248}
249
250fn build_task_status(status: String, message: Option<String>) -> TaskStatus {
251    match status.as_str() {
252        "pending" => TaskStatus::Pending,
253        "running" => TaskStatus::Running,
254        "completed" => TaskStatus::Completed,
255        "failed" => TaskStatus::Failed { message },
256        _ => {
257            tracing::error!(%status, "invalid status");
258            TaskStatus::Pending
259        }
260    }
261}
262
263#[derive(Debug, thiserror::Error)]
264#[error("invalid target kind {_0:?}")]
265pub struct TargetKindParseError(String);
266
267impl std::str::FromStr for Wrapper<TargetKind> {
268    type Err = TargetKindParseError;
269
270    fn from_str(input: &str) -> Result<Self, Self::Err> {
271        match input {
272            "noop" => Ok(Wrapper(TargetKind::Noop)),
273            "podcast" => Ok(Wrapper(TargetKind::Podcast)),
274            "tvshow" => Ok(Wrapper(TargetKind::TvShow)),
275            other => Err(TargetKindParseError(other.to_string())),
276        }
277    }
278}
279
280impl std::fmt::Display for Wrapper<TargetKind> {
281    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282        f.write_str(self.as_str())
283    }
284}
285
286impl Wrapper<TargetKind> {
287    const fn as_str(&self) -> &'static str {
288        match self.0 {
289            TargetKind::Media => "media",
290            TargetKind::Noop => "noop",
291            TargetKind::Podcast => "podcast",
292            TargetKind::TvShow => "tvshow",
293        }
294    }
295}
296
297impl sqlx::Type<sqlx::Sqlite> for Wrapper<TargetKind> {
298    fn type_info() -> <sqlx::Sqlite as sqlx::Database>::TypeInfo {
299        <String as sqlx::Type<sqlx::Sqlite>>::type_info()
300    }
301
302    fn compatible(ty: &<sqlx::Sqlite as sqlx::Database>::TypeInfo) -> bool {
303        <String as sqlx::Type<sqlx::Sqlite>>::compatible(ty)
304    }
305}
306
307impl<'q> sqlx::Encode<'q, sqlx::Sqlite> for Wrapper<TargetKind> {
308    fn encode_by_ref(
309        &self,
310        buf: &mut <sqlx::Sqlite as sqlx::Database>::ArgumentBuffer<'q>,
311    ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
312        <String as sqlx::Encode<'q, sqlx::Sqlite>>::encode(self.to_string(), buf)
313    }
314}
315
316impl<'r> sqlx::Decode<'r, sqlx::Sqlite> for Wrapper<TargetKind> {
317    fn decode(
318        value: <sqlx::Sqlite as sqlx::Database>::ValueRef<'r>,
319    ) -> Result<Self, sqlx::error::BoxDynError> {
320        use std::str::FromStr;
321
322        <String as sqlx::Decode<'r, sqlx::Sqlite>>::decode(value).and_then(|value| {
323            Wrapper::<TargetKind>::from_str(value.as_str())
324                .map_err(|err| Box::new(err) as Box<dyn std::error::Error + Send + Sync>)
325        })
326    }
327}
328
329impl sqlx::Type<sqlx::Sqlite> for Wrapper<TaskStatus> {
330    fn type_info() -> <sqlx::Sqlite as sqlx::Database>::TypeInfo {
331        <String as sqlx::Type<sqlx::Sqlite>>::type_info()
332    }
333
334    fn compatible(ty: &<sqlx::Sqlite as sqlx::Database>::TypeInfo) -> bool {
335        <String as sqlx::Type<sqlx::Sqlite>>::compatible(ty)
336    }
337}
338
339impl Wrapper<TaskMethod> {
340    fn tuple(&self) -> (TargetKind, &'static str) {
341        match self.0 {
342            TaskMethod::MediaSynchronize => (TargetKind::Media, "synchronize"),
343            TaskMethod::MediaSynchronizeAll => (TargetKind::Media, "synchronize-all"),
344            TaskMethod::Noop => (TargetKind::Noop, "noop"),
345            TaskMethod::PodcastSynchronize => (TargetKind::Podcast, "synchronize"),
346            TaskMethod::PodcastSynchronizeAll => (TargetKind::Podcast, "synchronize-all"),
347            TaskMethod::TvShowSynchronize => (TargetKind::TvShow, "synchronize"),
348            TaskMethod::TvShowSynchronizeAll => (TargetKind::TvShow, "synchronize-all"),
349            TaskMethod::TvShowSynchronizeLocated => (TargetKind::TvShow, "synchronize-located"),
350            TaskMethod::TvShowSynchronizeSeasonLocated => {
351                (TargetKind::TvShow, "synchronize-season-located")
352            }
353        }
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use entertainarr_domain::podcast::entity::PodcastTask;
360    use entertainarr_domain::task::entity::TaskPayload;
361    use entertainarr_domain::task::prelude::TaskRepository;
362    use sqlx::types::Json;
363
364    async fn seed(pool: &crate::Pool) {
365        let _: Vec<u64> = sqlx::query_scalar("insert into users (id, email, password) values (1, 'user1@example.com', 'password'), (2, 'user2@example.com', 'password') returning id").fetch_all(pool.as_ref()).await.unwrap();
366        let _: Vec<u64> = sqlx::query_scalar("insert into podcasts (id, feed_url, title) values (1, 'first', 'first'), (2, 'second', 'second'), (3, 'third', 'third') returning id").fetch_all(pool.as_ref()).await.unwrap();
367    }
368
369    #[tokio::test]
370    async fn should_find_by_target_ids() {
371        let tmpdir = tempfile::tempdir().unwrap();
372        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
373
374        seed(&pool).await;
375
376        let _first: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 1, 'synchronize', ?, 'complete', 0) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 1 }))).fetch_one(pool.as_ref()).await.unwrap();
377        let second: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 1, 'synchronize', ?, 'pending', 1) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 1 }))).fetch_one(pool.as_ref()).await.unwrap();
378
379        let _third: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 2, 'synchronize', ?, 'complete', 2) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 2 }))).fetch_one(pool.as_ref()).await.unwrap();
380        let _fourth: u64 = sqlx::query_scalar("insert into tasks (user_id, target_kind, target_id, method, parameters, status, updated_at) values (1, 'podcast', 2, 'synchronize', ?, 'complete', 3) returning id").bind(Json(TaskPayload::Podcast(PodcastTask::Synchronize { podcast_id: 2 }))).fetch_one(pool.as_ref()).await.unwrap();
381
382        let mut list = pool
383            .find_by_target_ids(
384                entertainarr_domain::task::entity::TaskMethod::PodcastSynchronize,
385                &[1],
386            )
387            .await
388            .unwrap();
389        assert_eq!(list.len(), 1);
390        let value = list.pop().unwrap();
391        assert_eq!(value.id, second);
392    }
393}