Skip to main content

authkestra_providers_discord/
lib.rs

1use async_trait::async_trait;
2use authkestra_core::{
3    error::AuthError,
4    state::{Identity, OAuthToken},
5    OAuthProvider,
6};
7use serde::Deserialize;
8use std::collections::HashMap;
9
10pub struct DiscordProvider {
11    client_id: String,
12    client_secret: String,
13    redirect_uri: String,
14    http_client: reqwest::Client,
15    token_url: String,
16    user_url: String,
17    revoke_url: String,
18}
19
20impl DiscordProvider {
21    pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self {
22        Self {
23            client_id,
24            client_secret,
25            redirect_uri,
26            http_client: reqwest::Client::new(),
27            token_url: "https://discord.com/api/oauth2/token".to_string(),
28            user_url: "https://discord.com/api/users/@me".to_string(),
29            revoke_url: "https://discord.com/api/oauth2/token/revoke".to_string(),
30        }
31    }
32
33    pub fn with_test_urls(
34        mut self,
35        token_url: String,
36        user_url: String,
37        revoke_url: String,
38    ) -> Self {
39        self.token_url = token_url;
40        self.user_url = user_url;
41        self.revoke_url = revoke_url;
42        self
43    }
44}
45
46#[derive(Deserialize)]
47struct DiscordAccessTokenResponse {
48    access_token: String,
49    token_type: String,
50    expires_in: Option<u64>,
51    refresh_token: Option<String>,
52    scope: Option<String>,
53    id_token: Option<String>,
54}
55
56#[derive(Deserialize)]
57struct DiscordUserResponse {
58    id: String,
59    username: String,
60    discriminator: String,
61    email: Option<String>,
62}
63
64#[async_trait]
65impl OAuthProvider for DiscordProvider {
66    fn provider_id(&self) -> &str {
67        "discord"
68    }
69
70    fn get_authorization_url(
71        &self,
72        state: &str,
73        scopes: &[&str],
74        code_challenge: Option<&str>,
75    ) -> String {
76        let scope_param = if scopes.is_empty() {
77            "identify email".to_string()
78        } else {
79            scopes.join(" ")
80        };
81
82        let mut url = format!(
83            "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&state={}&scope={}",
84            self.client_id, urlencoding::encode(&self.redirect_uri), state, urlencoding::encode(&scope_param)
85        );
86
87        if let Some(challenge) = code_challenge {
88            url.push_str(&format!(
89                "&code_challenge={}&code_challenge_method=S256",
90                challenge
91            ));
92        }
93
94        url
95    }
96
97    async fn exchange_code_for_identity(
98        &self,
99        code: &str,
100        code_verifier: Option<&str>,
101    ) -> Result<(Identity, OAuthToken), AuthError> {
102        // 1. Exchange code for access token
103        let mut params = vec![
104            ("client_id", self.client_id.clone()),
105            ("client_secret", self.client_secret.clone()),
106            ("grant_type", "authorization_code".to_string()),
107            ("code", code.to_string()),
108            ("redirect_uri", self.redirect_uri.clone()),
109        ];
110
111        if let Some(verifier) = code_verifier {
112            params.push(("code_verifier", verifier.to_string()));
113        }
114
115        let token_response = self
116            .http_client
117            .post(&self.token_url)
118            .form(&params)
119            .send()
120            .await
121            .map_err(|_| AuthError::Network)?
122            .json::<DiscordAccessTokenResponse>()
123            .await
124            .map_err(|e| AuthError::Provider(format!("Failed to parse token response: {}", e)))?;
125
126        // 2. Get user information
127        let user_response = self
128            .http_client
129            .get(&self.user_url)
130            .header(
131                "Authorization",
132                format!("Bearer {}", token_response.access_token),
133            )
134            .send()
135            .await
136            .map_err(|_| AuthError::Network)?
137            .json::<DiscordUserResponse>()
138            .await
139            .map_err(|e| AuthError::Provider(format!("Failed to parse user response: {}", e)))?;
140
141        // 3. Map to Identity
142        let username = if user_response.discriminator == "0" {
143            user_response.username
144        } else {
145            format!("{}#{}", user_response.username, user_response.discriminator)
146        };
147
148        let identity = Identity {
149            provider_id: "discord".to_string(),
150            external_id: user_response.id,
151            email: user_response.email,
152            username: Some(username),
153            attributes: HashMap::new(),
154        };
155
156        let token = OAuthToken {
157            access_token: token_response.access_token,
158            token_type: token_response.token_type,
159            expires_in: token_response.expires_in,
160            refresh_token: token_response.refresh_token,
161            scope: token_response.scope,
162            id_token: token_response.id_token,
163        };
164
165        Ok((identity, token))
166    }
167
168    async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken, AuthError> {
169        let token_response = self
170            .http_client
171            .post(&self.token_url)
172            .form(&[
173                ("client_id", &self.client_id),
174                ("client_secret", &self.client_secret),
175                ("grant_type", &"refresh_token".to_string()),
176                ("refresh_token", &refresh_token.to_string()),
177            ])
178            .send()
179            .await
180            .map_err(|_| AuthError::Network)?
181            .json::<DiscordAccessTokenResponse>()
182            .await
183            .map_err(|e| {
184                AuthError::Provider(format!("Failed to parse refresh token response: {}", e))
185            })?;
186
187        Ok(OAuthToken {
188            access_token: token_response.access_token,
189            token_type: token_response.token_type,
190            expires_in: token_response.expires_in,
191            refresh_token: token_response.refresh_token,
192            scope: token_response.scope,
193            id_token: token_response.id_token,
194        })
195    }
196
197    async fn revoke_token(&self, token: &str) -> Result<(), AuthError> {
198        let response = self
199            .http_client
200            .post(&self.revoke_url)
201            .form(&[
202                ("client_id", &self.client_id),
203                ("client_secret", &self.client_secret),
204                ("token", &token.to_string()),
205            ])
206            .send()
207            .await
208            .map_err(|_| AuthError::Network)?;
209
210        if response.status().is_success() {
211            Ok(())
212        } else {
213            let error_text = response
214                .text()
215                .await
216                .unwrap_or_else(|_| "Unknown error".to_string());
217            Err(AuthError::Provider(format!(
218                "Failed to revoke token: {}",
219                error_text
220            )))
221        }
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use wiremock::matchers::{method, path};
229    use wiremock::{Mock, MockServer, ResponseTemplate};
230
231    #[tokio::test]
232    async fn test_exchange_code_for_identity() {
233        let server = MockServer::start().await;
234        let token_url = format!("{}/api/oauth2/token", server.uri());
235        let user_url = format!("{}/api/users/@me", server.uri());
236
237        Mock::given(method("POST"))
238            .and(path("/api/oauth2/token"))
239            .respond_with(ResponseTemplate::new(200).set_body_json(
240                serde_json::json!({"access_token": "test_token", "token_type": "Bearer"}),
241            ))
242            .mount(&server)
243            .await;
244
245        Mock::given(method("GET"))
246            .and(path("/api/users/@me"))
247            .respond_with(ResponseTemplate::new(200)
248                .set_body_json(serde_json::json!({"id": "123456789", "username": "testuser", "discriminator": "0001", "email": "test@example.com"})))
249            .mount(&server)
250            .await;
251
252        let provider = DiscordProvider::new(
253            "client_id".to_string(),
254            "client_secret".to_string(),
255            "http://localhost/callback".to_string(),
256        )
257        .with_test_urls(
258            token_url,
259            user_url,
260            format!("{}/api/oauth2/token/revoke", server.uri()),
261        );
262
263        let (identity, token) = provider
264            .exchange_code_for_identity("test_code", None)
265            .await
266            .unwrap();
267
268        assert_eq!(identity.provider_id, "discord");
269        assert_eq!(identity.external_id, "123456789");
270        assert_eq!(identity.username, Some("testuser#0001".to_string()));
271        assert_eq!(identity.email, Some("test@example.com".to_string()));
272        assert_eq!(token.access_token, "test_token");
273    }
274}