1use std::{
2 fmt::{Debug, Display},
3 sync::LazyLock,
4};
5
6use reqwest::{header, Client, Method, Response, StatusCode, Url};
7use serde::de::DeserializeOwned;
8
9use crate::{
10 block::Block,
11 comment::Comment,
12 database::Database,
13 error::NotionError,
14 fetcher::AnyObject,
15 object::{NextCursor, ObjectList},
16 page::Page,
17 user::User,
18};
19
20const NOTION_API_VERSION: &str = "2022-06-28";
21
22#[derive(Clone)]
24pub struct Api {
25 client: Client,
26}
27
28#[derive(Debug)]
29pub enum RequestError {
30 InvalidRequest(String),
31 InvalidResponse(String),
32 RetryAfter(u64), Other(reqwest::Error),
34}
35
36impl RequestError {
37 pub fn invalid_response(s: impl Into<String>) -> Self {
38 Self::InvalidResponse(s.into())
39 }
40}
41
42impl Display for RequestError {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 let s = match self {
45 RequestError::InvalidRequest(s) => format!("invalid request: {s}"),
46 RequestError::InvalidResponse(s) => format!("invalid response: {s}"),
47 RequestError::RetryAfter(s) => format!("retry after: {s}"),
48 RequestError::Other(e) => format!("request error: {e:?}"),
49 };
50 Display::fmt(&s, f)
51 }
52}
53
54impl std::error::Error for RequestError {}
55
56impl Api {
57 pub fn new(token: &str) -> Self {
58 let mut headers = header::HeaderMap::new();
59 headers.insert(
60 "Notion-Version",
61 header::HeaderValue::from_static(NOTION_API_VERSION),
62 );
63 let bearer = format!("Bearer {}", token);
64 let mut auth_value = header::HeaderValue::from_str(&bearer)
65 .expect("token: only visible ASCII characters (32-127) are permitted");
66 auth_value.set_sensitive(true);
67 headers.insert(header::AUTHORIZATION, auth_value);
68
69 Api {
70 client: Client::builder().default_headers(headers).build().unwrap(),
71 }
72 }
73
74 pub async fn get_object<T>(&self, id: &str) -> Result<T, NotionError>
75 where
76 T: DeserializeOwned + Requestable,
77 {
78 let res = self.client.get(T::url(id)).send().await?;
79 check_retry_after(&res)?;
80 let res = check_status_code(res).await?;
81
82 res.json::<T>().await.map_err(|e| {
83 NotionError::RequestFailed(RequestError::InvalidResponse(format!(
84 "decode failed: {e:?}, {}",
85 T::url(id),
86 )))
87 })
88 }
89
90 pub async fn list<T, P>(&self, pagination: &P) -> Result<PaginationResult<T>, NotionError>
91 where
92 T: DeserializeOwned,
93 P: Pagination<T> + NextCursor,
94 {
95 pagination.next_page(&self.client).await
96 }
97}
98
99fn check_retry_after(res: &Response) -> Result<(), NotionError> {
100 if res.status() == StatusCode::TOO_MANY_REQUESTS {
101 let Some(retry_after) = res.headers().get(header::RETRY_AFTER) else {
103 return Err(NotionError::invalid_response(
104 "encounter rate limited error without Retry-After",
105 ));
106 };
107 let after: u64 = retry_after
108 .to_str()
109 .map_err(|_| NotionError::invalid_response("invalid Retry-After header"))
110 .and_then(|s| {
111 s.parse()
112 .map_err(|_| NotionError::invalid_response("invalid Retry-After header"))
113 })?;
114 return Err(NotionError::retry_after(after));
115 };
116 Ok(())
117}
118
119async fn check_status_code(res: Response) -> Result<Response, NotionError> {
120 if !res.status().is_success() {
121 let url = res.url().clone();
122 Err(NotionError::invalid_response(format!(
123 "status: {}, body: {}, url: {url}",
124 res.status(),
125 res.text().await?,
126 )))
127 } else {
128 Ok(res)
129 }
130}
131
132pub trait Pagination<Item>: Debug {
133 fn next_page(
134 &self,
135 client: &Client,
136 ) -> impl std::future::Future<Output = Result<PaginationResult<Item>, NotionError>> + Send;
137}
138
139#[derive(Clone)]
140pub struct PaginationInfo {
141 cursor: Option<String>,
142 url: Url,
143 method: Method,
144 start_index: usize,
145}
146
147impl PaginationInfo {
148 pub fn new<R>(id: &str) -> Self
149 where
150 R: Requestable,
151 {
152 Self::build(R::url(id), R::method())
153 }
154
155 fn build(url: Url, method: Method) -> Self {
156 Self {
157 cursor: None,
158 url,
159 method,
160 start_index: 0,
161 }
162 }
163
164 fn cursor(mut self, cursor: String) -> Self {
165 self.cursor = Some(cursor);
166 self
167 }
168
169 fn start_index(mut self, index: usize) -> Self {
170 self.start_index = index;
171 self
172 }
173}
174
175impl NextCursor for PaginationInfo {
176 fn next_cursor(&self) -> Option<&str> {
177 self.cursor.as_deref()
178 }
179}
180
181impl Debug for PaginationInfo {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("PaginationInfo")
184 .field("params", &self.cursor)
185 .field("url", &self.url)
186 .finish()
187 }
188}
189
190impl<T> Pagination<T> for PaginationInfo
191where
192 T: DeserializeOwned + Send,
193{
194 async fn next_page(&self, client: &Client) -> Result<PaginationResult<T>, NotionError> {
195 let mut url = self.url.clone();
196
197 if let Some(ref next_cursor) = self.cursor {
198 let q = self.url.query_pairs().filter(|(k, _)| k != "start_cursor");
200 url.query_pairs_mut()
201 .clear()
202 .extend_pairs(q)
203 .append_pair("start_cursor", next_cursor)
204 .finish();
205 };
206
207 let res = client.request(self.method.clone(), url).send().await?;
208 check_retry_after(&res)?;
209 let res = check_status_code(res).await?;
210
211 let mut res: ObjectList<T> = res.json().await?;
212 res.start_index = self.start_index;
213 let next_page = res.next_cursor().map(|x| {
214 PaginationInfo::build(self.url.clone(), self.method.clone())
215 .cursor(x.to_owned())
216 .start_index(self.start_index + res.results.len())
217 });
218
219 Ok(PaginationResult::<T> {
220 result: res,
221 pagination: next_page,
222 })
223 }
224}
225
226#[derive(Clone, Debug)]
227pub struct PaginationResult<T> {
228 pub result: ObjectList<T>,
229 pub pagination: Option<PaginationInfo>,
230}
231
232pub trait Requestable {
233 fn url(id: &str) -> Url;
234 fn method() -> Method {
235 Method::GET
236 }
237}
238
239static BASE_URL: LazyLock<Url> =
240 LazyLock::new(|| Url::parse("https://api.notion.com/v1/").unwrap());
241
242impl Requestable for Block {
243 fn url(id: &str) -> Url {
244 BASE_URL.join(&format!("blocks/{id}")).unwrap()
245 }
246}
247
248impl Requestable for Page {
249 fn url(id: &str) -> Url {
250 BASE_URL.join(&format!("pages/{id}")).unwrap()
251 }
252}
253
254impl Requestable for Database {
255 fn url(id: &str) -> Url {
256 BASE_URL.join(&format!("databases/{id}")).unwrap()
257 }
258}
259
260impl Requestable for ObjectList<Block> {
261 fn url(id: &str) -> Url {
262 BASE_URL.join(&format!("blocks/{id}/children")).unwrap()
263 }
264}
265
266impl Requestable for ObjectList<AnyObject> {
267 fn url(id: &str) -> Url {
268 BASE_URL.join(&format!("databases/{id}/query")).unwrap()
269 }
270
271 fn method() -> Method {
272 Method::POST
273 }
274}
275
276impl Requestable for ObjectList<Comment> {
277 fn url(id: &str) -> Url {
278 let mut url = BASE_URL.join("comments").unwrap();
279 url.query_pairs_mut().append_pair("block_id", id).finish();
280 url
281 }
282}
283
284impl Requestable for User {
285 fn url(id: &str) -> Url {
286 BASE_URL.join(&format!("users/{id}")).unwrap()
287 }
288}
289
290impl Requestable for ObjectList<User> {
291 fn url(_: &str) -> Url {
292 BASE_URL.join("users").unwrap()
293 }
294}