tiktokapi_v2/
oauth.rs

1use crate::{
2    error::{Error, OAuthError},
3    options::{apply_options, make_url, TiktokOptions},
4};
5use base64::prelude::{Engine as _, BASE64_URL_SAFE_NO_PAD};
6use itertools::Itertools;
7use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
8use rand::Rng;
9use reqwest::header::CACHE_CONTROL;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13const AUTH_URL: &str = "https://www.tiktok.com/v2/auth/authorize/";
14const TOKEN_URL: &str = "/oauth/token/";
15const REVOKE_URL: &str = "/oauth/revoke/";
16
17pub enum TiktokScope {
18    ResearchAdlibBasic,
19    ResearchDataBasic,
20    UserInfoBasic,
21    UserInfoProfile,
22    UserInfoStats,
23    VideoList,
24    VideoPublish,
25    VideoUpload,
26}
27
28impl TiktokScope {
29    pub fn all() -> Vec<Self> {
30        vec![
31            Self::ResearchAdlibBasic,
32            Self::ResearchDataBasic,
33            Self::UserInfoBasic,
34            Self::UserInfoProfile,
35            Self::UserInfoStats,
36            Self::VideoList,
37            Self::VideoPublish,
38            Self::VideoUpload,
39        ]
40    }
41}
42
43impl std::fmt::Display for TiktokScope {
44    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
45        match self {
46            Self::ResearchAdlibBasic => write!(f, "research.adlib.basic"),
47            Self::ResearchDataBasic => write!(f, "research.data.basic"),
48            Self::UserInfoBasic => write!(f, "user.info.basic"),
49            Self::UserInfoProfile => write!(f, "user.info.profile"),
50            Self::UserInfoStats => write!(f, "user.info.stats"),
51            Self::VideoList => write!(f, "video.list"),
52            Self::VideoPublish => write!(f, "video.publish"),
53            Self::VideoUpload => write!(f, "video.upload"),
54        }
55    }
56}
57
58#[derive(Debug, Clone)]
59pub struct OAuthUrlResult {
60    pub oauth_url: String,
61    pub csrf_token: String,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct TokenResult {
66    pub open_id: String,
67    pub scope: String,
68    pub access_token: String,
69    pub expires_in: u64,
70    pub refresh_token: String,
71    pub refresh_expires_in: u64,
72    pub token_type: String,
73}
74
75pub struct TiktokOauth {
76    scopes: Vec<TiktokScope>,
77    client_key: String,
78    client_secret: String,
79    callback_url: String,
80    options: Option<TiktokOptions>,
81}
82
83impl TiktokOauth {
84    pub fn new(
85        client_key: &str,
86        client_secret: &str,
87        callback_url: &str,
88        scopes: Vec<TiktokScope>,
89    ) -> Self {
90        Self::new_with_options(client_key, client_secret, callback_url, scopes, None)
91    }
92
93    pub fn new_with_options(
94        client_key: &str,
95        client_secret: &str,
96        callback_url: &str,
97        scopes: Vec<TiktokScope>,
98        options: Option<TiktokOptions>,
99    ) -> Self {
100        Self {
101            callback_url: callback_url.to_owned(),
102            scopes,
103            client_key: client_key.to_owned(),
104            client_secret: client_secret.to_owned(),
105            options,
106        }
107    }
108
109    pub fn oauth_url(&self, state: Option<String>) -> OAuthUrlResult {
110        let csrf_token = state.unwrap_or(csrf_token());
111        let scope = self.scopes.iter().map(|it| it.to_string()).join(",");
112        let redirect_uri = utf8_percent_encode(&self.callback_url, NON_ALPHANUMERIC);
113        let oauth_url = format!(
114            "{}?client_key={}&response_type=code&scope={}&redirect_uri={}&state={}",
115            AUTH_URL, self.client_key, scope, redirect_uri, csrf_token
116        );
117        OAuthUrlResult {
118            oauth_url,
119            csrf_token,
120        }
121    }
122
123    pub async fn token(&self, code: &str) -> Result<TokenResult, Error> {
124        let mut form = HashMap::new();
125        form.insert("client_key", self.client_key.as_str());
126        form.insert("client_secret", self.client_secret.as_str());
127        form.insert("grant_type", "authorization_code");
128        form.insert("code", code);
129        form.insert("redirect_uri", self.callback_url.as_str());
130        execute_token(form, &self.options).await
131    }
132
133    pub async fn refresh(&self, refresh_token: &str) -> Result<TokenResult, Error> {
134        let mut form = HashMap::new();
135        form.insert("client_key", self.client_key.as_str());
136        form.insert("client_secret", self.client_secret.as_str());
137        form.insert("grant_type", "refresh_token");
138        form.insert("refresh_token", refresh_token);
139        execute_token(form, &self.options).await
140    }
141
142    pub async fn revoke(&self, access_token: &str) -> Result<(), Error> {
143        let mut form = HashMap::new();
144        form.insert("client_key", self.client_key.as_str());
145        form.insert("client_secret", self.client_secret.as_str());
146        form.insert("token", access_token);
147        let response = execute_send(REVOKE_URL, &form, &self.options).await?;
148        let status_code = response.status();
149        if status_code.is_success() {
150            Ok(())
151        } else {
152            let json = response.json().await?;
153            let token_error: OAuthError = serde_json::from_value(json)?;
154            Err(Error::OAuth(token_error, status_code))
155        }
156    }
157}
158
159async fn execute_send(
160    url: &str,
161    form: &HashMap<&str, &str>,
162    options: &Option<TiktokOptions>,
163) -> Result<reqwest::Response, reqwest::Error> {
164    let builder = reqwest::Client::new()
165        .post(make_url(url, options))
166        .header(CACHE_CONTROL, "no-cache")
167        .form(form);
168    apply_options(builder, options).send().await
169}
170
171async fn execute_token(
172    form: HashMap<&str, &str>,
173    options: &Option<TiktokOptions>,
174) -> Result<TokenResult, Error> {
175    let response = execute_send(TOKEN_URL, &form, options).await?;
176    let status_code = response.status();
177    let json = response.json().await?;
178    if status_code.is_success() {
179        let token_result: TokenResult = serde_json::from_value(json)?;
180        Ok(token_result)
181    } else {
182        let token_error: OAuthError = serde_json::from_value(json)?;
183        Err(Error::OAuth(token_error, status_code))
184    }
185}
186
187fn csrf_token() -> String {
188    let random_bytes: Vec<u8> = (0..16).map(|_| rand::thread_rng().gen::<u8>()).collect();
189    BASE64_URL_SAFE_NO_PAD.encode(random_bytes)
190}