dioxus_oauth/
oauth_client.rs

1use std::error::Error;
2
3use serde::{Deserialize, Serialize};
4use url::Url;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct OAuthClient<'a> {
8    pub client_id: &'a str,
9    pub client_secret: Option<&'a str>,
10    pub redirect_uri: &'a str,
11    pub scopes: Vec<String>,
12    pub token_url: &'a str,
13    pub auth_url: &'a str,
14    pub revoke_url: Option<&'a str>,
15    pub userinfo_url: Option<&'a str>,
16    pub openid_url: Option<&'a str>,
17}
18
19impl<'a> OAuthClient<'a> {
20    pub fn new(
21        client_id: &'a str,
22        redirect_uri: &'a str,
23        auth_url: &'a str,
24        token_url: &'a str,
25    ) -> Self {
26        OAuthClient {
27            client_id,
28            client_secret: None,
29            redirect_uri,
30            scopes: vec![],
31            token_url,
32            auth_url,
33            revoke_url: None,
34            userinfo_url: None,
35            openid_url: None,
36        }
37    }
38
39    pub fn set_client_secret(&mut self, client_secret: &'a str) -> Self {
40        self.client_secret = Some(client_secret);
41        self.clone()
42    }
43
44    pub fn set_openid_url(&mut self, openid_url: &'a str) -> Self {
45        self.openid_url = Some(openid_url);
46        self.clone()
47    }
48
49    pub fn add_scope(&mut self, scope: &'a str) -> Self {
50        self.scopes.push(scope.to_string());
51        self.clone()
52    }
53
54    pub async fn get_auth_code(&self) -> Result<String, Box<dyn Error>> {
55        let mut auth_url = Url::parse(self.auth_url)?;
56        auth_url
57            .query_pairs_mut()
58            .append_pair("response_type", "code")
59            .append_pair("client_id", self.client_id)
60            .append_pair("redirect_uri", self.redirect_uri);
61
62        if self.scopes.len() > 0 {
63            auth_url
64                .query_pairs_mut()
65                .append_pair("scope", self.scopes.join(" ").as_str());
66        }
67
68        let auth_url = auth_url.to_string();
69        tracing::debug!("auth_url: {:?}", auth_url);
70
71        #[cfg(feature = "web")]
72        {
73            use wasm_bindgen::JsCast;
74
75            let window = match web_sys::window() {
76                Some(window) => window,
77                None => {
78                    return Err("Window not found".into());
79                }
80            };
81            let w = match window.open_with_url_and_target_and_features(
82                &auth_url.to_string(),
83                "",
84                "popup",
85            ) {
86                Ok(Some(w)) => w,
87                Ok(None) => {
88                    return Err("Window not found".into());
89                }
90                Err(e) => {
91                    return Err(format!("Window open failed: {:?}", e).into());
92                }
93            };
94            tracing::debug!("Window opened: {:?}", w);
95            let location = match w.location().href() {
96                Ok(location) => location,
97                Err(e) => {
98                    return Err(format!("Location failed {:?}", e).into());
99                }
100            };
101            tracing::debug!("Location: {:?}", location);
102            let code = std::sync::Arc::new(std::sync::RwLock::new(String::new()));
103
104            let promise = web_sys::js_sys::Promise::new(&mut |resolve, _reject| {
105                let code_arc = std::sync::Arc::clone(&code);
106
107                let on_message_callback = wasm_bindgen::prelude::Closure::wrap(Box::new(
108                    move |event: web_sys::MessageEvent| {
109                        if let Some(data) = event.data().as_string() {
110                            if data.as_str().starts_with("code=") {
111                                let code = data.as_str().replace("code=", "");
112                                *code_arc.write().unwrap() = code.clone();
113                                tracing::debug!("Code received: {:?}", code);
114                                resolve
115                                    .call1(&wasm_bindgen::JsValue::NULL, &event.data())
116                                    .unwrap();
117                            }
118                        }
119                    },
120                )
121                    as Box<dyn FnMut(web_sys::MessageEvent)>);
122
123                window.set_onmessage(Some(on_message_callback.as_ref().unchecked_ref()));
124                on_message_callback.forget();
125            });
126
127            let _ = wasm_bindgen_futures::JsFuture::from(promise).await;
128            tracing::debug!("oauth login finished");
129
130            let code = code.read().unwrap();
131            tracing::debug!("Code(received): {:?}", code);
132
133            return Ok(code.clone());
134        }
135
136        #[allow(unreachable_code)]
137        Ok(auth_url)
138    }
139
140    pub async fn get_token(&self, code: &str) -> Result<TokenResponse, Box<dyn Error>> {
141        let mut params = std::collections::HashMap::new();
142        params.insert("grant_type", "authorization_code");
143        params.insert("client_id", self.client_id);
144        params.insert("redirect_uri", self.redirect_uri);
145        params.insert("code", code);
146
147        let client = reqwest::Client::new();
148        let res = client.post(self.token_url).form(&params).send().await?;
149
150        Ok(res.json().await?)
151    }
152
153    pub async fn get_openid(&self, id_token: &str) -> Result<OpenIdResponse, Box<dyn Error>> {
154        if self.openid_url.is_none() {
155            return Err("openid_url is not set".into());
156        }
157
158        let mut params = std::collections::HashMap::new();
159        params.insert("id_token", id_token);
160
161        let client = reqwest::Client::new();
162        let res = client
163            .post(self.openid_url.unwrap())
164            .form(&params)
165            .send()
166            .await?;
167
168        Ok(res.json().await?)
169    }
170}
171
172#[derive(serde::Deserialize, Debug)]
173pub struct TokenResponse {
174    pub access_token: String,
175    pub id_token: String,
176    pub token_type: String,
177    pub expires_in: u64,
178}
179
180#[derive(serde::Deserialize, Debug)]
181pub struct OpenIdResponse {
182    pub iss: String,
183    pub sub: String,
184    pub nickname: Option<String>,
185    pub picture: Option<String>,
186    pub email: Option<String>,
187}