robinpath_modules/modules/
oauth_mod.rs1use robinpath::{RobinPath, Value};
2use std::sync::{LazyLock, Mutex};
3use std::collections::HashMap;
4
5struct TokenStore {
6 access_token: String,
7 refresh_token: Option<String>,
8 token_type: String,
9 expires_at: Option<u64>,
10 scope: Option<String>,
11}
12
13static TOKENS: LazyLock<Mutex<HashMap<String, TokenStore>>> =
14 LazyLock::new(|| Mutex::new(HashMap::new()));
15
16pub fn register(rp: &mut RobinPath) {
17 rp.register_builtin("oauth.authUrl", |args, _| {
19 let base_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
20 let opts = args.get(1).cloned().unwrap_or(Value::Null);
21 if base_url.is_empty() {
22 return Err("Authorization URL is required".to_string());
23 }
24 let mut params = Vec::new();
25 if let Value::Object(obj) = &opts {
26 let response_type = obj.get("responseType")
27 .map(|v| v.to_display_string())
28 .unwrap_or_else(|| "code".to_string());
29 params.push(format!("response_type={}", url_encode(&response_type)));
30 if let Some(v) = obj.get("clientId") {
31 params.push(format!("client_id={}", url_encode(&v.to_display_string())));
32 }
33 if let Some(v) = obj.get("redirectUri") {
34 params.push(format!("redirect_uri={}", url_encode(&v.to_display_string())));
35 }
36 if let Some(v) = obj.get("scope") {
37 params.push(format!("scope={}", url_encode(&v.to_display_string())));
38 }
39 if let Some(v) = obj.get("state") {
40 params.push(format!("state={}", url_encode(&v.to_display_string())));
41 }
42 if let Some(v) = obj.get("codeChallenge") {
43 params.push(format!("code_challenge={}", url_encode(&v.to_display_string())));
44 }
45 if let Some(v) = obj.get("codeChallengeMethod") {
46 params.push(format!("code_challenge_method={}", url_encode(&v.to_display_string())));
47 }
48 if let Some(v) = obj.get("accessType") {
49 params.push(format!("access_type={}", url_encode(&v.to_display_string())));
50 }
51 if let Some(v) = obj.get("prompt") {
52 params.push(format!("prompt={}", url_encode(&v.to_display_string())));
53 }
54 }
55 let sep = if base_url.contains('?') { "&" } else { "?" };
56 Ok(Value::String(format!("{}{}{}", base_url, sep, params.join("&"))))
57 });
58
59 rp.register_builtin("oauth.pkceVerifier", |args, _| {
61 let length = args.first().map(|v| v.to_number() as usize).unwrap_or(64);
62 let clamped = length.max(43).min(128);
63 let bytes: Vec<u8> = random_bytes(clamped);
64 let encoded = base64url_encode(&bytes);
65 Ok(Value::String(encoded[..clamped.min(encoded.len())].to_string()))
66 });
67
68 rp.register_builtin("oauth.pkceChallenge", |args, _| {
70 let verifier = args.first().map(|v| v.to_display_string()).unwrap_or_default();
71 let method = args.get(1).map(|v| v.to_display_string()).unwrap_or_else(|| "S256".to_string());
72 let mut obj = indexmap::IndexMap::new();
73 if method == "S256" {
74 use sha2::Digest;
75 let hash = sha2::Sha256::digest(verifier.as_bytes());
76 let challenge = base64url_encode(&hash);
77 obj.insert("challenge".to_string(), Value::String(challenge));
78 obj.insert("method".to_string(), Value::String("S256".to_string()));
79 } else if method == "plain" {
80 obj.insert("challenge".to_string(), Value::String(verifier));
81 obj.insert("method".to_string(), Value::String("plain".to_string()));
82 } else {
83 return Err(format!("Unsupported PKCE method: {}. Use \"S256\" or \"plain\".", method));
84 }
85 Ok(Value::Object(obj))
86 });
87
88 rp.register_builtin("oauth.generateState", |args, _| {
90 let length = args.first().map(|v| v.to_number() as usize).unwrap_or(32);
91 let bytes = random_bytes(length);
92 let hex: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
93 Ok(Value::String(hex))
94 });
95
96 rp.register_builtin("oauth.storeToken", |args, _| {
98 let name = args.first().map(|v| v.to_display_string()).unwrap_or_default();
99 let data = args.get(1).cloned().unwrap_or(Value::Null);
100 if let Value::Object(obj) = &data {
101 let store = TokenStore {
102 access_token: obj.get("accessToken").map(|v| v.to_display_string()).unwrap_or_default(),
103 refresh_token: obj.get("refreshToken").map(|v| v.to_display_string()),
104 token_type: obj.get("tokenType").map(|v| v.to_display_string()).unwrap_or_else(|| "Bearer".to_string()),
105 expires_at: obj.get("expiresAt").map(|v| v.to_number() as u64),
106 scope: obj.get("scope").map(|v| v.to_display_string()),
107 };
108 TOKENS.lock().unwrap().insert(name, store);
109 Ok(Value::Bool(true))
110 } else {
111 Err("Token data must be an object".to_string())
112 }
113 });
114
115 rp.register_builtin("oauth.getToken", |args, _| {
117 let name = args.first().map(|v| v.to_display_string()).unwrap_or_default();
118 let tokens = TOKENS.lock().unwrap();
119 match tokens.get(&name) {
120 Some(store) => {
121 let mut obj = indexmap::IndexMap::new();
122 obj.insert("accessToken".to_string(), Value::String(store.access_token.clone()));
123 if let Some(ref rt) = store.refresh_token {
124 obj.insert("refreshToken".to_string(), Value::String(rt.clone()));
125 }
126 obj.insert("tokenType".to_string(), Value::String(store.token_type.clone()));
127 let now_ms = std::time::SystemTime::now()
128 .duration_since(std::time::UNIX_EPOCH)
129 .unwrap_or_default()
130 .as_millis() as u64;
131 let expired = store.expires_at.map(|ea| now_ms > ea).unwrap_or(false);
132 obj.insert("expired".to_string(), Value::Bool(expired));
133 if let Some(ea) = store.expires_at {
134 obj.insert("expiresAt".to_string(), Value::Number(ea as f64));
135 }
136 if let Some(ref s) = store.scope {
137 obj.insert("scope".to_string(), Value::String(s.clone()));
138 }
139 Ok(Value::Object(obj))
140 }
141 None => Ok(Value::Null),
142 }
143 });
144
145 rp.register_builtin("oauth.isExpired", |args, _| {
147 let name = args.first().map(|v| v.to_display_string()).unwrap_or_default();
148 let buffer_ms = args.get(1).map(|v| v.to_number() as u64).unwrap_or(60000);
149 let tokens = TOKENS.lock().unwrap();
150 match tokens.get(&name) {
151 Some(store) => {
152 if let Some(ea) = store.expires_at {
153 let now_ms = std::time::SystemTime::now()
154 .duration_since(std::time::UNIX_EPOCH)
155 .unwrap_or_default()
156 .as_millis() as u64;
157 Ok(Value::Bool(now_ms + buffer_ms > ea))
158 } else {
159 Ok(Value::Bool(false))
160 }
161 }
162 None => Ok(Value::Bool(true)),
163 }
164 });
165
166 rp.register_builtin("oauth.clearTokens", |args, _| {
168 let name = args.first().map(|v| v.to_display_string());
169 let mut tokens = TOKENS.lock().unwrap();
170 if let Some(n) = name {
171 Ok(Value::Bool(tokens.remove(&n).is_some()))
172 } else {
173 tokens.clear();
174 Ok(Value::Bool(true))
175 }
176 });
177
178 #[cfg(feature = "api")]
180 rp.register_builtin("oauth.exchangeCode", |args, _| {
181 let token_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
182 let opts = args.get(1).cloned().unwrap_or(Value::Null);
183 if token_url.is_empty() {
184 return Err("Token URL is required".to_string());
185 }
186 let mut params = Vec::new();
187 params.push(("grant_type".to_string(), "authorization_code".to_string()));
188 if let Value::Object(obj) = &opts {
189 if let Some(v) = obj.get("code") { params.push(("code".to_string(), v.to_display_string())); }
190 if let Some(v) = obj.get("clientId") { params.push(("client_id".to_string(), v.to_display_string())); }
191 if let Some(v) = obj.get("clientSecret") { params.push(("client_secret".to_string(), v.to_display_string())); }
192 if let Some(v) = obj.get("redirectUri") { params.push(("redirect_uri".to_string(), v.to_display_string())); }
193 if let Some(v) = obj.get("codeVerifier") { params.push(("code_verifier".to_string(), v.to_display_string())); }
194 }
195 do_token_request(&token_url, ¶ms, &opts)
196 });
197
198 #[cfg(feature = "api")]
200 rp.register_builtin("oauth.refreshToken", |args, _| {
201 let token_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
202 let opts = args.get(1).cloned().unwrap_or(Value::Null);
203 if token_url.is_empty() {
204 return Err("Token URL is required".to_string());
205 }
206 let refresh = if let Value::Object(obj) = &opts {
207 obj.get("refreshToken").map(|v| v.to_display_string())
208 .or_else(|| {
209 obj.get("name").and_then(|n| {
210 TOKENS.lock().unwrap().get(&n.to_display_string())
211 .and_then(|s| s.refresh_token.clone())
212 })
213 })
214 } else {
215 None
216 };
217 let refresh = refresh.ok_or("Refresh token is required")?;
218 let mut params = Vec::new();
219 params.push(("grant_type".to_string(), "refresh_token".to_string()));
220 params.push(("refresh_token".to_string(), refresh));
221 if let Value::Object(obj) = &opts {
222 if let Some(v) = obj.get("clientId") { params.push(("client_id".to_string(), v.to_display_string())); }
223 if let Some(v) = obj.get("clientSecret") { params.push(("client_secret".to_string(), v.to_display_string())); }
224 if let Some(v) = obj.get("scope") { params.push(("scope".to_string(), v.to_display_string())); }
225 }
226 do_token_request(&token_url, ¶ms, &opts)
227 });
228
229 #[cfg(feature = "api")]
231 rp.register_builtin("oauth.clientCredentials", |args, _| {
232 let token_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
233 let opts = args.get(1).cloned().unwrap_or(Value::Null);
234 if token_url.is_empty() {
235 return Err("Token URL is required".to_string());
236 }
237 let mut params = Vec::new();
238 params.push(("grant_type".to_string(), "client_credentials".to_string()));
239 if let Value::Object(obj) = &opts {
240 if let Some(v) = obj.get("clientId") { params.push(("client_id".to_string(), v.to_display_string())); }
241 if let Some(v) = obj.get("clientSecret") { params.push(("client_secret".to_string(), v.to_display_string())); }
242 if let Some(v) = obj.get("scope") { params.push(("scope".to_string(), v.to_display_string())); }
243 }
244 do_token_request(&token_url, ¶ms, &opts)
245 });
246}
247
248#[cfg(feature = "api")]
249fn do_token_request(url: &str, params: &[(String, String)], opts: &Value) -> Result<Value, String> {
250 let client = reqwest::blocking::Client::new();
251 let resp = client.post(url)
252 .header("Content-Type", "application/x-www-form-urlencoded")
253 .form(params)
254 .send()
255 .map_err(|e| format!("OAuth request error: {}", e))?;
256
257 let status = resp.status().as_u16();
258 let body_text = resp.text().map_err(|e| format!("body error: {}", e))?;
259 let json: serde_json::Value = serde_json::from_str(&body_text)
260 .map_err(|_| format!("Invalid JSON response: {}", body_text))?;
261
262 if status >= 400 {
263 let error = json.get("error").and_then(|v| v.as_str()).unwrap_or("unknown");
264 let desc = json.get("error_description").and_then(|v| v.as_str()).unwrap_or("");
265 return Err(format!("OAuth error: {} - {}", error, desc));
266 }
267
268 let access_token = json.get("access_token").and_then(|v| v.as_str()).unwrap_or("").to_string();
269 let refresh_token = json.get("refresh_token").and_then(|v| v.as_str()).map(|s| s.to_string());
270 let token_type = json.get("token_type").and_then(|v| v.as_str()).unwrap_or("Bearer").to_string();
271 let expires_in = json.get("expires_in").and_then(|v| v.as_u64());
272 let scope = json.get("scope").and_then(|v| v.as_str()).map(|s| s.to_string());
273
274 if let Value::Object(obj) = opts {
276 if let Some(name_val) = obj.get("name") {
277 let name = name_val.to_display_string();
278 let now_ms = std::time::SystemTime::now()
279 .duration_since(std::time::UNIX_EPOCH)
280 .unwrap_or_default()
281 .as_millis() as u64;
282 let store = TokenStore {
283 access_token: access_token.clone(),
284 refresh_token: refresh_token.clone(),
285 token_type: token_type.clone(),
286 expires_at: expires_in.map(|ei| now_ms + ei * 1000),
287 scope: scope.clone(),
288 };
289 TOKENS.lock().unwrap().insert(name, store);
290 }
291 }
292
293 let mut result = indexmap::IndexMap::new();
294 result.insert("accessToken".to_string(), Value::String(access_token));
295 if let Some(rt) = refresh_token {
296 result.insert("refreshToken".to_string(), Value::String(rt));
297 }
298 result.insert("tokenType".to_string(), Value::String(token_type));
299 if let Some(ei) = expires_in {
300 result.insert("expiresIn".to_string(), Value::Number(ei as f64));
301 }
302 if let Some(s) = scope {
303 result.insert("scope".to_string(), Value::String(s));
304 }
305 Ok(Value::Object(result))
306}
307
308fn random_bytes(n: usize) -> Vec<u8> {
309 use std::collections::hash_map::RandomState;
310 use std::hash::{BuildHasher, Hasher};
311 let mut bytes = Vec::with_capacity(n);
312 for i in 0..n {
313 let state = RandomState::new();
314 let mut hasher = state.build_hasher();
315 hasher.write_usize(i);
316 bytes.push((hasher.finish() & 0xFF) as u8);
317 }
318 bytes
319}
320
321fn base64url_encode(data: &[u8]) -> String {
322 use base64::Engine;
323 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
324}
325
326fn url_encode(s: &str) -> String {
327 let mut result = String::new();
328 for b in s.bytes() {
329 match b {
330 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
331 result.push(b as char);
332 }
333 _ => result.push_str(&format!("%{:02X}", b)),
334 }
335 }
336 result
337}