entertainarr_adapter_sqlite/
podcast.rs

1use anyhow::Context;
2use entertainarr_domain::podcast::entity::{
3    ListPodcastParams, Podcast, PodcastEpisodeInput, PodcastInput, PodcastSubscription,
4    PodcastWithSubscription,
5};
6
7use crate::Wrapper;
8use crate::prelude::{HasAnyOf, QueryBuilderExt};
9
10const FIND_PODCAST_BY_ID_QUERY: &str = "select id, feed_url, title, description, image_url, language, website, created_at, updated_at from podcasts where id = ? limit 1";
11const FIND_PODCAST_BY_FEED_URL_QUERY: &str = "select id, feed_url, title, description, image_url, language, website, created_at, updated_at from podcasts where feed_url like ? limit 1";
12const UPSERT_PODCAST_QUERY: &str = r#"insert into podcasts (feed_url, title, description, image_url, language, website)
13values (?, ?, ?, ?, ?, ?)
14on conflict (feed_url) do update set
15    title=excluded.title,
16    description=excluded.description,
17    image_url=excluded.image_url,
18    language=excluded.language,
19    website=excluded.website,
20    updated_at=CURRENT_TIMESTAMP
21returning id, feed_url, title, description, image_url, language, website, created_at, updated_at"#;
22const UPSERT_USER_PODCAST_QUERY: &str = "insert into user_podcasts (user_id, podcast_id) values (?, ?) on conflict (user_id, podcast_id) do nothing";
23const DELETE_USER_PODCAST_QUERY: &str =
24    "delete from user_podcasts where user_id = ? and podcast_id = ?";
25
26impl super::Pool {
27    #[tracing::instrument(
28        skip_all,
29        fields(
30            otel.kind = "client",
31            db.system = "sqlite",
32            db.name = "podcast",
33            db.operation = "SELECT",
34            db.sql.table = "podcasts",
35            db.query.text = FIND_PODCAST_BY_ID_QUERY,
36            db.response.returned_rows = tracing::field::Empty,
37            error.type = tracing::field::Empty,
38            error.message = tracing::field::Empty,
39            error.stacktrace = tracing::field::Empty,
40        ),
41        err(Debug),
42    )]
43    pub(crate) async fn find_podcast_by_id(
44        &self,
45        podcast_id: u64,
46    ) -> anyhow::Result<Option<Podcast>> {
47        sqlx::query_as(FIND_PODCAST_BY_ID_QUERY)
48            .bind(podcast_id as i64)
49            .fetch_optional(&self.0)
50            .await
51            .inspect(super::record_optional)
52            .inspect_err(super::record_error)
53            .map(Wrapper::maybe_inner)
54            .context("unable to query podcasts by id")
55    }
56    #[tracing::instrument(
57        skip_all,
58        fields(
59            otel.kind = "client",
60            db.system = "sqlite",
61            db.name = "podcast",
62            db.operation = "SELECT",
63            db.sql.table = "podcasts",
64            db.query.text = FIND_PODCAST_BY_FEED_URL_QUERY,
65            db.response.returned_rows = tracing::field::Empty,
66            error.type = tracing::field::Empty,
67            error.message = tracing::field::Empty,
68            error.stacktrace = tracing::field::Empty,
69        ),
70        err(Debug),
71    )]
72    pub(crate) async fn find_podcast_by_feed_url(
73        &self,
74        feed_url: &str,
75    ) -> anyhow::Result<Option<Podcast>> {
76        sqlx::query_as(FIND_PODCAST_BY_FEED_URL_QUERY)
77            .bind(feed_url)
78            .fetch_optional(&self.0)
79            .await
80            .inspect(super::record_optional)
81            .inspect_err(super::record_error)
82            .map(Wrapper::maybe_inner)
83            .context("unable to query podcasts by feed url")
84    }
85
86    #[tracing::instrument(
87        skip_all,
88        fields(
89            otel.kind = "client",
90            db.system = "sqlite",
91            db.name = "podcast",
92            db.operation = "INSERT",
93            db.sql.table = "podcasts",
94            db.query.text = UPSERT_PODCAST_QUERY,
95            db.response.returned_rows = tracing::field::Empty,
96            error.type = tracing::field::Empty,
97            error.message = tracing::field::Empty,
98            error.stacktrace = tracing::field::Empty,
99        ),
100        err(Debug),
101    )]
102    pub(crate) async fn upsert_podcast<'c, E: sqlx::SqliteExecutor<'c>>(
103        &self,
104        executor: E,
105        entity: &PodcastInput,
106    ) -> anyhow::Result<Podcast> {
107        sqlx::query_as(UPSERT_PODCAST_QUERY)
108            .bind(&entity.feed_url)
109            .bind(&entity.title)
110            .bind(entity.description.as_ref())
111            .bind(entity.image_url.as_ref())
112            .bind(entity.language.as_ref())
113            .bind(entity.website.as_ref())
114            .fetch_one(executor)
115            .await
116            .inspect(super::record_one)
117            .inspect_err(super::record_error)
118            .map(Wrapper::inner)
119            .context("unable to upsert podcast")
120    }
121
122    #[tracing::instrument(
123        skip_all,
124        fields(
125            otel.kind = "client",
126            db.system = "sqlite",
127            db.name = "podcast",
128            db.operation = "INSERT",
129            db.sql.table = "podcast_episodes",
130            db.query.text = tracing::field::Empty,
131            db.response.returned_rows = tracing::field::Empty,
132            error.type = tracing::field::Empty,
133            error.message = tracing::field::Empty,
134            error.stacktrace = tracing::field::Empty,
135        ),
136        err(Debug),
137    )]
138    pub(crate) async fn upsert_podcast_episodes<'c, E: sqlx::SqliteExecutor<'c>>(
139        &self,
140        executor: E,
141        podcast_id: u64,
142        episodes: &[PodcastEpisodeInput],
143    ) -> anyhow::Result<()> {
144        let mut qb: sqlx::QueryBuilder<'_, sqlx::Sqlite> = sqlx::QueryBuilder::new(
145            "insert into podcast_episodes (podcast_id, guid, published_at, title, description, link, duration, file_url, file_size, file_type)",
146        );
147        qb.push_values(episodes.iter(), |mut b, item| {
148            b.push_bind(podcast_id as i64)
149                .push_bind(&item.guid)
150                .push_bind(item.published_at)
151                .push_bind(&item.title)
152                .push_bind(&item.description)
153                .push_bind(&item.link)
154                .push_bind(item.duration.as_ref().map(|value| value.as_secs() as i64))
155                .push_bind(&item.file_url)
156                .push_bind(item.file_size.map(|value| value as i64))
157                .push_bind(&item.file_type);
158        });
159        qb.push(" on conflict (podcast_id, guid) do nothing");
160
161        tracing::Span::current().record("db.query.text", qb.sql());
162
163        qb.build()
164            .execute(executor)
165            .await
166            .inspect_err(super::record_error)
167            .map(|_| ())
168            .context("unable to upsert podcast episodes")
169    }
170    #[tracing::instrument(
171        skip_all,
172        fields(
173            otel.kind = "client",
174            db.system = "sqlite",
175            db.name = "podcast",
176            db.operation = "SELECT",
177            db.sql.table = "podcasts",
178            db.query.text = tracing::field::Empty,
179            db.response.returned_rows = tracing::field::Empty,
180            error.type = tracing::field::Empty,
181            error.message = tracing::field::Empty,
182            error.stacktrace = tracing::field::Empty,
183        ),
184        err(Debug),
185    )]
186    pub(crate) async fn list_podcasts<'a>(
187        &self,
188        params: ListPodcastParams<'a>,
189    ) -> anyhow::Result<Vec<Podcast>> {
190        let mut qb = sqlx::QueryBuilder::new(
191            "select podcasts.id, podcasts.feed_url, podcasts.title, podcasts.description, podcasts.image_url, podcasts.language, podcasts.website, podcasts.created_at, podcasts.updated_at from podcasts",
192        );
193        if let Some(user_id) = params.user_id
194            && params.filter.subscribed.is_some()
195        {
196            qb.push(" left outer join user_podcasts on user_podcasts.podcast_id = podcasts.id and user_podcasts.user_id = ").push_bind(user_id as i64);
197        }
198        qb.push(" where true");
199        if !params.filter.podcast_ids.is_empty() {
200            qb.push(" and ");
201            qb.push_any("podcasts.id", params.filter.podcast_ids);
202        }
203        if params.user_id.is_some() {
204            match params.filter.subscribed {
205                Some(true) => {
206                    qb.push(" and user_podcasts.created_at is not null");
207                }
208                Some(false) => {
209                    qb.push(" and user_podcasts.created_at is null");
210                }
211                None => {}
212            }
213        }
214        qb.push_pagination(params.page);
215
216        tracing::Span::current().record("db.query.text", qb.sql());
217
218        qb.build_query_as()
219            .fetch_all(&self.0)
220            .await
221            .inspect(super::record_all)
222            .inspect_err(super::record_error)
223            .map(Wrapper::list)
224            .context("unable to query podcasts by feed url")
225    }
226}
227
228impl entertainarr_domain::podcast::prelude::PodcastRepository for super::Pool {
229    async fn find_by_id(&self, podcast_id: u64) -> anyhow::Result<Option<Podcast>> {
230        self.find_podcast_by_id(podcast_id).await
231    }
232
233    async fn find_by_feed_url(&self, feed_url: &str) -> anyhow::Result<Option<Podcast>> {
234        self.find_podcast_by_feed_url(feed_url).await
235    }
236
237    async fn list<'a>(&self, params: ListPodcastParams<'a>) -> anyhow::Result<Vec<Podcast>> {
238        self.list_podcasts(params).await
239    }
240
241    async fn upsert(&self, entity: &PodcastInput) -> anyhow::Result<Podcast> {
242        let mut tx = self
243            .0
244            .begin()
245            .await
246            .context("unable to begin transaction")?;
247
248        let podcast = self.upsert_podcast(&mut *tx, entity).await?;
249        self.upsert_podcast_episodes(&mut *tx, podcast.id, &entity.episodes)
250            .await?;
251
252        tx.commit().await.context("unable to commit transaction")?;
253        Ok(podcast)
254    }
255}
256
257const FIND_PODCAST_SUBSCRIPTION_BY_ID_QUERY: &str = r#"select
258    user_podcasts.user_id,
259    user_podcasts.podcast_id,
260    count(podcast_episodes.id),
261    user_podcasts.min_duration,
262    user_podcasts.max_duration,
263    user_podcasts.created_at
264from user_podcasts
265left outer join podcast_episodes on podcast_episodes.podcast_id = user_podcasts.podcast_id
266where user_podcasts.user_id = ?
267    and user_podcasts.podcast_id = ?
268    and (
269        user_podcasts.min_duration is null
270        or podcast_episodes.duration is null
271        or podcast_episodes.duration > user_podcasts.min_duration
272    )
273    and (
274        user_podcasts.max_duration is null
275        or podcast_episodes.duration is null
276        or podcast_episodes.duration < user_podcasts.max_duration
277    )
278group by user_podcasts.user_id, user_podcasts.podcast_id"#;
279
280const UPDATE_SUBSCRIPTION_QUERY: &str = r#"update user_podcasts set min_duration = ?, max_duration = ? where user_id = ? and podcast_id = ?"#;
281
282impl super::Pool {
283    #[tracing::instrument(
284        skip_all,
285        fields(
286            otel.kind = "client",
287            db.system = "sqlite",
288            db.name = "podcast",
289            db.operation = "UPSERT",
290            db.sql.table = "user_podcasts",
291            db.query.text = UPSERT_USER_PODCAST_QUERY,
292            db.response.returned_rows = tracing::field::Empty,
293            error.type = tracing::field::Empty,
294            error.message = tracing::field::Empty,
295            error.stacktrace = tracing::field::Empty,
296        ),
297        err(Debug),
298    )]
299    pub(crate) async fn create_podcast_subscription(
300        &self,
301        user_id: u64,
302        podcast_id: u64,
303    ) -> anyhow::Result<()> {
304        sqlx::query(UPSERT_USER_PODCAST_QUERY)
305            .bind(user_id as i64)
306            .bind(podcast_id as i64)
307            .execute(&self.0)
308            .await
309            .inspect_err(super::record_error)
310            .map(|_| ())
311            .context("unable to upsert user podcast relation")
312    }
313
314    #[tracing::instrument(
315        skip_all,
316        fields(
317            otel.kind = "client",
318            db.system = "sqlite",
319            db.name = "podcast",
320            db.operation = "update",
321            db.sql.table = "user_podcasts",
322            db.query.text = UPDATE_SUBSCRIPTION_QUERY,
323            db.response.returned_rows = tracing::field::Empty,
324            error.type = tracing::field::Empty,
325            error.message = tracing::field::Empty,
326            error.stacktrace = tracing::field::Empty,
327        ),
328        err(Debug),
329    )]
330    pub(crate) async fn update_podcast_subscription(
331        &self,
332        user_id: u64,
333        podcast_id: u64,
334        body: entertainarr_domain::podcast::entity::PodcastSubscriptionUpdate,
335    ) -> anyhow::Result<()> {
336        sqlx::query(UPDATE_SUBSCRIPTION_QUERY)
337            .bind(body.min_duration.map(|value| value as i64))
338            .bind(body.max_duration.map(|value| value as i64))
339            .bind(user_id as i64)
340            .bind(podcast_id as i64)
341            .fetch_optional(&self.0)
342            .await
343            .inspect_err(super::record_error)
344            .map(|_| ())
345            .context("unable to update podcast subscriptions")
346    }
347
348    #[tracing::instrument(
349        skip_all,
350        fields(
351            otel.kind = "client",
352            db.system = "sqlite",
353            db.name = "podcast",
354            db.operation = "DELETE",
355            db.sql.table = "user_podcasts",
356            db.query.text = DELETE_USER_PODCAST_QUERY,
357            db.response.returned_rows = tracing::field::Empty,
358            error.type = tracing::field::Empty,
359            error.message = tracing::field::Empty,
360            error.stacktrace = tracing::field::Empty,
361        ),
362        err(Debug),
363    )]
364    pub(crate) async fn delete_podcast_subscription(
365        &self,
366        user_id: u64,
367        podcast_id: u64,
368    ) -> anyhow::Result<()> {
369        sqlx::query(DELETE_USER_PODCAST_QUERY)
370            .bind(user_id as i64)
371            .bind(podcast_id as i64)
372            .execute(&self.0)
373            .await
374            .inspect_err(super::record_error)
375            .map(|_| ())
376            .context("unable to delete user podcast relation")
377    }
378
379    #[tracing::instrument(
380        skip_all,
381        fields(
382            otel.kind = "client",
383            db.system = "sqlite",
384            db.name = "podcast",
385            db.operation = "SELECT",
386            db.sql.table = "podcasts",
387            db.query.text = tracing::field::Empty,
388            db.response.returned_rows = tracing::field::Empty,
389            error.type = tracing::field::Empty,
390            error.message = tracing::field::Empty,
391            error.stacktrace = tracing::field::Empty,
392        ),
393        err(Debug),
394    )]
395    pub(crate) async fn list_podcast_subscriptions(
396        &self,
397        user_id: u64,
398        podcast_ids: &[u64],
399    ) -> anyhow::Result<Vec<PodcastSubscription>> {
400        let mut qb: sqlx::QueryBuilder<'_, sqlx::Sqlite> = sqlx::QueryBuilder::new(
401            "select user_podcasts.user_id, user_podcasts.podcast_id, count(podcast_episodes.id), user_podcasts.min_duration, user_podcasts.max_duration, user_podcasts.created_at",
402        );
403        qb.push(" from user_podcasts");
404        qb.push(" left outer join podcast_episodes on podcast_episodes.podcast_id = user_podcasts.podcast_id");
405        qb.push(" where user_podcasts.user_id = ")
406            .push_bind(user_id as i64);
407        if !podcast_ids.is_empty() {
408            qb.push(" and ")
409                .push_any("user_podcasts.podcast_id", podcast_ids);
410        }
411        qb.push(" and (user_podcasts.min_duration is null or podcast_episodes.duration is null or podcast_episodes.duration > user_podcasts.min_duration)");
412        qb.push(" and (user_podcasts.max_duration is null or podcast_episodes.duration is null or podcast_episodes.duration < user_podcasts.max_duration)");
413        qb.push(" group by user_podcasts.user_id, user_podcasts.podcast_id");
414
415        tracing::Span::current().record("db.query.text", qb.sql());
416
417        qb.build_query_as()
418            .fetch_all(&self.0)
419            .await
420            .inspect(super::record_all)
421            .inspect_err(super::record_error)
422            .map(Wrapper::list)
423            .context("unable to list podcast subscriptions")
424    }
425
426    #[tracing::instrument(
427        skip_all,
428        fields(
429            otel.kind = "client",
430            db.system = "sqlite",
431            db.name = "podcast",
432            db.operation = "select",
433            db.sql.table = "podcasts",
434            db.query.text = FIND_PODCAST_SUBSCRIPTION_BY_ID_QUERY,
435            db.response.returned_rows = tracing::field::Empty,
436            error.type = tracing::field::Empty,
437            error.message = tracing::field::Empty,
438            error.stacktrace = tracing::field::Empty,
439        ),
440        err(Debug),
441    )]
442    pub(crate) async fn find_podcast_subscription(
443        &self,
444        user_id: u64,
445        podcast_id: u64,
446    ) -> anyhow::Result<Option<entertainarr_domain::podcast::entity::PodcastSubscription>> {
447        sqlx::query_as(FIND_PODCAST_SUBSCRIPTION_BY_ID_QUERY)
448            .bind(user_id as i64)
449            .bind(podcast_id as i64)
450            .fetch_optional(&self.0)
451            .await
452            .inspect(super::record_optional)
453            .inspect_err(super::record_error)
454            .map(Wrapper::maybe_inner)
455            .context("unable to find podcast subscriptions")
456    }
457}
458
459impl entertainarr_domain::podcast::prelude::PodcastSubscriptionRepository for super::Pool {
460    async fn create(&self, user_id: u64, podcast_id: u64) -> anyhow::Result<()> {
461        self.create_podcast_subscription(user_id, podcast_id).await
462    }
463
464    async fn delete(&self, user_id: u64, podcast_id: u64) -> anyhow::Result<()> {
465        self.delete_podcast_subscription(user_id, podcast_id).await
466    }
467
468    async fn list(
469        &self,
470        user_id: u64,
471        podcast_ids: &[u64],
472    ) -> anyhow::Result<Vec<PodcastSubscription>> {
473        self.list_podcast_subscriptions(user_id, podcast_ids).await
474    }
475
476    async fn update(
477        &self,
478        user_id: u64,
479        podcast_id: u64,
480        body: entertainarr_domain::podcast::entity::PodcastSubscriptionUpdate,
481    ) -> anyhow::Result<()> {
482        self.update_podcast_subscription(user_id, podcast_id, body)
483            .await
484    }
485
486    async fn find(
487        &self,
488        user_id: u64,
489        podcast_id: u64,
490    ) -> anyhow::Result<Option<entertainarr_domain::podcast::entity::PodcastSubscription>> {
491        self.find_podcast_subscription(user_id, podcast_id).await
492    }
493}
494
495impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<Podcast> {
496    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
497        use sqlx::Row;
498
499        Ok(Self(Podcast {
500            id: row.try_get(0)?,
501            feed_url: row.try_get(1)?,
502            title: row.try_get(2)?,
503            description: row.try_get(3)?,
504            image_url: row.try_get(4)?,
505            language: row.try_get(5)?,
506            website: row.try_get(6)?,
507            created_at: row.try_get(7)?,
508            updated_at: row.try_get(8)?,
509        }))
510    }
511}
512
513impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<PodcastWithSubscription> {
514    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
515        use sqlx::Row;
516
517        Ok(Self(PodcastWithSubscription {
518            podcast: Podcast {
519                id: row.try_get(0)?,
520                feed_url: row.try_get(1)?,
521                title: row.try_get(2)?,
522                description: row.try_get(3)?,
523                image_url: row.try_get(4)?,
524                language: row.try_get(5)?,
525                website: row.try_get(6)?,
526                created_at: row.try_get(7)?,
527                updated_at: row.try_get(8)?,
528            },
529            queue_size: row.try_get(9)?,
530        }))
531    }
532}
533
534impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<PodcastSubscription> {
535    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
536        use sqlx::Row;
537
538        Ok(Self(PodcastSubscription {
539            user_id: row.try_get(0)?,
540            podcast_id: row.try_get(1)?,
541            queue_size: row.try_get(2)?,
542            min_duration: row.try_get(3)?,
543            max_duration: row.try_get(4)?,
544            created_at: row.try_get(5)?,
545        }))
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use entertainarr_domain::podcast::entity::{PodcastEpisodeInput, PodcastInput};
552    use entertainarr_domain::podcast::prelude::PodcastRepository;
553
554    use crate::Pool;
555
556    #[tokio::test]
557    async fn should_find_episode_by_feed_url_when_existing() {
558        let tmpdir = tempfile::tempdir().unwrap();
559        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
560
561        let podcast = pool
562            .upsert_podcast(
563                pool.as_ref(),
564                &PodcastInput::builder()
565                    .feed_url("http://foo.bar/atom.rss")
566                    .title("Foo Bar")
567                    .build(),
568            )
569            .await
570            .unwrap();
571
572        let res = pool
573            .find_by_feed_url("http://foo.bar/atom.rss")
574            .await
575            .unwrap()
576            .unwrap();
577
578        assert_eq!(podcast.id, res.id);
579    }
580
581    #[tokio::test]
582    async fn should_find_episode_by_feed_url_when_missing() {
583        let tmpdir = tempfile::tempdir().unwrap();
584        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
585        let res = pool.find_by_feed_url("missing_url").await.unwrap();
586        assert!(res.is_none());
587    }
588
589    async fn count_postcasts(pool: &Pool) -> u32 {
590        sqlx::query_scalar("select count(*) from podcasts")
591            .fetch_one(&pool.0)
592            .await
593            .unwrap()
594    }
595
596    async fn count_postcast_episodes(pool: &Pool) -> u32 {
597        sqlx::query_scalar("select count(*) from podcast_episodes")
598            .fetch_one(&pool.0)
599            .await
600            .unwrap()
601    }
602
603    #[tokio::test]
604    async fn should_upsert_podcast_and_episodes() {
605        let tmpdir = tempfile::tempdir().unwrap();
606        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
607
608        let content = PodcastInput::builder()
609            .feed_url("http://example.com/atom.rss")
610            .title("Example")
611            .episodes(vec![
612                PodcastEpisodeInput::builder()
613                    .guid("aaaaa")
614                    .title("First episode")
615                    .file_url("http://example.com/first.mp3")
616                    .build(),
617                PodcastEpisodeInput::builder()
618                    .guid("aaaab")
619                    .title("Second episode")
620                    .file_url("http://example.com/second.mp3")
621                    .build(),
622            ])
623            .build();
624        let _res = pool.upsert(&content).await.unwrap();
625        assert_eq!(count_postcasts(&pool).await, 1);
626        assert_eq!(count_postcast_episodes(&pool).await, 2);
627        // should not recreate it
628        let _res = pool.upsert(&content).await.unwrap();
629        assert_eq!(count_postcasts(&pool).await, 1);
630        assert_eq!(count_postcast_episodes(&pool).await, 2);
631    }
632}