Skip to main content

threads_rs/
auth.rs

1use std::collections::HashMap;
2
3use base64::Engine;
4use chrono::{DateTime, Utc};
5use rand::Rng;
6use serde::Deserialize;
7
8use crate::client::{Client, TokenInfo};
9use crate::error;
10use crate::http::RequestBody;
11
12// ---------------------------------------------------------------------------
13// OAuth response types
14// ---------------------------------------------------------------------------
15
16/// Response from the short-lived token exchange (`/oauth/access_token`).
17#[derive(Debug, Deserialize)]
18pub struct TokenResponse {
19    /// The OAuth access token.
20    pub access_token: String,
21    /// Token type (usually "bearer"). Not always returned by the API.
22    #[serde(default)]
23    pub token_type: Option<String>,
24    /// Token lifetime in seconds.
25    pub expires_in: Option<i64>,
26    /// App-scoped user ID.
27    pub user_id: Option<i64>,
28}
29
30/// Response from the long-lived token exchange (`/access_token`).
31#[derive(Debug, Deserialize)]
32pub struct LongLivedTokenResponse {
33    /// The long-lived access token.
34    pub access_token: String,
35    /// Token type (usually "bearer"). Not always returned by the API.
36    #[serde(default)]
37    pub token_type: Option<String>,
38    /// Token lifetime in seconds (typically 5184000 for 60 days).
39    pub expires_in: i64,
40}
41
42/// Response from the debug token endpoint (`/debug_token`).
43#[derive(Debug, Deserialize)]
44pub struct DebugTokenResponse {
45    /// Token introspection data.
46    pub data: DebugTokenData,
47}
48
49/// Inner payload of a debug-token response.
50#[derive(Debug, Deserialize)]
51pub struct DebugTokenData {
52    /// Whether the token is currently valid.
53    pub is_valid: bool,
54    /// Unix timestamp when the token expires.
55    pub expires_at: i64,
56    /// Unix timestamp when the token was issued.
57    pub issued_at: i64,
58    /// OAuth scopes granted to the token.
59    pub scopes: Vec<String>,
60    /// App-scoped user ID.
61    pub user_id: String,
62    /// Token type: "USER" or "APP".
63    #[serde(default, rename = "type")]
64    pub token_type: Option<String>,
65    /// Name of the application.
66    #[serde(default)]
67    pub application: Option<String>,
68    /// Unix timestamp when the app's data access expires.
69    #[serde(default)]
70    pub data_access_expires_at: Option<i64>,
71}
72
73/// Response from the app access token endpoint.
74#[derive(Debug, Deserialize)]
75pub struct AppAccessTokenResponse {
76    /// The app access token.
77    pub access_token: String,
78    /// Token type (usually "bearer").
79    pub token_type: String,
80}
81
82// ---------------------------------------------------------------------------
83// Helpers
84// ---------------------------------------------------------------------------
85
86/// Build the app access token shorthand string.
87///
88/// Returns `"TH|{client_id}|{client_secret}"` or an empty string if either is empty.
89fn app_access_token_shorthand(client_id: &str, client_secret: &str) -> String {
90    if client_id.is_empty() || client_secret.is_empty() {
91        return String::new();
92    }
93    format!("TH|{client_id}|{client_secret}")
94}
95
96/// Generate a cryptographically-random state parameter (base64url, 32 bytes).
97fn generate_state() -> String {
98    let bytes: [u8; 32] = rand::rng().random();
99    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
100}
101
102// ---------------------------------------------------------------------------
103// Auth methods on Client
104// ---------------------------------------------------------------------------
105
106impl Client {
107    /// Build the OAuth authorization URL that the user should visit.
108    ///
109    /// `scopes` overrides the scopes from the client config. Pass an empty
110    /// slice to use the config defaults.
111    /// Returns `(url, state)` — the caller must store `state` and verify it
112    /// matches the `state` query parameter on the OAuth callback to prevent CSRF.
113    pub fn get_auth_url(&self, scopes: &[String]) -> (String, String) {
114        let cfg = self.config();
115        let effective_scopes = if scopes.is_empty() {
116            &cfg.scopes
117        } else {
118            scopes
119        };
120
121        let scope = effective_scopes.join(",");
122        let state = generate_state();
123
124        let mut url = url::Url::parse("https://www.threads.net/oauth/authorize")
125            .expect("static URL is valid");
126
127        url.query_pairs_mut()
128            .append_pair("client_id", &cfg.client_id)
129            .append_pair("redirect_uri", &cfg.redirect_uri)
130            .append_pair("scope", &scope)
131            .append_pair("response_type", "code")
132            .append_pair("state", &state);
133
134        (url.into(), state)
135    }
136
137    /// Get an app access token using client credentials.
138    ///
139    /// This does NOT store the token in the client (matches Go behavior).
140    /// The caller should use the returned token as needed.
141    pub async fn get_app_access_token(&self) -> crate::Result<AppAccessTokenResponse> {
142        let cfg = self.config();
143
144        // SECURITY: The Graph API requires client_secret as a query parameter for
145        // app access token requests (GET /oauth/access_token). This means the secret
146        // appears in server/proxy access logs. Always use HTTPS and ensure log access
147        // is restricted.
148        let mut params = HashMap::new();
149        params.insert("client_id".into(), cfg.client_id.clone());
150        params.insert("client_secret".into(), cfg.client_secret.clone());
151        params.insert("grant_type".into(), "client_credentials".into());
152
153        let resp = self
154            .http_client
155            .get("/oauth/access_token", params, "")
156            .await?;
157
158        resp.json()
159    }
160
161    /// Get an app access token in shorthand format.
162    ///
163    /// Returns `"TH|{client_id}|{client_secret}"` or an empty string if
164    /// `client_id` or `client_secret` are empty.
165    pub fn get_app_access_token_shorthand(&self) -> String {
166        let cfg = self.config();
167        app_access_token_shorthand(&cfg.client_id, &cfg.client_secret)
168    }
169
170    /// Exchange an authorization code for a short-lived access token.
171    ///
172    /// On success the token is stored via `set_token_info`.
173    pub async fn exchange_code_for_token(&self, code: &str) -> crate::Result<()> {
174        let cfg = self.config().clone();
175
176        let mut form = HashMap::new();
177        form.insert("client_id".into(), cfg.client_id);
178        form.insert("client_secret".into(), cfg.client_secret);
179        form.insert("grant_type".into(), "authorization_code".into());
180        form.insert("redirect_uri".into(), cfg.redirect_uri);
181        form.insert("code".into(), code.to_owned());
182
183        let resp = self
184            .http_client
185            .post("/oauth/access_token", Some(RequestBody::Form(form)), "")
186            .await?;
187
188        let token_resp: TokenResponse = resp.json()?;
189
190        let expires_in = token_resp.expires_in.unwrap_or(3600);
191        let user_id = token_resp
192            .user_id
193            .map(|id| id.to_string())
194            .unwrap_or_default();
195
196        let token_info = TokenInfo {
197            access_token: token_resp.access_token,
198            token_type: token_resp
199                .token_type
200                .unwrap_or_else(|| "bearer".to_string()),
201            expires_at: Utc::now() + chrono::Duration::seconds(expires_in),
202            user_id,
203            created_at: Utc::now(),
204        };
205
206        self.set_token_info(token_info).await
207    }
208
209    /// Convert the current short-lived token into a long-lived token (60 days).
210    ///
211    /// Requires that the client already holds a valid short-lived token.
212    pub async fn get_long_lived_token(&self) -> crate::Result<()> {
213        let access_token = self.access_token().await;
214        if access_token.is_empty() {
215            return Err(error::new_authentication_error(
216                401,
217                "No access token available",
218                "Call exchange_code_for_token first",
219            ));
220        }
221
222        let cfg = self.config();
223
224        // SECURITY: The Graph API requires client_secret as a query parameter for
225        // long-lived token exchange (GET /access_token). This means the secret appears
226        // in server/proxy access logs. Always use HTTPS and ensure log access is restricted.
227        let mut params = HashMap::new();
228        params.insert("grant_type".into(), "th_exchange_token".into());
229        params.insert("client_secret".into(), cfg.client_secret.clone());
230        params.insert("access_token".into(), access_token.clone());
231
232        let resp = self
233            .http_client
234            .get("/access_token", params, &access_token)
235            .await?;
236
237        let long_resp: LongLivedTokenResponse = resp.json()?;
238
239        let user_id = self.user_id().await;
240
241        let token_info = TokenInfo {
242            access_token: long_resp.access_token,
243            token_type: long_resp.token_type.unwrap_or_else(|| "bearer".to_string()),
244            expires_at: Utc::now() + chrono::Duration::seconds(long_resp.expires_in),
245            user_id,
246            created_at: Utc::now(),
247        };
248
249        self.set_token_info(token_info).await
250    }
251
252    /// Refresh the current long-lived token, extending its expiry.
253    ///
254    /// The token must still be valid (not expired) to be refreshed.
255    pub async fn refresh_token(&self) -> crate::Result<()> {
256        let access_token = self.access_token().await;
257        if access_token.is_empty() {
258            return Err(error::new_authentication_error(
259                401,
260                "No access token available",
261                "Cannot refresh without a valid token",
262            ));
263        }
264
265        let mut params = HashMap::new();
266        params.insert("grant_type".into(), "th_refresh_token".into());
267        params.insert("access_token".into(), access_token.clone());
268
269        let resp = self
270            .http_client
271            .get("/refresh_access_token", params, &access_token)
272            .await?;
273
274        let long_resp: LongLivedTokenResponse = resp.json()?;
275
276        let user_id = self.user_id().await;
277
278        let token_info = TokenInfo {
279            access_token: long_resp.access_token,
280            token_type: long_resp.token_type.unwrap_or_else(|| "bearer".to_string()),
281            expires_at: Utc::now() + chrono::Duration::seconds(long_resp.expires_in),
282            user_id,
283            created_at: Utc::now(),
284        };
285
286        self.set_token_info(token_info).await
287    }
288
289    /// Inspect a token via the `/debug_token` endpoint.
290    pub async fn debug_token(&self, input_token: &str) -> crate::Result<DebugTokenResponse> {
291        let token = self.access_token().await;
292        if token.is_empty() {
293            return Err(crate::error::new_authentication_error(
294                401,
295                "Access token is required to call debug_token",
296                "",
297            ));
298        }
299
300        let mut params = HashMap::new();
301        params.insert("input_token".into(), input_token.to_owned());
302
303        let resp = self.http_client.get("/debug_token", params, &token).await?;
304
305        resp.json()
306    }
307
308    /// Validate the current token locally: non-empty and not expired.
309    pub async fn validate_token(&self) -> crate::Result<()> {
310        let state = self.get_token_info().await;
311        match state {
312            Some(info) => {
313                if info.access_token.is_empty() {
314                    return Err(error::new_authentication_error(401, "Token is empty", ""));
315                }
316                if Utc::now() > info.expires_at {
317                    return Err(error::new_authentication_error(
318                        401,
319                        "Token has expired",
320                        "",
321                    ));
322                }
323                Ok(())
324            }
325            None => Err(error::new_authentication_error(
326                401,
327                "No token available",
328                "",
329            )),
330        }
331    }
332
333    /// Validate the current token and auto-refresh if expired.
334    ///
335    /// Only attempts a refresh when the token exists but has expired.
336    /// Returns the original error for other failures (no token, empty token).
337    pub async fn ensure_valid_token(&self) -> crate::Result<()> {
338        match self.validate_token().await {
339            Ok(()) => Ok(()),
340            Err(e) => {
341                // Only refresh if we have a token that expired
342                if self.is_token_expired().await && self.get_token_info().await.is_some() {
343                    self.refresh_token().await
344                } else {
345                    Err(e)
346                }
347            }
348        }
349    }
350
351    /// Return debug information about the current token.
352    ///
353    /// The access token is masked (first 4 + last 4 characters shown).
354    pub async fn get_token_debug_info(&self) -> HashMap<String, String> {
355        let mut info = HashMap::new();
356        let state = self.get_token_info().await;
357        match state {
358            Some(token_info) => {
359                let masked = if token_info.access_token.len() > 8 {
360                    let len = token_info.access_token.len();
361                    format!(
362                        "{}...{}",
363                        &token_info.access_token[..4],
364                        &token_info.access_token[len - 4..]
365                    )
366                } else {
367                    "****".to_owned()
368                };
369                info.insert("access_token".into(), masked);
370                info.insert("token_type".into(), token_info.token_type.clone());
371                info.insert("expires_at".into(), token_info.expires_at.to_rfc3339());
372                info.insert("user_id".into(), token_info.user_id.clone());
373                info.insert("created_at".into(), token_info.created_at.to_rfc3339());
374                info.insert(
375                    "is_expired".into(),
376                    (Utc::now() > token_info.expires_at).to_string(),
377                );
378            }
379            None => {
380                info.insert("status".into(), "no_token".into());
381            }
382        }
383        info
384    }
385
386    /// Explicitly reload the token from storage.
387    pub async fn load_token_from_storage(&self) -> crate::Result<()> {
388        let loaded = self.token_storage.load().await?;
389        self.set_token_info(loaded).await
390    }
391
392    /// Store a token built from a previous `debug_token` response.
393    ///
394    /// Useful for bootstrapping the client from a known-valid token without
395    /// going through the full OAuth flow again.
396    pub async fn set_token_from_debug_info(
397        &self,
398        access_token: &str,
399        debug_resp: &DebugTokenResponse,
400    ) -> crate::Result<()> {
401        let data = &debug_resp.data;
402
403        if !data.is_valid {
404            return Err(error::new_authentication_error(
405                401,
406                "Cannot set token from invalid debug info: token is not valid",
407                "",
408            ));
409        }
410
411        let expires_at =
412            DateTime::<Utc>::from_timestamp(data.expires_at, 0).unwrap_or_else(Utc::now);
413
414        let created_at =
415            DateTime::<Utc>::from_timestamp(data.issued_at, 0).unwrap_or_else(Utc::now);
416
417        let token_info = TokenInfo {
418            access_token: access_token.to_owned(),
419            token_type: "bearer".into(),
420            expires_at,
421            user_id: data.user_id.clone(),
422            created_at,
423        };
424
425        self.set_token_info(token_info).await
426    }
427}
428
429// ---------------------------------------------------------------------------
430// Tests
431// ---------------------------------------------------------------------------
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use crate::client::Config;
437
438    fn test_config() -> Config {
439        Config::new(
440            "test-client-id",
441            "test-secret",
442            "https://example.com/callback",
443        )
444    }
445
446    #[test]
447    fn test_generate_state_unique() {
448        let a = generate_state();
449        let b = generate_state();
450        assert_ne!(a, b);
451        // base64url of 32 bytes = 43 chars (no padding)
452        assert_eq!(a.len(), 43);
453    }
454
455    #[test]
456    fn test_generate_state_is_valid_base64url() {
457        let s = generate_state();
458        let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
459            .decode(&s)
460            .expect("should be valid base64url");
461        assert_eq!(decoded.len(), 32);
462    }
463
464    #[tokio::test]
465    async fn test_get_auth_url_contains_required_params() {
466        let client = Client::new(test_config()).await.unwrap();
467        let (url, state) = client.get_auth_url(&[]);
468
469        assert!(url.starts_with("https://www.threads.net/oauth/authorize?"));
470        assert!(url.contains("client_id=test-client-id"));
471        assert!(url.contains("redirect_uri="));
472        assert!(url.contains("response_type=code"));
473        assert!(url.contains("state="));
474        assert!(url.contains("scope="));
475        assert!(
476            !state.is_empty(),
477            "state must be returned for CSRF verification"
478        );
479        assert!(url.contains(&format!("state={state}")));
480    }
481
482    #[tokio::test]
483    async fn test_get_auth_url_uses_custom_scopes() {
484        let client = Client::new(test_config()).await.unwrap();
485        let scopes = vec!["threads_basic".into(), "threads_manage_replies".into()];
486        let (url, _state) = client.get_auth_url(&scopes);
487
488        // comma-joined in the scope param
489        assert!(url.contains("scope=threads_basic%2Cthreads_manage_replies"));
490    }
491
492    #[tokio::test]
493    async fn test_get_auth_url_uses_config_scopes_when_empty() {
494        let client = Client::new(test_config()).await.unwrap();
495        let (url, _state) = client.get_auth_url(&[]);
496
497        // Config default includes threads_basic
498        assert!(url.contains("threads_basic"));
499    }
500
501    #[test]
502    fn test_token_response_deserialize() {
503        let json = r#"{
504            "access_token": "tok_abc",
505            "token_type": "bearer",
506            "expires_in": 3600,
507            "user_id": 12345
508        }"#;
509        let resp: TokenResponse = serde_json::from_str(json).unwrap();
510        assert_eq!(resp.access_token, "tok_abc");
511        assert_eq!(resp.token_type, Some("bearer".to_string()));
512        assert_eq!(resp.expires_in, Some(3600));
513        assert_eq!(resp.user_id, Some(12345));
514    }
515
516    #[test]
517    fn test_token_response_deserialize_optional_fields() {
518        let json = r#"{
519            "access_token": "tok_abc",
520            "token_type": "bearer"
521        }"#;
522        let resp: TokenResponse = serde_json::from_str(json).unwrap();
523        assert!(resp.expires_in.is_none());
524        assert!(resp.user_id.is_none());
525    }
526
527    #[test]
528    fn test_long_lived_token_response_deserialize() {
529        let json = r#"{
530            "access_token": "long_tok",
531            "token_type": "bearer",
532            "expires_in": 5184000
533        }"#;
534        let resp: LongLivedTokenResponse = serde_json::from_str(json).unwrap();
535        assert_eq!(resp.access_token, "long_tok");
536        assert_eq!(resp.expires_in, 5184000);
537    }
538
539    #[test]
540    fn test_debug_token_response_deserialize() {
541        let json = r#"{
542            "data": {
543                "is_valid": true,
544                "expires_at": 1700000000,
545                "issued_at": 1699900000,
546                "scopes": ["threads_basic", "threads_content_publish"],
547                "user_id": "987654"
548            }
549        }"#;
550        let resp: DebugTokenResponse = serde_json::from_str(json).unwrap();
551        assert!(resp.data.is_valid);
552        assert_eq!(resp.data.expires_at, 1700000000);
553        assert_eq!(resp.data.issued_at, 1699900000);
554        assert_eq!(resp.data.scopes.len(), 2);
555        assert_eq!(resp.data.user_id, "987654");
556    }
557
558    #[tokio::test]
559    async fn test_validate_token_no_token() {
560        let client = Client::new(test_config()).await.unwrap();
561        assert!(client.validate_token().await.is_err());
562    }
563
564    #[tokio::test]
565    async fn test_validate_token_valid() {
566        let client = Client::new(test_config()).await.unwrap();
567        let token = crate::client::TokenInfo {
568            access_token: "valid-tok".into(),
569            token_type: "Bearer".into(),
570            expires_at: Utc::now() + chrono::Duration::hours(1),
571            user_id: "u-1".into(),
572            created_at: Utc::now(),
573        };
574        client.set_token_info(token).await.unwrap();
575        assert!(client.validate_token().await.is_ok());
576    }
577
578    #[tokio::test]
579    async fn test_validate_token_expired() {
580        let client = Client::new(test_config()).await.unwrap();
581        let token = crate::client::TokenInfo {
582            access_token: "expired-tok".into(),
583            token_type: "Bearer".into(),
584            expires_at: Utc::now() - chrono::Duration::hours(1),
585            user_id: "u-1".into(),
586            created_at: Utc::now() - chrono::Duration::hours(2),
587        };
588        client.set_token_info(token).await.unwrap();
589        assert!(client.validate_token().await.is_err());
590    }
591
592    #[tokio::test]
593    async fn test_get_token_debug_info_no_token() {
594        let client = Client::new(test_config()).await.unwrap();
595        let info = client.get_token_debug_info().await;
596        assert_eq!(info.get("status").unwrap(), "no_token");
597    }
598
599    #[tokio::test]
600    async fn test_get_token_debug_info_with_token() {
601        let client = Client::new(test_config()).await.unwrap();
602        let token = crate::client::TokenInfo {
603            access_token: "abcdefghijklmnop".into(),
604            token_type: "Bearer".into(),
605            expires_at: Utc::now() + chrono::Duration::hours(1),
606            user_id: "u-1".into(),
607            created_at: Utc::now(),
608        };
609        client.set_token_info(token).await.unwrap();
610        let info = client.get_token_debug_info().await;
611        let masked = info.get("access_token").unwrap();
612        assert!(masked.starts_with("abcd"));
613        assert!(masked.ends_with("mnop"));
614        assert!(masked.contains("..."));
615        assert_eq!(info.get("user_id").unwrap(), "u-1");
616        assert_eq!(info.get("is_expired").unwrap(), "false");
617    }
618
619    #[tokio::test]
620    async fn test_load_token_from_storage_empty() {
621        let client = Client::new(test_config()).await.unwrap();
622        // No token stored — should error
623        assert!(client.load_token_from_storage().await.is_err());
624    }
625
626    #[test]
627    fn test_app_access_token_response_deserialize() {
628        let json = r#"{
629            "access_token": "app_tok_abc",
630            "token_type": "bearer"
631        }"#;
632        let resp: AppAccessTokenResponse = serde_json::from_str(json).unwrap();
633        assert_eq!(resp.access_token, "app_tok_abc");
634        assert_eq!(resp.token_type, "bearer");
635    }
636
637    #[tokio::test]
638    async fn test_get_app_access_token_shorthand() {
639        let client = Client::new(test_config()).await.unwrap();
640        let shorthand = client.get_app_access_token_shorthand();
641        assert_eq!(shorthand, "TH|test-client-id|test-secret");
642    }
643
644    #[test]
645    fn test_app_access_token_shorthand_empty_client_id() {
646        assert_eq!(app_access_token_shorthand("", "secret"), "");
647    }
648
649    #[test]
650    fn test_app_access_token_shorthand_empty_secret() {
651        assert_eq!(app_access_token_shorthand("id", ""), "");
652    }
653
654    #[test]
655    fn test_app_access_token_shorthand_both_empty() {
656        assert_eq!(app_access_token_shorthand("", ""), "");
657    }
658}