entertainarr_adapter_sqlite/
podcast_episode.rs

1use anyhow::Context;
2use entertainarr_domain::podcast::entity::{
3    ListPodcastEpisodeParams, PodcastEpisode, PodcastEpisodeField, PodcastEpisodeProgress,
4};
5use entertainarr_domain::prelude::SortOrder;
6
7use crate::Wrapper;
8use crate::prelude::HasAnyOf;
9
10const FIND_BY_ID_QUERY: &str = r#"select id, podcast_id, guid, published_at, title, description, link, duration, file_url, file_size, file_type, created_at, updated_at
11from podcast_episodes
12where id = ?
13limit 1"#;
14
15impl super::Pool {
16    #[tracing::instrument(
17        skip_all,
18        fields(
19            otel.kind = "client",
20            db.system = "sqlite",
21            db.name = "podcast",
22            db.operation = "SELECT",
23            db.sql.table = "podcast_episodes",
24            db.query.text = tracing::field::Empty,
25            db.response.returned_rows = tracing::field::Empty,
26            error.type = tracing::field::Empty,
27            error.message = tracing::field::Empty,
28            error.stacktrace = tracing::field::Empty,
29        ),
30        err(Debug),
31    )]
32    pub(crate) async fn list_podcast_episodes<'a>(
33        &self,
34        params: ListPodcastEpisodeParams<'a>,
35    ) -> anyhow::Result<Vec<PodcastEpisode>> {
36        let mut qb: sqlx::QueryBuilder<'_, sqlx::Sqlite> = sqlx::QueryBuilder::new(
37            r#"select
38    podcast_episodes.id,
39    podcast_episodes.podcast_id,
40    podcast_episodes.guid,
41    podcast_episodes.published_at,
42    podcast_episodes.title,
43    podcast_episodes.description,
44    podcast_episodes.link,
45    podcast_episodes.duration,
46    podcast_episodes.file_url,
47    podcast_episodes.file_size,
48    podcast_episodes.file_type,
49    podcast_episodes.created_at,
50    podcast_episodes.updated_at
51from podcast_episodes"#,
52        );
53
54        if let Some(subscribed) = params.filter.subscribed {
55            if subscribed {
56                qb.push(" join user_podcasts on user_podcasts.podcast_id = podcast_episodes.podcast_id and user_podcasts.user_id = ").push_bind(params.user_id as i64);
57            } else {
58                // TODO handle filtered those where the user is not subscribed
59            }
60        } else {
61            // nothing to do here
62        }
63
64        if params.filter.filtered.is_some() {
65            qb.push(" left outer join user_podcasts filters on filters.podcast_id = podcast_episodes.podcast_id and filters.user_id = ").push_bind(params.user_id as i64);
66        }
67
68        if params.filter.watched.is_some() {
69            qb.push(" left outer join user_podcast_episodes on user_podcast_episodes.podcast_episode_id = podcast_episodes.id ");
70            qb.push(" and user_podcast_episodes.user_id = ")
71                .push_bind(params.user_id as i64);
72        }
73
74        // WHERE
75
76        qb.push(" where true");
77
78        if let Some(watched) = params.filter.watched {
79            if watched {
80                qb.push(" and user_podcast_episodes.completed");
81            } else {
82                qb.push(" and (user_podcast_episodes.completed is null or not user_podcast_episodes.completed)");
83            }
84        }
85
86        if let Some(filtered) = params.filter.filtered {
87            if filtered {
88                // maybe keep only the filtered
89            } else {
90                qb.push(" and (filters.user_id is null or podcast_episodes.duration is null or filters.min_duration is null or filters.min_duration < podcast_episodes.duration)");
91                qb.push(" and (filters.user_id is null or podcast_episodes.duration is null or filters.max_duration is null or filters.max_duration > podcast_episodes.duration)");
92            }
93        }
94
95        if !params.filter.podcast_ids.is_empty() {
96            qb.push(" and ")
97                .push_any("podcast_episodes.podcast_id", params.filter.podcast_ids);
98        }
99
100        match params.sort.field {
101            PodcastEpisodeField::PublishedAt => qb.push(" order by podcast_episodes.published_at"),
102        };
103        match params.sort.order {
104            SortOrder::Asc => qb.push(" asc"),
105            SortOrder::Desc => qb.push(" desc"),
106        };
107
108        qb.push(" limit ")
109            .push_bind(params.page.limit)
110            .push(" offset ")
111            .push_bind(params.page.offset);
112
113        let span = tracing::Span::current();
114        span.record("db.query.text", qb.sql());
115
116        qb.build_query_as()
117            .fetch_all(&self.0)
118            .await
119            .inspect(super::record_all)
120            .inspect_err(super::record_error)
121            .map(Wrapper::list)
122            .context("unable to query podcast episodes")
123    }
124
125    #[tracing::instrument(
126        skip_all,
127        fields(
128            otel.kind = "client",
129            db.system = "sqlite",
130            db.name = "podcast",
131            db.operation = "SELECT",
132            db.sql.table = "podcast_episodes",
133            db.query.text = FIND_BY_ID_QUERY,
134            db.response.returned_rows = tracing::field::Empty,
135            error.type = tracing::field::Empty,
136            error.message = tracing::field::Empty,
137            error.stacktrace = tracing::field::Empty,
138        ),
139        err(Debug),
140    )]
141    pub(crate) async fn find_podcast_episode_by_id(
142        &self,
143        episode_id: u64,
144    ) -> anyhow::Result<Option<PodcastEpisode>> {
145        sqlx::query_as(FIND_BY_ID_QUERY)
146            .bind(episode_id as i64)
147            .fetch_optional(self.as_ref())
148            .await
149            .inspect(super::record_optional)
150            .inspect_err(super::record_error)
151            .map(super::Wrapper::maybe_inner)
152            .context("unable to query podcast episodes")
153    }
154
155    #[tracing::instrument(
156        skip_all,
157        fields(
158            otel.kind = "client",
159            db.system = "sqlite",
160            db.name = "podcast",
161            db.operation = "DELETE",
162            db.sql.table = "user_podcast_episodes",
163            db.query.text = tracing::field::Empty,
164            db.response.returned_rows = tracing::field::Empty,
165            error.type = tracing::field::Empty,
166            error.message = tracing::field::Empty,
167            error.stacktrace = tracing::field::Empty,
168        ),
169        err(Debug),
170    )]
171    pub(crate) async fn delete_podcast_episode_progressions(
172        &self,
173        user_id: u64,
174        episode_ids: &[u64],
175    ) -> anyhow::Result<()> {
176        if episode_ids.is_empty() {
177            return Ok(());
178        }
179
180        let mut qb = sqlx::QueryBuilder::new("delete from user_podcast_episodes");
181        qb.push(" where user_id = ").push_bind(user_id as i64);
182        qb.push(" and ( ");
183        for (index, id) in episode_ids.iter().enumerate() {
184            if index > 0 {
185                qb.push(" or");
186            }
187            qb.push(" podcast_episode_id = ").push_bind(*id as i64);
188        }
189        qb.push(")");
190
191        let span = tracing::Span::current();
192        span.record("db.query.text", qb.sql());
193
194        qb.build()
195            .execute(self.as_ref())
196            .await
197            .map(|_| ())
198            .inspect_err(super::record_error)
199            .context("unable to delete podcast episode progresses")
200    }
201
202    #[tracing::instrument(
203        skip_all,
204        fields(
205            otel.kind = "client",
206            db.system = "sqlite",
207            db.name = "podcast",
208            db.operation = "SELECT",
209            db.sql.table = "user_podcast_episodes",
210            db.query.text = tracing::field::Empty,
211            db.response.returned_rows = tracing::field::Empty,
212            error.type = tracing::field::Empty,
213            error.message = tracing::field::Empty,
214            error.stacktrace = tracing::field::Empty,
215        ),
216        err(Debug),
217    )]
218    pub(crate) async fn list_podcast_episode_progressions(
219        &self,
220        user_id: u64,
221        episode_ids: &[u64],
222    ) -> anyhow::Result<Vec<entertainarr_domain::podcast::entity::PodcastEpisodeProgress>> {
223        if episode_ids.is_empty() {
224            return Ok(Default::default());
225        }
226
227        let mut qb = sqlx::QueryBuilder::new(
228            "select user_id, podcast_episode_id, progress, completed, created_at, updated_at from user_podcast_episodes",
229        );
230        qb.push(" where user_id = ").push_bind(user_id as i64);
231        qb.push(" and ( ");
232        for (index, id) in episode_ids.iter().enumerate() {
233            if index > 0 {
234                qb.push(" or");
235            }
236            qb.push(" podcast_episode_id = ").push_bind(*id as i64);
237        }
238        qb.push(")");
239
240        let span = tracing::Span::current();
241        span.record("db.query.text", qb.sql());
242
243        qb.build_query_as::<'_, super::Wrapper<PodcastEpisodeProgress>>()
244            .fetch_all(self.as_ref())
245            .await
246            .inspect(super::record_all)
247            .inspect_err(super::record_error)
248            .map(super::Wrapper::list)
249            .context("unable to query podcast episode progresses")
250    }
251
252    #[tracing::instrument(
253        skip_all,
254        fields(
255            otel.kind = "client",
256            db.system = "sqlite",
257            db.name = "podcast",
258            db.operation = "UPSERT",
259            db.sql.table = "user_podcast_episodes",
260            db.query.text = tracing::field::Empty,
261            db.response.returned_rows = tracing::field::Empty,
262            error.type = tracing::field::Empty,
263            error.message = tracing::field::Empty,
264            error.stacktrace = tracing::field::Empty,
265        ),
266        err(Debug),
267    )]
268    pub(crate) async fn upsert_podcast_episode_progressions(
269        &self,
270        inputs: &[entertainarr_domain::podcast::entity::PodcastEpisodeProgressInput],
271    ) -> anyhow::Result<Vec<PodcastEpisodeProgress>> {
272        if inputs.is_empty() {
273            return Ok(Default::default());
274        }
275
276        let mut qb = sqlx::QueryBuilder::new(
277            "insert into user_podcast_episodes (user_id, podcast_episode_id, progress, completed) ",
278        );
279        qb.push_values(inputs, |mut q, item| {
280            q.push_bind(item.user_id as i64)
281                .push_bind(item.podcast_episode_id as i64)
282                .push_bind(item.progress as i64)
283                .push_bind(item.completed);
284        });
285        qb.push(" on conflict (user_id, podcast_episode_id) do update set progress = excluded.progress, completed = excluded.completed, updated_at = CURRENT_TIMESTAMP");
286        qb.push(
287            " returning user_id, podcast_episode_id, progress, completed, created_at, updated_at",
288        );
289
290        let span = tracing::Span::current();
291        span.record("db.query.text", qb.sql());
292
293        qb.build_query_as()
294            .fetch_all(self.as_ref())
295            .await
296            .inspect(super::record_all)
297            .inspect_err(super::record_error)
298            .map(super::Wrapper::list)
299            .context("unable to upsert podcast episode progresses")
300    }
301}
302
303impl entertainarr_domain::podcast::prelude::PodcastEpisodeRepository for super::Pool {
304    async fn find_by_id(&self, episode_id: u64) -> anyhow::Result<Option<PodcastEpisode>> {
305        self.find_podcast_episode_by_id(episode_id).await
306    }
307
308    async fn list<'a>(
309        &self,
310        params: ListPodcastEpisodeParams<'a>,
311    ) -> anyhow::Result<Vec<PodcastEpisode>> {
312        self.list_podcast_episodes(params).await
313    }
314
315    async fn delete_progressions(&self, user_id: u64, episode_ids: &[u64]) -> anyhow::Result<()> {
316        self.delete_podcast_episode_progressions(user_id, episode_ids)
317            .await
318    }
319
320    async fn list_progressions(
321        &self,
322        user_id: u64,
323        episode_ids: &[u64],
324    ) -> anyhow::Result<Vec<entertainarr_domain::podcast::entity::PodcastEpisodeProgress>> {
325        self.list_podcast_episode_progressions(user_id, episode_ids)
326            .await
327    }
328
329    async fn upsert_progressions(
330        &self,
331        inputs: &[entertainarr_domain::podcast::entity::PodcastEpisodeProgressInput],
332    ) -> anyhow::Result<Vec<PodcastEpisodeProgress>> {
333        self.upsert_podcast_episode_progressions(inputs).await
334    }
335}
336
337impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<PodcastEpisode> {
338    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
339        use sqlx::Row;
340
341        Ok(Self(PodcastEpisode {
342            id: row.try_get(0)?,
343            podcast_id: row.try_get(1)?,
344            guid: row.try_get(2)?,
345            published_at: row.try_get(3)?,
346            title: row.try_get(4)?,
347            description: row.try_get(5)?,
348            link: row.try_get(6)?,
349            duration: row
350                .try_get(7)
351                .map(|value: Option<u64>| value.map(std::time::Duration::from_secs))?,
352            file_url: row.try_get(8)?,
353            file_size: row.try_get(9)?,
354            file_type: row.try_get(10)?,
355            created_at: row.try_get(11)?,
356            updated_at: row.try_get(12)?,
357        }))
358    }
359}
360
361impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<PodcastEpisodeProgress> {
362    fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
363        use sqlx::Row;
364
365        Ok(Self(PodcastEpisodeProgress {
366            user_id: row.try_get(0)?,
367            podcast_episode_id: row.try_get(1)?,
368            progress: row.try_get(2)?,
369            completed: row.try_get(3)?,
370            created_at: row.try_get(4)?,
371            updated_at: row.try_get(5)?,
372        }))
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use std::time::Duration;
379
380    use entertainarr_domain::podcast::entity::{
381        ListPodcastEpisodeFilter, ListPodcastEpisodeParams, PodcastEpisodeField,
382        PodcastEpisodeInput, PodcastEpisodeProgressInput, PodcastInput, PodcastSubscriptionUpdate,
383    };
384    use entertainarr_domain::podcast::prelude::PodcastEpisodeRepository;
385    use entertainarr_domain::prelude::{Page, Sort, SortOrder};
386
387    use crate::Pool;
388
389    async fn seed(pool: &Pool) {
390        let _: Vec<u64> = sqlx::query_scalar("insert into users (id, email, password) values (1, 'user1@example.com', 'password'), (2, 'user2@example.com', 'password') returning id").fetch_all(pool.as_ref()).await.unwrap();
391        let _: Vec<u64> = sqlx::query_scalar("insert into podcasts (id, feed_url, title) values (1, 'first', 'first'), (2, 'second', 'second'), (3, 'third', 'third') returning id").fetch_all(pool.as_ref()).await.unwrap();
392        let _: Vec<u64> = sqlx::query_scalar("insert into podcast_episodes (id, podcast_id, title, file_url, duration) values (1, 1, 'title 1', 'url 1', 100), (2, 1, 'title 2', 'url 2', 30), (3, 1, 'title 3', 'url 3', 90), (4, 2, 'title 4', 'url 4', 120), (5, 2, 'title 5', 'url 5', 100), (6, 3, 'title 6', 'url 6', 90) returning id").fetch_all(pool.as_ref()).await.unwrap();
393        let _: Vec<(u64, u64)> = sqlx::query_as(
394            "insert into user_podcasts (user_id, podcast_id, min_duration) values (1, 1, 80), (1, 2, null), (2, 2, null), (2, 3, 80) returning user_id, podcast_id",
395        )
396        .fetch_all(pool.as_ref())
397        .await
398        .unwrap();
399        let _: Vec<(u64, u64)> = sqlx::query_as("insert into user_podcast_episodes (user_id, podcast_episode_id, progress, completed) values (1, 1, 123, true), (1, 4, 10, false), (2, 4, 153, true), (2, 5, 10, false) returning user_id, podcast_episode_id").fetch_all(pool.as_ref()).await.unwrap();
400    }
401
402    #[tokio::test]
403    async fn should_list_all_subscribed_episodes() {
404        let _ = tracing_subscriber::fmt::try_init();
405
406        let tmpdir = tempfile::tempdir().unwrap();
407        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
408
409        seed(&pool).await;
410        let list = pool
411            .list(ListPodcastEpisodeParams {
412                user_id: 1,
413                filter: ListPodcastEpisodeFilter {
414                    podcast_ids: &[],
415                    filtered: None,
416                    subscribed: Some(true),
417                    watched: None,
418                },
419                sort: Sort {
420                    field: PodcastEpisodeField::PublishedAt,
421                    order: SortOrder::Asc,
422                },
423                page: Page {
424                    limit: 10,
425                    offset: 0,
426                },
427            })
428            .await
429            .unwrap();
430        assert_eq!(list.len(), 5);
431    }
432
433    #[tokio::test]
434    async fn should_list_all_watched_episodes() {
435        let _ = tracing_subscriber::fmt::try_init();
436
437        let tmpdir = tempfile::tempdir().unwrap();
438        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
439
440        seed(&pool).await;
441        let list = pool
442            .list(ListPodcastEpisodeParams {
443                user_id: 1,
444                filter: ListPodcastEpisodeFilter {
445                    podcast_ids: &[],
446                    filtered: None,
447                    subscribed: Some(true),
448                    watched: Some(true),
449                },
450                sort: Sort {
451                    field: PodcastEpisodeField::PublishedAt,
452                    order: SortOrder::Asc,
453                },
454                page: Page {
455                    limit: 10,
456                    offset: 0,
457                },
458            })
459            .await
460            .unwrap();
461        assert_eq!(list.len(), 1);
462    }
463
464    #[tokio::test]
465    async fn should_list_all_unwatched_episodes() {
466        let _ = tracing_subscriber::fmt::try_init();
467
468        let tmpdir = tempfile::tempdir().unwrap();
469        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
470
471        seed(&pool).await;
472        let list = pool
473            .list(ListPodcastEpisodeParams {
474                user_id: 1,
475                filter: ListPodcastEpisodeFilter {
476                    podcast_ids: &[],
477                    filtered: None,
478                    subscribed: Some(true),
479                    watched: Some(false),
480                },
481                sort: Sort {
482                    field: PodcastEpisodeField::PublishedAt,
483                    order: SortOrder::Asc,
484                },
485                page: Page {
486                    limit: 10,
487                    offset: 0,
488                },
489            })
490            .await
491            .unwrap();
492        assert_eq!(list.len(), 4);
493    }
494
495    #[tokio::test]
496    async fn should_upsert_progress() {
497        let _ = tracing_subscriber::fmt::try_init();
498
499        let tmpdir = tempfile::tempdir().unwrap();
500        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
501
502        seed(&pool).await;
503        let updated = pool
504            .upsert_progressions(&[
505                PodcastEpisodeProgressInput {
506                    user_id: 1,
507                    podcast_episode_id: 4,
508                    progress: 100,
509                    completed: true,
510                },
511                PodcastEpisodeProgressInput {
512                    user_id: 1,
513                    podcast_episode_id: 2,
514                    progress: 60,
515                    completed: false,
516                },
517            ])
518            .await
519            .unwrap();
520        assert_eq!(updated.len(), 2);
521
522        let list = pool
523            .list(ListPodcastEpisodeParams {
524                user_id: 1,
525                filter: ListPodcastEpisodeFilter {
526                    podcast_ids: &[],
527                    filtered: None,
528                    subscribed: Some(true),
529                    watched: Some(false),
530                },
531                sort: Sort {
532                    field: PodcastEpisodeField::PublishedAt,
533                    order: SortOrder::Asc,
534                },
535                page: Page {
536                    limit: 10,
537                    offset: 0,
538                },
539            })
540            .await
541            .unwrap();
542        assert_eq!(list.len(), 3);
543    }
544
545    #[tokio::test]
546    async fn should_list_only_filtered_episodes() {
547        let _ = tracing_subscriber::fmt::try_init();
548
549        let tmpdir = tempfile::tempdir().unwrap();
550        let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
551
552        let user = pool
553            .create_user("user@example.com", "password")
554            .await
555            .unwrap();
556
557        let input = PodcastInput::builder()
558            .title("title")
559            .feed_url("http://feed_url")
560            .episodes(vec![
561                PodcastEpisodeInput::builder()
562                    .title("first title")
563                    .file_url("http://file.mp3")
564                    .duration(Duration::from_secs(4000))
565                    .build(),
566                PodcastEpisodeInput::builder()
567                    .title("second title")
568                    .file_url("http://file.mp3")
569                    .duration(Duration::from_secs(1200))
570                    .build(),
571                PodcastEpisodeInput::builder()
572                    .title("third title")
573                    .file_url("http://file.mp3")
574                    .build(),
575            ])
576            .build();
577        let podcast = pool.upsert_podcast(pool.as_ref(), &input).await.unwrap();
578        pool.upsert_podcast_episodes(pool.as_ref(), podcast.id, &input.episodes)
579            .await
580            .unwrap();
581
582        pool.create_podcast_subscription(user.id, podcast.id)
583            .await
584            .unwrap();
585
586        pool.update_podcast_subscription(
587            user.id,
588            podcast.id,
589            PodcastSubscriptionUpdate {
590                min_duration: Some(2000),
591                max_duration: None,
592            },
593        )
594        .await
595        .unwrap();
596
597        // should list all episodes that are not filtered, even not subscribed
598        let list = pool
599            .list_podcast_episodes(ListPodcastEpisodeParams {
600                user_id: user.id,
601                filter: ListPodcastEpisodeFilter {
602                    podcast_ids: &[],
603                    filtered: Some(false),
604                    subscribed: None,
605                    watched: Some(false),
606                },
607                sort: Sort {
608                    field: PodcastEpisodeField::PublishedAt,
609                    order: SortOrder::Asc,
610                },
611                page: Page {
612                    limit: 100,
613                    offset: 0,
614                },
615            })
616            .await
617            .unwrap();
618        assert_eq!(list.len(), 2);
619
620        // should list all episodes that are not filtered, even not subscribed
621        let list = pool
622            .list_podcast_episodes(ListPodcastEpisodeParams {
623                user_id: 0,
624                filter: ListPodcastEpisodeFilter {
625                    podcast_ids: &[],
626                    filtered: Some(false),
627                    subscribed: None,
628                    watched: Some(false),
629                },
630                sort: Sort {
631                    field: PodcastEpisodeField::PublishedAt,
632                    order: SortOrder::Asc,
633                },
634                page: Page {
635                    limit: 100,
636                    offset: 0,
637                },
638            })
639            .await
640            .unwrap();
641        assert_eq!(list.len(), 3);
642
643        // should list all episodes that are not filtered
644        let list = pool
645            .list_podcast_episodes(ListPodcastEpisodeParams {
646                user_id: 0,
647                filter: ListPodcastEpisodeFilter {
648                    podcast_ids: &[],
649                    filtered: Some(false),
650                    subscribed: Some(true),
651                    watched: Some(false),
652                },
653                sort: Sort {
654                    field: PodcastEpisodeField::PublishedAt,
655                    order: SortOrder::Asc,
656                },
657                page: Page {
658                    limit: 100,
659                    offset: 0,
660                },
661            })
662            .await
663            .unwrap();
664        assert!(list.is_empty());
665    }
666}