1use std::time::Duration;
2
3use async_rate_limiter::RateLimiter;
4use futures::{
5 channel::mpsc::{channel, Sender},
6 future::BoxFuture,
7 FutureExt, SinkExt, Stream, StreamExt,
8};
9use serde::{Deserialize, Serialize};
10use tokio::spawn;
11
12use crate::{
13 api::{PaginationInfo, PaginationResult},
14 block::Block,
15 comment::Comment,
16 database::Database,
17 error::NotionError,
18 object::{Object, ObjectList, ObjectType},
19 page::Page,
20 user::User,
21 Api,
22};
23
24#[derive(Clone)]
25pub struct Fetcher {
26 api: Api,
27 rate_limiter: RateLimiter,
28}
29
30#[derive(Serialize, Deserialize, Debug, Clone)]
31pub enum AnyObject {
32 Block(Block),
33 Page(Page),
34 Database(Database),
35 User(User),
36 Comment(Comment),
37}
38
39impl Object for AnyObject {
40 fn id(&self) -> &str {
41 match self {
42 AnyObject::Block(x) => x.id(),
43 AnyObject::Page(x) => x.id(),
44 AnyObject::Database(x) => x.id(),
45 AnyObject::User(x) => x.id(),
46 AnyObject::Comment(x) => x.id(),
47 }
48 }
49
50 fn object_type(&self) -> crate::object::ObjectType {
51 match self {
52 AnyObject::Block(_) => ObjectType::Block,
53 AnyObject::Page(_) => ObjectType::Page,
54 AnyObject::Database(_) => ObjectType::Database,
55 AnyObject::User(_) => ObjectType::User,
56 AnyObject::Comment(_) => ObjectType::Comment,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
62struct Task {
63 req_type: ReqType,
64}
65
66#[derive(Clone, Debug)]
67enum ReqType {
68 Block(String),
69 Page(String),
70 Database(String),
71
72 BlockChildren(PaginationInfo),
73 DatabaseQuery(PaginationInfo),
74 Comments(PaginationInfo),
75}
76
77enum TaskOutput {
78 Block(Block),
79 Page(Page),
80 Database(Database),
81
82 BlockChildren(PaginationResult<Block>),
83 QueryDatabase(PaginationResult<AnyObject>),
84 Comments(PaginationResult<Comment>),
85}
86
87impl<E> TryFrom<Result<PaginationResult<Block>, E>> for TaskOutput {
88 type Error = E;
89 fn try_from(value: Result<PaginationResult<Block>, E>) -> Result<Self, Self::Error> {
90 match value {
91 Ok(x) => Ok(TaskOutput::BlockChildren(x)),
92 Err(e) => Err(e),
93 }
94 }
95}
96
97impl<E> TryFrom<Result<PaginationResult<AnyObject>, E>> for TaskOutput {
98 type Error = E;
99 fn try_from(value: Result<PaginationResult<AnyObject>, E>) -> Result<Self, Self::Error> {
100 match value {
101 Ok(x) => Ok(TaskOutput::QueryDatabase(x)),
102 Err(e) => Err(e),
103 }
104 }
105}
106
107impl<E> TryFrom<Result<Block, E>> for TaskOutput {
108 type Error = E;
109 fn try_from(value: Result<Block, E>) -> Result<Self, Self::Error> {
110 match value {
111 Ok(x) => Ok(TaskOutput::Block(x)),
112 Err(e) => Err(e),
113 }
114 }
115}
116
117impl Fetcher {
118 pub fn new(token: &str) -> Fetcher {
119 Fetcher {
120 api: Api::new(token),
121 rate_limiter: {
122 let rl = RateLimiter::new(3);
123 rl.burst(5);
124 rl
125 },
126 }
127 }
128
129 pub async fn fetch(&self, id: &str) -> impl Stream<Item = Result<AnyObject, NotionError>> {
130 let (res_tx, res_rx) = channel::<Result<AnyObject, NotionError>>(10);
131
132 let task = Task {
134 req_type: ReqType::Block(id.to_owned()),
135 };
136
137 let this = self.clone();
138 spawn(async move {
139 this.do_task_recurs(task, res_tx).await;
140 });
141
142 res_rx
143 }
144
145 fn do_task_recurs(
147 &self,
148 task: Task,
149 res_tx: Sender<Result<AnyObject, NotionError>>,
150 ) -> BoxFuture<'static, ()> {
151 let this = self.clone();
152 async move {
153 let (task_tx, mut task_rx) = channel(10);
154
155 {
156 let this = this.clone();
157 let res_tx = res_tx.clone();
158 spawn(async move {
159 while let Some(task) = task_rx.next().await {
160 this.do_task_recurs(task, res_tx.clone()).await;
161 }
162 });
163 }
164
165 this.do_task(task, res_tx.clone(), task_tx).await;
166 }
167 .boxed()
168 }
169
170 async fn do_task(
171 &self,
172 task: Task,
173 mut res_tx: Sender<Result<AnyObject, NotionError>>,
174 mut task_tx: Sender<Task>,
175 ) {
176 let res = self.do_request(task).await;
177 match res {
178 Ok(obj) => {
179 match obj {
180 TaskOutput::Page(page) => {
181 let task = Task {
183 req_type: ReqType::BlockChildren(PaginationInfo::new::<
184 ObjectList<Block>,
185 >(
186 page.id()
187 )),
188 };
189 task_tx.send(task).await.unwrap();
190
191 let task = Task {
193 req_type: ReqType::Comments(
194 PaginationInfo::new::<ObjectList<Comment>>(page.id()),
195 ),
196 };
197 task_tx.send(task).await.unwrap();
198
199 res_tx.send(Ok(AnyObject::Page(page))).await.unwrap();
200 }
201 TaskOutput::Database(database) => {
202 let task = Task {
203 req_type: ReqType::DatabaseQuery(PaginationInfo::new::<
204 ObjectList<Block>,
205 >(
206 database.id()
207 )),
208 };
209 task_tx.send(task).await.unwrap();
210 res_tx
211 .send(Ok(AnyObject::Database(database)))
212 .await
213 .unwrap();
214 }
215 TaskOutput::BlockChildren(result) => {
216 for (idx, mut block) in result.result.results.into_iter().enumerate() {
217 block.child_index = result.result.start_index + idx;
218 if let Some(task) = get_task_for_block(&block) {
219 task_tx.send(task).await.unwrap();
220 }
221 res_tx.send(Ok(AnyObject::Block(block))).await.unwrap();
222 }
223 if let Some(pagination) = result.pagination {
224 task_tx
225 .send(Task {
226 req_type: ReqType::BlockChildren(pagination),
227 })
228 .await
229 .unwrap();
230 }
231 }
232 TaskOutput::QueryDatabase(result) => {
233 for obj in result.result.results {
234 let task = match obj {
235 AnyObject::Database(_) => Task {
236 req_type: ReqType::DatabaseQuery(PaginationInfo::new::<
237 ObjectList<AnyObject>,
238 >(
239 obj.id()
240 )),
241 },
242 AnyObject::Page(_) => Task {
243 req_type: ReqType::BlockChildren(PaginationInfo::new::<
244 ObjectList<Block>,
245 >(
246 obj.id()
247 )),
248 },
249 AnyObject::Block(_) => unreachable!("shouldn't be a block"),
250 AnyObject::User(_) => unreachable!("shouldn't be a user"),
251 AnyObject::Comment(_) => unreachable!("shouldn't be a comment"),
252 };
253 task_tx.send(task).await.unwrap();
254 res_tx.send(Ok(obj)).await.unwrap();
255 }
256 if let Some(pagination) = result.pagination {
257 task_tx
258 .send(Task {
259 req_type: ReqType::DatabaseQuery(pagination),
260 })
261 .await
262 .unwrap();
263 }
264 }
265 TaskOutput::Block(block) => {
266 if let Some(task) = get_task_for_block(&block) {
267 task_tx.send(task).await.unwrap();
268 }
269 res_tx.send(Ok(AnyObject::Block(block))).await.unwrap();
270 }
271 TaskOutput::Comments(comments) => {
272 for obj in comments.result.results {
273 res_tx.send(Ok(AnyObject::Comment(obj))).await.unwrap();
274 }
275 if let Some(pagination) = comments.pagination {
276 task_tx
277 .send(Task {
278 req_type: ReqType::Comments(pagination),
279 })
280 .await
281 .unwrap();
282 }
283 }
284 };
285 }
286 Err(e) => res_tx.send(Err(e)).await.unwrap(),
287 }
288 }
289
290 async fn do_request(&self, task: Task) -> Result<TaskOutput, NotionError> {
291 loop {
294 self.rate_limiter.acquire().await;
295
296 let res = match task.req_type {
297 ReqType::Block(ref id) => self
298 .api
299 .get_object::<Block>(id)
300 .await
301 .map(TaskOutput::Block),
302 ReqType::Page(ref id) => {
303 self.api.get_object::<Page>(id).await.map(TaskOutput::Page)
304 }
305 ReqType::Database(ref id) => self
306 .api
307 .get_object::<Database>(id)
308 .await
309 .map(TaskOutput::Database),
310 ReqType::BlockChildren(ref pagination) => self
311 .api
312 .list(pagination)
313 .await
314 .map(TaskOutput::BlockChildren),
315 ReqType::DatabaseQuery(ref pagination) => self
316 .api
317 .list(pagination)
318 .await
319 .map(TaskOutput::QueryDatabase),
320 ReqType::Comments(ref pagination) => {
321 self.api.list(pagination).await.map(TaskOutput::Comments)
322 }
323 };
324
325 let Err(err) = &res else {
326 break res;
327 };
328
329 let crate::error::NotionError::RequestFailed(err) = err else {
330 break res;
331 };
332
333 let crate::api::RequestError::RetryAfter(secs) = err else {
334 break res;
335 };
336
337 tokio::time::sleep(Duration::from_secs(*secs)).await;
338 }
340 }
341}
342
343fn get_task_for_block(block: &Block) -> Option<Task> {
344 let block_type = &block.block_type;
345 let id = block.id().to_owned();
346 match block_type {
347 crate::block::BlockType::ChildPage => Some(Task {
348 req_type: ReqType::Page(id),
349 }),
350 crate::block::BlockType::ChildDatabase => Some(Task {
351 req_type: ReqType::Database(id),
352 }),
353 _ => {
354 if block.has_children {
355 Some(Task {
356 req_type: ReqType::BlockChildren(PaginationInfo::new::<ObjectList<Block>>(&id)),
357 })
358 } else {
359 None
360 }
361 }
362 }
363}