entertainarr_adapter_sqlite/
podcast.rs

1use anyhow::Context;
2use entertainarr_domain::podcast::entity::{
3    ListPodcastParams, Podcast, PodcastEpisodeInput, PodcastInput, PodcastSubscription,
4    PodcastWithSubscription,
5};
6
7use crate::{
8    Wrapper,
9    prelude::{HasAnyOf, QueryBuilderExt},
10};
11
12const FIND_PODCAST_BY_ID_QUERY: &str = "select id, feed_url, title, description, image_url, language, website, created_at, updated_at from podcasts where id = ? limit 1";
13const FIND_PODCAST_BY_FEED_URL_QUERY: &str = "select id, feed_url, title, description, image_url, language, website, created_at, updated_at from podcasts where feed_url like ? limit 1";
14const UPSERT_PODCAST_QUERY: &str = r#"insert into podcasts (feed_url, title, description, image_url, language, website)
15values (?, ?, ?, ?, ?, ?)
16on conflict (feed_url) do update set
17    title=excluded.title,
18    description=excluded.description,
19    image_url=excluded.image_url,
20    language=excluded.language,
21    website=excluded.website,
22    updated_at=CURRENT_TIMESTAMP
23returning id, feed_url, title, description, image_url, language, website, created_at, updated_at"#;
24const UPSERT_USER_PODCAST_QUERY: &str = "insert into user_podcasts (user_id, podcast_id) values (?, ?) on conflict (user_id, podcast_id) do nothing";
25const DELETE_USER_PODCAST_QUERY: &str =
26    "delete from user_podcasts where user_id = ? and podcast_id = ?";
27
28impl super::Pool {
29    #[tracing::instrument(
30        skip_all,
31        fields(
32            otel.kind = "client",
33            db.system = "sqlite",
34            db.name = "podcast",
35            db.operation = "SELECT",
36            db.sql.table = "podcasts",
37            db.query.text = FIND_PODCAST_BY_ID_QUERY,
38            db.response.returned_rows = tracing::field::Empty,
39            error.type = tracing::field::Empty,
40            error.message = tracing::field::Empty,
41            error.stacktrace = tracing::field::Empty,
42        ),
43        err(Debug),
44    )]
45    pub(crate) async fn find_podcast_by_id(
46        &self,
47        podcast_id: u64,
48    ) -> anyhow::Result<Option<Podcast>> {
49        sqlx::query_as(FIND_PODCAST_BY_ID_QUERY)
50            .bind(podcast_id as i64)
51            .fetch_optional(&self.0)
52            .await
53            .inspect(super::record_optional)
54            .inspect_err(super::record_error)
55            .map(Wrapper::maybe_inner)
56            .context("unable to query podcasts by id")
57    }
58    #[tracing::instrument(
59        skip_all,
60        fields(
61            otel.kind = "client",
62            db.system = "sqlite",
63            db.name = "podcast",
64            db.operation = "SELECT",
65            db.sql.table = "podcasts",
66            db.query.text = FIND_PODCAST_BY_FEED_URL_QUERY,
67            db.response.returned_rows = tracing::field::Empty,
68            error.type = tracing::field::Empty,
69            error.message = tracing::field::Empty,
70            error.stacktrace = tracing::field::Empty,
71        ),
72        err(Debug),
73    )]
74    pub(crate) async fn find_podcast_by_feed_url(
75        &self,
76        feed_url: &str,
77    ) -> anyhow::Result<Option<Podcast>> {
78        sqlx::query_as(FIND_PODCAST_BY_FEED_URL_QUERY)
79            .bind(feed_url)
80            .fetch_optional(&self.0)
81            .await
82            .inspect(super::record_optional)
83            .inspect_err(super::record_error)
84            .map(Wrapper::maybe_inner)
85            .context("unable to query podcasts by feed url")
86    }
87
88    #[tracing::instrument(
89        skip_all,
90        fields(
91            otel.kind = "client",
92            db.system = "sqlite",
93            db.name = "podcast",
94            db.operation = "INSERT",
95            db.sql.table = "podcasts",
96            db.query.text = UPSERT_PODCAST_QUERY,
97            db.response.returned_rows = tracing::field::Empty,
98            error.type = tracing::field::Empty,
99            error.message = tracing::field::Empty,
100            error.stacktrace = tracing::field::Empty,
101        ),
102        err(Debug),
103    )]
104    pub(crate) async fn upsert_podcast<'c, E: sqlx::SqliteExecutor<'c>>(
105        &self,
106        executor: E,
107        entity: &PodcastInput,
108    ) -> anyhow::Result<Podcast> {
109        sqlx::query_as(UPSERT_PODCAST_QUERY)
110            .bind(&entity.feed_url)
111            .bind(&entity.title)
112            .bind(entity.description.as_ref())
113            .bind(entity.image_url.as_ref())
114            .bind(entity.language.as_ref())
115            .bind(entity.website.as_ref())
116            .fetch_one(executor)
117            .await
118            .inspect(super::record_one)
119            .inspect_err(super::record_error)
120            .map(Wrapper::inner)
121            .context("unable to upsert podcast")
122    }
123
124    #[tracing::instrument(
125        skip_all,
126        fields(
127            otel.kind = "client",
128            db.system = "sqlite",
129            db.name = "podcast",
130            db.operation = "INSERT",
131            db.sql.table = "podcast_episodes",
132            db.query.text = tracing::field::Empty,
133            db.response.returned_rows = tracing::field::Empty,
134            error.type = tracing::field::Empty,
135            error.message = tracing::field::Empty,
136            error.stacktrace = tracing::field::Empty,
137        ),
138        err(Debug),
139    )]
140    pub(crate) async fn upsert_podcast_episodes<'c, E: sqlx::SqliteExecutor<'c>>(
141        &self,
142        executor: E,
143        podcast_id: u64,
144        episodes: &[PodcastEpisodeInput],
145    ) -> anyhow::Result<()> {
146        let mut qb: sqlx::QueryBuilder<'_, sqlx::Sqlite> = sqlx::QueryBuilder::new(
147            "insert into podcast_episodes (podcast_id, guid, published_at, title, description, link, duration, file_url, file_size, file_type)",
148        );
149        qb.push_values(episodes.iter(), |mut b, item| {
150            b.push_bind(podcast_id as i64)
151                .push_bind(&item.guid)
152                .push_bind(item.published_at)
153                .push_bind(&item.title)
154                .push_bind(&item.description)
155                .push_bind(&item.link)
156                .push_bind(item.duration.as_ref().map(|value| value.as_secs() as i64))
157                .push_bind(&item.file_url)
158                .push_bind(item.file_size.map(|value| value as i64))
159                .push_bind(&item.file_type);
160        });
161        qb.push(" on conflict (podcast_id, guid) do nothing");
162
163        tracing::Span::current().record("db.query.text", qb.sql());
164
165        qb.build()
166            .execute(executor)
167            .await
168            .inspect_err(super::record_error)
169            .map(|_| ())
170            .context("unable to upsert podcast episodes")
171    }
172    #[tracing::instrument(
173        skip_all,
174        fields(
175            otel.kind = "client",
176            db.system = "sqlite",
177            db.name = "podcast",
178            db.operation = "SELECT",
179            db.sql.table = "podcasts",
180            db.query.text = tracing::field::Empty,
181            db.response.returned_rows = tracing::field::Empty,
182            error.type = tracing::field::Empty,
183            error.message = tracing::field::Empty,
184            error.stacktrace = tracing::field::Empty,
185        ),
186        err(Debug),
187    )]
188    pub(crate) async fn list_podcasts<'a>(
189        &self,
190        params: ListPodcastParams<'a>,
191    ) -> anyhow::Result<Vec<Podcast>> {
192        let mut qb = sqlx::QueryBuilder::new(
193            "select podcasts.id, podcasts.feed_url, podcasts.title, podcasts.description, podcasts.image_url, podcasts.language, podcasts.website, podcasts.created_at, podcasts.updated_at from podcasts",
194        );
195        if params.filter.subscribed.is_some() {
196            qb.push(" left outer join user_podcasts on user_podcasts.podcast_id = podcasts.id and user_podcasts.user_id = ").push_bind(params.user_id as i64);
197        }
198        qb.push(" where true");
199        if !params.filter.podcast_ids.is_empty() {
200            qb.push(" and ");
201            qb.push_any("podcasts.id", params.filter.podcast_ids);
202        }
203        match params.filter.subscribed {
204            Some(true) => {
205                qb.push(" and user_podcasts.created_at is not null");
206            }
207            Some(false) => {
208                qb.push(" and user_podcasts.created_at is null");
209            }
210            None => {}
211        }
212        qb.push_pagination(params.page);
213
214        tracing::Span::current().record("db.query.text", qb.sql());
215
216        qb.build_query_as()
217            .fetch_all(&self.0)
218            .await
219            .inspect(super::record_all)
220            .inspect_err(super::record_error)
221            .map(Wrapper::list)
222            .context("unable to query podcasts by feed url")
223    }
224}
225
226impl entertainarr_domain::podcast::prelude::PodcastRepository for super::Pool {
227    async fn find_by_id(&self, podcast_id: u64) -> anyhow::Result<Option<Podcast>> {
228        self.find_podcast_by_id(podcast_id).await
229    }
230
231    async fn find_by_feed_url(&self, feed_url: &str) -> anyhow::Result<Option<Podcast>> {
232        self.find_podcast_by_feed_url(feed_url).await
233    }
234
235    async fn list<'a>(&self, params: ListPodcastParams<'a>) -> anyhow::Result<Vec<Podcast>> {
236        self.list_podcasts(params).await
237    }
238
239    async fn upsert(&self, entity: &PodcastInput) -> anyhow::Result<Podcast> {
240        let mut tx = self
241            .0
242            .begin()
243            .await
244            .context("unable to begin transaction")?;
245
246        let podcast = self.upsert_podcast(&mut *tx, entity).await?;
247        self.upsert_podcast_episodes(&mut *tx, podcast.id, &entity.episodes)
248            .await?;
249
250        tx.commit().await.context("unable to commit transaction")?;
251        Ok(podcast)
252    }
253}
254
255const FIND_PODCAST_SUBSCRIPTION_BY_ID_QUERY: &str = r#"select
256    user_podcasts.user_id,
257    user_podcasts.podcast_id,
258    count(podcast_episodes.id),
259    user_podcasts.min_duration,
260    user_podcasts.max_duration,
261    user_podcasts.created_at
262from user_podcasts
263left outer join podcast_episodes on podcast_episodes.podcast_id = user_podcasts.podcast_id
264where user_podcasts.user_id = ?
265    and user_podcasts.podcast_id = ?
266    and (
267        user_podcasts.min_duration is null
268        or podcast_episodes.duration is null
269        or podcast_episodes.duration > user_podcasts.min_duration
270    )
271    and (
272        user_podcasts.max_duration is null
273        or podcast_episodes.duration is null
274        or podcast_episodes.duration < user_podcasts.max_duration
275    )
276group by user_podcasts.user_id, user_podcasts.podcast_id"#;
277
278const UPDATE_SUBSCRIPTION_QUERY: &str = r#"update user_podcasts set min_duration = ?, max_duration = ? where user_id = ? and podcast_id = ?"#;
279
280impl super::Pool {
281    #[tracing::instrument(
282        skip_all,
283        fields(
284            otel.kind = "client",
285            db.system = "sqlite",
286            db.name = "podcast",
287            db.operation = "UPSERT",
288            db.sql.table = "user_podcasts",
289            db.query.text = UPSERT_USER_PODCAST_QUERY,
290            db.response.returned_rows = tracing::field::Empty,
291            error.type = tracing::field::Empty,
292            error.message = tracing::field::Empty,
293            error.stacktrace = tracing::field::Empty,
294        ),
295        err(Debug),
296    )]
297    pub(crate) async fn create_podcast_subscription(
298        &self,
299        user_id: u64,
300        podcast_id: u64,
301    ) -> anyhow::Result<()> {
302        sqlx::query(UPSERT_USER_PODCAST_QUERY)
303            .bind(user_id as i64)
304            .bind(podcast_id as i64)
305            .execute(&self.0)
306            .await
307            .inspect_err(super::record_error)
308            .map(|_| ())
309            .context("unable to upsert user podcast relation")
310    }
311
312    #[tracing::instrument(
313        skip_all,
314        fields(
315            otel.kind = "client",
316            db.system = "sqlite",
317            db.name = "podcast",
318            db.operation = "update",
319            db.sql.table = "user_podcasts",
320            db.query.text = UPDATE_SUBSCRIPTION_QUERY,
321            db.response.returned_rows = tracing::field::Empty,
322            error.type = tracing::field::Empty,
323            error.message = tracing::field::Empty,
324            error.stacktrace = tracing::field::Empty,
325        ),
326        err(Debug),
327    )]
328    pub(crate) async fn update_podcast_subscription(
329        &self,
330        user_id: u64,
331        podcast_id: u64,
332        body: entertainarr_domain::podcast::entity::PodcastSubscriptionUpdate,
333    ) -> anyhow::Result<()> {
334        sqlx::query(UPDATE_SUBSCRIPTION_QUERY)
335            .bind(body.min_duration.map(|value| value as i64))
336            .bind(body.max_duration.map(|value| value as i64))
337            .bind(user_id as i64)
338            .bind(podcast_id as i64)
339            .fetch_optional(&self.0)
340            .await
341            .inspect_err(super::record_error)
342            .map(|_| ())
343            .context("unable to update podcast subscriptions")
344    }
345
346    #[tracing::instrument(
347        skip_all,
348        fields(
349            otel.kind = "client",
350            db.system = "sqlite",
351            db.name = "podcast",
352            db.operation = "DELETE",
353            db.sql.table = "user_podcasts",
354            db.query.text = DELETE_USER_PODCAST_QUERY,
355            db.response.returned_rows = tracing::field::Empty,
356            error.type = tracing::field::Empty,
357            error.message = tracing::field::Empty,
358            error.stacktrace = tracing::field::Empty,
359        ),
360        err(Debug),
361    )]
362    pub(crate) async fn delete_podcast_subscription(
363        &self,
364        user_id: u64,
365        podcast_id: u64,
366    ) -> anyhow::Result<()> {
367        sqlx::query(DELETE_USER_PODCAST_QUERY)
368            .bind(user_id as i64)
369            .bind(podcast_id as i64)
370            .execute(&self.0)
371            .await
372            .inspect_err(super::record_error)
373            .map(|_| ())
374            .context("unable to delete user podcast relation")
375    }
376
377    #[tracing::instrument(
378        skip_all,
379        fields(
380            otel.kind = "client",
381            db.system = "sqlite",
382            db.name = "podcast",
383            db.operation = "SELECT",
384            db.sql.table = "podcasts",
385            db.query.text = tracing::field::Empty,
386            db.response.returned_rows = tracing::field::Empty,
387            error.type = tracing::field::Empty,
388            error.message = tracing::field::Empty,
389            error.stacktrace = tracing::field::Empty,
390        ),
391        err(Debug),
392    )]
393    pub(crate) async fn list_podcast_subscriptions(
394        &self,
395        user_id: u64,
396        podcast_ids: &[u64],
397    ) -> anyhow::Result<Vec<PodcastSubscription>> {
398        let mut qb: sqlx::QueryBuilder<'_, sqlx::Sqlite> = sqlx::QueryBuilder::new(
399            "select user_podcasts.user_id, user_podcasts.podcast_id, count(podcast_episodes.id), user_podcasts.min_duration, user_podcasts.max_duration, user_podcasts.created_at",
400        );
401        qb.push(" from user_podcasts");
402        qb.push(" left outer join podcast_episodes on podcast_episodes.podcast_id = user_podcasts.podcast_id");
403        qb.push(" where user_podcasts.user_id = ")
404            .push_bind(user_id as i64);
405        if !podcast_ids.is_empty() {
406            qb.push(" and ")
407                .push_any("user_podcasts.podcast_id", podcast_ids);
408        }
409        qb.push(" and (user_podcasts.min_duration is null or podcast_episodes.duration is null or podcast_episodes.duration > user_podcasts.min_duration)");
410        qb.push(" and (user_podcasts.max_duration is null or podcast_episodes.duration is null or podcast_episodes.duration < user_podcasts.max_duration)");
411        qb.push(" group by user_podcasts.user_id, user_podcasts.podcast_id");
412
413        tracing::Span::current().record("db.query.text", qb.sql());
414
415        qb.build_query_as()
416            .fetch_all(&self.0)
417            .await
418            .inspect(super::record_all)
419            .inspect_err(super::record_error)
420            .map(Wrapper::list)
421            .context("unable to list podcast subscriptions")
422    }
423
424    #[tracing::instrument(
425        skip_all,
426        fields(
427            otel.kind = "client",
428            db.system = "sqlite",
429            db.name = "podcast",
430            db.operation = "select",
431            db.sql.table = "podcasts",
432            db.query.text = FIND_PODCAST_SUBSCRIPTION_BY_ID_QUERY,
433            db.response.returned_rows = tracing::field::Empty,
434            error.type = tracing::field::Empty,
435            error.message = tracing::field::Empty,
436            error.stacktrace = tracing::field::Empty,
437        ),
438        err(Debug),
439    )]
440    pub(crate) async fn find_podcast_subscription(
441        &self,
442        user_id: u64,
443        podcast_id: u64,
444    ) -> anyhow::Result<Option<entertainarr_domain::podcast::entity::PodcastSubscription>> {
445        sqlx::query_as(FIND_PODCAST_SUBSCRIPTION_BY_ID_QUERY)
446            .bind(user_id as i64)
447            .bind(podcast_id as i64)
448            .fetch_optional(&self.0)
449            .await
450            .inspect(super::record_optional)
451            .inspect_err(super::record_error)
452            .map(Wrapper::maybe_inner)
453            .context("unable to find podcast subscriptions")
454    }
455}
456
457impl entertainarr_domain::podcast::prelude::PodcastSubscriptionRepository for super::Pool {
458    async fn create(&self, user_id: u64, podcast_id: u64) -> anyhow::Result<()> {
459        self.create_podcast_subscription(user_id, podcast_id).await
460    }
461
462    async fn delete(&self, user_id: u64, podcast_id: u64) -> anyhow::Result<()> {
463        self.delete_podcast_subscription(user_id, podcast_id).await
464    }
465
466    async fn list(
467        &self,
468        user_id: u64,
469        podcast_ids: &[u64],
470    ) -> anyhow::Result<Vec<PodcastSubscription>> {
471        self.list_podcast_subscriptions(user_id, podcast_ids).await
472    }
473
474    async fn update(
475        &self,
476        user_id: u64,
477        podcast_id: u64,
478        body: entertainarr_domain::podcast::entity::PodcastSubscriptionUpdate,
479    ) -> anyhow::Result<()> {
480        self.update_podcast_subscription(user_id, podcast_id, body)
481            .await
482    }
483
484    async fn find(
485        &self,
486        user_id: u64,
487        podcast_id: u64,
488    ) -> anyhow::Result<Option<entertainarr_domain::podcast::entity::PodcastSubscription>> {
489        self.find_podcast_subscription(user_id, podcast_id).await
490    }
491}
492
493impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<Podcast> {
494    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
495        use sqlx::Row;
496
497        Ok(Self(Podcast {
498            id: row.try_get(0)?,
499            feed_url: row.try_get(1)?,
500            title: row.try_get(2)?,
501            description: row.try_get(3)?,
502            image_url: row.try_get(4)?,
503            language: row.try_get(5)?,
504            website: row.try_get(6)?,
505            created_at: row.try_get(7)?,
506            updated_at: row.try_get(8)?,
507        }))
508    }
509}
510
511impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<PodcastWithSubscription> {
512    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
513        use sqlx::Row;
514
515        Ok(Self(PodcastWithSubscription {
516            podcast: Podcast {
517                id: row.try_get(0)?,
518                feed_url: row.try_get(1)?,
519                title: row.try_get(2)?,
520                description: row.try_get(3)?,
521                image_url: row.try_get(4)?,
522                language: row.try_get(5)?,
523                website: row.try_get(6)?,
524                created_at: row.try_get(7)?,
525                updated_at: row.try_get(8)?,
526            },
527            queue_size: row.try_get(9)?,
528        }))
529    }
530}
531
532impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<PodcastSubscription> {
533    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
534        use sqlx::Row;
535
536        Ok(Self(PodcastSubscription {
537            user_id: row.try_get(0)?,
538            podcast_id: row.try_get(1)?,
539            queue_size: row.try_get(2)?,
540            min_duration: row.try_get(3)?,
541            max_duration: row.try_get(4)?,
542            created_at: row.try_get(5)?,
543        }))
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use entertainarr_domain::podcast::entity::{PodcastEpisodeInput, PodcastInput};
550    use entertainarr_domain::podcast::prelude::PodcastRepository;
551
552    use crate::Pool;
553
554    #[tokio::test]
555    async fn should_find_episode_by_feed_url_when_existing() {
556        let tmpdir = tempfile::tempdir().unwrap();
557        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
558
559        let podcast = pool
560            .upsert_podcast(
561                pool.as_ref(),
562                &PodcastInput::builder()
563                    .feed_url("http://foo.bar/atom.rss")
564                    .title("Foo Bar")
565                    .build(),
566            )
567            .await
568            .unwrap();
569
570        let res = pool
571            .find_by_feed_url("http://foo.bar/atom.rss")
572            .await
573            .unwrap()
574            .unwrap();
575
576        assert_eq!(podcast.id, res.id);
577    }
578
579    #[tokio::test]
580    async fn should_find_episode_by_feed_url_when_missing() {
581        let tmpdir = tempfile::tempdir().unwrap();
582        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
583        let res = pool.find_by_feed_url("missing_url").await.unwrap();
584        assert!(res.is_none());
585    }
586
587    async fn count_postcasts(pool: &Pool) -> u32 {
588        sqlx::query_scalar("select count(*) from podcasts")
589            .fetch_one(&pool.0)
590            .await
591            .unwrap()
592    }
593
594    async fn count_postcast_episodes(pool: &Pool) -> u32 {
595        sqlx::query_scalar("select count(*) from podcast_episodes")
596            .fetch_one(&pool.0)
597            .await
598            .unwrap()
599    }
600
601    #[tokio::test]
602    async fn should_upsert_podcast_and_episodes() {
603        let tmpdir = tempfile::tempdir().unwrap();
604        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
605
606        let content = PodcastInput::builder()
607            .feed_url("http://example.com/atom.rss")
608            .title("Example")
609            .episodes(vec![
610                PodcastEpisodeInput::builder()
611                    .guid("aaaaa")
612                    .title("First episode")
613                    .file_url("http://example.com/first.mp3")
614                    .build(),
615                PodcastEpisodeInput::builder()
616                    .guid("aaaab")
617                    .title("Second episode")
618                    .file_url("http://example.com/second.mp3")
619                    .build(),
620            ])
621            .build();
622        let _res = pool.upsert(&content).await.unwrap();
623        assert_eq!(count_postcasts(&pool).await, 1);
624        assert_eq!(count_postcast_episodes(&pool).await, 2);
625        // should not recreate it
626        let _res = pool.upsert(&content).await.unwrap();
627        assert_eq!(count_postcasts(&pool).await, 1);
628        assert_eq!(count_postcast_episodes(&pool).await, 2);
629    }
630}