ornithology_cli/
api.rs

1use anyhow::Context;
2use futures_util::StreamExt;
3use indicatif::ProgressBar;
4use oauth2::basic::BasicClient;
5use oauth2::{AuthUrl, AuthorizationCode, CsrfToken, PkceCodeChallenge, Scope, TokenUrl};
6use serde::{Deserialize, Serialize};
7use tower::{Service, ServiceExt};
8
9mod oauth;
10
11/// A client that knows how to authenticate Twitter API requests.
12#[derive(Debug, Clone)]
13struct RawClient {
14    http: reqwest::Client,
15    oauth:
16        oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, oauth2::basic::BasicTokenType>,
17}
18
19impl RawClient {
20    async fn new(client_id: oauth2::ClientId) -> anyhow::Result<Self> {
21        // Stand up a localhost server to receive the OAuth redirect.
22        let (redirect, auth) = oauth::redirect_server()
23            .await
24            .context("start auth callback server")?;
25
26        // OAuth time!
27        let client = BasicClient::new(
28            client_id.clone(),
29            None,
30            AuthUrl::new("https://twitter.com/i/oauth2/authorize".to_string())?,
31            Some(TokenUrl::new(
32                "https://api.twitter.com/2/oauth2/token".to_string(),
33            )?),
34        )
35        .set_redirect_uri(redirect);
36        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
37        let (auth_url, csrf_token) = client
38            .authorize_url(|| CsrfToken::new_random())
39            .add_scope(Scope::new("tweet.read".to_string()))
40            .add_scope(Scope::new("users.read".to_string()))
41            .add_scope(Scope::new("follows.read".to_string()))
42            .set_pkce_challenge(pkce_challenge)
43            .url();
44
45        // Now the user needs to auth us, so open that in their browser. Once they click authorize
46        // (or not), the localhost webserver will catch the redirect and we'll have the
47        // authorization code that we can then exchange for a token.
48        open::that(auth_url.to_string()).context("forward to Twitter for authorization")?;
49        let authorization_code = match auth.await.context("oauth callback is called")? {
50            oauth::Redirect::Authorized { state, .. } if &*state != csrf_token.secret() => {
51                anyhow::bail!("bad csrf token")
52            }
53            oauth::Redirect::Authorized { code, .. } => code,
54            oauth::Redirect::Error(e) => anyhow::bail!(e),
55        };
56
57        // Exchange the one-time auth code for a longer-lived multi-use auth token.
58        // XXX: refresh token after 2h? request offline.access? spawn that?
59        // https://developer.twitter.com/en/docs/authentication/oauth-2-0/user-access-token
60        let http = reqwest::Client::new();
61        let token_response = client
62            .exchange_code(AuthorizationCode::new(authorization_code))
63            .set_pkce_verifier(pkce_verifier)
64            // Twitter's API requires we supply this.
65            .add_extra_param("client_id", client_id.as_str())
66            .request_async(|req| oauth::async_client_request(&http, req))
67            .await;
68        let token = match token_response {
69            Ok(token) => Ok(token),
70            Err(oauth2::RequestTokenError::ServerResponse(r)) => {
71                let e = Err(anyhow::anyhow!(r.error().clone()));
72                match (r.error_description(), r.error_uri()) {
73                    (Some(desc), Some(url)) => {
74                        e.context(url.to_string()).context(desc.to_string()).into()
75                    }
76                    (Some(desc), None) => e.context(desc.to_string()).into(),
77                    (None, Some(url)) => e.context(url.to_string()).into(),
78                    (None, None) => e,
79                }
80            }
81            Err(e) => Err(anyhow::anyhow!(e)),
82        }
83        .context("exchange oauth code")?;
84
85        Ok(Self { http, oauth: token })
86    }
87}
88
89impl tower::Service<reqwest::RequestBuilder> for RawClient {
90    type Response = reqwest::Response;
91    type Error = anyhow::Error;
92    type Future =
93        std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>>>>;
94
95    fn poll_ready(
96        &mut self,
97        _: &mut std::task::Context<'_>,
98    ) -> std::task::Poll<Result<(), Self::Error>> {
99        std::task::Poll::Ready(Ok(()))
100    }
101
102    fn call(&mut self, req: reqwest::RequestBuilder) -> Self::Future {
103        use oauth2::TokenResponse;
104        let fut = req.bearer_auth(self.oauth.access_token().secret()).send();
105        Box::pin(async move { Ok(fut.await.context("request")?) })
106    }
107}
108
109/// Retry policy that knows to look for Twitter's special HTTP reply + header.
110/// <https://developer.twitter.com/en/docs/twitter-api/rate-limits#headers-and-codes>
111#[derive(Copy, Clone)]
112struct TwitterRateLimitPolicy;
113impl tower::retry::Policy<reqwest::RequestBuilder, reqwest::Response, anyhow::Error>
114    for TwitterRateLimitPolicy
115{
116    type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Self>>>;
117
118    fn retry(
119        &self,
120        _: &reqwest::RequestBuilder,
121        result: Result<&reqwest::Response, &anyhow::Error>,
122    ) -> Option<Self::Future> {
123        let r = match result {
124            Err(_) => return None,
125            Ok(r) if r.status() == reqwest::StatusCode::TOO_MANY_REQUESTS => r,
126            Ok(_) => return None,
127        };
128
129        let reset = r
130            .headers()
131            .get("x-rate-limit-reset")
132            .expect("Twitter promised");
133        let reset: u64 = reset
134            .to_str()
135            .expect("x-rate-limit-reset as str")
136            .parse()
137            .expect("x-rate-limit-reset is a number");
138        let time = std::time::UNIX_EPOCH + std::time::Duration::from_secs(reset);
139
140        Some(Box::pin(async move {
141            match time.duration_since(std::time::SystemTime::now()) {
142                Ok(d) if d.as_secs() > 1 => {
143                    tokio::time::sleep(d).await;
144                }
145                _ => {
146                    // Not worth waiting -- can just retry immediately.
147                }
148            }
149            Self
150        }))
151    }
152
153    fn clone_request(&self, req: &reqwest::RequestBuilder) -> Option<reqwest::RequestBuilder> {
154        req.try_clone()
155    }
156}
157
158/// A Twitter API client that authenticates requests and respects rate limitations.
159///
160/// Note that this client does not try to proactively follow rate limits, since the limits depend
161/// on the endpoint, and this is generic over all endpoints. It's up to the caller (see
162/// page_at_a_time! for an instance of this) to wrap the inner service in an appropriate-limited
163/// `tower::limit::RateLimit` as needed for repeated requests.
164pub struct Client(tower::retry::Retry<TwitterRateLimitPolicy, RawClient>);
165
166macro_rules! page_at_a_time {
167    ($this:ident, $msg:expr, $ids:ident, $pagesize:literal, $rate:expr, $url:literal, $t:ty) => {{
168        let mut ids = $ids.into_iter();
169        let n = ids.len();
170        let bar = ProgressBar::new(n as u64)
171            .with_style(
172                indicatif::ProgressStyle::default_bar()
173                    .template("{msg:>15} {bar:40} {percent:>3}% [{elapsed}]"),
174            )
175            .with_message($msg);
176        let mut all: Vec<$t> = Vec::with_capacity(n);
177        let mut svc = tower::limit::RateLimit::new(&mut $this.0, $rate);
178        let mut futs = futures_util::stream::FuturesUnordered::new();
179        for page in 0.. {
180            const PAGE_SIZE: usize = $pagesize;
181            let i = page * PAGE_SIZE;
182            if i >= n {
183                break;
184            }
185            let mut idsstr = String::new();
186            for id in (&mut ids).take(PAGE_SIZE) {
187                use std::fmt::Write;
188                if idsstr.is_empty() {
189                    write!(&mut idsstr, "{}", id)
190                } else {
191                    write!(&mut idsstr, ",{}", id)
192                }
193                .expect("this is fine");
194            }
195            let url = format!($url, idsstr);
196            let req = svc.get_mut().get_mut().http.get(&url);
197            loop {
198                // The service may not be ready until we make progress on one of the in-flight
199                // requests, so make sure to drive those forward too.
200                tokio::select! {
201                    ready = svc.ready() => {
202                        let _ = ready.context("Service::poll_ready")?;
203                        break;
204                    }
205                    chunk = futs.next(), if !futs.is_empty() => {
206                        let chunk: anyhow::Result<Vec<$t>> = chunk.expect("!futs.is_empty()");
207                        all.extend(chunk.context("grab next chunk")?);
208                    }
209                };
210            }
211            let res = tower::Service::call(&mut svc, req);
212            let bar = bar.clone();
213            futs.push(async move {
214                let data: Vec<$t> = Self::parse(
215                    res.await
216                        .with_context(|| format!("Service::call('{}')", url))?,
217                )
218                .await
219                .with_context(|| format!("parse('{}')", url))?
220                .0;
221                bar.inc(data.len() as u64);
222                Ok(data)
223            });
224        }
225
226        while let Some(chunk) = futs.next().await.transpose().context("grab chunks")? {
227            all.extend(chunk);
228        }
229        bar.finish();
230
231        Ok(all)
232    }};
233}
234
235impl Client {
236    pub async fn new(client_id: oauth2::ClientId) -> anyhow::Result<Self> {
237        RawClient::new(client_id)
238            .await
239            .map(|svc| tower::retry::Retry::new(TwitterRateLimitPolicy, svc))
240            .map(Self)
241    }
242
243    pub async fn whoami(&mut self) -> anyhow::Result<WhoAmI> {
244        let req = self
245            .0
246            .get_mut()
247            .http
248            .get("https://api.twitter.com/2/users/me");
249        let data: WhoAmI = Self::parse(
250            self.0
251                .ready()
252                .await
253                .context("Service::poll_ready")?
254                .call(req)
255                .await
256                .context("Service::call")?,
257        )
258        .await
259        .context("parse whoami")?
260        .0;
261        Ok(data)
262    }
263
264    pub async fn tweets<I>(&mut self, ids: I) -> anyhow::Result<Vec<Tweet>>
265    where
266        I: IntoIterator<Item = u64>,
267        I::IntoIter: ExactSizeIterator,
268    {
269        page_at_a_time!(
270            self,
271            "Fetch tweets",
272            ids,
273            100,
274            tower::limit::rate::Rate::new(900, std::time::Duration::from_secs(15 * 60)),
275            "https://api.twitter.com/2/tweets?tweet.fields=id,created_at,public_metrics&ids={}",
276            Tweet
277        )
278    }
279
280    pub async fn users<I>(&mut self, ids: I) -> anyhow::Result<Vec<User>>
281    where
282        I: IntoIterator<Item = u64>,
283        I::IntoIter: ExactSizeIterator,
284    {
285        page_at_a_time!(
286            self,
287            "Fetch followers",
288            ids,
289            100,
290            tower::limit::rate::Rate::new(900, std::time::Duration::from_secs(15 * 60)),
291            "https://api.twitter.com/2/users?user.fields=username,public_metrics&ids={}",
292            User
293        )
294    }
295
296    async fn parse<T>(res: reqwest::Response) -> anyhow::Result<(T, Option<Meta>)>
297    where
298        T: serde::de::DeserializeOwned,
299    {
300        // This is th general structure of all Twitter API responses.
301        #[derive(Debug, Deserialize)]
302        struct Data<T> {
303            data: T,
304            meta: Option<Meta>,
305        }
306
307        // We _could_ do:
308        // let data: Data<T> = res.json().await.context("parse")?;
309        // but that would make for unhelpful error messages if parsing fails, so we do:
310        let data = res.text().await.context("get body")?;
311        let data: Data<T> = serde_json::from_str(&data)
312            .with_context(|| data)
313            .context("parse")?;
314        Ok((data.data, data.meta))
315    }
316}
317
318/*
319pub async fn from_pages<T, TT, TR, F, FT, C>(
320    http: &reqwest::Client,
321    token: &TR,
322    url: impl Into<url::Url>,
323    mut map: F,
324) -> anyhow::Result<C>
325where
326    T: serde::de::DeserializeOwned,
327    TT: oauth2::TokenType,
328    TR: oauth2::TokenResponse<TT>,
329    F: FnMut(T, &Meta) -> FT,
330    C: Default + Extend<FT>,
331{
332    let mut all: C = Default::default();
333    let mut next = None::<String>;
334    let url = url.into();
335    loop {
336        let url = match next.as_ref() {
337            None => url.clone(),
338            Some(p) => {
339                let mut url = url.clone();
340                url.query_pairs_mut().append_pair("pagination_token", p);
341                url
342            }
343        };
344
345        let (page, meta): (Vec<T>, _) = grab(http, token, url.to_string())
346            .await
347            .context("followers")?;
348        let meta = meta.expect("always meta for this");
349
350        assert_eq!(page.len(), meta.results);
351        all.extend(page.into_iter().map(|t| (map)(t, &meta)));
352        if meta.next.is_none() {
353            break;
354        }
355        next = meta.next;
356    }
357    Ok(all)
358}
359*/
360
361#[derive(Debug, Deserialize)]
362pub struct WhoAmI {
363    pub id: String,
364    pub username: String,
365}
366
367#[derive(Debug, Serialize, Deserialize)]
368pub struct PublicTweetMetrics {
369    #[serde(rename = "retweet_count")]
370    pub retweets: usize,
371    #[serde(rename = "reply_count")]
372    pub replies: usize,
373    #[serde(rename = "like_count")]
374    pub likes: usize,
375    #[serde(rename = "quote_count")]
376    pub quotations: usize,
377}
378
379#[derive(Debug, Serialize, Deserialize)]
380pub struct Tweet {
381    #[serde(rename = "id", with = "u64_but_str")]
382    pub id: u64,
383    #[serde(rename = "created_at", with = "time::serde::rfc3339")]
384    pub created: time::OffsetDateTime,
385    #[serde(rename = "public_metrics")]
386    pub metrics: PublicTweetMetrics,
387    // not reading in text: String here
388    // would be great to read non_public_metrics, but those aren't available >30 days
389}
390
391impl Tweet {
392    pub fn goodness(&self) -> usize {
393        self.metrics.likes
394            + 2 * self.metrics.retweets
395            + 3 * self.metrics.quotations
396            + self.metrics.replies / 2
397    }
398}
399
400#[derive(Debug, Serialize, Deserialize)]
401pub struct PublicUserMetrics {
402    #[serde(rename = "followers_count")]
403    pub followers: usize,
404    #[serde(rename = "following_count")]
405    pub following: usize,
406}
407
408#[derive(Debug, Serialize, Deserialize)]
409pub struct User {
410    pub username: String,
411    #[serde(rename = "public_metrics")]
412    pub metrics: PublicUserMetrics,
413}
414
415// Keeping this around for if I ever need to add pagination.
416#[derive(Debug, Deserialize)]
417#[allow(dead_code)]
418struct Meta {
419    #[serde(rename = "result_count")]
420    results: usize,
421    #[serde(rename = "next_token")]
422    next: Option<String>,
423}
424
425mod u64_but_str {
426    use std::fmt::Display;
427
428    use serde::{de, Deserialize, Deserializer, Serializer};
429
430    pub fn serialize<T, S>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
431    where
432        T: Display,
433        S: Serializer,
434    {
435        serializer.collect_str(value)
436    }
437
438    pub fn deserialize<'de, D>(deserializer: D) -> Result<u64, D::Error>
439    where
440        D: Deserializer<'de>,
441    {
442        let s = String::deserialize(deserializer)?;
443        s.parse().map_err(de::Error::custom)
444    }
445}