1use anyhow::Context;
2use entertainarr_domain::podcast::entity::{
3 ListPodcastParams, Podcast, PodcastEpisodeInput, PodcastInput, PodcastSubscription,
4 PodcastWithSubscription,
5};
6
7use crate::Wrapper;
8use crate::prelude::{HasAnyOf, QueryBuilderExt};
9
10const FIND_PODCAST_BY_ID_QUERY: &str = "select id, feed_url, title, description, image_url, language, website, created_at, updated_at from podcasts where id = ? limit 1";
11const FIND_PODCAST_BY_FEED_URL_QUERY: &str = "select id, feed_url, title, description, image_url, language, website, created_at, updated_at from podcasts where feed_url like ? limit 1";
12const UPSERT_PODCAST_QUERY: &str = r#"insert into podcasts (feed_url, title, description, image_url, language, website)
13values (?, ?, ?, ?, ?, ?)
14on conflict (feed_url) do update set
15 title=excluded.title,
16 description=excluded.description,
17 image_url=excluded.image_url,
18 language=excluded.language,
19 website=excluded.website,
20 updated_at=CURRENT_TIMESTAMP
21returning id, feed_url, title, description, image_url, language, website, created_at, updated_at"#;
22const UPSERT_USER_PODCAST_QUERY: &str = "insert into user_podcasts (user_id, podcast_id) values (?, ?) on conflict (user_id, podcast_id) do nothing";
23const DELETE_USER_PODCAST_QUERY: &str =
24 "delete from user_podcasts where user_id = ? and podcast_id = ?";
25
26impl super::Pool {
27 #[tracing::instrument(
28 skip_all,
29 fields(
30 otel.kind = "client",
31 db.system = "sqlite",
32 db.name = "podcast",
33 db.operation = "SELECT",
34 db.sql.table = "podcasts",
35 db.query.text = FIND_PODCAST_BY_ID_QUERY,
36 db.response.returned_rows = tracing::field::Empty,
37 error.type = tracing::field::Empty,
38 error.message = tracing::field::Empty,
39 error.stacktrace = tracing::field::Empty,
40 ),
41 err(Debug),
42 )]
43 pub(crate) async fn find_podcast_by_id(
44 &self,
45 podcast_id: u64,
46 ) -> anyhow::Result<Option<Podcast>> {
47 sqlx::query_as(FIND_PODCAST_BY_ID_QUERY)
48 .bind(podcast_id as i64)
49 .fetch_optional(&self.0)
50 .await
51 .inspect(super::record_optional)
52 .inspect_err(super::record_error)
53 .map(Wrapper::maybe_inner)
54 .context("unable to query podcasts by id")
55 }
56 #[tracing::instrument(
57 skip_all,
58 fields(
59 otel.kind = "client",
60 db.system = "sqlite",
61 db.name = "podcast",
62 db.operation = "SELECT",
63 db.sql.table = "podcasts",
64 db.query.text = FIND_PODCAST_BY_FEED_URL_QUERY,
65 db.response.returned_rows = tracing::field::Empty,
66 error.type = tracing::field::Empty,
67 error.message = tracing::field::Empty,
68 error.stacktrace = tracing::field::Empty,
69 ),
70 err(Debug),
71 )]
72 pub(crate) async fn find_podcast_by_feed_url(
73 &self,
74 feed_url: &str,
75 ) -> anyhow::Result<Option<Podcast>> {
76 sqlx::query_as(FIND_PODCAST_BY_FEED_URL_QUERY)
77 .bind(feed_url)
78 .fetch_optional(&self.0)
79 .await
80 .inspect(super::record_optional)
81 .inspect_err(super::record_error)
82 .map(Wrapper::maybe_inner)
83 .context("unable to query podcasts by feed url")
84 }
85
86 #[tracing::instrument(
87 skip_all,
88 fields(
89 otel.kind = "client",
90 db.system = "sqlite",
91 db.name = "podcast",
92 db.operation = "INSERT",
93 db.sql.table = "podcasts",
94 db.query.text = UPSERT_PODCAST_QUERY,
95 db.response.returned_rows = tracing::field::Empty,
96 error.type = tracing::field::Empty,
97 error.message = tracing::field::Empty,
98 error.stacktrace = tracing::field::Empty,
99 ),
100 err(Debug),
101 )]
102 pub(crate) async fn upsert_podcast<'c, E: sqlx::SqliteExecutor<'c>>(
103 &self,
104 executor: E,
105 entity: &PodcastInput,
106 ) -> anyhow::Result<Podcast> {
107 sqlx::query_as(UPSERT_PODCAST_QUERY)
108 .bind(&entity.feed_url)
109 .bind(&entity.title)
110 .bind(entity.description.as_ref())
111 .bind(entity.image_url.as_ref())
112 .bind(entity.language.as_ref())
113 .bind(entity.website.as_ref())
114 .fetch_one(executor)
115 .await
116 .inspect(super::record_one)
117 .inspect_err(super::record_error)
118 .map(Wrapper::inner)
119 .context("unable to upsert podcast")
120 }
121
122 #[tracing::instrument(
123 skip_all,
124 fields(
125 otel.kind = "client",
126 db.system = "sqlite",
127 db.name = "podcast",
128 db.operation = "INSERT",
129 db.sql.table = "podcast_episodes",
130 db.query.text = tracing::field::Empty,
131 db.response.returned_rows = tracing::field::Empty,
132 error.type = tracing::field::Empty,
133 error.message = tracing::field::Empty,
134 error.stacktrace = tracing::field::Empty,
135 ),
136 err(Debug),
137 )]
138 pub(crate) async fn upsert_podcast_episodes<'c, E: sqlx::SqliteExecutor<'c>>(
139 &self,
140 executor: E,
141 podcast_id: u64,
142 episodes: &[PodcastEpisodeInput],
143 ) -> anyhow::Result<()> {
144 let mut qb: sqlx::QueryBuilder<'_, sqlx::Sqlite> = sqlx::QueryBuilder::new(
145 "insert into podcast_episodes (podcast_id, guid, published_at, title, description, link, duration, file_url, file_size, file_type)",
146 );
147 qb.push_values(episodes.iter(), |mut b, item| {
148 b.push_bind(podcast_id as i64)
149 .push_bind(&item.guid)
150 .push_bind(item.published_at)
151 .push_bind(&item.title)
152 .push_bind(&item.description)
153 .push_bind(&item.link)
154 .push_bind(item.duration.as_ref().map(|value| value.as_secs() as i64))
155 .push_bind(&item.file_url)
156 .push_bind(item.file_size.map(|value| value as i64))
157 .push_bind(&item.file_type);
158 });
159 qb.push(" on conflict (podcast_id, guid) do nothing");
160
161 tracing::Span::current().record("db.query.text", qb.sql());
162
163 qb.build()
164 .execute(executor)
165 .await
166 .inspect_err(super::record_error)
167 .map(|_| ())
168 .context("unable to upsert podcast episodes")
169 }
170 #[tracing::instrument(
171 skip_all,
172 fields(
173 otel.kind = "client",
174 db.system = "sqlite",
175 db.name = "podcast",
176 db.operation = "SELECT",
177 db.sql.table = "podcasts",
178 db.query.text = tracing::field::Empty,
179 db.response.returned_rows = tracing::field::Empty,
180 error.type = tracing::field::Empty,
181 error.message = tracing::field::Empty,
182 error.stacktrace = tracing::field::Empty,
183 ),
184 err(Debug),
185 )]
186 pub(crate) async fn list_podcasts<'a>(
187 &self,
188 params: ListPodcastParams<'a>,
189 ) -> anyhow::Result<Vec<Podcast>> {
190 let mut qb = sqlx::QueryBuilder::new(
191 "select podcasts.id, podcasts.feed_url, podcasts.title, podcasts.description, podcasts.image_url, podcasts.language, podcasts.website, podcasts.created_at, podcasts.updated_at from podcasts",
192 );
193 if let Some(user_id) = params.user_id
194 && params.filter.subscribed.is_some()
195 {
196 qb.push(" left outer join user_podcasts on user_podcasts.podcast_id = podcasts.id and user_podcasts.user_id = ").push_bind(user_id as i64);
197 }
198 qb.push(" where true");
199 if !params.filter.podcast_ids.is_empty() {
200 qb.push(" and ");
201 qb.push_any("podcasts.id", params.filter.podcast_ids);
202 }
203 if params.user_id.is_some() {
204 match params.filter.subscribed {
205 Some(true) => {
206 qb.push(" and user_podcasts.created_at is not null");
207 }
208 Some(false) => {
209 qb.push(" and user_podcasts.created_at is null");
210 }
211 None => {}
212 }
213 }
214 qb.push_pagination(params.page);
215
216 tracing::Span::current().record("db.query.text", qb.sql());
217
218 qb.build_query_as()
219 .fetch_all(&self.0)
220 .await
221 .inspect(super::record_all)
222 .inspect_err(super::record_error)
223 .map(Wrapper::list)
224 .context("unable to query podcasts by feed url")
225 }
226}
227
228impl entertainarr_domain::podcast::prelude::PodcastRepository for super::Pool {
229 async fn find_by_id(&self, podcast_id: u64) -> anyhow::Result<Option<Podcast>> {
230 self.find_podcast_by_id(podcast_id).await
231 }
232
233 async fn find_by_feed_url(&self, feed_url: &str) -> anyhow::Result<Option<Podcast>> {
234 self.find_podcast_by_feed_url(feed_url).await
235 }
236
237 async fn list<'a>(&self, params: ListPodcastParams<'a>) -> anyhow::Result<Vec<Podcast>> {
238 self.list_podcasts(params).await
239 }
240
241 async fn upsert(&self, entity: &PodcastInput) -> anyhow::Result<Podcast> {
242 let mut tx = self
243 .0
244 .begin()
245 .await
246 .context("unable to begin transaction")?;
247
248 let podcast = self.upsert_podcast(&mut *tx, entity).await?;
249 self.upsert_podcast_episodes(&mut *tx, podcast.id, &entity.episodes)
250 .await?;
251
252 tx.commit().await.context("unable to commit transaction")?;
253 Ok(podcast)
254 }
255}
256
257const FIND_PODCAST_SUBSCRIPTION_BY_ID_QUERY: &str = r#"select
258 user_podcasts.user_id,
259 user_podcasts.podcast_id,
260 count(podcast_episodes.id),
261 user_podcasts.min_duration,
262 user_podcasts.max_duration,
263 user_podcasts.created_at
264from user_podcasts
265left outer join podcast_episodes on podcast_episodes.podcast_id = user_podcasts.podcast_id
266where user_podcasts.user_id = ?
267 and user_podcasts.podcast_id = ?
268 and (
269 user_podcasts.min_duration is null
270 or podcast_episodes.duration is null
271 or podcast_episodes.duration > user_podcasts.min_duration
272 )
273 and (
274 user_podcasts.max_duration is null
275 or podcast_episodes.duration is null
276 or podcast_episodes.duration < user_podcasts.max_duration
277 )
278group by user_podcasts.user_id, user_podcasts.podcast_id"#;
279
280const UPDATE_SUBSCRIPTION_QUERY: &str = r#"update user_podcasts set min_duration = ?, max_duration = ? where user_id = ? and podcast_id = ?"#;
281
282impl super::Pool {
283 #[tracing::instrument(
284 skip_all,
285 fields(
286 otel.kind = "client",
287 db.system = "sqlite",
288 db.name = "podcast",
289 db.operation = "UPSERT",
290 db.sql.table = "user_podcasts",
291 db.query.text = UPSERT_USER_PODCAST_QUERY,
292 db.response.returned_rows = tracing::field::Empty,
293 error.type = tracing::field::Empty,
294 error.message = tracing::field::Empty,
295 error.stacktrace = tracing::field::Empty,
296 ),
297 err(Debug),
298 )]
299 pub(crate) async fn create_podcast_subscription(
300 &self,
301 user_id: u64,
302 podcast_id: u64,
303 ) -> anyhow::Result<()> {
304 sqlx::query(UPSERT_USER_PODCAST_QUERY)
305 .bind(user_id as i64)
306 .bind(podcast_id as i64)
307 .execute(&self.0)
308 .await
309 .inspect_err(super::record_error)
310 .map(|_| ())
311 .context("unable to upsert user podcast relation")
312 }
313
314 #[tracing::instrument(
315 skip_all,
316 fields(
317 otel.kind = "client",
318 db.system = "sqlite",
319 db.name = "podcast",
320 db.operation = "update",
321 db.sql.table = "user_podcasts",
322 db.query.text = UPDATE_SUBSCRIPTION_QUERY,
323 db.response.returned_rows = tracing::field::Empty,
324 error.type = tracing::field::Empty,
325 error.message = tracing::field::Empty,
326 error.stacktrace = tracing::field::Empty,
327 ),
328 err(Debug),
329 )]
330 pub(crate) async fn update_podcast_subscription(
331 &self,
332 user_id: u64,
333 podcast_id: u64,
334 body: entertainarr_domain::podcast::entity::PodcastSubscriptionUpdate,
335 ) -> anyhow::Result<()> {
336 sqlx::query(UPDATE_SUBSCRIPTION_QUERY)
337 .bind(body.min_duration.map(|value| value as i64))
338 .bind(body.max_duration.map(|value| value as i64))
339 .bind(user_id as i64)
340 .bind(podcast_id as i64)
341 .fetch_optional(&self.0)
342 .await
343 .inspect_err(super::record_error)
344 .map(|_| ())
345 .context("unable to update podcast subscriptions")
346 }
347
348 #[tracing::instrument(
349 skip_all,
350 fields(
351 otel.kind = "client",
352 db.system = "sqlite",
353 db.name = "podcast",
354 db.operation = "DELETE",
355 db.sql.table = "user_podcasts",
356 db.query.text = DELETE_USER_PODCAST_QUERY,
357 db.response.returned_rows = tracing::field::Empty,
358 error.type = tracing::field::Empty,
359 error.message = tracing::field::Empty,
360 error.stacktrace = tracing::field::Empty,
361 ),
362 err(Debug),
363 )]
364 pub(crate) async fn delete_podcast_subscription(
365 &self,
366 user_id: u64,
367 podcast_id: u64,
368 ) -> anyhow::Result<()> {
369 sqlx::query(DELETE_USER_PODCAST_QUERY)
370 .bind(user_id as i64)
371 .bind(podcast_id as i64)
372 .execute(&self.0)
373 .await
374 .inspect_err(super::record_error)
375 .map(|_| ())
376 .context("unable to delete user podcast relation")
377 }
378
379 #[tracing::instrument(
380 skip_all,
381 fields(
382 otel.kind = "client",
383 db.system = "sqlite",
384 db.name = "podcast",
385 db.operation = "SELECT",
386 db.sql.table = "podcasts",
387 db.query.text = tracing::field::Empty,
388 db.response.returned_rows = tracing::field::Empty,
389 error.type = tracing::field::Empty,
390 error.message = tracing::field::Empty,
391 error.stacktrace = tracing::field::Empty,
392 ),
393 err(Debug),
394 )]
395 pub(crate) async fn list_podcast_subscriptions(
396 &self,
397 user_id: u64,
398 podcast_ids: &[u64],
399 ) -> anyhow::Result<Vec<PodcastSubscription>> {
400 let mut qb: sqlx::QueryBuilder<'_, sqlx::Sqlite> = sqlx::QueryBuilder::new(
401 "select user_podcasts.user_id, user_podcasts.podcast_id, count(podcast_episodes.id), user_podcasts.min_duration, user_podcasts.max_duration, user_podcasts.created_at",
402 );
403 qb.push(" from user_podcasts");
404 qb.push(" left outer join podcast_episodes on podcast_episodes.podcast_id = user_podcasts.podcast_id");
405 qb.push(" where user_podcasts.user_id = ")
406 .push_bind(user_id as i64);
407 if !podcast_ids.is_empty() {
408 qb.push(" and ")
409 .push_any("user_podcasts.podcast_id", podcast_ids);
410 }
411 qb.push(" and (user_podcasts.min_duration is null or podcast_episodes.duration is null or podcast_episodes.duration > user_podcasts.min_duration)");
412 qb.push(" and (user_podcasts.max_duration is null or podcast_episodes.duration is null or podcast_episodes.duration < user_podcasts.max_duration)");
413 qb.push(" group by user_podcasts.user_id, user_podcasts.podcast_id");
414
415 tracing::Span::current().record("db.query.text", qb.sql());
416
417 qb.build_query_as()
418 .fetch_all(&self.0)
419 .await
420 .inspect(super::record_all)
421 .inspect_err(super::record_error)
422 .map(Wrapper::list)
423 .context("unable to list podcast subscriptions")
424 }
425
426 #[tracing::instrument(
427 skip_all,
428 fields(
429 otel.kind = "client",
430 db.system = "sqlite",
431 db.name = "podcast",
432 db.operation = "select",
433 db.sql.table = "podcasts",
434 db.query.text = FIND_PODCAST_SUBSCRIPTION_BY_ID_QUERY,
435 db.response.returned_rows = tracing::field::Empty,
436 error.type = tracing::field::Empty,
437 error.message = tracing::field::Empty,
438 error.stacktrace = tracing::field::Empty,
439 ),
440 err(Debug),
441 )]
442 pub(crate) async fn find_podcast_subscription(
443 &self,
444 user_id: u64,
445 podcast_id: u64,
446 ) -> anyhow::Result<Option<entertainarr_domain::podcast::entity::PodcastSubscription>> {
447 sqlx::query_as(FIND_PODCAST_SUBSCRIPTION_BY_ID_QUERY)
448 .bind(user_id as i64)
449 .bind(podcast_id as i64)
450 .fetch_optional(&self.0)
451 .await
452 .inspect(super::record_optional)
453 .inspect_err(super::record_error)
454 .map(Wrapper::maybe_inner)
455 .context("unable to find podcast subscriptions")
456 }
457}
458
459impl entertainarr_domain::podcast::prelude::PodcastSubscriptionRepository for super::Pool {
460 async fn create(&self, user_id: u64, podcast_id: u64) -> anyhow::Result<()> {
461 self.create_podcast_subscription(user_id, podcast_id).await
462 }
463
464 async fn delete(&self, user_id: u64, podcast_id: u64) -> anyhow::Result<()> {
465 self.delete_podcast_subscription(user_id, podcast_id).await
466 }
467
468 async fn list(
469 &self,
470 user_id: u64,
471 podcast_ids: &[u64],
472 ) -> anyhow::Result<Vec<PodcastSubscription>> {
473 self.list_podcast_subscriptions(user_id, podcast_ids).await
474 }
475
476 async fn update(
477 &self,
478 user_id: u64,
479 podcast_id: u64,
480 body: entertainarr_domain::podcast::entity::PodcastSubscriptionUpdate,
481 ) -> anyhow::Result<()> {
482 self.update_podcast_subscription(user_id, podcast_id, body)
483 .await
484 }
485
486 async fn find(
487 &self,
488 user_id: u64,
489 podcast_id: u64,
490 ) -> anyhow::Result<Option<entertainarr_domain::podcast::entity::PodcastSubscription>> {
491 self.find_podcast_subscription(user_id, podcast_id).await
492 }
493}
494
495impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<Podcast> {
496 fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
497 use sqlx::Row;
498
499 Ok(Self(Podcast {
500 id: row.try_get(0)?,
501 feed_url: row.try_get(1)?,
502 title: row.try_get(2)?,
503 description: row.try_get(3)?,
504 image_url: row.try_get(4)?,
505 language: row.try_get(5)?,
506 website: row.try_get(6)?,
507 created_at: row.try_get(7)?,
508 updated_at: row.try_get(8)?,
509 }))
510 }
511}
512
513impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<PodcastWithSubscription> {
514 fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
515 use sqlx::Row;
516
517 Ok(Self(PodcastWithSubscription {
518 podcast: Podcast {
519 id: row.try_get(0)?,
520 feed_url: row.try_get(1)?,
521 title: row.try_get(2)?,
522 description: row.try_get(3)?,
523 image_url: row.try_get(4)?,
524 language: row.try_get(5)?,
525 website: row.try_get(6)?,
526 created_at: row.try_get(7)?,
527 updated_at: row.try_get(8)?,
528 },
529 queue_size: row.try_get(9)?,
530 }))
531 }
532}
533
534impl<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> for super::Wrapper<PodcastSubscription> {
535 fn from_row(row: &'r sqlx::sqlite::SqliteRow) -> Result<Self, sqlx::Error> {
536 use sqlx::Row;
537
538 Ok(Self(PodcastSubscription {
539 user_id: row.try_get(0)?,
540 podcast_id: row.try_get(1)?,
541 queue_size: row.try_get(2)?,
542 min_duration: row.try_get(3)?,
543 max_duration: row.try_get(4)?,
544 created_at: row.try_get(5)?,
545 }))
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use entertainarr_domain::podcast::entity::{PodcastEpisodeInput, PodcastInput};
552 use entertainarr_domain::podcast::prelude::PodcastRepository;
553
554 use crate::Pool;
555
556 #[tokio::test]
557 async fn should_find_episode_by_feed_url_when_existing() {
558 let tmpdir = tempfile::tempdir().unwrap();
559 let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
560
561 let podcast = pool
562 .upsert_podcast(
563 pool.as_ref(),
564 &PodcastInput::builder()
565 .feed_url("http://foo.bar/atom.rss")
566 .title("Foo Bar")
567 .build(),
568 )
569 .await
570 .unwrap();
571
572 let res = pool
573 .find_by_feed_url("http://foo.bar/atom.rss")
574 .await
575 .unwrap()
576 .unwrap();
577
578 assert_eq!(podcast.id, res.id);
579 }
580
581 #[tokio::test]
582 async fn should_find_episode_by_feed_url_when_missing() {
583 let tmpdir = tempfile::tempdir().unwrap();
584 let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
585 let res = pool.find_by_feed_url("missing_url").await.unwrap();
586 assert!(res.is_none());
587 }
588
589 async fn count_postcasts(pool: &Pool) -> u32 {
590 sqlx::query_scalar("select count(*) from podcasts")
591 .fetch_one(&pool.0)
592 .await
593 .unwrap()
594 }
595
596 async fn count_postcast_episodes(pool: &Pool) -> u32 {
597 sqlx::query_scalar("select count(*) from podcast_episodes")
598 .fetch_one(&pool.0)
599 .await
600 .unwrap()
601 }
602
603 #[tokio::test]
604 async fn should_upsert_podcast_and_episodes() {
605 let tmpdir = tempfile::tempdir().unwrap();
606 let pool = crate::Pool::test(&tmpdir.path().join("db")).await;
607
608 let content = PodcastInput::builder()
609 .feed_url("http://example.com/atom.rss")
610 .title("Example")
611 .episodes(vec![
612 PodcastEpisodeInput::builder()
613 .guid("aaaaa")
614 .title("First episode")
615 .file_url("http://example.com/first.mp3")
616 .build(),
617 PodcastEpisodeInput::builder()
618 .guid("aaaab")
619 .title("Second episode")
620 .file_url("http://example.com/second.mp3")
621 .build(),
622 ])
623 .build();
624 let _res = pool.upsert(&content).await.unwrap();
625 assert_eq!(count_postcasts(&pool).await, 1);
626 assert_eq!(count_postcast_episodes(&pool).await, 2);
627 let _res = pool.upsert(&content).await.unwrap();
629 assert_eq!(count_postcasts(&pool).await, 1);
630 assert_eq!(count_postcast_episodes(&pool).await, 2);
631 }
632}