authkestra_providers_discord/
lib.rs1use async_trait::async_trait;
2use authkestra_core::{
3 error::AuthError,
4 state::{Identity, OAuthToken},
5 OAuthProvider,
6};
7use serde::Deserialize;
8use std::collections::HashMap;
9
10pub struct DiscordProvider {
11 client_id: String,
12 client_secret: String,
13 redirect_uri: String,
14 http_client: reqwest::Client,
15 token_url: String,
16 user_url: String,
17 revoke_url: String,
18}
19
20impl DiscordProvider {
21 pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self {
22 Self {
23 client_id,
24 client_secret,
25 redirect_uri,
26 http_client: reqwest::Client::new(),
27 token_url: "https://discord.com/api/oauth2/token".to_string(),
28 user_url: "https://discord.com/api/users/@me".to_string(),
29 revoke_url: "https://discord.com/api/oauth2/token/revoke".to_string(),
30 }
31 }
32
33 pub fn with_test_urls(
34 mut self,
35 token_url: String,
36 user_url: String,
37 revoke_url: String,
38 ) -> Self {
39 self.token_url = token_url;
40 self.user_url = user_url;
41 self.revoke_url = revoke_url;
42 self
43 }
44}
45
46#[derive(Deserialize)]
47struct DiscordAccessTokenResponse {
48 access_token: String,
49 token_type: String,
50 expires_in: Option<u64>,
51 refresh_token: Option<String>,
52 scope: Option<String>,
53 id_token: Option<String>,
54}
55
56#[derive(Deserialize)]
57struct DiscordUserResponse {
58 id: String,
59 username: String,
60 discriminator: String,
61 email: Option<String>,
62}
63
64#[async_trait]
65impl OAuthProvider for DiscordProvider {
66 fn provider_id(&self) -> &str {
67 "discord"
68 }
69
70 fn get_authorization_url(
71 &self,
72 state: &str,
73 scopes: &[&str],
74 code_challenge: Option<&str>,
75 ) -> String {
76 let scope_param = if scopes.is_empty() {
77 "identify email".to_string()
78 } else {
79 scopes.join(" ")
80 };
81
82 let mut url = format!(
83 "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&state={}&scope={}",
84 self.client_id, urlencoding::encode(&self.redirect_uri), state, urlencoding::encode(&scope_param)
85 );
86
87 if let Some(challenge) = code_challenge {
88 url.push_str(&format!(
89 "&code_challenge={}&code_challenge_method=S256",
90 challenge
91 ));
92 }
93
94 url
95 }
96
97 async fn exchange_code_for_identity(
98 &self,
99 code: &str,
100 code_verifier: Option<&str>,
101 ) -> Result<(Identity, OAuthToken), AuthError> {
102 let mut params = vec![
104 ("client_id", self.client_id.clone()),
105 ("client_secret", self.client_secret.clone()),
106 ("grant_type", "authorization_code".to_string()),
107 ("code", code.to_string()),
108 ("redirect_uri", self.redirect_uri.clone()),
109 ];
110
111 if let Some(verifier) = code_verifier {
112 params.push(("code_verifier", verifier.to_string()));
113 }
114
115 let token_response = self
116 .http_client
117 .post(&self.token_url)
118 .form(¶ms)
119 .send()
120 .await
121 .map_err(|_| AuthError::Network)?
122 .json::<DiscordAccessTokenResponse>()
123 .await
124 .map_err(|e| AuthError::Provider(format!("Failed to parse token response: {}", e)))?;
125
126 let user_response = self
128 .http_client
129 .get(&self.user_url)
130 .header(
131 "Authorization",
132 format!("Bearer {}", token_response.access_token),
133 )
134 .send()
135 .await
136 .map_err(|_| AuthError::Network)?
137 .json::<DiscordUserResponse>()
138 .await
139 .map_err(|e| AuthError::Provider(format!("Failed to parse user response: {}", e)))?;
140
141 let username = if user_response.discriminator == "0" {
143 user_response.username
144 } else {
145 format!("{}#{}", user_response.username, user_response.discriminator)
146 };
147
148 let identity = Identity {
149 provider_id: "discord".to_string(),
150 external_id: user_response.id,
151 email: user_response.email,
152 username: Some(username),
153 attributes: HashMap::new(),
154 };
155
156 let token = OAuthToken {
157 access_token: token_response.access_token,
158 token_type: token_response.token_type,
159 expires_in: token_response.expires_in,
160 refresh_token: token_response.refresh_token,
161 scope: token_response.scope,
162 id_token: token_response.id_token,
163 };
164
165 Ok((identity, token))
166 }
167
168 async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken, AuthError> {
169 let token_response = self
170 .http_client
171 .post(&self.token_url)
172 .form(&[
173 ("client_id", &self.client_id),
174 ("client_secret", &self.client_secret),
175 ("grant_type", &"refresh_token".to_string()),
176 ("refresh_token", &refresh_token.to_string()),
177 ])
178 .send()
179 .await
180 .map_err(|_| AuthError::Network)?
181 .json::<DiscordAccessTokenResponse>()
182 .await
183 .map_err(|e| {
184 AuthError::Provider(format!("Failed to parse refresh token response: {}", e))
185 })?;
186
187 Ok(OAuthToken {
188 access_token: token_response.access_token,
189 token_type: token_response.token_type,
190 expires_in: token_response.expires_in,
191 refresh_token: token_response.refresh_token,
192 scope: token_response.scope,
193 id_token: token_response.id_token,
194 })
195 }
196
197 async fn revoke_token(&self, token: &str) -> Result<(), AuthError> {
198 let response = self
199 .http_client
200 .post(&self.revoke_url)
201 .form(&[
202 ("client_id", &self.client_id),
203 ("client_secret", &self.client_secret),
204 ("token", &token.to_string()),
205 ])
206 .send()
207 .await
208 .map_err(|_| AuthError::Network)?;
209
210 if response.status().is_success() {
211 Ok(())
212 } else {
213 let error_text = response
214 .text()
215 .await
216 .unwrap_or_else(|_| "Unknown error".to_string());
217 Err(AuthError::Provider(format!(
218 "Failed to revoke token: {}",
219 error_text
220 )))
221 }
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use wiremock::matchers::{method, path};
229 use wiremock::{Mock, MockServer, ResponseTemplate};
230
231 #[tokio::test]
232 async fn test_exchange_code_for_identity() {
233 let server = MockServer::start().await;
234 let token_url = format!("{}/api/oauth2/token", server.uri());
235 let user_url = format!("{}/api/users/@me", server.uri());
236
237 Mock::given(method("POST"))
238 .and(path("/api/oauth2/token"))
239 .respond_with(ResponseTemplate::new(200).set_body_json(
240 serde_json::json!({"access_token": "test_token", "token_type": "Bearer"}),
241 ))
242 .mount(&server)
243 .await;
244
245 Mock::given(method("GET"))
246 .and(path("/api/users/@me"))
247 .respond_with(ResponseTemplate::new(200)
248 .set_body_json(serde_json::json!({"id": "123456789", "username": "testuser", "discriminator": "0001", "email": "test@example.com"})))
249 .mount(&server)
250 .await;
251
252 let provider = DiscordProvider::new(
253 "client_id".to_string(),
254 "client_secret".to_string(),
255 "http://localhost/callback".to_string(),
256 )
257 .with_test_urls(
258 token_url,
259 user_url,
260 format!("{}/api/oauth2/token/revoke", server.uri()),
261 );
262
263 let (identity, token) = provider
264 .exchange_code_for_identity("test_code", None)
265 .await
266 .unwrap();
267
268 assert_eq!(identity.provider_id, "discord");
269 assert_eq!(identity.external_id, "123456789");
270 assert_eq!(identity.username, Some("testuser#0001".to_string()));
271 assert_eq!(identity.email, Some("test@example.com".to_string()));
272 assert_eq!(token.access_token, "test_token");
273 }
274}