notion_async_api/
fetcher.rs

1use std::time::Duration;
2
3use async_rate_limiter::RateLimiter;
4use futures::{
5    channel::mpsc::{channel, Sender},
6    future::BoxFuture,
7    FutureExt, SinkExt, Stream, StreamExt,
8};
9use serde::{Deserialize, Serialize};
10use tokio::spawn;
11
12use crate::{
13    api::{PaginationInfo, PaginationResult},
14    block::Block,
15    comment::Comment,
16    database::Database,
17    error::NotionError,
18    object::{Object, ObjectList, ObjectType},
19    page::Page,
20    user::User,
21    Api,
22};
23
24#[derive(Clone)]
25pub struct Fetcher {
26    api: Api,
27    rate_limiter: RateLimiter,
28}
29
30#[derive(Serialize, Deserialize, Debug, Clone)]
31pub enum AnyObject {
32    Block(Block),
33    Page(Page),
34    Database(Database),
35    User(User),
36    Comment(Comment),
37}
38
39impl Object for AnyObject {
40    fn id(&self) -> &str {
41        match self {
42            AnyObject::Block(x) => x.id(),
43            AnyObject::Page(x) => x.id(),
44            AnyObject::Database(x) => x.id(),
45            AnyObject::User(x) => x.id(),
46            AnyObject::Comment(x) => x.id(),
47        }
48    }
49
50    fn object_type(&self) -> crate::object::ObjectType {
51        match self {
52            AnyObject::Block(_) => ObjectType::Block,
53            AnyObject::Page(_) => ObjectType::Page,
54            AnyObject::Database(_) => ObjectType::Database,
55            AnyObject::User(_) => ObjectType::User,
56            AnyObject::Comment(_) => ObjectType::Comment,
57        }
58    }
59}
60
61#[derive(Debug, Clone)]
62struct Task {
63    req_type: ReqType,
64}
65
66#[derive(Clone, Debug)]
67enum ReqType {
68    Block(String),
69    Page(String),
70    Database(String),
71
72    BlockChildren(PaginationInfo),
73    DatabaseQuery(PaginationInfo),
74    Comments(PaginationInfo),
75}
76
77enum TaskOutput {
78    Block(Block),
79    Page(Page),
80    Database(Database),
81
82    BlockChildren(PaginationResult<Block>),
83    QueryDatabase(PaginationResult<AnyObject>),
84    Comments(PaginationResult<Comment>),
85}
86
87impl<E> TryFrom<Result<PaginationResult<Block>, E>> for TaskOutput {
88    type Error = E;
89    fn try_from(value: Result<PaginationResult<Block>, E>) -> Result<Self, Self::Error> {
90        match value {
91            Ok(x) => Ok(TaskOutput::BlockChildren(x)),
92            Err(e) => Err(e),
93        }
94    }
95}
96
97impl<E> TryFrom<Result<PaginationResult<AnyObject>, E>> for TaskOutput {
98    type Error = E;
99    fn try_from(value: Result<PaginationResult<AnyObject>, E>) -> Result<Self, Self::Error> {
100        match value {
101            Ok(x) => Ok(TaskOutput::QueryDatabase(x)),
102            Err(e) => Err(e),
103        }
104    }
105}
106
107impl<E> TryFrom<Result<Block, E>> for TaskOutput {
108    type Error = E;
109    fn try_from(value: Result<Block, E>) -> Result<Self, Self::Error> {
110        match value {
111            Ok(x) => Ok(TaskOutput::Block(x)),
112            Err(e) => Err(e),
113        }
114    }
115}
116
117impl Fetcher {
118    pub fn new(token: &str) -> Fetcher {
119        Fetcher {
120            api: Api::new(token),
121            rate_limiter: {
122                let rl = RateLimiter::new(3);
123                rl.burst(5);
124                rl
125            },
126        }
127    }
128
129    pub async fn fetch(&self, id: &str) -> impl Stream<Item = Result<AnyObject, NotionError>> {
130        let (res_tx, res_rx) = channel::<Result<AnyObject, NotionError>>(10);
131
132        // Initial task
133        let task = Task {
134            req_type: ReqType::Block(id.to_owned()),
135        };
136
137        let this = self.clone();
138        spawn(async move {
139            this.do_task_recurs(task, res_tx).await;
140        });
141
142        res_rx
143    }
144
145    // Recursive async fn need to be boxed in BoxFuture
146    fn do_task_recurs(
147        &self,
148        task: Task,
149        res_tx: Sender<Result<AnyObject, NotionError>>,
150    ) -> BoxFuture<'static, ()> {
151        let this = self.clone();
152        async move {
153            let (task_tx, mut task_rx) = channel(10);
154
155            {
156                let this = this.clone();
157                let res_tx = res_tx.clone();
158                spawn(async move {
159                    while let Some(task) = task_rx.next().await {
160                        this.do_task_recurs(task, res_tx.clone()).await;
161                    }
162                });
163            }
164
165            this.do_task(task, res_tx.clone(), task_tx).await;
166        }
167        .boxed()
168    }
169
170    async fn do_task(
171        &self,
172        task: Task,
173        mut res_tx: Sender<Result<AnyObject, NotionError>>,
174        mut task_tx: Sender<Task>,
175    ) {
176        let res = self.do_request(task).await;
177        match res {
178            Ok(obj) => {
179                match obj {
180                    TaskOutput::Page(page) => {
181                        // get children
182                        let task = Task {
183                            req_type: ReqType::BlockChildren(PaginationInfo::new::<
184                                ObjectList<Block>,
185                            >(
186                                page.id()
187                            )),
188                        };
189                        task_tx.send(task).await.unwrap();
190
191                        // get comments
192                        let task = Task {
193                            req_type: ReqType::Comments(
194                                PaginationInfo::new::<ObjectList<Comment>>(page.id()),
195                            ),
196                        };
197                        task_tx.send(task).await.unwrap();
198
199                        res_tx.send(Ok(AnyObject::Page(page))).await.unwrap();
200                    }
201                    TaskOutput::Database(database) => {
202                        let task = Task {
203                            req_type: ReqType::DatabaseQuery(PaginationInfo::new::<
204                                ObjectList<Block>,
205                            >(
206                                database.id()
207                            )),
208                        };
209                        task_tx.send(task).await.unwrap();
210                        res_tx
211                            .send(Ok(AnyObject::Database(database)))
212                            .await
213                            .unwrap();
214                    }
215                    TaskOutput::BlockChildren(result) => {
216                        for (idx, mut block) in result.result.results.into_iter().enumerate() {
217                            block.child_index = result.result.start_index + idx;
218                            if let Some(task) = get_task_for_block(&block) {
219                                task_tx.send(task).await.unwrap();
220                            }
221                            res_tx.send(Ok(AnyObject::Block(block))).await.unwrap();
222                        }
223                        if let Some(pagination) = result.pagination {
224                            task_tx
225                                .send(Task {
226                                    req_type: ReqType::BlockChildren(pagination),
227                                })
228                                .await
229                                .unwrap();
230                        }
231                    }
232                    TaskOutput::QueryDatabase(result) => {
233                        for obj in result.result.results {
234                            let task = match obj {
235                                AnyObject::Database(_) => Task {
236                                    req_type: ReqType::DatabaseQuery(PaginationInfo::new::<
237                                        ObjectList<AnyObject>,
238                                    >(
239                                        obj.id()
240                                    )),
241                                },
242                                AnyObject::Page(_) => Task {
243                                    req_type: ReqType::BlockChildren(PaginationInfo::new::<
244                                        ObjectList<Block>,
245                                    >(
246                                        obj.id()
247                                    )),
248                                },
249                                AnyObject::Block(_) => unreachable!("shouldn't be a block"),
250                                AnyObject::User(_) => unreachable!("shouldn't be a user"),
251                                AnyObject::Comment(_) => unreachable!("shouldn't be a comment"),
252                            };
253                            task_tx.send(task).await.unwrap();
254                            res_tx.send(Ok(obj)).await.unwrap();
255                        }
256                        if let Some(pagination) = result.pagination {
257                            task_tx
258                                .send(Task {
259                                    req_type: ReqType::DatabaseQuery(pagination),
260                                })
261                                .await
262                                .unwrap();
263                        }
264                    }
265                    TaskOutput::Block(block) => {
266                        if let Some(task) = get_task_for_block(&block) {
267                            task_tx.send(task).await.unwrap();
268                        }
269                        res_tx.send(Ok(AnyObject::Block(block))).await.unwrap();
270                    }
271                    TaskOutput::Comments(comments) => {
272                        for obj in comments.result.results {
273                            res_tx.send(Ok(AnyObject::Comment(obj))).await.unwrap();
274                        }
275                        if let Some(pagination) = comments.pagination {
276                            task_tx
277                                .send(Task {
278                                    req_type: ReqType::Comments(pagination),
279                                })
280                                .await
281                                .unwrap();
282                        }
283                    }
284                };
285            }
286            Err(e) => res_tx.send(Err(e)).await.unwrap(),
287        }
288    }
289
290    async fn do_request(&self, task: Task) -> Result<TaskOutput, NotionError> {
291        // Repeatly send request if there is a RetryAfter error, otherwise send
292        // the result to the channel.
293        loop {
294            self.rate_limiter.acquire().await;
295
296            let res = match task.req_type {
297                ReqType::Block(ref id) => self
298                    .api
299                    .get_object::<Block>(id)
300                    .await
301                    .map(TaskOutput::Block),
302                ReqType::Page(ref id) => {
303                    self.api.get_object::<Page>(id).await.map(TaskOutput::Page)
304                }
305                ReqType::Database(ref id) => self
306                    .api
307                    .get_object::<Database>(id)
308                    .await
309                    .map(TaskOutput::Database),
310                ReqType::BlockChildren(ref pagination) => self
311                    .api
312                    .list(pagination)
313                    .await
314                    .map(TaskOutput::BlockChildren),
315                ReqType::DatabaseQuery(ref pagination) => self
316                    .api
317                    .list(pagination)
318                    .await
319                    .map(TaskOutput::QueryDatabase),
320                ReqType::Comments(ref pagination) => {
321                    self.api.list(pagination).await.map(TaskOutput::Comments)
322                }
323            };
324
325            let Err(err) = &res else {
326                break res;
327            };
328
329            let crate::error::NotionError::RequestFailed(err) = err else {
330                break res;
331            };
332
333            let crate::api::RequestError::RetryAfter(secs) = err else {
334                break res;
335            };
336
337            tokio::time::sleep(Duration::from_secs(*secs)).await;
338            // should we reset the rate_limiter here?
339        }
340    }
341}
342
343fn get_task_for_block(block: &Block) -> Option<Task> {
344    let block_type = &block.block_type;
345    let id = block.id().to_owned();
346    match block_type {
347        crate::block::BlockType::ChildPage => Some(Task {
348            req_type: ReqType::Page(id),
349        }),
350        crate::block::BlockType::ChildDatabase => Some(Task {
351            req_type: ReqType::Database(id),
352        }),
353        _ => {
354            if block.has_children {
355                Some(Task {
356                    req_type: ReqType::BlockChildren(PaginationInfo::new::<ObjectList<Block>>(&id)),
357                })
358            } else {
359                None
360            }
361        }
362    }
363}