robinpath_modules/modules/
oauth_mod.rs1use robinpath::{RobinPath, Value};
2use std::collections::HashMap;
3use std::sync::{Arc, Mutex};
4
5struct TokenStore {
6 access_token: String,
7 refresh_token: Option<String>,
8 token_type: String,
9 expires_at: Option<u64>,
10 scope: Option<String>,
11}
12
13pub fn register(rp: &mut RobinPath) {
14 let state = Arc::new(Mutex::new(HashMap::<String, TokenStore>::new()));
15
16 rp.register_builtin("oauth.authUrl", |args, _| {
18 let base_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
19 let opts = args.get(1).cloned().unwrap_or(Value::Null);
20 if base_url.is_empty() {
21 return Err("Authorization URL is required".to_string());
22 }
23 let mut params = Vec::new();
24 if let Value::Object(obj) = &opts {
25 let response_type = obj.get("responseType")
26 .map(|v| v.to_display_string())
27 .unwrap_or_else(|| "code".to_string());
28 params.push(format!("response_type={}", url_encode(&response_type)));
29 if let Some(v) = obj.get("clientId") {
30 params.push(format!("client_id={}", url_encode(&v.to_display_string())));
31 }
32 if let Some(v) = obj.get("redirectUri") {
33 params.push(format!("redirect_uri={}", url_encode(&v.to_display_string())));
34 }
35 if let Some(v) = obj.get("scope") {
36 params.push(format!("scope={}", url_encode(&v.to_display_string())));
37 }
38 if let Some(v) = obj.get("state") {
39 params.push(format!("state={}", url_encode(&v.to_display_string())));
40 }
41 if let Some(v) = obj.get("codeChallenge") {
42 params.push(format!("code_challenge={}", url_encode(&v.to_display_string())));
43 }
44 if let Some(v) = obj.get("codeChallengeMethod") {
45 params.push(format!("code_challenge_method={}", url_encode(&v.to_display_string())));
46 }
47 if let Some(v) = obj.get("accessType") {
48 params.push(format!("access_type={}", url_encode(&v.to_display_string())));
49 }
50 if let Some(v) = obj.get("prompt") {
51 params.push(format!("prompt={}", url_encode(&v.to_display_string())));
52 }
53 }
54 let sep = if base_url.contains('?') { "&" } else { "?" };
55 Ok(Value::String(format!("{}{}{}", base_url, sep, params.join("&"))))
56 });
57
58 rp.register_builtin("oauth.pkceVerifier", |args, _| {
60 let length = args.first().map(|v| v.to_number() as usize).unwrap_or(64);
61 let clamped = length.max(43).min(128);
62 let bytes: Vec<u8> = random_bytes(clamped);
63 let encoded = base64url_encode(&bytes);
64 Ok(Value::String(encoded[..clamped.min(encoded.len())].to_string()))
65 });
66
67 rp.register_builtin("oauth.pkceChallenge", |args, _| {
69 let verifier = args.first().map(|v| v.to_display_string()).unwrap_or_default();
70 let method = args.get(1).map(|v| v.to_display_string()).unwrap_or_else(|| "S256".to_string());
71 let mut obj = indexmap::IndexMap::new();
72 if method == "S256" {
73 use sha2::Digest;
74 let hash = sha2::Sha256::digest(verifier.as_bytes());
75 let challenge = base64url_encode(&hash);
76 obj.insert("challenge".to_string(), Value::String(challenge));
77 obj.insert("method".to_string(), Value::String("S256".to_string()));
78 } else if method == "plain" {
79 obj.insert("challenge".to_string(), Value::String(verifier));
80 obj.insert("method".to_string(), Value::String("plain".to_string()));
81 } else {
82 return Err(format!("Unsupported PKCE method: {}. Use \"S256\" or \"plain\".", method));
83 }
84 Ok(Value::Object(obj))
85 });
86
87 rp.register_builtin("oauth.generateState", |args, _| {
89 let length = args.first().map(|v| v.to_number() as usize).unwrap_or(32);
90 let bytes = random_bytes(length);
91 let hex: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
92 Ok(Value::String(hex))
93 });
94
95 let s = state.clone();
97 rp.register_builtin("oauth.storeToken", move |args, _| {
98 let name = args.first().map(|v| v.to_display_string()).unwrap_or_default();
99 let data = args.get(1).cloned().unwrap_or(Value::Null);
100 if let Value::Object(obj) = &data {
101 let store = TokenStore {
102 access_token: obj.get("accessToken").map(|v| v.to_display_string()).unwrap_or_default(),
103 refresh_token: obj.get("refreshToken").map(|v| v.to_display_string()),
104 token_type: obj.get("tokenType").map(|v| v.to_display_string()).unwrap_or_else(|| "Bearer".to_string()),
105 expires_at: obj.get("expiresAt").map(|v| v.to_number() as u64),
106 scope: obj.get("scope").map(|v| v.to_display_string()),
107 };
108 s.lock().unwrap().insert(name, store);
109 Ok(Value::Bool(true))
110 } else {
111 Err("Token data must be an object".to_string())
112 }
113 });
114
115 let s = state.clone();
117 rp.register_builtin("oauth.getToken", move |args, _| {
118 let name = args.first().map(|v| v.to_display_string()).unwrap_or_default();
119 let tokens = s.lock().unwrap();
120 match tokens.get(&name) {
121 Some(store) => {
122 let mut obj = indexmap::IndexMap::new();
123 obj.insert("accessToken".to_string(), Value::String(store.access_token.clone()));
124 if let Some(ref rt) = store.refresh_token {
125 obj.insert("refreshToken".to_string(), Value::String(rt.clone()));
126 }
127 obj.insert("tokenType".to_string(), Value::String(store.token_type.clone()));
128 let now_ms = std::time::SystemTime::now()
129 .duration_since(std::time::UNIX_EPOCH)
130 .unwrap_or_default()
131 .as_millis() as u64;
132 let expired = store.expires_at.map(|ea| now_ms > ea).unwrap_or(false);
133 obj.insert("expired".to_string(), Value::Bool(expired));
134 if let Some(ea) = store.expires_at {
135 obj.insert("expiresAt".to_string(), Value::Number(ea as f64));
136 }
137 if let Some(ref sc) = store.scope {
138 obj.insert("scope".to_string(), Value::String(sc.clone()));
139 }
140 Ok(Value::Object(obj))
141 }
142 None => Ok(Value::Null),
143 }
144 });
145
146 let s = state.clone();
148 rp.register_builtin("oauth.isExpired", move |args, _| {
149 let name = args.first().map(|v| v.to_display_string()).unwrap_or_default();
150 let buffer_ms = args.get(1).map(|v| v.to_number() as u64).unwrap_or(60000);
151 let tokens = s.lock().unwrap();
152 match tokens.get(&name) {
153 Some(store) => {
154 if let Some(ea) = store.expires_at {
155 let now_ms = std::time::SystemTime::now()
156 .duration_since(std::time::UNIX_EPOCH)
157 .unwrap_or_default()
158 .as_millis() as u64;
159 Ok(Value::Bool(now_ms + buffer_ms > ea))
160 } else {
161 Ok(Value::Bool(false))
162 }
163 }
164 None => Ok(Value::Bool(true)),
165 }
166 });
167
168 let s = state.clone();
170 rp.register_builtin("oauth.clearTokens", move |args, _| {
171 let name = args.first().map(|v| v.to_display_string());
172 let mut tokens = s.lock().unwrap();
173 if let Some(n) = name {
174 Ok(Value::Bool(tokens.remove(&n).is_some()))
175 } else {
176 tokens.clear();
177 Ok(Value::Bool(true))
178 }
179 });
180
181 #[cfg(feature = "api")]
183 {
184 let s = state.clone();
185 rp.register_builtin("oauth.exchangeCode", move |args, _| {
186 let token_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
187 let opts = args.get(1).cloned().unwrap_or(Value::Null);
188 if token_url.is_empty() {
189 return Err("Token URL is required".to_string());
190 }
191 let mut params = Vec::new();
192 params.push(("grant_type".to_string(), "authorization_code".to_string()));
193 if let Value::Object(obj) = &opts {
194 if let Some(v) = obj.get("code") { params.push(("code".to_string(), v.to_display_string())); }
195 if let Some(v) = obj.get("clientId") { params.push(("client_id".to_string(), v.to_display_string())); }
196 if let Some(v) = obj.get("clientSecret") { params.push(("client_secret".to_string(), v.to_display_string())); }
197 if let Some(v) = obj.get("redirectUri") { params.push(("redirect_uri".to_string(), v.to_display_string())); }
198 if let Some(v) = obj.get("codeVerifier") { params.push(("code_verifier".to_string(), v.to_display_string())); }
199 }
200 do_token_request(&token_url, ¶ms, &opts, &s)
201 });
202 }
203
204 #[cfg(feature = "api")]
206 {
207 let s = state.clone();
208 rp.register_builtin("oauth.refreshToken", move |args, _| {
209 let token_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
210 let opts = args.get(1).cloned().unwrap_or(Value::Null);
211 if token_url.is_empty() {
212 return Err("Token URL is required".to_string());
213 }
214 let refresh = if let Value::Object(obj) = &opts {
215 obj.get("refreshToken").map(|v| v.to_display_string())
216 .or_else(|| {
217 obj.get("name").and_then(|n| {
218 s.lock().unwrap().get(&n.to_display_string())
219 .and_then(|st| st.refresh_token.clone())
220 })
221 })
222 } else {
223 None
224 };
225 let refresh = refresh.ok_or("Refresh token is required")?;
226 let mut params = Vec::new();
227 params.push(("grant_type".to_string(), "refresh_token".to_string()));
228 params.push(("refresh_token".to_string(), refresh));
229 if let Value::Object(obj) = &opts {
230 if let Some(v) = obj.get("clientId") { params.push(("client_id".to_string(), v.to_display_string())); }
231 if let Some(v) = obj.get("clientSecret") { params.push(("client_secret".to_string(), v.to_display_string())); }
232 if let Some(v) = obj.get("scope") { params.push(("scope".to_string(), v.to_display_string())); }
233 }
234 do_token_request(&token_url, ¶ms, &opts, &s)
235 });
236 }
237
238 #[cfg(feature = "api")]
240 {
241 let s = state.clone();
242 rp.register_builtin("oauth.clientCredentials", move |args, _| {
243 let token_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
244 let opts = args.get(1).cloned().unwrap_or(Value::Null);
245 if token_url.is_empty() {
246 return Err("Token URL is required".to_string());
247 }
248 let mut params = Vec::new();
249 params.push(("grant_type".to_string(), "client_credentials".to_string()));
250 if let Value::Object(obj) = &opts {
251 if let Some(v) = obj.get("clientId") { params.push(("client_id".to_string(), v.to_display_string())); }
252 if let Some(v) = obj.get("clientSecret") { params.push(("client_secret".to_string(), v.to_display_string())); }
253 if let Some(v) = obj.get("scope") { params.push(("scope".to_string(), v.to_display_string())); }
254 }
255 do_token_request(&token_url, ¶ms, &opts, &s)
256 });
257 }
258}
259
260#[cfg(feature = "api")]
261fn do_token_request(
262 url: &str,
263 params: &[(String, String)],
264 opts: &Value,
265 state: &Arc<Mutex<HashMap<String, TokenStore>>>,
266) -> Result<Value, String> {
267 let client = reqwest::blocking::Client::new();
268 let resp = client.post(url)
269 .header("Content-Type", "application/x-www-form-urlencoded")
270 .form(params)
271 .send()
272 .map_err(|e| format!("OAuth request error: {}", e))?;
273
274 let status = resp.status().as_u16();
275 let body_text = resp.text().map_err(|e| format!("body error: {}", e))?;
276 let json: serde_json::Value = serde_json::from_str(&body_text)
277 .map_err(|_| format!("Invalid JSON response: {}", body_text))?;
278
279 if status >= 400 {
280 let error = json.get("error").and_then(|v| v.as_str()).unwrap_or("unknown");
281 let desc = json.get("error_description").and_then(|v| v.as_str()).unwrap_or("");
282 return Err(format!("OAuth error: {} - {}", error, desc));
283 }
284
285 let access_token = json.get("access_token").and_then(|v| v.as_str()).unwrap_or("").to_string();
286 let refresh_token = json.get("refresh_token").and_then(|v| v.as_str()).map(|s| s.to_string());
287 let token_type = json.get("token_type").and_then(|v| v.as_str()).unwrap_or("Bearer").to_string();
288 let expires_in = json.get("expires_in").and_then(|v| v.as_u64());
289 let scope = json.get("scope").and_then(|v| v.as_str()).map(|s| s.to_string());
290
291 if let Value::Object(obj) = opts {
293 if let Some(name_val) = obj.get("name") {
294 let name = name_val.to_display_string();
295 let now_ms = std::time::SystemTime::now()
296 .duration_since(std::time::UNIX_EPOCH)
297 .unwrap_or_default()
298 .as_millis() as u64;
299 let store = TokenStore {
300 access_token: access_token.clone(),
301 refresh_token: refresh_token.clone(),
302 token_type: token_type.clone(),
303 expires_at: expires_in.map(|ei| now_ms + ei * 1000),
304 scope: scope.clone(),
305 };
306 state.lock().unwrap().insert(name, store);
307 }
308 }
309
310 let mut result = indexmap::IndexMap::new();
311 result.insert("accessToken".to_string(), Value::String(access_token));
312 if let Some(rt) = refresh_token {
313 result.insert("refreshToken".to_string(), Value::String(rt));
314 }
315 result.insert("tokenType".to_string(), Value::String(token_type));
316 if let Some(ei) = expires_in {
317 result.insert("expiresIn".to_string(), Value::Number(ei as f64));
318 }
319 if let Some(s) = scope {
320 result.insert("scope".to_string(), Value::String(s));
321 }
322 Ok(Value::Object(result))
323}
324
325fn random_bytes(n: usize) -> Vec<u8> {
326 use std::collections::hash_map::RandomState;
327 use std::hash::{BuildHasher, Hasher};
328 let mut bytes = Vec::with_capacity(n);
329 for i in 0..n {
330 let state = RandomState::new();
331 let mut hasher = state.build_hasher();
332 hasher.write_usize(i);
333 bytes.push((hasher.finish() & 0xFF) as u8);
334 }
335 bytes
336}
337
338fn base64url_encode(data: &[u8]) -> String {
339 use base64::Engine;
340 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
341}
342
343fn url_encode(s: &str) -> String {
344 let mut result = String::new();
345 for b in s.bytes() {
346 match b {
347 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
348 result.push(b as char);
349 }
350 _ => result.push_str(&format!("%{:02X}", b)),
351 }
352 }
353 result
354}