1use anyhow::Context;
2use futures_util::StreamExt;
3use indicatif::ProgressBar;
4use oauth2::basic::BasicClient;
5use oauth2::{AuthUrl, AuthorizationCode, CsrfToken, PkceCodeChallenge, Scope, TokenUrl};
6use serde::{Deserialize, Serialize};
7use tower::{Service, ServiceExt};
8
9mod oauth;
10
11#[derive(Debug, Clone)]
13struct RawClient {
14 http: reqwest::Client,
15 oauth:
16 oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, oauth2::basic::BasicTokenType>,
17}
18
19impl RawClient {
20 async fn new(client_id: oauth2::ClientId) -> anyhow::Result<Self> {
21 let (redirect, auth) = oauth::redirect_server()
23 .await
24 .context("start auth callback server")?;
25
26 let client = BasicClient::new(
28 client_id.clone(),
29 None,
30 AuthUrl::new("https://twitter.com/i/oauth2/authorize".to_string())?,
31 Some(TokenUrl::new(
32 "https://api.twitter.com/2/oauth2/token".to_string(),
33 )?),
34 )
35 .set_redirect_uri(redirect);
36 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
37 let (auth_url, csrf_token) = client
38 .authorize_url(|| CsrfToken::new_random())
39 .add_scope(Scope::new("tweet.read".to_string()))
40 .add_scope(Scope::new("users.read".to_string()))
41 .add_scope(Scope::new("follows.read".to_string()))
42 .set_pkce_challenge(pkce_challenge)
43 .url();
44
45 open::that(auth_url.to_string()).context("forward to Twitter for authorization")?;
49 let authorization_code = match auth.await.context("oauth callback is called")? {
50 oauth::Redirect::Authorized { state, .. } if &*state != csrf_token.secret() => {
51 anyhow::bail!("bad csrf token")
52 }
53 oauth::Redirect::Authorized { code, .. } => code,
54 oauth::Redirect::Error(e) => anyhow::bail!(e),
55 };
56
57 let http = reqwest::Client::new();
61 let token_response = client
62 .exchange_code(AuthorizationCode::new(authorization_code))
63 .set_pkce_verifier(pkce_verifier)
64 .add_extra_param("client_id", client_id.as_str())
66 .request_async(|req| oauth::async_client_request(&http, req))
67 .await;
68 let token = match token_response {
69 Ok(token) => Ok(token),
70 Err(oauth2::RequestTokenError::ServerResponse(r)) => {
71 let e = Err(anyhow::anyhow!(r.error().clone()));
72 match (r.error_description(), r.error_uri()) {
73 (Some(desc), Some(url)) => {
74 e.context(url.to_string()).context(desc.to_string()).into()
75 }
76 (Some(desc), None) => e.context(desc.to_string()).into(),
77 (None, Some(url)) => e.context(url.to_string()).into(),
78 (None, None) => e,
79 }
80 }
81 Err(e) => Err(anyhow::anyhow!(e)),
82 }
83 .context("exchange oauth code")?;
84
85 Ok(Self { http, oauth: token })
86 }
87}
88
89impl tower::Service<reqwest::RequestBuilder> for RawClient {
90 type Response = reqwest::Response;
91 type Error = anyhow::Error;
92 type Future =
93 std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>>>>;
94
95 fn poll_ready(
96 &mut self,
97 _: &mut std::task::Context<'_>,
98 ) -> std::task::Poll<Result<(), Self::Error>> {
99 std::task::Poll::Ready(Ok(()))
100 }
101
102 fn call(&mut self, req: reqwest::RequestBuilder) -> Self::Future {
103 use oauth2::TokenResponse;
104 let fut = req.bearer_auth(self.oauth.access_token().secret()).send();
105 Box::pin(async move { Ok(fut.await.context("request")?) })
106 }
107}
108
109#[derive(Copy, Clone)]
112struct TwitterRateLimitPolicy;
113impl tower::retry::Policy<reqwest::RequestBuilder, reqwest::Response, anyhow::Error>
114 for TwitterRateLimitPolicy
115{
116 type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Self>>>;
117
118 fn retry(
119 &self,
120 _: &reqwest::RequestBuilder,
121 result: Result<&reqwest::Response, &anyhow::Error>,
122 ) -> Option<Self::Future> {
123 let r = match result {
124 Err(_) => return None,
125 Ok(r) if r.status() == reqwest::StatusCode::TOO_MANY_REQUESTS => r,
126 Ok(_) => return None,
127 };
128
129 let reset = r
130 .headers()
131 .get("x-rate-limit-reset")
132 .expect("Twitter promised");
133 let reset: u64 = reset
134 .to_str()
135 .expect("x-rate-limit-reset as str")
136 .parse()
137 .expect("x-rate-limit-reset is a number");
138 let time = std::time::UNIX_EPOCH + std::time::Duration::from_secs(reset);
139
140 Some(Box::pin(async move {
141 match time.duration_since(std::time::SystemTime::now()) {
142 Ok(d) if d.as_secs() > 1 => {
143 tokio::time::sleep(d).await;
144 }
145 _ => {
146 }
148 }
149 Self
150 }))
151 }
152
153 fn clone_request(&self, req: &reqwest::RequestBuilder) -> Option<reqwest::RequestBuilder> {
154 req.try_clone()
155 }
156}
157
158pub struct Client(tower::retry::Retry<TwitterRateLimitPolicy, RawClient>);
165
166macro_rules! page_at_a_time {
167 ($this:ident, $msg:expr, $ids:ident, $pagesize:literal, $rate:expr, $url:literal, $t:ty) => {{
168 let mut ids = $ids.into_iter();
169 let n = ids.len();
170 let bar = ProgressBar::new(n as u64)
171 .with_style(
172 indicatif::ProgressStyle::default_bar()
173 .template("{msg:>15} {bar:40} {percent:>3}% [{elapsed}]"),
174 )
175 .with_message($msg);
176 let mut all: Vec<$t> = Vec::with_capacity(n);
177 let mut svc = tower::limit::RateLimit::new(&mut $this.0, $rate);
178 let mut futs = futures_util::stream::FuturesUnordered::new();
179 for page in 0.. {
180 const PAGE_SIZE: usize = $pagesize;
181 let i = page * PAGE_SIZE;
182 if i >= n {
183 break;
184 }
185 let mut idsstr = String::new();
186 for id in (&mut ids).take(PAGE_SIZE) {
187 use std::fmt::Write;
188 if idsstr.is_empty() {
189 write!(&mut idsstr, "{}", id)
190 } else {
191 write!(&mut idsstr, ",{}", id)
192 }
193 .expect("this is fine");
194 }
195 let url = format!($url, idsstr);
196 let req = svc.get_mut().get_mut().http.get(&url);
197 loop {
198 tokio::select! {
201 ready = svc.ready() => {
202 let _ = ready.context("Service::poll_ready")?;
203 break;
204 }
205 chunk = futs.next(), if !futs.is_empty() => {
206 let chunk: anyhow::Result<Vec<$t>> = chunk.expect("!futs.is_empty()");
207 all.extend(chunk.context("grab next chunk")?);
208 }
209 };
210 }
211 let res = tower::Service::call(&mut svc, req);
212 let bar = bar.clone();
213 futs.push(async move {
214 let data: Vec<$t> = Self::parse(
215 res.await
216 .with_context(|| format!("Service::call('{}')", url))?,
217 )
218 .await
219 .with_context(|| format!("parse('{}')", url))?
220 .0;
221 bar.inc(data.len() as u64);
222 Ok(data)
223 });
224 }
225
226 while let Some(chunk) = futs.next().await.transpose().context("grab chunks")? {
227 all.extend(chunk);
228 }
229 bar.finish();
230
231 Ok(all)
232 }};
233}
234
235impl Client {
236 pub async fn new(client_id: oauth2::ClientId) -> anyhow::Result<Self> {
237 RawClient::new(client_id)
238 .await
239 .map(|svc| tower::retry::Retry::new(TwitterRateLimitPolicy, svc))
240 .map(Self)
241 }
242
243 pub async fn whoami(&mut self) -> anyhow::Result<WhoAmI> {
244 let req = self
245 .0
246 .get_mut()
247 .http
248 .get("https://api.twitter.com/2/users/me");
249 let data: WhoAmI = Self::parse(
250 self.0
251 .ready()
252 .await
253 .context("Service::poll_ready")?
254 .call(req)
255 .await
256 .context("Service::call")?,
257 )
258 .await
259 .context("parse whoami")?
260 .0;
261 Ok(data)
262 }
263
264 pub async fn tweets<I>(&mut self, ids: I) -> anyhow::Result<Vec<Tweet>>
265 where
266 I: IntoIterator<Item = u64>,
267 I::IntoIter: ExactSizeIterator,
268 {
269 page_at_a_time!(
270 self,
271 "Fetch tweets",
272 ids,
273 100,
274 tower::limit::rate::Rate::new(900, std::time::Duration::from_secs(15 * 60)),
275 "https://api.twitter.com/2/tweets?tweet.fields=id,created_at,public_metrics&ids={}",
276 Tweet
277 )
278 }
279
280 pub async fn users<I>(&mut self, ids: I) -> anyhow::Result<Vec<User>>
281 where
282 I: IntoIterator<Item = u64>,
283 I::IntoIter: ExactSizeIterator,
284 {
285 page_at_a_time!(
286 self,
287 "Fetch followers",
288 ids,
289 100,
290 tower::limit::rate::Rate::new(900, std::time::Duration::from_secs(15 * 60)),
291 "https://api.twitter.com/2/users?user.fields=username,public_metrics&ids={}",
292 User
293 )
294 }
295
296 async fn parse<T>(res: reqwest::Response) -> anyhow::Result<(T, Option<Meta>)>
297 where
298 T: serde::de::DeserializeOwned,
299 {
300 #[derive(Debug, Deserialize)]
302 struct Data<T> {
303 data: T,
304 meta: Option<Meta>,
305 }
306
307 let data = res.text().await.context("get body")?;
311 let data: Data<T> = serde_json::from_str(&data)
312 .with_context(|| data)
313 .context("parse")?;
314 Ok((data.data, data.meta))
315 }
316}
317
318#[derive(Debug, Deserialize)]
362pub struct WhoAmI {
363 pub id: String,
364 pub username: String,
365}
366
367#[derive(Debug, Serialize, Deserialize)]
368pub struct PublicTweetMetrics {
369 #[serde(rename = "retweet_count")]
370 pub retweets: usize,
371 #[serde(rename = "reply_count")]
372 pub replies: usize,
373 #[serde(rename = "like_count")]
374 pub likes: usize,
375 #[serde(rename = "quote_count")]
376 pub quotations: usize,
377}
378
379#[derive(Debug, Serialize, Deserialize)]
380pub struct Tweet {
381 #[serde(rename = "id", with = "u64_but_str")]
382 pub id: u64,
383 #[serde(rename = "created_at", with = "time::serde::rfc3339")]
384 pub created: time::OffsetDateTime,
385 #[serde(rename = "public_metrics")]
386 pub metrics: PublicTweetMetrics,
387 }
390
391impl Tweet {
392 pub fn goodness(&self) -> usize {
393 self.metrics.likes
394 + 2 * self.metrics.retweets
395 + 3 * self.metrics.quotations
396 + self.metrics.replies / 2
397 }
398}
399
400#[derive(Debug, Serialize, Deserialize)]
401pub struct PublicUserMetrics {
402 #[serde(rename = "followers_count")]
403 pub followers: usize,
404 #[serde(rename = "following_count")]
405 pub following: usize,
406}
407
408#[derive(Debug, Serialize, Deserialize)]
409pub struct User {
410 pub username: String,
411 #[serde(rename = "public_metrics")]
412 pub metrics: PublicUserMetrics,
413}
414
415#[derive(Debug, Deserialize)]
417#[allow(dead_code)]
418struct Meta {
419 #[serde(rename = "result_count")]
420 results: usize,
421 #[serde(rename = "next_token")]
422 next: Option<String>,
423}
424
425mod u64_but_str {
426 use std::fmt::Display;
427
428 use serde::{de, Deserialize, Deserializer, Serializer};
429
430 pub fn serialize<T, S>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
431 where
432 T: Display,
433 S: Serializer,
434 {
435 serializer.collect_str(value)
436 }
437
438 pub fn deserialize<'de, D>(deserializer: D) -> Result<u64, D::Error>
439 where
440 D: Deserializer<'de>,
441 {
442 let s = String::deserialize(deserializer)?;
443 s.parse().map_err(de::Error::custom)
444 }
445}