Skip to main content

robinpath_modules/modules/
oauth_mod.rs

1use robinpath::{RobinPath, Value};
2use std::collections::HashMap;
3use std::sync::{Arc, Mutex};
4
5struct TokenStore {
6    access_token: String,
7    refresh_token: Option<String>,
8    token_type: String,
9    expires_at: Option<u64>,
10    scope: Option<String>,
11}
12
13pub fn register(rp: &mut RobinPath) {
14    let state = Arc::new(Mutex::new(HashMap::<String, TokenStore>::new()));
15
16    // oauth.authUrl baseUrl options → URL string
17    rp.register_builtin("oauth.authUrl", |args, _| {
18        let base_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
19        let opts = args.get(1).cloned().unwrap_or(Value::Null);
20        if base_url.is_empty() {
21            return Err("Authorization URL is required".to_string());
22        }
23        let mut params = Vec::new();
24        if let Value::Object(obj) = &opts {
25            let response_type = obj.get("responseType")
26                .map(|v| v.to_display_string())
27                .unwrap_or_else(|| "code".to_string());
28            params.push(format!("response_type={}", url_encode(&response_type)));
29            if let Some(v) = obj.get("clientId") {
30                params.push(format!("client_id={}", url_encode(&v.to_display_string())));
31            }
32            if let Some(v) = obj.get("redirectUri") {
33                params.push(format!("redirect_uri={}", url_encode(&v.to_display_string())));
34            }
35            if let Some(v) = obj.get("scope") {
36                params.push(format!("scope={}", url_encode(&v.to_display_string())));
37            }
38            if let Some(v) = obj.get("state") {
39                params.push(format!("state={}", url_encode(&v.to_display_string())));
40            }
41            if let Some(v) = obj.get("codeChallenge") {
42                params.push(format!("code_challenge={}", url_encode(&v.to_display_string())));
43            }
44            if let Some(v) = obj.get("codeChallengeMethod") {
45                params.push(format!("code_challenge_method={}", url_encode(&v.to_display_string())));
46            }
47            if let Some(v) = obj.get("accessType") {
48                params.push(format!("access_type={}", url_encode(&v.to_display_string())));
49            }
50            if let Some(v) = obj.get("prompt") {
51                params.push(format!("prompt={}", url_encode(&v.to_display_string())));
52            }
53        }
54        let sep = if base_url.contains('?') { "&" } else { "?" };
55        Ok(Value::String(format!("{}{}{}", base_url, sep, params.join("&"))))
56    });
57
58    // oauth.pkceVerifier length? → random base64url string
59    rp.register_builtin("oauth.pkceVerifier", |args, _| {
60        let length = args.first().map(|v| v.to_number() as usize).unwrap_or(64);
61        let clamped = length.max(43).min(128);
62        let bytes: Vec<u8> = random_bytes(clamped);
63        let encoded = base64url_encode(&bytes);
64        Ok(Value::String(encoded[..clamped.min(encoded.len())].to_string()))
65    });
66
67    // oauth.pkceChallenge verifier method? → {challenge, method}
68    rp.register_builtin("oauth.pkceChallenge", |args, _| {
69        let verifier = args.first().map(|v| v.to_display_string()).unwrap_or_default();
70        let method = args.get(1).map(|v| v.to_display_string()).unwrap_or_else(|| "S256".to_string());
71        let mut obj = indexmap::IndexMap::new();
72        if method == "S256" {
73            use sha2::Digest;
74            let hash = sha2::Sha256::digest(verifier.as_bytes());
75            let challenge = base64url_encode(&hash);
76            obj.insert("challenge".to_string(), Value::String(challenge));
77            obj.insert("method".to_string(), Value::String("S256".to_string()));
78        } else if method == "plain" {
79            obj.insert("challenge".to_string(), Value::String(verifier));
80            obj.insert("method".to_string(), Value::String("plain".to_string()));
81        } else {
82            return Err(format!("Unsupported PKCE method: {}. Use \"S256\" or \"plain\".", method));
83        }
84        Ok(Value::Object(obj))
85    });
86
87    // oauth.generateState length? → random hex string
88    rp.register_builtin("oauth.generateState", |args, _| {
89        let length = args.first().map(|v| v.to_number() as usize).unwrap_or(32);
90        let bytes = random_bytes(length);
91        let hex: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
92        Ok(Value::String(hex))
93    });
94
95    // oauth.storeToken name tokenData → stores a token
96    let s = state.clone();
97    rp.register_builtin("oauth.storeToken", move |args, _| {
98        let name = args.first().map(|v| v.to_display_string()).unwrap_or_default();
99        let data = args.get(1).cloned().unwrap_or(Value::Null);
100        if let Value::Object(obj) = &data {
101            let store = TokenStore {
102                access_token: obj.get("accessToken").map(|v| v.to_display_string()).unwrap_or_default(),
103                refresh_token: obj.get("refreshToken").map(|v| v.to_display_string()),
104                token_type: obj.get("tokenType").map(|v| v.to_display_string()).unwrap_or_else(|| "Bearer".to_string()),
105                expires_at: obj.get("expiresAt").map(|v| v.to_number() as u64),
106                scope: obj.get("scope").map(|v| v.to_display_string()),
107            };
108            s.lock().unwrap().insert(name, store);
109            Ok(Value::Bool(true))
110        } else {
111            Err("Token data must be an object".to_string())
112        }
113    });
114
115    // oauth.getToken name → token object or null
116    let s = state.clone();
117    rp.register_builtin("oauth.getToken", move |args, _| {
118        let name = args.first().map(|v| v.to_display_string()).unwrap_or_default();
119        let tokens = s.lock().unwrap();
120        match tokens.get(&name) {
121            Some(store) => {
122                let mut obj = indexmap::IndexMap::new();
123                obj.insert("accessToken".to_string(), Value::String(store.access_token.clone()));
124                if let Some(ref rt) = store.refresh_token {
125                    obj.insert("refreshToken".to_string(), Value::String(rt.clone()));
126                }
127                obj.insert("tokenType".to_string(), Value::String(store.token_type.clone()));
128                let now_ms = std::time::SystemTime::now()
129                    .duration_since(std::time::UNIX_EPOCH)
130                    .unwrap_or_default()
131                    .as_millis() as u64;
132                let expired = store.expires_at.map(|ea| now_ms > ea).unwrap_or(false);
133                obj.insert("expired".to_string(), Value::Bool(expired));
134                if let Some(ea) = store.expires_at {
135                    obj.insert("expiresAt".to_string(), Value::Number(ea as f64));
136                }
137                if let Some(ref sc) = store.scope {
138                    obj.insert("scope".to_string(), Value::String(sc.clone()));
139                }
140                Ok(Value::Object(obj))
141            }
142            None => Ok(Value::Null),
143        }
144    });
145
146    // oauth.isExpired name bufferMs? → bool
147    let s = state.clone();
148    rp.register_builtin("oauth.isExpired", move |args, _| {
149        let name = args.first().map(|v| v.to_display_string()).unwrap_or_default();
150        let buffer_ms = args.get(1).map(|v| v.to_number() as u64).unwrap_or(60000);
151        let tokens = s.lock().unwrap();
152        match tokens.get(&name) {
153            Some(store) => {
154                if let Some(ea) = store.expires_at {
155                    let now_ms = std::time::SystemTime::now()
156                        .duration_since(std::time::UNIX_EPOCH)
157                        .unwrap_or_default()
158                        .as_millis() as u64;
159                    Ok(Value::Bool(now_ms + buffer_ms > ea))
160                } else {
161                    Ok(Value::Bool(false))
162                }
163            }
164            None => Ok(Value::Bool(true)),
165        }
166    });
167
168    // oauth.clearTokens name? → bool
169    let s = state.clone();
170    rp.register_builtin("oauth.clearTokens", move |args, _| {
171        let name = args.first().map(|v| v.to_display_string());
172        let mut tokens = s.lock().unwrap();
173        if let Some(n) = name {
174            Ok(Value::Bool(tokens.remove(&n).is_some()))
175        } else {
176            tokens.clear();
177            Ok(Value::Bool(true))
178        }
179    });
180
181    // oauth.exchangeCode tokenUrl options → {accessToken, ...} (requires reqwest)
182    #[cfg(feature = "api")]
183    {
184        let s = state.clone();
185        rp.register_builtin("oauth.exchangeCode", move |args, _| {
186            let token_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
187            let opts = args.get(1).cloned().unwrap_or(Value::Null);
188            if token_url.is_empty() {
189                return Err("Token URL is required".to_string());
190            }
191            let mut params = Vec::new();
192            params.push(("grant_type".to_string(), "authorization_code".to_string()));
193            if let Value::Object(obj) = &opts {
194                if let Some(v) = obj.get("code") { params.push(("code".to_string(), v.to_display_string())); }
195                if let Some(v) = obj.get("clientId") { params.push(("client_id".to_string(), v.to_display_string())); }
196                if let Some(v) = obj.get("clientSecret") { params.push(("client_secret".to_string(), v.to_display_string())); }
197                if let Some(v) = obj.get("redirectUri") { params.push(("redirect_uri".to_string(), v.to_display_string())); }
198                if let Some(v) = obj.get("codeVerifier") { params.push(("code_verifier".to_string(), v.to_display_string())); }
199            }
200            do_token_request(&token_url, &params, &opts, &s)
201        });
202    }
203
204    // oauth.refreshToken tokenUrl options → {accessToken, ...}
205    #[cfg(feature = "api")]
206    {
207        let s = state.clone();
208        rp.register_builtin("oauth.refreshToken", move |args, _| {
209            let token_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
210            let opts = args.get(1).cloned().unwrap_or(Value::Null);
211            if token_url.is_empty() {
212                return Err("Token URL is required".to_string());
213            }
214            let refresh = if let Value::Object(obj) = &opts {
215                obj.get("refreshToken").map(|v| v.to_display_string())
216                    .or_else(|| {
217                        obj.get("name").and_then(|n| {
218                            s.lock().unwrap().get(&n.to_display_string())
219                                .and_then(|st| st.refresh_token.clone())
220                        })
221                    })
222            } else {
223                None
224            };
225            let refresh = refresh.ok_or("Refresh token is required")?;
226            let mut params = Vec::new();
227            params.push(("grant_type".to_string(), "refresh_token".to_string()));
228            params.push(("refresh_token".to_string(), refresh));
229            if let Value::Object(obj) = &opts {
230                if let Some(v) = obj.get("clientId") { params.push(("client_id".to_string(), v.to_display_string())); }
231                if let Some(v) = obj.get("clientSecret") { params.push(("client_secret".to_string(), v.to_display_string())); }
232                if let Some(v) = obj.get("scope") { params.push(("scope".to_string(), v.to_display_string())); }
233            }
234            do_token_request(&token_url, &params, &opts, &s)
235        });
236    }
237
238    // oauth.clientCredentials tokenUrl options → {accessToken, ...}
239    #[cfg(feature = "api")]
240    {
241        let s = state.clone();
242        rp.register_builtin("oauth.clientCredentials", move |args, _| {
243            let token_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
244            let opts = args.get(1).cloned().unwrap_or(Value::Null);
245            if token_url.is_empty() {
246                return Err("Token URL is required".to_string());
247            }
248            let mut params = Vec::new();
249            params.push(("grant_type".to_string(), "client_credentials".to_string()));
250            if let Value::Object(obj) = &opts {
251                if let Some(v) = obj.get("clientId") { params.push(("client_id".to_string(), v.to_display_string())); }
252                if let Some(v) = obj.get("clientSecret") { params.push(("client_secret".to_string(), v.to_display_string())); }
253                if let Some(v) = obj.get("scope") { params.push(("scope".to_string(), v.to_display_string())); }
254            }
255            do_token_request(&token_url, &params, &opts, &s)
256        });
257    }
258}
259
260#[cfg(feature = "api")]
261fn do_token_request(
262    url: &str,
263    params: &[(String, String)],
264    opts: &Value,
265    state: &Arc<Mutex<HashMap<String, TokenStore>>>,
266) -> Result<Value, String> {
267    let client = reqwest::blocking::Client::new();
268    let resp = client.post(url)
269        .header("Content-Type", "application/x-www-form-urlencoded")
270        .form(params)
271        .send()
272        .map_err(|e| format!("OAuth request error: {}", e))?;
273
274    let status = resp.status().as_u16();
275    let body_text = resp.text().map_err(|e| format!("body error: {}", e))?;
276    let json: serde_json::Value = serde_json::from_str(&body_text)
277        .map_err(|_| format!("Invalid JSON response: {}", body_text))?;
278
279    if status >= 400 {
280        let error = json.get("error").and_then(|v| v.as_str()).unwrap_or("unknown");
281        let desc = json.get("error_description").and_then(|v| v.as_str()).unwrap_or("");
282        return Err(format!("OAuth error: {} - {}", error, desc));
283    }
284
285    let access_token = json.get("access_token").and_then(|v| v.as_str()).unwrap_or("").to_string();
286    let refresh_token = json.get("refresh_token").and_then(|v| v.as_str()).map(|s| s.to_string());
287    let token_type = json.get("token_type").and_then(|v| v.as_str()).unwrap_or("Bearer").to_string();
288    let expires_in = json.get("expires_in").and_then(|v| v.as_u64());
289    let scope = json.get("scope").and_then(|v| v.as_str()).map(|s| s.to_string());
290
291    // Store if name provided
292    if let Value::Object(obj) = opts {
293        if let Some(name_val) = obj.get("name") {
294            let name = name_val.to_display_string();
295            let now_ms = std::time::SystemTime::now()
296                .duration_since(std::time::UNIX_EPOCH)
297                .unwrap_or_default()
298                .as_millis() as u64;
299            let store = TokenStore {
300                access_token: access_token.clone(),
301                refresh_token: refresh_token.clone(),
302                token_type: token_type.clone(),
303                expires_at: expires_in.map(|ei| now_ms + ei * 1000),
304                scope: scope.clone(),
305            };
306            state.lock().unwrap().insert(name, store);
307        }
308    }
309
310    let mut result = indexmap::IndexMap::new();
311    result.insert("accessToken".to_string(), Value::String(access_token));
312    if let Some(rt) = refresh_token {
313        result.insert("refreshToken".to_string(), Value::String(rt));
314    }
315    result.insert("tokenType".to_string(), Value::String(token_type));
316    if let Some(ei) = expires_in {
317        result.insert("expiresIn".to_string(), Value::Number(ei as f64));
318    }
319    if let Some(s) = scope {
320        result.insert("scope".to_string(), Value::String(s));
321    }
322    Ok(Value::Object(result))
323}
324
325fn random_bytes(n: usize) -> Vec<u8> {
326    use std::collections::hash_map::RandomState;
327    use std::hash::{BuildHasher, Hasher};
328    let mut bytes = Vec::with_capacity(n);
329    for i in 0..n {
330        let state = RandomState::new();
331        let mut hasher = state.build_hasher();
332        hasher.write_usize(i);
333        bytes.push((hasher.finish() & 0xFF) as u8);
334    }
335    bytes
336}
337
338fn base64url_encode(data: &[u8]) -> String {
339    use base64::Engine;
340    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
341}
342
343fn url_encode(s: &str) -> String {
344    let mut result = String::new();
345    for b in s.bytes() {
346        match b {
347            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
348                result.push(b as char);
349            }
350            _ => result.push_str(&format!("%{:02X}", b)),
351        }
352    }
353    result
354}