notion_async_api/
api.rs

1use std::{
2    fmt::{Debug, Display},
3    sync::LazyLock,
4};
5
6use reqwest::{header, Client, Method, Response, StatusCode, Url};
7use serde::de::DeserializeOwned;
8
9use crate::{
10    block::Block,
11    comment::Comment,
12    database::Database,
13    error::NotionError,
14    fetcher::AnyObject,
15    object::{NextCursor, ObjectList},
16    page::Page,
17    user::User,
18};
19
20const NOTION_API_VERSION: &str = "2022-06-28";
21
22/// Low-level notion Api.
23#[derive(Clone)]
24pub struct Api {
25    client: Client,
26}
27
28#[derive(Debug)]
29pub enum RequestError {
30    InvalidRequest(String),
31    InvalidResponse(String),
32    RetryAfter(u64), // seconds
33    Other(reqwest::Error),
34}
35
36impl RequestError {
37    pub fn invalid_response(s: impl Into<String>) -> Self {
38        Self::InvalidResponse(s.into())
39    }
40}
41
42impl Display for RequestError {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        let s = match self {
45            RequestError::InvalidRequest(s) => format!("invalid request: {s}"),
46            RequestError::InvalidResponse(s) => format!("invalid response: {s}"),
47            RequestError::RetryAfter(s) => format!("retry after: {s}"),
48            RequestError::Other(e) => format!("request error: {e:?}"),
49        };
50        Display::fmt(&s, f)
51    }
52}
53
54impl std::error::Error for RequestError {}
55
56impl Api {
57    pub fn new(token: &str) -> Self {
58        let mut headers = header::HeaderMap::new();
59        headers.insert(
60            "Notion-Version",
61            header::HeaderValue::from_static(NOTION_API_VERSION),
62        );
63        let bearer = format!("Bearer {}", token);
64        let mut auth_value = header::HeaderValue::from_str(&bearer)
65            .expect("token: only visible ASCII characters (32-127) are permitted");
66        auth_value.set_sensitive(true);
67        headers.insert(header::AUTHORIZATION, auth_value);
68
69        Api {
70            client: Client::builder().default_headers(headers).build().unwrap(),
71        }
72    }
73
74    pub async fn get_object<T>(&self, id: &str) -> Result<T, NotionError>
75    where
76        T: DeserializeOwned + Requestable,
77    {
78        let res = self.client.get(T::url(id)).send().await?;
79        check_retry_after(&res)?;
80        let res = check_status_code(res).await?;
81
82        res.json::<T>().await.map_err(|e| {
83            NotionError::RequestFailed(RequestError::InvalidResponse(format!(
84                "decode failed: {e:?}, {}",
85                T::url(id),
86            )))
87        })
88    }
89
90    pub async fn list<T, P>(&self, pagination: &P) -> Result<PaginationResult<T>, NotionError>
91    where
92        T: DeserializeOwned,
93        P: Pagination<T> + NextCursor,
94    {
95        pagination.next_page(&self.client).await
96    }
97}
98
99fn check_retry_after(res: &Response) -> Result<(), NotionError> {
100    if res.status() == StatusCode::TOO_MANY_REQUESTS {
101        // extract Retry-After
102        let Some(retry_after) = res.headers().get(header::RETRY_AFTER) else {
103            return Err(NotionError::invalid_response(
104                "encounter rate limited error without Retry-After",
105            ));
106        };
107        let after: u64 = retry_after
108            .to_str()
109            .map_err(|_| NotionError::invalid_response("invalid Retry-After header"))
110            .and_then(|s| {
111                s.parse()
112                    .map_err(|_| NotionError::invalid_response("invalid Retry-After header"))
113            })?;
114        return Err(NotionError::retry_after(after));
115    };
116    Ok(())
117}
118
119async fn check_status_code(res: Response) -> Result<Response, NotionError> {
120    if !res.status().is_success() {
121        let url = res.url().clone();
122        Err(NotionError::invalid_response(format!(
123            "status: {}, body: {}, url: {url}",
124            res.status(),
125            res.text().await?,
126        )))
127    } else {
128        Ok(res)
129    }
130}
131
132pub trait Pagination<Item>: Debug {
133    fn next_page(
134        &self,
135        client: &Client,
136    ) -> impl std::future::Future<Output = Result<PaginationResult<Item>, NotionError>> + Send;
137}
138
139#[derive(Clone)]
140pub struct PaginationInfo {
141    cursor: Option<String>,
142    url: Url,
143    method: Method,
144    start_index: usize,
145}
146
147impl PaginationInfo {
148    pub fn new<R>(id: &str) -> Self
149    where
150        R: Requestable,
151    {
152        Self::build(R::url(id), R::method())
153    }
154
155    fn build(url: Url, method: Method) -> Self {
156        Self {
157            cursor: None,
158            url,
159            method,
160            start_index: 0,
161        }
162    }
163
164    fn cursor(mut self, cursor: String) -> Self {
165        self.cursor = Some(cursor);
166        self
167    }
168
169    fn start_index(mut self, index: usize) -> Self {
170        self.start_index = index;
171        self
172    }
173}
174
175impl NextCursor for PaginationInfo {
176    fn next_cursor(&self) -> Option<&str> {
177        self.cursor.as_deref()
178    }
179}
180
181impl Debug for PaginationInfo {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        f.debug_struct("PaginationInfo")
184            .field("params", &self.cursor)
185            .field("url", &self.url)
186            .finish()
187    }
188}
189
190impl<T> Pagination<T> for PaginationInfo
191where
192    T: DeserializeOwned + Send,
193{
194    async fn next_page(&self, client: &Client) -> Result<PaginationResult<T>, NotionError> {
195        let mut url = self.url.clone();
196
197        if let Some(ref next_cursor) = self.cursor {
198            // set start_cursor
199            let q = self.url.query_pairs().filter(|(k, _)| k != "start_cursor");
200            url.query_pairs_mut()
201                .clear()
202                .extend_pairs(q)
203                .append_pair("start_cursor", next_cursor)
204                .finish();
205        };
206
207        let res = client.request(self.method.clone(), url).send().await?;
208        check_retry_after(&res)?;
209        let res = check_status_code(res).await?;
210
211        let mut res: ObjectList<T> = res.json().await?;
212        res.start_index = self.start_index;
213        let next_page = res.next_cursor().map(|x| {
214            PaginationInfo::build(self.url.clone(), self.method.clone())
215                .cursor(x.to_owned())
216                .start_index(self.start_index + res.results.len())
217        });
218
219        Ok(PaginationResult::<T> {
220            result: res,
221            pagination: next_page,
222        })
223    }
224}
225
226#[derive(Clone, Debug)]
227pub struct PaginationResult<T> {
228    pub result: ObjectList<T>,
229    pub pagination: Option<PaginationInfo>,
230}
231
232pub trait Requestable {
233    fn url(id: &str) -> Url;
234    fn method() -> Method {
235        Method::GET
236    }
237}
238
239static BASE_URL: LazyLock<Url> =
240    LazyLock::new(|| Url::parse("https://api.notion.com/v1/").unwrap());
241
242impl Requestable for Block {
243    fn url(id: &str) -> Url {
244        BASE_URL.join(&format!("blocks/{id}")).unwrap()
245    }
246}
247
248impl Requestable for Page {
249    fn url(id: &str) -> Url {
250        BASE_URL.join(&format!("pages/{id}")).unwrap()
251    }
252}
253
254impl Requestable for Database {
255    fn url(id: &str) -> Url {
256        BASE_URL.join(&format!("databases/{id}")).unwrap()
257    }
258}
259
260impl Requestable for ObjectList<Block> {
261    fn url(id: &str) -> Url {
262        BASE_URL.join(&format!("blocks/{id}/children")).unwrap()
263    }
264}
265
266impl Requestable for ObjectList<AnyObject> {
267    fn url(id: &str) -> Url {
268        BASE_URL.join(&format!("databases/{id}/query")).unwrap()
269    }
270
271    fn method() -> Method {
272        Method::POST
273    }
274}
275
276impl Requestable for ObjectList<Comment> {
277    fn url(id: &str) -> Url {
278        let mut url = BASE_URL.join("comments").unwrap();
279        url.query_pairs_mut().append_pair("block_id", id).finish();
280        url
281    }
282}
283
284impl Requestable for User {
285    fn url(id: &str) -> Url {
286        BASE_URL.join(&format!("users/{id}")).unwrap()
287    }
288}
289
290impl Requestable for ObjectList<User> {
291    fn url(_: &str) -> Url {
292        BASE_URL.join("users").unwrap()
293    }
294}