Skip to main content

robinpath_modules/modules/
oauth_mod.rs

1use robinpath::{RobinPath, Value};
2use std::sync::{LazyLock, Mutex};
3use std::collections::HashMap;
4
5struct TokenStore {
6    access_token: String,
7    refresh_token: Option<String>,
8    token_type: String,
9    expires_at: Option<u64>,
10    scope: Option<String>,
11}
12
13static TOKENS: LazyLock<Mutex<HashMap<String, TokenStore>>> =
14    LazyLock::new(|| Mutex::new(HashMap::new()));
15
16pub fn register(rp: &mut RobinPath) {
17    // oauth.authUrl baseUrl options → URL string
18    rp.register_builtin("oauth.authUrl", |args, _| {
19        let base_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
20        let opts = args.get(1).cloned().unwrap_or(Value::Null);
21        if base_url.is_empty() {
22            return Err("Authorization URL is required".to_string());
23        }
24        let mut params = Vec::new();
25        if let Value::Object(obj) = &opts {
26            let response_type = obj.get("responseType")
27                .map(|v| v.to_display_string())
28                .unwrap_or_else(|| "code".to_string());
29            params.push(format!("response_type={}", url_encode(&response_type)));
30            if let Some(v) = obj.get("clientId") {
31                params.push(format!("client_id={}", url_encode(&v.to_display_string())));
32            }
33            if let Some(v) = obj.get("redirectUri") {
34                params.push(format!("redirect_uri={}", url_encode(&v.to_display_string())));
35            }
36            if let Some(v) = obj.get("scope") {
37                params.push(format!("scope={}", url_encode(&v.to_display_string())));
38            }
39            if let Some(v) = obj.get("state") {
40                params.push(format!("state={}", url_encode(&v.to_display_string())));
41            }
42            if let Some(v) = obj.get("codeChallenge") {
43                params.push(format!("code_challenge={}", url_encode(&v.to_display_string())));
44            }
45            if let Some(v) = obj.get("codeChallengeMethod") {
46                params.push(format!("code_challenge_method={}", url_encode(&v.to_display_string())));
47            }
48            if let Some(v) = obj.get("accessType") {
49                params.push(format!("access_type={}", url_encode(&v.to_display_string())));
50            }
51            if let Some(v) = obj.get("prompt") {
52                params.push(format!("prompt={}", url_encode(&v.to_display_string())));
53            }
54        }
55        let sep = if base_url.contains('?') { "&" } else { "?" };
56        Ok(Value::String(format!("{}{}{}", base_url, sep, params.join("&"))))
57    });
58
59    // oauth.pkceVerifier length? → random base64url string
60    rp.register_builtin("oauth.pkceVerifier", |args, _| {
61        let length = args.first().map(|v| v.to_number() as usize).unwrap_or(64);
62        let clamped = length.max(43).min(128);
63        let bytes: Vec<u8> = random_bytes(clamped);
64        let encoded = base64url_encode(&bytes);
65        Ok(Value::String(encoded[..clamped.min(encoded.len())].to_string()))
66    });
67
68    // oauth.pkceChallenge verifier method? → {challenge, method}
69    rp.register_builtin("oauth.pkceChallenge", |args, _| {
70        let verifier = args.first().map(|v| v.to_display_string()).unwrap_or_default();
71        let method = args.get(1).map(|v| v.to_display_string()).unwrap_or_else(|| "S256".to_string());
72        let mut obj = indexmap::IndexMap::new();
73        if method == "S256" {
74            use sha2::Digest;
75            let hash = sha2::Sha256::digest(verifier.as_bytes());
76            let challenge = base64url_encode(&hash);
77            obj.insert("challenge".to_string(), Value::String(challenge));
78            obj.insert("method".to_string(), Value::String("S256".to_string()));
79        } else if method == "plain" {
80            obj.insert("challenge".to_string(), Value::String(verifier));
81            obj.insert("method".to_string(), Value::String("plain".to_string()));
82        } else {
83            return Err(format!("Unsupported PKCE method: {}. Use \"S256\" or \"plain\".", method));
84        }
85        Ok(Value::Object(obj))
86    });
87
88    // oauth.generateState length? → random hex string
89    rp.register_builtin("oauth.generateState", |args, _| {
90        let length = args.first().map(|v| v.to_number() as usize).unwrap_or(32);
91        let bytes = random_bytes(length);
92        let hex: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
93        Ok(Value::String(hex))
94    });
95
96    // oauth.storeToken name tokenData → stores a token
97    rp.register_builtin("oauth.storeToken", |args, _| {
98        let name = args.first().map(|v| v.to_display_string()).unwrap_or_default();
99        let data = args.get(1).cloned().unwrap_or(Value::Null);
100        if let Value::Object(obj) = &data {
101            let store = TokenStore {
102                access_token: obj.get("accessToken").map(|v| v.to_display_string()).unwrap_or_default(),
103                refresh_token: obj.get("refreshToken").map(|v| v.to_display_string()),
104                token_type: obj.get("tokenType").map(|v| v.to_display_string()).unwrap_or_else(|| "Bearer".to_string()),
105                expires_at: obj.get("expiresAt").map(|v| v.to_number() as u64),
106                scope: obj.get("scope").map(|v| v.to_display_string()),
107            };
108            TOKENS.lock().unwrap().insert(name, store);
109            Ok(Value::Bool(true))
110        } else {
111            Err("Token data must be an object".to_string())
112        }
113    });
114
115    // oauth.getToken name → token object or null
116    rp.register_builtin("oauth.getToken", |args, _| {
117        let name = args.first().map(|v| v.to_display_string()).unwrap_or_default();
118        let tokens = TOKENS.lock().unwrap();
119        match tokens.get(&name) {
120            Some(store) => {
121                let mut obj = indexmap::IndexMap::new();
122                obj.insert("accessToken".to_string(), Value::String(store.access_token.clone()));
123                if let Some(ref rt) = store.refresh_token {
124                    obj.insert("refreshToken".to_string(), Value::String(rt.clone()));
125                }
126                obj.insert("tokenType".to_string(), Value::String(store.token_type.clone()));
127                let now_ms = std::time::SystemTime::now()
128                    .duration_since(std::time::UNIX_EPOCH)
129                    .unwrap_or_default()
130                    .as_millis() as u64;
131                let expired = store.expires_at.map(|ea| now_ms > ea).unwrap_or(false);
132                obj.insert("expired".to_string(), Value::Bool(expired));
133                if let Some(ea) = store.expires_at {
134                    obj.insert("expiresAt".to_string(), Value::Number(ea as f64));
135                }
136                if let Some(ref s) = store.scope {
137                    obj.insert("scope".to_string(), Value::String(s.clone()));
138                }
139                Ok(Value::Object(obj))
140            }
141            None => Ok(Value::Null),
142        }
143    });
144
145    // oauth.isExpired name bufferMs? → bool
146    rp.register_builtin("oauth.isExpired", |args, _| {
147        let name = args.first().map(|v| v.to_display_string()).unwrap_or_default();
148        let buffer_ms = args.get(1).map(|v| v.to_number() as u64).unwrap_or(60000);
149        let tokens = TOKENS.lock().unwrap();
150        match tokens.get(&name) {
151            Some(store) => {
152                if let Some(ea) = store.expires_at {
153                    let now_ms = std::time::SystemTime::now()
154                        .duration_since(std::time::UNIX_EPOCH)
155                        .unwrap_or_default()
156                        .as_millis() as u64;
157                    Ok(Value::Bool(now_ms + buffer_ms > ea))
158                } else {
159                    Ok(Value::Bool(false))
160                }
161            }
162            None => Ok(Value::Bool(true)),
163        }
164    });
165
166    // oauth.clearTokens name? → bool
167    rp.register_builtin("oauth.clearTokens", |args, _| {
168        let name = args.first().map(|v| v.to_display_string());
169        let mut tokens = TOKENS.lock().unwrap();
170        if let Some(n) = name {
171            Ok(Value::Bool(tokens.remove(&n).is_some()))
172        } else {
173            tokens.clear();
174            Ok(Value::Bool(true))
175        }
176    });
177
178    // oauth.exchangeCode tokenUrl options → {accessToken, ...} (requires reqwest)
179    #[cfg(feature = "api")]
180    rp.register_builtin("oauth.exchangeCode", |args, _| {
181        let token_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
182        let opts = args.get(1).cloned().unwrap_or(Value::Null);
183        if token_url.is_empty() {
184            return Err("Token URL is required".to_string());
185        }
186        let mut params = Vec::new();
187        params.push(("grant_type".to_string(), "authorization_code".to_string()));
188        if let Value::Object(obj) = &opts {
189            if let Some(v) = obj.get("code") { params.push(("code".to_string(), v.to_display_string())); }
190            if let Some(v) = obj.get("clientId") { params.push(("client_id".to_string(), v.to_display_string())); }
191            if let Some(v) = obj.get("clientSecret") { params.push(("client_secret".to_string(), v.to_display_string())); }
192            if let Some(v) = obj.get("redirectUri") { params.push(("redirect_uri".to_string(), v.to_display_string())); }
193            if let Some(v) = obj.get("codeVerifier") { params.push(("code_verifier".to_string(), v.to_display_string())); }
194        }
195        do_token_request(&token_url, &params, &opts)
196    });
197
198    // oauth.refreshToken tokenUrl options → {accessToken, ...}
199    #[cfg(feature = "api")]
200    rp.register_builtin("oauth.refreshToken", |args, _| {
201        let token_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
202        let opts = args.get(1).cloned().unwrap_or(Value::Null);
203        if token_url.is_empty() {
204            return Err("Token URL is required".to_string());
205        }
206        let refresh = if let Value::Object(obj) = &opts {
207            obj.get("refreshToken").map(|v| v.to_display_string())
208                .or_else(|| {
209                    obj.get("name").and_then(|n| {
210                        TOKENS.lock().unwrap().get(&n.to_display_string())
211                            .and_then(|s| s.refresh_token.clone())
212                    })
213                })
214        } else {
215            None
216        };
217        let refresh = refresh.ok_or("Refresh token is required")?;
218        let mut params = Vec::new();
219        params.push(("grant_type".to_string(), "refresh_token".to_string()));
220        params.push(("refresh_token".to_string(), refresh));
221        if let Value::Object(obj) = &opts {
222            if let Some(v) = obj.get("clientId") { params.push(("client_id".to_string(), v.to_display_string())); }
223            if let Some(v) = obj.get("clientSecret") { params.push(("client_secret".to_string(), v.to_display_string())); }
224            if let Some(v) = obj.get("scope") { params.push(("scope".to_string(), v.to_display_string())); }
225        }
226        do_token_request(&token_url, &params, &opts)
227    });
228
229    // oauth.clientCredentials tokenUrl options → {accessToken, ...}
230    #[cfg(feature = "api")]
231    rp.register_builtin("oauth.clientCredentials", |args, _| {
232        let token_url = args.first().map(|v| v.to_display_string()).unwrap_or_default();
233        let opts = args.get(1).cloned().unwrap_or(Value::Null);
234        if token_url.is_empty() {
235            return Err("Token URL is required".to_string());
236        }
237        let mut params = Vec::new();
238        params.push(("grant_type".to_string(), "client_credentials".to_string()));
239        if let Value::Object(obj) = &opts {
240            if let Some(v) = obj.get("clientId") { params.push(("client_id".to_string(), v.to_display_string())); }
241            if let Some(v) = obj.get("clientSecret") { params.push(("client_secret".to_string(), v.to_display_string())); }
242            if let Some(v) = obj.get("scope") { params.push(("scope".to_string(), v.to_display_string())); }
243        }
244        do_token_request(&token_url, &params, &opts)
245    });
246}
247
248#[cfg(feature = "api")]
249fn do_token_request(url: &str, params: &[(String, String)], opts: &Value) -> Result<Value, String> {
250    let client = reqwest::blocking::Client::new();
251    let resp = client.post(url)
252        .header("Content-Type", "application/x-www-form-urlencoded")
253        .form(params)
254        .send()
255        .map_err(|e| format!("OAuth request error: {}", e))?;
256
257    let status = resp.status().as_u16();
258    let body_text = resp.text().map_err(|e| format!("body error: {}", e))?;
259    let json: serde_json::Value = serde_json::from_str(&body_text)
260        .map_err(|_| format!("Invalid JSON response: {}", body_text))?;
261
262    if status >= 400 {
263        let error = json.get("error").and_then(|v| v.as_str()).unwrap_or("unknown");
264        let desc = json.get("error_description").and_then(|v| v.as_str()).unwrap_or("");
265        return Err(format!("OAuth error: {} - {}", error, desc));
266    }
267
268    let access_token = json.get("access_token").and_then(|v| v.as_str()).unwrap_or("").to_string();
269    let refresh_token = json.get("refresh_token").and_then(|v| v.as_str()).map(|s| s.to_string());
270    let token_type = json.get("token_type").and_then(|v| v.as_str()).unwrap_or("Bearer").to_string();
271    let expires_in = json.get("expires_in").and_then(|v| v.as_u64());
272    let scope = json.get("scope").and_then(|v| v.as_str()).map(|s| s.to_string());
273
274    // Store if name provided
275    if let Value::Object(obj) = opts {
276        if let Some(name_val) = obj.get("name") {
277            let name = name_val.to_display_string();
278            let now_ms = std::time::SystemTime::now()
279                .duration_since(std::time::UNIX_EPOCH)
280                .unwrap_or_default()
281                .as_millis() as u64;
282            let store = TokenStore {
283                access_token: access_token.clone(),
284                refresh_token: refresh_token.clone(),
285                token_type: token_type.clone(),
286                expires_at: expires_in.map(|ei| now_ms + ei * 1000),
287                scope: scope.clone(),
288            };
289            TOKENS.lock().unwrap().insert(name, store);
290        }
291    }
292
293    let mut result = indexmap::IndexMap::new();
294    result.insert("accessToken".to_string(), Value::String(access_token));
295    if let Some(rt) = refresh_token {
296        result.insert("refreshToken".to_string(), Value::String(rt));
297    }
298    result.insert("tokenType".to_string(), Value::String(token_type));
299    if let Some(ei) = expires_in {
300        result.insert("expiresIn".to_string(), Value::Number(ei as f64));
301    }
302    if let Some(s) = scope {
303        result.insert("scope".to_string(), Value::String(s));
304    }
305    Ok(Value::Object(result))
306}
307
308fn random_bytes(n: usize) -> Vec<u8> {
309    use std::collections::hash_map::RandomState;
310    use std::hash::{BuildHasher, Hasher};
311    let mut bytes = Vec::with_capacity(n);
312    for i in 0..n {
313        let state = RandomState::new();
314        let mut hasher = state.build_hasher();
315        hasher.write_usize(i);
316        bytes.push((hasher.finish() & 0xFF) as u8);
317    }
318    bytes
319}
320
321fn base64url_encode(data: &[u8]) -> String {
322    use base64::Engine;
323    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
324}
325
326fn url_encode(s: &str) -> String {
327    let mut result = String::new();
328    for b in s.bytes() {
329        match b {
330            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
331                result.push(b as char);
332            }
333            _ => result.push_str(&format!("%{:02X}", b)),
334        }
335    }
336    result
337}