entertainarr_adapter_sqlite/
podcast_episode.rs

1use anyhow::Context;
2use entertainarr_domain::podcast::entity::{ListPodcastEpisodeParams, PodcastEpisodeField};
3use entertainarr_domain::podcast::entity::{PodcastEpisode, PodcastEpisodeProgress};
4use entertainarr_domain::prelude::SortOrder;
5
6use crate::Wrapper;
7use crate::prelude::HasAnyOf;
8
9const FIND_BY_ID_QUERY: &str = r#"select id, podcast_id, guid, published_at, title, description, link, duration, file_url, file_size, file_type, created_at, updated_at
10from podcast_episodes
11where id = ?
12limit 1"#;
13
14impl super::Pool {
15    #[tracing::instrument(
16        skip_all,
17        fields(
18            otel.kind = "client",
19            db.system = "sqlite",
20            db.name = "podcast",
21            db.operation = "SELECT",
22            db.sql.table = "podcast_episodes",
23            db.query.text = tracing::field::Empty,
24            db.response.returned_rows = tracing::field::Empty,
25            error.type = tracing::field::Empty,
26            error.message = tracing::field::Empty,
27            error.stacktrace = tracing::field::Empty,
28        ),
29        err(Debug),
30    )]
31    pub(crate) async fn list_podcast_episodes<'a>(
32        &self,
33        params: ListPodcastEpisodeParams<'a>,
34    ) -> anyhow::Result<Vec<PodcastEpisode>> {
35        let mut qb: sqlx::QueryBuilder<'_, sqlx::Sqlite> = sqlx::QueryBuilder::new(
36            r#"select
37    podcast_episodes.id,
38    podcast_episodes.podcast_id,
39    podcast_episodes.guid,
40    podcast_episodes.published_at,
41    podcast_episodes.title,
42    podcast_episodes.description,
43    podcast_episodes.link,
44    podcast_episodes.duration,
45    podcast_episodes.file_url,
46    podcast_episodes.file_size,
47    podcast_episodes.file_type,
48    podcast_episodes.created_at,
49    podcast_episodes.updated_at
50from podcast_episodes"#,
51        );
52
53        if let Some(subscribed) = params.filter.subscribed {
54            if subscribed {
55                qb.push(" join user_podcasts on user_podcasts.podcast_id = podcast_episodes.podcast_id and user_podcasts.user_id = ").push_bind(params.user_id as i64);
56            } else {
57                // TODO handle filtered those where the user is not subscribed
58            }
59        } else {
60            // nothing to do here
61        }
62
63        if params.filter.filtered.is_some() {
64            qb.push(" left outer join user_podcasts filters on filters.podcast_id = podcast_episodes.podcast_id and filters.user_id = ").push_bind(params.user_id as i64);
65        }
66
67        if params.filter.watched.is_some() {
68            qb.push(" left outer join user_podcast_episodes on user_podcast_episodes.podcast_episode_id = podcast_episodes.id ");
69            qb.push(" and user_podcast_episodes.user_id = ")
70                .push_bind(params.user_id as i64);
71        }
72
73        // WHERE
74
75        qb.push(" where true");
76
77        if let Some(watched) = params.filter.watched {
78            if watched {
79                qb.push(" and user_podcast_episodes.completed");
80            } else {
81                qb.push(" and (user_podcast_episodes.completed is null or not user_podcast_episodes.completed)");
82            }
83        }
84
85        if let Some(filtered) = params.filter.filtered {
86            if filtered {
87                // maybe keep only the filtered
88            } else {
89                qb.push(" and (filters.user_id is null or podcast_episodes.duration is null or filters.min_duration is null or filters.min_duration < podcast_episodes.duration)");
90                qb.push(" and (filters.user_id is null or podcast_episodes.duration is null or filters.max_duration is null or filters.max_duration > podcast_episodes.duration)");
91            }
92        }
93
94        if !params.filter.podcast_ids.is_empty() {
95            qb.push(" and ")
96                .push_any("podcast_episodes.podcast_id", params.filter.podcast_ids);
97        }
98
99        match params.sort.field {
100            PodcastEpisodeField::PublishedAt => qb.push(" order by podcast_episodes.published_at"),
101        };
102        match params.sort.order {
103            SortOrder::Asc => qb.push(" asc"),
104            SortOrder::Desc => qb.push(" desc"),
105        };
106
107        qb.push(" limit ")
108            .push_bind(params.page.limit)
109            .push(" offset ")
110            .push_bind(params.page.offset);
111
112        let span = tracing::Span::current();
113        span.record("db.query.text", qb.sql());
114
115        qb.build_query_as()
116            .fetch_all(&self.0)
117            .await
118            .inspect(super::record_all)
119            .inspect_err(super::record_error)
120            .map(Wrapper::list)
121            .context("unable to query podcast episodes")
122    }
123
124    #[tracing::instrument(
125        skip_all,
126        fields(
127            otel.kind = "client",
128            db.system = "sqlite",
129            db.name = "podcast",
130            db.operation = "SELECT",
131            db.sql.table = "podcast_episodes",
132            db.query.text = FIND_BY_ID_QUERY,
133            db.response.returned_rows = tracing::field::Empty,
134            error.type = tracing::field::Empty,
135            error.message = tracing::field::Empty,
136            error.stacktrace = tracing::field::Empty,
137        ),
138        err(Debug),
139    )]
140    pub(crate) async fn find_podcast_episode_by_id(
141        &self,
142        episode_id: u64,
143    ) -> anyhow::Result<Option<PodcastEpisode>> {
144        sqlx::query_as(FIND_BY_ID_QUERY)
145            .bind(episode_id as i64)
146            .fetch_optional(self.as_ref())
147            .await
148            .inspect(super::record_optional)
149            .inspect_err(super::record_error)
150            .map(super::Wrapper::maybe_inner)
151            .context("unable to query podcast episodes")
152    }
153
154    #[tracing::instrument(
155        skip_all,
156        fields(
157            otel.kind = "client",
158            db.system = "sqlite",
159            db.name = "podcast",
160            db.operation = "DELETE",
161            db.sql.table = "user_podcast_episodes",
162            db.query.text = tracing::field::Empty,
163            db.response.returned_rows = tracing::field::Empty,
164            error.type = tracing::field::Empty,
165            error.message = tracing::field::Empty,
166            error.stacktrace = tracing::field::Empty,
167        ),
168        err(Debug),
169    )]
170    pub(crate) async fn delete_podcast_episode_progressions(
171        &self,
172        user_id: u64,
173        episode_ids: &[u64],
174    ) -> anyhow::Result<()> {
175        if episode_ids.is_empty() {
176            return Ok(());
177        }
178
179        let mut qb = sqlx::QueryBuilder::new("delete from user_podcast_episodes");
180        qb.push(" where user_id = ").push_bind(user_id as i64);
181        qb.push(" and ( ");
182        for (index, id) in episode_ids.iter().enumerate() {
183            if index > 0 {
184                qb.push(" or");
185            }
186            qb.push(" podcast_episode_id = ").push_bind(*id as i64);
187        }
188        qb.push(")");
189
190        let span = tracing::Span::current();
191        span.record("db.query.text", qb.sql());
192
193        qb.build()
194            .execute(self.as_ref())
195            .await
196            .map(|_| ())
197            .inspect_err(super::record_error)
198            .context("unable to delete podcast episode progresses")
199    }
200
201    #[tracing::instrument(
202        skip_all,
203        fields(
204            otel.kind = "client",
205            db.system = "sqlite",
206            db.name = "podcast",
207            db.operation = "SELECT",
208            db.sql.table = "user_podcast_episodes",
209            db.query.text = tracing::field::Empty,
210            db.response.returned_rows = tracing::field::Empty,
211            error.type = tracing::field::Empty,
212            error.message = tracing::field::Empty,
213            error.stacktrace = tracing::field::Empty,
214        ),
215        err(Debug),
216    )]
217    pub(crate) async fn list_podcast_episode_progressions(
218        &self,
219        user_id: u64,
220        episode_ids: &[u64],
221    ) -> anyhow::Result<Vec<entertainarr_domain::podcast::entity::PodcastEpisodeProgress>> {
222        if episode_ids.is_empty() {
223            return Ok(Default::default());
224        }
225
226        let mut qb = sqlx::QueryBuilder::new(
227            "select user_id, podcast_episode_id, progress, completed, created_at, updated_at from user_podcast_episodes",
228        );
229        qb.push(" where user_id = ").push_bind(user_id as i64);
230        qb.push(" and ( ");
231        for (index, id) in episode_ids.iter().enumerate() {
232            if index > 0 {
233                qb.push(" or");
234            }
235            qb.push(" podcast_episode_id = ").push_bind(*id as i64);
236        }
237        qb.push(")");
238
239        let span = tracing::Span::current();
240        span.record("db.query.text", qb.sql());
241
242        qb.build_query_as::<'_, super::Wrapper<PodcastEpisodeProgress>>()
243            .fetch_all(self.as_ref())
244            .await
245            .inspect(super::record_all)
246            .inspect_err(super::record_error)
247            .map(super::Wrapper::list)
248            .context("unable to query podcast episode progresses")
249    }
250
251    #[tracing::instrument(
252        skip_all,
253        fields(
254            otel.kind = "client",
255            db.system = "sqlite",
256            db.name = "podcast",
257            db.operation = "UPSERT",
258            db.sql.table = "user_podcast_episodes",
259            db.query.text = tracing::field::Empty,
260            db.response.returned_rows = tracing::field::Empty,
261            error.type = tracing::field::Empty,
262            error.message = tracing::field::Empty,
263            error.stacktrace = tracing::field::Empty,
264        ),
265        err(Debug),
266    )]
267    pub(crate) async fn upsert_podcast_episode_progressions(
268        &self,
269        inputs: &[entertainarr_domain::podcast::entity::PodcastEpisodeProgressInput],
270    ) -> anyhow::Result<Vec<PodcastEpisodeProgress>> {
271        if inputs.is_empty() {
272            return Ok(Default::default());
273        }
274
275        let mut qb = sqlx::QueryBuilder::new(
276            "insert into user_podcast_episodes (user_id, podcast_episode_id, progress, completed) ",
277        );
278        qb.push_values(inputs, |mut q, item| {
279            q.push_bind(item.user_id as i64)
280                .push_bind(item.podcast_episode_id as i64)
281                .push_bind(item.progress as i64)
282                .push_bind(item.completed);
283        });
284        qb.push(" on conflict (user_id, podcast_episode_id) do update set progress = excluded.progress, completed = excluded.completed, updated_at = CURRENT_TIMESTAMP");
285        qb.push(
286            " returning user_id, podcast_episode_id, progress, completed, created_at, updated_at",
287        );
288
289        let span = tracing::Span::current();
290        span.record("db.query.text", qb.sql());
291
292        qb.build_query_as()
293            .fetch_all(self.as_ref())
294            .await
295            .inspect(super::record_all)
296            .inspect_err(super::record_error)
297            .map(super::Wrapper::list)
298            .context("unable to upsert podcast episode progresses")
299    }
300}
301
302impl entertainarr_domain::podcast::prelude::PodcastEpisodeRepository for super::Pool {
303    async fn find_by_id(&self, episode_id: u64) -> anyhow::Result<Option<PodcastEpisode>> {
304        self.find_podcast_episode_by_id(episode_id).await
305    }
306
307    async fn list<'a>(
308        &self,
309        params: ListPodcastEpisodeParams<'a>,
310    ) -> anyhow::Result<Vec<PodcastEpisode>> {
311        self.list_podcast_episodes(params).await
312    }
313
314    async fn delete_progressions(&self, user_id: u64, episode_ids: &[u64]) -> anyhow::Result<()> {
315        self.delete_podcast_episode_progressions(user_id, episode_ids)
316            .await
317    }
318
319    async fn list_progressions(
320        &self,
321        user_id: u64,
322        episode_ids: &[u64],
323    ) -> anyhow::Result<Vec<entertainarr_domain::podcast::entity::PodcastEpisodeProgress>> {
324        self.list_podcast_episode_progressions(user_id, episode_ids)
325            .await
326    }
327
328    async fn upsert_progressions(
329        &self,
330        inputs: &[entertainarr_domain::podcast::entity::PodcastEpisodeProgressInput],
331    ) -> anyhow::Result<Vec<PodcastEpisodeProgress>> {
332        self.upsert_podcast_episode_progressions(inputs).await
333    }
334}
335
336impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<PodcastEpisode> {
337    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
338        use sqlx::Row;
339
340        Ok(Self(PodcastEpisode {
341            id: row.try_get(0)?,
342            podcast_id: row.try_get(1)?,
343            guid: row.try_get(2)?,
344            published_at: row.try_get(3)?,
345            title: row.try_get(4)?,
346            description: row.try_get(5)?,
347            link: row.try_get(6)?,
348            duration: row
349                .try_get(7)
350                .map(|value: Option<u64>| value.map(std::time::Duration::from_secs))?,
351            file_url: row.try_get(8)?,
352            file_size: row.try_get(9)?,
353            file_type: row.try_get(10)?,
354            created_at: row.try_get(11)?,
355            updated_at: row.try_get(12)?,
356        }))
357    }
358}
359
360impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<PodcastEpisodeProgress> {
361    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
362        use sqlx::Row;
363
364        Ok(Self(PodcastEpisodeProgress {
365            user_id: row.try_get(0)?,
366            podcast_episode_id: row.try_get(1)?,
367            progress: row.try_get(2)?,
368            completed: row.try_get(3)?,
369            created_at: row.try_get(4)?,
370            updated_at: row.try_get(5)?,
371        }))
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use std::time::Duration;
378
379    use entertainarr_domain::podcast::entity::{
380        ListPodcastEpisodeFilter, ListPodcastEpisodeParams, PodcastEpisodeField,
381    };
382    use entertainarr_domain::podcast::entity::{
383        PodcastEpisodeInput, PodcastEpisodeProgressInput, PodcastInput, PodcastSubscriptionUpdate,
384    };
385    use entertainarr_domain::podcast::prelude::PodcastEpisodeRepository;
386    use entertainarr_domain::prelude::{Page, Sort, SortOrder};
387
388    use crate::Pool;
389
390    async fn seed(pool: &Pool) {
391        let _: Vec<u64> = sqlx::query_scalar("insert into users (id, email, password) values (1, 'user1@example.com', 'password'), (2, 'user2@example.com', 'password') returning id").fetch_all(pool.as_ref()).await.unwrap();
392        let _: Vec<u64> = sqlx::query_scalar("insert into podcasts (id, feed_url, title) values (1, 'first', 'first'), (2, 'second', 'second'), (3, 'third', 'third') returning id").fetch_all(pool.as_ref()).await.unwrap();
393        let _: Vec<u64> = sqlx::query_scalar("insert into podcast_episodes (id, podcast_id, title, file_url, duration) values (1, 1, 'title 1', 'url 1', 100), (2, 1, 'title 2', 'url 2', 30), (3, 1, 'title 3', 'url 3', 90), (4, 2, 'title 4', 'url 4', 120), (5, 2, 'title 5', 'url 5', 100), (6, 3, 'title 6', 'url 6', 90) returning id").fetch_all(pool.as_ref()).await.unwrap();
394        let _: Vec<(u64, u64)> = sqlx::query_as(
395            "insert into user_podcasts (user_id, podcast_id, min_duration) values (1, 1, 80), (1, 2, null), (2, 2, null), (2, 3, 80) returning user_id, podcast_id",
396        )
397        .fetch_all(pool.as_ref())
398        .await
399        .unwrap();
400        let _: Vec<(u64, u64)> = sqlx::query_as("insert into user_podcast_episodes (user_id, podcast_episode_id, progress, completed) values (1, 1, 123, true), (1, 4, 10, false), (2, 4, 153, true), (2, 5, 10, false) returning user_id, podcast_episode_id").fetch_all(pool.as_ref()).await.unwrap();
401    }
402
403    #[tokio::test]
404    async fn should_list_all_subscribed_episodes() {
405        let _ = tracing_subscriber::fmt::try_init();
406
407        let tmpdir = tempfile::tempdir().unwrap();
408        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
409
410        seed(&pool).await;
411        let list = pool
412            .list(ListPodcastEpisodeParams {
413                user_id: 1,
414                filter: ListPodcastEpisodeFilter {
415                    podcast_ids: &[],
416                    filtered: None,
417                    subscribed: Some(true),
418                    watched: None,
419                },
420                sort: Sort {
421                    field: PodcastEpisodeField::PublishedAt,
422                    order: SortOrder::Asc,
423                },
424                page: Page {
425                    limit: 10,
426                    offset: 0,
427                },
428            })
429            .await
430            .unwrap();
431        assert_eq!(list.len(), 5);
432    }
433
434    #[tokio::test]
435    async fn should_list_all_watched_episodes() {
436        let _ = tracing_subscriber::fmt::try_init();
437
438        let tmpdir = tempfile::tempdir().unwrap();
439        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
440
441        seed(&pool).await;
442        let list = pool
443            .list(ListPodcastEpisodeParams {
444                user_id: 1,
445                filter: ListPodcastEpisodeFilter {
446                    podcast_ids: &[],
447                    filtered: None,
448                    subscribed: Some(true),
449                    watched: Some(true),
450                },
451                sort: Sort {
452                    field: PodcastEpisodeField::PublishedAt,
453                    order: SortOrder::Asc,
454                },
455                page: Page {
456                    limit: 10,
457                    offset: 0,
458                },
459            })
460            .await
461            .unwrap();
462        assert_eq!(list.len(), 1);
463    }
464
465    #[tokio::test]
466    async fn should_list_all_unwatched_episodes() {
467        let _ = tracing_subscriber::fmt::try_init();
468
469        let tmpdir = tempfile::tempdir().unwrap();
470        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
471
472        seed(&pool).await;
473        let list = pool
474            .list(ListPodcastEpisodeParams {
475                user_id: 1,
476                filter: ListPodcastEpisodeFilter {
477                    podcast_ids: &[],
478                    filtered: None,
479                    subscribed: Some(true),
480                    watched: Some(false),
481                },
482                sort: Sort {
483                    field: PodcastEpisodeField::PublishedAt,
484                    order: SortOrder::Asc,
485                },
486                page: Page {
487                    limit: 10,
488                    offset: 0,
489                },
490            })
491            .await
492            .unwrap();
493        assert_eq!(list.len(), 4);
494    }
495
496    #[tokio::test]
497    async fn should_upsert_progress() {
498        let _ = tracing_subscriber::fmt::try_init();
499
500        let tmpdir = tempfile::tempdir().unwrap();
501        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
502
503        seed(&pool).await;
504        let updated = pool
505            .upsert_progressions(&[
506                PodcastEpisodeProgressInput {
507                    user_id: 1,
508                    podcast_episode_id: 4,
509                    progress: 100,
510                    completed: true,
511                },
512                PodcastEpisodeProgressInput {
513                    user_id: 1,
514                    podcast_episode_id: 2,
515                    progress: 60,
516                    completed: false,
517                },
518            ])
519            .await
520            .unwrap();
521        assert_eq!(updated.len(), 2);
522
523        let list = pool
524            .list(ListPodcastEpisodeParams {
525                user_id: 1,
526                filter: ListPodcastEpisodeFilter {
527                    podcast_ids: &[],
528                    filtered: None,
529                    subscribed: Some(true),
530                    watched: Some(false),
531                },
532                sort: Sort {
533                    field: PodcastEpisodeField::PublishedAt,
534                    order: SortOrder::Asc,
535                },
536                page: Page {
537                    limit: 10,
538                    offset: 0,
539                },
540            })
541            .await
542            .unwrap();
543        assert_eq!(list.len(), 3);
544    }
545
546    #[tokio::test]
547    async fn should_list_only_filtered_episodes() {
548        let _ = tracing_subscriber::fmt::try_init();
549
550        let tmpdir = tempfile::tempdir().unwrap();
551        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
552
553        let user = pool
554            .create_user("user@example.com", "password")
555            .await
556            .unwrap();
557
558        let input = PodcastInput::builder()
559            .title("title")
560            .feed_url("http://feed_url")
561            .episodes(vec![
562                PodcastEpisodeInput::builder()
563                    .title("first title")
564                    .file_url("http://file.mp3")
565                    .duration(Duration::from_secs(4000))
566                    .build(),
567                PodcastEpisodeInput::builder()
568                    .title("second title")
569                    .file_url("http://file.mp3")
570                    .duration(Duration::from_secs(1200))
571                    .build(),
572                PodcastEpisodeInput::builder()
573                    .title("third title")
574                    .file_url("http://file.mp3")
575                    .build(),
576            ])
577            .build();
578        let podcast = pool.upsert_podcast(pool.as_ref(), &input).await.unwrap();
579        pool.upsert_podcast_episodes(pool.as_ref(), podcast.id, &input.episodes)
580            .await
581            .unwrap();
582
583        pool.create_podcast_subscription(user.id, podcast.id)
584            .await
585            .unwrap();
586
587        pool.update_podcast_subscription(
588            user.id,
589            podcast.id,
590            PodcastSubscriptionUpdate {
591                min_duration: Some(2000),
592                max_duration: None,
593            },
594        )
595        .await
596        .unwrap();
597
598        // should list all episodes that are not filtered, even not subscribed
599        let list = pool
600            .list_podcast_episodes(ListPodcastEpisodeParams {
601                user_id: user.id,
602                filter: ListPodcastEpisodeFilter {
603                    podcast_ids: &[],
604                    filtered: Some(false),
605                    subscribed: None,
606                    watched: Some(false),
607                },
608                sort: Sort {
609                    field: PodcastEpisodeField::PublishedAt,
610                    order: SortOrder::Asc,
611                },
612                page: Page {
613                    limit: 100,
614                    offset: 0,
615                },
616            })
617            .await
618            .unwrap();
619        assert_eq!(list.len(), 2);
620
621        // should list all episodes that are not filtered, even not subscribed
622        let list = pool
623            .list_podcast_episodes(ListPodcastEpisodeParams {
624                user_id: 0,
625                filter: ListPodcastEpisodeFilter {
626                    podcast_ids: &[],
627                    filtered: Some(false),
628                    subscribed: None,
629                    watched: Some(false),
630                },
631                sort: Sort {
632                    field: PodcastEpisodeField::PublishedAt,
633                    order: SortOrder::Asc,
634                },
635                page: Page {
636                    limit: 100,
637                    offset: 0,
638                },
639            })
640            .await
641            .unwrap();
642        assert_eq!(list.len(), 3);
643
644        // should list all episodes that are not filtered
645        let list = pool
646            .list_podcast_episodes(ListPodcastEpisodeParams {
647                user_id: 0,
648                filter: ListPodcastEpisodeFilter {
649                    podcast_ids: &[],
650                    filtered: Some(false),
651                    subscribed: Some(true),
652                    watched: Some(false),
653                },
654                sort: Sort {
655                    field: PodcastEpisodeField::PublishedAt,
656                    order: SortOrder::Asc,
657                },
658                page: Page {
659                    limit: 100,
660                    offset: 0,
661                },
662            })
663            .await
664            .unwrap();
665        assert!(list.is_empty());
666    }
667}