1use {
2 super::{client::Client, error::Result as Rs621Result},
3 chrono::{offset::Utc, DateTime},
4 derivative::Derivative,
5 futures::{
6 prelude::*,
7 task::{Context, Poll},
8 },
9 itertools::Itertools,
10 serde::{
11 de::{self, Visitor},
12 Deserialize, Deserializer,
13 },
14 std::{borrow::Borrow, pin::Pin},
15};
16
17const ITER_CHUNK_SIZE: u64 = 320;
19
20#[derive(Debug, PartialEq, Eq, Deserialize)]
21pub enum PostFileExtension {
22 #[serde(rename = "jpg")]
23 Jpeg,
24 #[serde(rename = "png")]
25 Png,
26 #[serde(rename = "gif")]
27 Gif,
28 #[serde(rename = "swf")]
29 Swf,
30 #[serde(rename = "webm")]
31 WebM,
32}
33
34#[derive(Debug, PartialEq, Eq, Deserialize)]
35pub struct PostFile {
36 pub width: u64,
37 pub height: u64,
38 pub ext: PostFileExtension,
39 pub size: u64,
40 pub md5: String,
41 pub url: Option<String>,
42}
43
44#[derive(Debug, PartialEq, Eq, Deserialize)]
45pub struct PostPreview {
46 pub width: u64,
47 pub height: u64,
48 pub url: Option<String>,
49}
50
51#[derive(Debug, PartialEq, Eq, Deserialize)]
52pub struct PostSample {
53 pub width: u64,
54 pub height: u64,
55 pub url: Option<String>,
56}
57
58#[derive(Debug, PartialEq, Eq, Deserialize)]
59pub struct PostScore {
60 pub up: i64,
61 pub down: i64,
62 pub total: i64,
63}
64
65#[derive(Debug, PartialEq, Eq, Deserialize)]
66pub struct PostTags {
67 pub general: Vec<String>,
68 pub species: Vec<String>,
69 pub character: Vec<String>,
70 pub artist: Vec<String>,
71 pub invalid: Vec<String>,
72 pub lore: Vec<String>,
73 pub meta: Vec<String>,
74}
75
76#[derive(Debug, PartialEq, Eq, Deserialize)]
77pub struct PostFlags {
78 #[serde(deserialize_with = "nullable_bool_from_json")]
79 pub pending: bool,
80 #[serde(deserialize_with = "nullable_bool_from_json")]
81 pub flagged: bool,
82 #[serde(deserialize_with = "nullable_bool_from_json")]
83 pub note_locked: bool,
84 #[serde(deserialize_with = "nullable_bool_from_json")]
85 pub status_locked: bool,
86 #[serde(deserialize_with = "nullable_bool_from_json")]
87 pub rating_locked: bool,
88 #[serde(deserialize_with = "nullable_bool_from_json")]
89 pub deleted: bool,
90}
91
92#[derive(Debug, PartialEq, Eq, Deserialize)]
93pub enum PostRating {
94 #[serde(rename = "s")]
95 Safe,
96 #[serde(rename = "q")]
97 Questionable,
98 #[serde(rename = "e")]
99 Explicit,
100}
101
102#[derive(Debug, PartialEq, Eq, Deserialize)]
103pub struct PostRelationships {
104 pub parent_id: Option<u64>,
105 pub has_children: bool,
106 pub has_active_children: bool,
107 pub children: Vec<u64>,
108}
109
110#[derive(Debug, PartialEq, Eq, Deserialize)]
112pub struct Post {
113 pub id: u64,
114 pub created_at: DateTime<Utc>,
115 pub updated_at: Option<DateTime<Utc>>,
116 pub file: PostFile,
117 pub preview: PostPreview,
118 pub sample: Option<PostSample>,
119 pub score: PostScore,
120 pub tags: PostTags,
121 pub locked_tags: Vec<String>,
122 pub change_seq: u64,
123 pub flags: PostFlags,
124 pub rating: PostRating,
125 pub fav_count: u64,
126 pub sources: Vec<String>,
127 pub pools: Vec<u64>,
128 pub relationships: PostRelationships,
129 pub approver_id: Option<u64>,
130 pub uploader_id: u64,
131 pub description: String,
132 pub comment_count: u64,
133 pub is_favorited: bool,
134}
135
136#[derive(Debug, PartialEq, Eq, Deserialize)]
137struct PostListApiResponse {
138 pub posts: Vec<Post>,
139}
140
141#[derive(Debug, PartialEq, Eq, Deserialize)]
142struct PostShowApiResponse {
143 pub post: Post,
144}
145
146fn nullable_bool_from_json<'de, D>(de: D) -> Result<bool, D::Error>
147where
148 D: Deserializer<'de>,
149{
150 struct NullableBoolVisitor;
151
152 impl<'de> Visitor<'de> for NullableBoolVisitor {
153 type Value = bool;
154
155 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
156 formatter.write_str("null or bool")
157 }
158
159 fn visit_bool<E: de::Error>(self, v: bool) -> Result<bool, E> {
160 Ok(v)
161 }
162
163 fn visit_unit<E: de::Error>(self) -> Result<bool, E> {
164 Ok(false)
165 }
166 }
167
168 de.deserialize_any(NullableBoolVisitor)
169}
170
171#[derive(Debug, PartialEq, Clone)]
173pub struct Query {
174 url_encoded_tags: String,
175 ordered: bool,
176}
177
178impl<T> From<&[T]> for Query
179where
180 T: AsRef<str>,
181{
182 fn from(q: &[T]) -> Self {
183 let tags: Vec<&str> = q.iter().map(|t| t.as_ref()).collect();
184 let query_str = tags.join(" ");
185 let url_encoded_tags = urlencoding::encode(&query_str);
186 let ordered = tags.iter().any(|t| t.starts_with("order:"));
187
188 Query {
189 url_encoded_tags,
190 ordered,
191 }
192 }
193}
194
195#[derive(Debug, PartialEq, Eq)]
196pub enum SearchPage {
197 Page(u64),
198 BeforePost(u64),
199 AfterPost(u64),
200}
201
202#[derive(Derivative)]
204#[derivative(Debug)]
205pub struct PostSearchStream<'a> {
206 client: &'a Client,
207 query: Query,
208
209 query_url: Option<String>,
210
211 #[derivative(Debug = "ignore")]
212 query_future: Option<Pin<Box<dyn Future<Output = Rs621Result<serde_json::Value>> + Send>>>,
213
214 next_page: SearchPage,
215 chunk: Vec<Rs621Result<Post>>,
216 ended: bool,
217}
218
219impl<'a> PostSearchStream<'a> {
220 fn new<T: Into<Query>>(client: &'a Client, query: T, page: SearchPage) -> Self {
221 PostSearchStream {
222 client: client,
223 query: query.into(),
224
225 query_url: None,
226 query_future: None,
227
228 next_page: page,
229 chunk: Vec::new(),
230 ended: false,
231 }
232 }
233}
234
235impl<'a> Stream for PostSearchStream<'a> {
236 type Item = Rs621Result<Post>;
237
238 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Rs621Result<Post>>> {
239 enum QueryPollRes {
240 Pending,
241 Err(crate::error::Error),
242 NotFetching,
243 }
244
245 let this = self.get_mut();
246
247 loop {
248 let query_status = if let Some(ref mut fut) = this.query_future {
250 match fut.as_mut().poll(cx) {
251 Poll::Ready(res) => {
252 this.query_future = None;
254
255 match res {
256 Ok(body) => {
257 this.chunk =
259 match serde_json::from_value::<PostListApiResponse>(body) {
260 Ok(res) => res
261 .posts
262 .into_iter()
263 .rev()
264 .map(|post| Ok(post))
265 .collect(),
266 Err(e) => vec![Err(e.into())],
267 };
268
269 let last_id = match this.chunk.first() {
270 Some(Ok(post)) => post.id,
271 _ => 0,
272 };
273
274 this.next_page = if this.query.ordered {
276 match this.next_page {
277 SearchPage::Page(i) => SearchPage::Page(i + 1),
278 _ => SearchPage::Page(1),
279 }
280 } else {
281 match this.next_page {
282 SearchPage::Page(_) => SearchPage::BeforePost(last_id),
283 SearchPage::BeforePost(_) => {
284 SearchPage::BeforePost(last_id)
285 }
286 SearchPage::AfterPost(_) => SearchPage::AfterPost(last_id),
287 }
288 };
289
290 this.ended = this.chunk.is_empty();
292 QueryPollRes::NotFetching
293 }
294
295 Err(e) => {
297 this.ended = true;
298 QueryPollRes::Err(e)
299 }
300 }
301 }
302
303 Poll::Pending => QueryPollRes::Pending,
304 }
305 } else {
306 QueryPollRes::NotFetching
307 };
308
309 match query_status {
310 QueryPollRes::Err(e) => return Poll::Ready(Some(Err(e))),
311 QueryPollRes::Pending => return Poll::Pending,
312 QueryPollRes::NotFetching if this.ended => {
313 return Poll::Ready(None);
317 }
318 QueryPollRes::NotFetching if !this.chunk.is_empty() => {
319 let post = this.chunk.pop().unwrap();
321
322 return Poll::Ready(Some(post));
324 }
325 QueryPollRes::NotFetching => {
326 let url = format!(
328 "/posts.json?limit={}&page={}&tags={}",
329 ITER_CHUNK_SIZE,
330 match this.next_page {
331 SearchPage::Page(i) => format!("{}", i),
332 SearchPage::BeforePost(i) => format!("b{}", i),
333 SearchPage::AfterPost(i) => format!("a{}", i),
334 },
335 this.query.url_encoded_tags
336 );
337 this.query_url = Some(url);
338
339 this.query_future = Some(Box::pin(
341 this.client
342 .get_json_endpoint(this.query_url.as_ref().unwrap()),
343 ));
344 }
345 }
346 }
347 }
348}
349
350#[derive(Derivative)]
352#[derivative(Debug)]
353pub struct PostStream<'a, I, T>
354where
355 T: Borrow<u64> + Unpin,
356 I: Iterator<Item = T> + Unpin,
357{
358 client: &'a Client,
359 ids: I,
360
361 query_url: Option<String>,
362
363 #[derivative(Debug = "ignore")]
364 query_future: Option<Pin<Box<dyn Future<Output = Rs621Result<serde_json::Value>> + Send>>>,
365
366 chunk: Vec<Rs621Result<Post>>,
367}
368
369impl<'a, I, T> PostStream<'a, I, T>
370where
371 T: Borrow<u64> + Unpin,
372 I: Iterator<Item = T> + Unpin,
373{
374 fn new(client: &'a Client, ids: I) -> Self {
375 PostStream {
376 client,
377 ids,
378 query_url: None,
379 query_future: None,
380 chunk: Vec::new(),
381 }
382 }
383}
384
385impl<'a, I, T> Stream for PostStream<'a, I, T>
386where
387 T: Borrow<u64> + Unpin,
388 I: Iterator<Item = T> + Unpin,
389{
390 type Item = Rs621Result<Post>;
391
392 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Rs621Result<Post>>> {
393 enum QueryPollRes {
394 Pending,
395 Err(crate::error::Error),
396 NotFetching,
397 }
398
399 let this = self.get_mut();
400
401 loop {
402 let query_status = if let Some(ref mut fut) = this.query_future {
404 match fut.as_mut().poll(cx) {
405 Poll::Ready(res) => {
406 this.query_future = None;
408
409 match res {
410 Ok(body) => {
411 this.chunk =
413 match serde_json::from_value::<PostListApiResponse>(body) {
414 Ok(res) => res
415 .posts
416 .into_iter()
417 .rev()
418 .map(|post| Ok(post))
419 .collect(),
420 Err(e) => vec![Err(e.into())],
421 };
422
423 QueryPollRes::NotFetching
424 }
425
426 Err(e) => QueryPollRes::Err(e),
428 }
429 }
430
431 Poll::Pending => QueryPollRes::Pending,
432 }
433 } else {
434 QueryPollRes::NotFetching
435 };
436
437 match query_status {
438 QueryPollRes::Err(e) => return Poll::Ready(Some(Err(e))),
439 QueryPollRes::Pending => return Poll::Pending,
440 QueryPollRes::NotFetching if !this.chunk.is_empty() => {
441 let post = this.chunk.pop().unwrap();
443
444 return Poll::Ready(Some(post));
446 }
447 QueryPollRes::NotFetching => {
448 let id_list = this.ids.by_ref().take(100).map(|x| *x.borrow()).join(",");
450
451 if id_list.is_empty() {
452 return Poll::Ready(None);
454 }
455
456 let url = format!("/posts.json?tags=id%3A{}", id_list);
457 this.query_url = Some(url);
458
459 this.query_future = Some(Box::pin(
461 this.client
462 .get_json_endpoint(this.query_url.as_ref().unwrap()),
463 ));
464 }
465 }
466 }
467 }
468}
469
470impl Client {
471 pub fn get_posts<'a, I, J, T>(&'a self, ids: I) -> PostStream<'a, J, T>
488 where
489 T: Borrow<u64> + Unpin,
490 J: Iterator<Item = T> + Unpin,
491 I: IntoIterator<Item = T, IntoIter = J> + Unpin,
492 {
493 PostStream::new(self, ids.into_iter())
494 }
495
496 pub fn post_search<'a, T: Into<Query>>(&'a self, tags: T) -> PostSearchStream<'a> {
514 self.post_search_from_page(tags, SearchPage::Page(1))
515 }
516
517 pub fn post_search_from_page<'a, T: Into<Query>>(
564 &'a self,
565 tags: T,
566 page: SearchPage,
567 ) -> PostSearchStream<'a> {
568 PostSearchStream::new(self, tags, page)
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use mockito::{mock, Matcher};
576
577 #[tokio::test]
578 async fn search_ordered() {
579 let client = Client::new(&mockito::server_url(), b"rs621/unit_test").unwrap();
580
581 let query = Query::from(&["fluffy", "rating:s", "order:score"][..]);
582
583 let _m = mock(
584 "GET",
585 Matcher::Exact(format!(
586 "/posts.json?limit={}&page=1&tags={}",
587 ITER_CHUNK_SIZE, query.url_encoded_tags
588 )),
589 )
590 .with_body(include_str!(
591 "mocked/320_page-1_fluffy_rating-s_order-score.json"
592 ))
593 .create();
594
595 assert_eq!(
596 client
597 .post_search(query)
598 .take(100)
599 .collect::<Vec<_>>()
600 .await,
601 serde_json::from_str::<PostListApiResponse>(include_str!(
602 "mocked/320_page-1_fluffy_rating-s_order-score.json"
603 ))
604 .unwrap()
605 .posts
606 .into_iter()
607 .take(100)
608 .map(|x| Ok(x))
609 .collect::<Vec<_>>()
610 );
611 }
612
613 #[tokio::test]
614 async fn search_above_limit_ordered() {
615 let client = Client::new(&mockito::server_url(), b"rs621/unit_test").unwrap();
616
617 let query = Query::from(&["fluffy", "rating:s", "order:score"][..]);
618 const PAGES: [&str; 2] = [
619 include_str!("mocked/320_page-1_fluffy_rating-s_order-score.json"),
620 include_str!("mocked/320_page-2_fluffy_rating-s_order-score.json"),
621 ];
622
623 let _m = [
624 mock(
625 "GET",
626 Matcher::Exact(format!(
627 "/posts.json?limit={}&page=1&tags={}",
628 ITER_CHUNK_SIZE, query.url_encoded_tags
629 )),
630 )
631 .with_body(PAGES[0])
632 .create(),
633 mock(
634 "GET",
635 Matcher::Exact(format!(
636 "/posts.json?limit={}&page=2&tags={}",
637 ITER_CHUNK_SIZE, query.url_encoded_tags
638 )),
639 )
640 .with_body(PAGES[1])
641 .create(),
642 ];
643
644 assert_eq!(
645 client
646 .post_search(query)
647 .take(400)
648 .collect::<Vec<_>>()
649 .await,
650 serde_json::from_str::<PostListApiResponse>(PAGES[0])
651 .unwrap()
652 .posts
653 .into_iter()
654 .chain(
655 serde_json::from_str::<PostListApiResponse>(PAGES[1])
656 .unwrap()
657 .posts
658 .into_iter()
659 )
660 .take(400)
661 .map(|x| Ok(x))
662 .collect::<Vec<_>>()
663 );
664 }
665
666 #[tokio::test]
667 async fn search_before_id() {
668 let client = Client::new(&mockito::server_url(), b"rs621/unit_test").unwrap();
669
670 let query = Query::from(&["fluffy", "rating:s"][..]);
671 let response_json = include_str!("mocked/320_fluffy_rating-s_before-2269211.json");
672 let response: PostListApiResponse = serde_json::from_str(response_json).unwrap();
673 let expected: Vec<_> = response.posts.into_iter().take(80).map(|x| Ok(x)).collect();
674
675 let _m = mock(
676 "GET",
677 Matcher::Exact(format!(
678 "/posts.json?limit={}&page=b2269211&tags={}",
679 ITER_CHUNK_SIZE, query.url_encoded_tags
680 )),
681 )
682 .with_body(response_json)
683 .create();
684
685 assert_eq!(
686 client
687 .post_search_from_page(query, SearchPage::BeforePost(2269211))
688 .take(80)
689 .collect::<Vec<_>>()
690 .await,
691 expected
692 );
693 }
694
695 #[tokio::test]
696 async fn search_above_limit() {
697 let client = Client::new(&mockito::server_url(), b"rs621/unit_test").unwrap();
698
699 let query = Query::from(&["fluffy", "rating:s"][..]);
700 let responses_json: [&str; 2] = [
701 include_str!("mocked/320_fluffy_rating-s.json"),
702 include_str!("mocked/320_fluffy_rating-s_before-2269211.json"),
703 ];
704 let mut responses: [Option<PostListApiResponse>; 2] = [
705 Some(serde_json::from_str(responses_json[0]).unwrap()),
706 Some(serde_json::from_str(responses_json[1]).unwrap()),
707 ];
708 let expected: Vec<_> = responses[0]
709 .take()
710 .unwrap()
711 .posts
712 .into_iter()
713 .chain(responses[1].take().unwrap().posts.into_iter())
714 .take(400)
715 .map(|x| Ok(x))
716 .collect();
717
718 let _m = [
719 mock(
720 "GET",
721 Matcher::Exact(format!(
722 "/posts.json?limit={}&page=1&tags={}",
723 ITER_CHUNK_SIZE, query.url_encoded_tags
724 )),
725 )
726 .with_body(responses_json[0])
727 .create(),
728 mock(
729 "GET",
730 Matcher::Exact(format!(
731 "/posts.json?limit={}&page=b2269211&tags={}",
732 ITER_CHUNK_SIZE, query.url_encoded_tags
733 )),
734 )
735 .with_body(responses_json[1])
736 .create(),
737 ];
738
739 assert_eq!(
740 client
741 .post_search(query)
742 .take(400)
743 .collect::<Vec<_>>()
744 .await,
745 expected
746 );
747 }
748
749 #[tokio::test]
750 async fn search_no_result() {
751 let client = Client::new(&mockito::server_url(), b"rs621/unit_test").unwrap();
752
753 let query = Query::from(&["fluffy", "rating:s"][..]);
754 let response = "{\"posts\":[]}";
755
756 let _m = mock(
757 "GET",
758 Matcher::Exact(format!(
759 "/posts.json?limit={}&page=1&tags={}",
760 ITER_CHUNK_SIZE, query.url_encoded_tags
761 )),
762 )
763 .with_body(response)
764 .create();
765
766 assert_eq!(
767 client.post_search(query).take(5).collect::<Vec<_>>().await,
768 vec![]
769 );
770 }
771
772 #[tokio::test]
773 async fn search_simple() {
774 let client = Client::new(&mockito::server_url(), b"rs621/unit_test").unwrap();
775
776 let query = Query::from(&["fluffy", "rating:s"][..]);
777 let response_json = include_str!("mocked/320_fluffy_rating-s.json");
778 let response: PostListApiResponse = serde_json::from_str(response_json).unwrap();
779 let expected: Vec<_> = response.posts.into_iter().take(5).map(|x| Ok(x)).collect();
780
781 let _m = mock(
782 "GET",
783 Matcher::Exact(format!(
784 "/posts.json?limit={}&page=1&tags={}",
785 ITER_CHUNK_SIZE, query.url_encoded_tags
786 )),
787 )
788 .with_body(response_json)
789 .create();
790
791 assert_eq!(
792 client.post_search(query).take(5).collect::<Vec<_>>().await,
793 expected
794 );
795 }
796
797 #[tokio::test]
798 async fn get_posts_by_id() {
799 let client = Client::new(&mockito::server_url(), b"rs621/unit_test").unwrap();
800
801 let response_json = include_str!("mocked/id_8595_535_2105_1470.json");
802 let response: PostListApiResponse = serde_json::from_str(response_json).unwrap();
803 let expected = response.posts;
804
805 let _m = mock("GET", "/posts.json?tags=id%3A8595,535,2105,1470")
806 .with_body(response_json)
807 .create();
808
809 assert_eq!(
810 client
811 .get_posts(&[8595, 535, 2105, 1470])
812 .collect::<Vec<_>>()
813 .await,
814 expected.into_iter().map(|p| Ok(p)).collect::<Vec<_>>(),
815 );
816 }
817}