Skip to main content

agent_tools_interface/core/
http.rs

1use reqwest::Client;
2use serde_json::Value;
3use std::collections::HashMap;
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6use thiserror::Error;
7
8use crate::core::auth_generator::{self, AuthCache, GenContext};
9use crate::core::keyring::Keyring;
10use crate::core::manifest::{AuthType, HttpMethod, Provider, Tool};
11
12#[derive(Error, Debug)]
13pub enum HttpError {
14    #[error("API key '{0}' not found in keyring")]
15    MissingKey(String),
16    #[error("HTTP request failed: {0}")]
17    Request(#[from] reqwest::Error),
18    #[error("API error ({status}): {body}")]
19    ApiError { status: u16, body: String },
20    #[error("Failed to parse response as JSON: {0}")]
21    ParseError(String),
22    #[error("OAuth2 token exchange failed: {0}")]
23    Oauth2Error(String),
24    #[error("Invalid path parameter '{key}': value '{value}' contains forbidden characters")]
25    InvalidPathParam { key: String, value: String },
26    #[error("Header '{0}' is not allowed as a user-supplied parameter")]
27    DeniedHeader(String),
28    #[error("SSRF protection: URL '{0}' targets a private/internal network address")]
29    SsrfBlocked(String),
30    #[error("OAuth2 token URL must use HTTPS: '{0}'")]
31    InsecureTokenUrl(String),
32}
33
34/// Cached OAuth2 token: (access_token, expiry_instant)
35static OAUTH2_CACHE: std::sync::LazyLock<Mutex<HashMap<String, (String, Instant)>>> =
36    std::sync::LazyLock::new(|| Mutex::new(HashMap::new()));
37
38const DEFAULT_TIMEOUT_SECS: u64 = 60;
39
40/// Validate that a URL does not target private/internal network addresses (SSRF protection).
41/// Checks the hostname against deny-listed private IP ranges.
42///
43/// Enforcement is controlled by `ATI_SSRF_PROTECTION` env var:
44/// - "1" or "true": block requests to private addresses (default in proxy mode)
45/// - "warn": log a warning but allow the request
46/// - unset/other: allow the request (for local development/testing)
47pub fn validate_url_not_private(url: &str) -> Result<(), HttpError> {
48    let mode = std::env::var("ATI_SSRF_PROTECTION").unwrap_or_default();
49    let enforce = mode == "1" || mode.eq_ignore_ascii_case("true");
50    let warn_only = mode.eq_ignore_ascii_case("warn");
51
52    if !enforce && !warn_only {
53        return Ok(());
54    }
55    // Parse just the host part
56    let host = url
57        .strip_prefix("http://")
58        .or_else(|| url.strip_prefix("https://"))
59        .unwrap_or(url)
60        .split('/')
61        .next()
62        .unwrap_or("")
63        .split(':')
64        .next()
65        .unwrap_or("");
66
67    if host.is_empty() {
68        return Ok(());
69    }
70
71    let mut is_private = false;
72
73    // Check literal IP addresses
74    if let Ok(ip) = host.parse::<std::net::Ipv4Addr>() {
75        if ip.is_loopback()           // 127.0.0.0/8
76            || ip.is_private()        // 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
77            || ip.is_link_local()     // 169.254.0.0/16
78            || ip.is_unspecified()    // 0.0.0.0
79            || (ip.octets()[0] == 100 && ip.octets()[1] >= 64 && ip.octets()[1] <= 127)
80        // CGNAT 100.64.0.0/10
81        {
82            is_private = true;
83        }
84    }
85
86    // Check common internal hostnames
87    let host_lower = host.to_lowercase();
88    if host_lower == "localhost"
89        || host_lower == "metadata.google.internal"
90        || host_lower.ends_with(".internal")
91        || host_lower.ends_with(".local")
92    {
93        is_private = true;
94    }
95
96    if is_private {
97        if warn_only {
98            eprintln!("Warning: SSRF protection — URL targets private address: {url}");
99            return Ok(());
100        }
101        return Err(HttpError::SsrfBlocked(url.to_string()));
102    }
103
104    Ok(())
105}
106
107/// Headers that must never be set by agent-supplied parameters.
108/// Checked case-insensitively.
109const DENIED_HEADERS: &[&str] = &[
110    "authorization",
111    "host",
112    "cookie",
113    "set-cookie",
114    "content-type",
115    "content-length",
116    "transfer-encoding",
117    "connection",
118    "proxy-authorization",
119    "x-forwarded-for",
120    "x-forwarded-host",
121    "x-forwarded-proto",
122    "x-real-ip",
123];
124
125/// Check that classified header parameters don't contain denied headers.
126pub fn validate_headers(
127    headers: &HashMap<String, String>,
128    provider_auth_header: Option<&str>,
129) -> Result<(), HttpError> {
130    for key in headers.keys() {
131        let lower = key.to_lowercase();
132        if DENIED_HEADERS.contains(&lower.as_str()) {
133            return Err(HttpError::DeniedHeader(key.clone()));
134        }
135        if let Some(auth_header) = provider_auth_header {
136            if lower == auth_header.to_lowercase() {
137                return Err(HttpError::DeniedHeader(key.clone()));
138            }
139        }
140    }
141    Ok(())
142}
143
144/// Merge manifest defaults into the args map for any params not provided by caller.
145fn merge_defaults(tool: &Tool, args: &HashMap<String, Value>) -> HashMap<String, Value> {
146    let mut merged = args.clone();
147    if let Some(schema) = &tool.input_schema {
148        if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
149            for (key, prop_def) in props {
150                if !merged.contains_key(key) {
151                    if let Some(default_val) = prop_def.get("default") {
152                        // Skip empty arrays/objects as defaults — they add no value
153                        // and some APIs reject them (e.g. ClinicalTrials `sort=[]`).
154                        let dominated = match default_val {
155                            Value::Array(a) => a.is_empty(),
156                            Value::Object(o) => o.is_empty(),
157                            Value::Null => true,
158                            _ => false,
159                        };
160                        if !dominated {
161                            merged.insert(key.clone(), default_val.clone());
162                        }
163                    }
164                }
165            }
166        }
167    }
168    merged
169}
170
171/// How array query parameters should be serialized.
172#[derive(Debug, Clone, Copy, PartialEq)]
173enum CollectionFormat {
174    /// Repeated key: ?status=a&status=b
175    Multi,
176    /// Comma-separated: ?status=a,b
177    Csv,
178    /// Space-separated: ?status=a%20b
179    Ssv,
180    /// Pipe-separated: ?status=a|b
181    Pipes,
182}
183
184/// How the request body should be encoded.
185#[derive(Debug, Clone, Copy, PartialEq)]
186enum BodyEncoding {
187    Json,
188    Form,
189}
190
191/// Classified parameter maps, split by location.
192struct ClassifiedParams {
193    path: HashMap<String, String>,
194    query: HashMap<String, String>,
195    query_arrays: HashMap<String, (Vec<String>, CollectionFormat)>,
196    header: HashMap<String, String>,
197    body: HashMap<String, Value>,
198    body_encoding: BodyEncoding,
199}
200
201/// Classify parameters by their `x-ati-param-location` metadata in the input schema.
202/// If no location metadata exists (legacy TOML tools), returns None for legacy fallback.
203fn classify_params(tool: &Tool, args: &HashMap<String, Value>) -> Option<ClassifiedParams> {
204    let schema = tool.input_schema.as_ref()?;
205    let props = schema.get("properties")?.as_object()?;
206
207    // Check if any property has x-ati-param-location — if none do, this is a legacy tool
208    let has_locations = props
209        .values()
210        .any(|p| p.get("x-ati-param-location").is_some());
211
212    if !has_locations {
213        return None;
214    }
215
216    // Detect body encoding from schema-level metadata
217    let body_encoding = match schema.get("x-ati-body-encoding").and_then(|v| v.as_str()) {
218        Some("form") => BodyEncoding::Form,
219        _ => BodyEncoding::Json,
220    };
221
222    let mut classified = ClassifiedParams {
223        path: HashMap::new(),
224        query: HashMap::new(),
225        query_arrays: HashMap::new(),
226        header: HashMap::new(),
227        body: HashMap::new(),
228        body_encoding,
229    };
230
231    for (key, value) in args {
232        let prop_def = props.get(key);
233        let location = prop_def
234            .and_then(|p| p.get("x-ati-param-location"))
235            .and_then(|l| l.as_str())
236            .unwrap_or("body"); // default to body if no location specified
237
238        match location {
239            "path" => {
240                classified.path.insert(key.clone(), value_to_string(value));
241            }
242            "query" => {
243                // Check if this is an array value with a collection format
244                if let Value::Array(arr) = value {
245                    let cf_str = prop_def
246                        .and_then(|p| p.get("x-ati-collection-format"))
247                        .and_then(|v| v.as_str());
248                    let cf = match cf_str {
249                        Some("multi") => CollectionFormat::Multi,
250                        Some("csv") => CollectionFormat::Csv,
251                        Some("ssv") => CollectionFormat::Ssv,
252                        Some("pipes") => CollectionFormat::Pipes,
253                        _ => CollectionFormat::Multi, // default for arrays
254                    };
255                    let values: Vec<String> = arr.iter().map(|v| value_to_string(v)).collect();
256                    classified.query_arrays.insert(key.clone(), (values, cf));
257                } else {
258                    classified.query.insert(key.clone(), value_to_string(value));
259                }
260            }
261            "header" => {
262                classified
263                    .header
264                    .insert(key.clone(), value_to_string(value));
265            }
266            _ => {
267                classified.body.insert(key.clone(), value.clone());
268            }
269        }
270    }
271
272    Some(classified)
273}
274
275/// Substitute path parameters like `{petId}` in the endpoint template.
276/// Rejects values containing path traversal or URL-breaking characters,
277/// then percent-encodes the value before substitution.
278fn substitute_path_params(
279    endpoint: &str,
280    path_args: &HashMap<String, String>,
281) -> Result<String, HttpError> {
282    let mut result = endpoint.to_string();
283    for (key, value) in path_args {
284        if value.contains("..")
285            || value.contains('\\')
286            || value.contains('?')
287            || value.contains('#')
288            || value.contains('\0')
289        {
290            return Err(HttpError::InvalidPathParam {
291                key: key.clone(),
292                value: value.clone(),
293            });
294        }
295        let encoded = percent_encode_path_segment(value);
296        result = result.replace(&format!("{{{key}}}"), &encoded);
297    }
298    Ok(result)
299}
300
301/// Percent-encode a path segment value. Encodes everything except unreserved chars
302/// (RFC 3986 section 2.3: ALPHA / DIGIT / "-" / "." / "_" / "~").
303fn percent_encode_path_segment(s: &str) -> String {
304    let mut encoded = String::with_capacity(s.len());
305    for byte in s.bytes() {
306        match byte {
307            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' | b'/' => {
308                encoded.push(byte as char);
309            }
310            _ => {
311                encoded.push_str(&format!("%{:02X}", byte));
312            }
313        }
314    }
315    encoded
316}
317
318/// Convert a serde_json::Value to a URL-safe string.
319fn value_to_string(v: &Value) -> String {
320    match v {
321        Value::String(s) => s.clone(),
322        Value::Number(n) => n.to_string(),
323        Value::Bool(b) => b.to_string(),
324        Value::Null => String::new(),
325        other => other.to_string(),
326    }
327}
328
329/// Apply array query parameters to a request builder using the specified collection format.
330fn apply_query_arrays(
331    mut req: reqwest::RequestBuilder,
332    arrays: &HashMap<String, (Vec<String>, CollectionFormat)>,
333) -> reqwest::RequestBuilder {
334    for (key, (values, format)) in arrays {
335        match format {
336            CollectionFormat::Multi => {
337                // Repeated key: ?status=a&status=b
338                for val in values {
339                    req = req.query(&[(key.as_str(), val.as_str())]);
340                }
341            }
342            CollectionFormat::Csv => {
343                let joined = values.join(",");
344                req = req.query(&[(key.as_str(), joined.as_str())]);
345            }
346            CollectionFormat::Ssv => {
347                let joined = values.join(" ");
348                req = req.query(&[(key.as_str(), joined.as_str())]);
349            }
350            CollectionFormat::Pipes => {
351                let joined = values.join("|");
352                req = req.query(&[(key.as_str(), joined.as_str())]);
353            }
354        }
355    }
356    req
357}
358
359/// Execute an HTTP tool call against a provider's API.
360///
361/// Supports two modes:
362/// 1. **Location-aware** (OpenAPI tools): Parameters are classified by `x-ati-param-location`
363///    metadata in the input schema. Path params are substituted into the URL template,
364///    query params go to the query string, header params become request headers,
365///    and body params go to the JSON body.
366/// 2. **Legacy** (hand-written TOML tools): GET → all args as query params, POST/PUT/DELETE → JSON body.
367pub async fn execute_tool(
368    provider: &Provider,
369    tool: &Tool,
370    args: &HashMap<String, Value>,
371    keyring: &Keyring,
372) -> Result<Value, HttpError> {
373    execute_tool_with_gen(provider, tool, args, keyring, None, None).await
374}
375
376/// Execute an HTTP tool call, optionally using a dynamic auth generator.
377pub async fn execute_tool_with_gen(
378    provider: &Provider,
379    tool: &Tool,
380    args: &HashMap<String, Value>,
381    keyring: &Keyring,
382    gen_ctx: Option<&GenContext>,
383    auth_cache: Option<&AuthCache>,
384) -> Result<Value, HttpError> {
385    // SSRF protection: validate base_url is not targeting private networks
386    validate_url_not_private(&provider.base_url)?;
387
388    let client = Client::builder()
389        .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
390        .build()?;
391
392    // Merge manifest defaults into caller-provided args
393    let merged_args = merge_defaults(tool, args);
394
395    // Try location-aware classification (OpenAPI tools have x-ati-param-location)
396    let mut request = if let Some(classified) = classify_params(tool, &merged_args) {
397        // Validate headers against deny-list before injecting
398        validate_headers(&classified.header, provider.auth_header_name.as_deref())?;
399
400        // Location-aware mode: substitute path params, route by location
401        let resolved_endpoint = substitute_path_params(&tool.endpoint, &classified.path)?;
402        let url = format!(
403            "{}{}",
404            provider.base_url.trim_end_matches('/'),
405            resolved_endpoint
406        );
407
408        let mut req = match tool.method {
409            HttpMethod::Get | HttpMethod::Delete => {
410                let base_req = match tool.method {
411                    HttpMethod::Get => client.get(&url),
412                    HttpMethod::Delete => client.delete(&url),
413                    _ => unreachable!(),
414                };
415                // Query params to query string
416                let mut r = base_req;
417                for (k, v) in &classified.query {
418                    r = r.query(&[(k.as_str(), v.as_str())]);
419                }
420                r = apply_query_arrays(r, &classified.query_arrays);
421                r
422            }
423            HttpMethod::Post | HttpMethod::Put => {
424                let base_req = match tool.method {
425                    HttpMethod::Post => client.post(&url),
426                    HttpMethod::Put => client.put(&url),
427                    _ => unreachable!(),
428                };
429                // Body params: encode as JSON or form-urlencoded based on metadata
430                let mut r = if classified.body.is_empty() {
431                    base_req
432                } else {
433                    match classified.body_encoding {
434                        BodyEncoding::Json => base_req.json(&classified.body),
435                        BodyEncoding::Form => {
436                            let pairs: Vec<(String, String)> = classified
437                                .body
438                                .iter()
439                                .map(|(k, v)| (k.clone(), value_to_string(v)))
440                                .collect();
441                            base_req.form(&pairs)
442                        }
443                    }
444                };
445                // Query params still go to query string
446                for (k, v) in &classified.query {
447                    r = r.query(&[(k.as_str(), v.as_str())]);
448                }
449                r = apply_query_arrays(r, &classified.query_arrays);
450                r
451            }
452        };
453
454        // Inject classified header params
455        for (k, v) in &classified.header {
456            req = req.header(k.as_str(), v.as_str());
457        }
458
459        req
460    } else {
461        // Legacy mode: no x-ati-param-location metadata
462        let url = format!(
463            "{}{}",
464            provider.base_url.trim_end_matches('/'),
465            &tool.endpoint
466        );
467
468        match tool.method {
469            HttpMethod::Get => {
470                let mut req = client.get(&url);
471                for (k, v) in &merged_args {
472                    req = req.query(&[(k.as_str(), value_to_string(v))]);
473                }
474                req
475            }
476            HttpMethod::Post => client.post(&url).json(&merged_args),
477            HttpMethod::Put => client.put(&url).json(&merged_args),
478            HttpMethod::Delete => client.delete(&url).json(&merged_args),
479        }
480    };
481
482    // Inject authentication (generator takes priority over static keyring)
483    request = inject_auth(request, provider, keyring, gen_ctx, auth_cache).await?;
484
485    // Inject extra headers from provider config
486    for (header_name, header_value) in &provider.extra_headers {
487        request = request.header(header_name.as_str(), header_value.as_str());
488    }
489
490    // Execute request
491    let response = request.send().await?;
492    let status = response.status();
493
494    if !status.is_success() {
495        let body = response.text().await.unwrap_or_else(|_| "empty".into());
496        return Err(HttpError::ApiError {
497            status: status.as_u16(),
498            body,
499        });
500    }
501
502    // Parse response
503    let text = response.text().await?;
504    let value: Value = serde_json::from_str(&text).unwrap_or_else(|_| Value::String(text));
505
506    Ok(value)
507}
508
509/// Inject authentication headers/params based on provider auth_type.
510///
511/// If the provider has an `auth_generator`, the generator is run first to produce
512/// dynamic credentials. Otherwise, static keyring credentials are used.
513async fn inject_auth(
514    request: reqwest::RequestBuilder,
515    provider: &Provider,
516    keyring: &Keyring,
517    gen_ctx: Option<&GenContext>,
518    auth_cache: Option<&AuthCache>,
519) -> Result<reqwest::RequestBuilder, HttpError> {
520    // Dynamic auth generator takes priority
521    if let Some(gen) = &provider.auth_generator {
522        let default_ctx = GenContext::default();
523        let ctx = gen_ctx.unwrap_or(&default_ctx);
524        let default_cache = AuthCache::new();
525        let cache = auth_cache.unwrap_or(&default_cache);
526
527        let cred = auth_generator::generate(provider, gen, ctx, keyring, cache)
528            .await
529            .map_err(|e| HttpError::MissingKey(format!("auth_generator: {e}")))?;
530
531        // Inject primary credential based on auth_type
532        let mut req = match provider.auth_type {
533            AuthType::Bearer => request.bearer_auth(&cred.value),
534            AuthType::Header => {
535                let name = provider.auth_header_name.as_deref().unwrap_or("X-Api-Key");
536                let val = match &provider.auth_value_prefix {
537                    Some(pfx) => format!("{pfx}{}", cred.value),
538                    None => cred.value.clone(),
539                };
540                request.header(name, val)
541            }
542            AuthType::Query => {
543                let name = provider.auth_query_name.as_deref().unwrap_or("api_key");
544                request.query(&[(name, &cred.value)])
545            }
546            _ => request,
547        };
548        // Inject extra headers from JSON inject targets
549        for (name, value) in &cred.extra_headers {
550            req = req.header(name.as_str(), value.as_str());
551        }
552        return Ok(req);
553    }
554
555    match provider.auth_type {
556        AuthType::None => Ok(request),
557        AuthType::Bearer => {
558            let key_name = provider
559                .auth_key_name
560                .as_deref()
561                .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
562            let key_value = keyring
563                .get(key_name)
564                .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
565            Ok(request.bearer_auth(key_value))
566        }
567        AuthType::Header => {
568            let key_name = provider
569                .auth_key_name
570                .as_deref()
571                .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
572            let key_value = keyring
573                .get(key_name)
574                .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
575            let header_name = provider.auth_header_name.as_deref().unwrap_or("X-Api-Key");
576            let final_value = match &provider.auth_value_prefix {
577                Some(prefix) => format!("{}{}", prefix, key_value),
578                None => key_value.to_string(),
579            };
580            Ok(request.header(header_name, final_value))
581        }
582        AuthType::Query => {
583            let key_name = provider
584                .auth_key_name
585                .as_deref()
586                .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
587            let key_value = keyring
588                .get(key_name)
589                .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
590            let query_name = provider.auth_query_name.as_deref().unwrap_or("api_key");
591            Ok(request.query(&[(query_name, key_value)]))
592        }
593        AuthType::Basic => {
594            let key_name = provider
595                .auth_key_name
596                .as_deref()
597                .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
598            let key_value = keyring
599                .get(key_name)
600                .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
601            Ok(request.basic_auth(key_value, None::<&str>))
602        }
603        AuthType::Oauth2 => {
604            let access_token = get_oauth2_token(provider, keyring).await?;
605            Ok(request.bearer_auth(access_token))
606        }
607    }
608}
609
610/// Fetch (or return cached) OAuth2 access token via client_credentials grant.
611async fn get_oauth2_token(provider: &Provider, keyring: &Keyring) -> Result<String, HttpError> {
612    let cache_key = provider.name.clone();
613
614    // Check cache
615    {
616        let cache = OAUTH2_CACHE.lock().unwrap();
617        if let Some((token, expiry)) = cache.get(&cache_key) {
618            // Use cached token if it has at least 60s remaining
619            if Instant::now() + Duration::from_secs(60) < *expiry {
620                return Ok(token.clone());
621            }
622        }
623    }
624
625    // Token expired or not cached — exchange credentials
626    let client_id_key = provider
627        .auth_key_name
628        .as_deref()
629        .ok_or_else(|| HttpError::Oauth2Error("auth_key_name not set for OAuth2".into()))?;
630    let client_id = keyring
631        .get(client_id_key)
632        .ok_or_else(|| HttpError::MissingKey(client_id_key.into()))?;
633
634    let client_secret_key = provider
635        .auth_secret_name
636        .as_deref()
637        .ok_or_else(|| HttpError::Oauth2Error("auth_secret_name not set for OAuth2".into()))?;
638    let client_secret = keyring
639        .get(client_secret_key)
640        .ok_or_else(|| HttpError::MissingKey(client_secret_key.into()))?;
641
642    let token_url = match &provider.oauth2_token_url {
643        Some(url) if url.starts_with("http") => url.clone(),
644        Some(path) => format!("{}{}", provider.base_url.trim_end_matches('/'), path),
645        None => return Err(HttpError::Oauth2Error("oauth2_token_url not set".into())),
646    };
647
648    // Enforce HTTPS for OAuth2 token URLs (credentials are sent in plaintext otherwise)
649    if token_url.starts_with("http://") {
650        return Err(HttpError::InsecureTokenUrl(token_url));
651    }
652
653    let client = Client::builder().timeout(Duration::from_secs(15)).build()?;
654
655    // Two OAuth2 client_credentials modes:
656    // 1. Form body: client_id + client_secret in form data (Amadeus)
657    // 2. Basic Auth: base64(client_id:client_secret) in Authorization header (Sovos)
658    let response = if provider.oauth2_basic_auth {
659        client
660            .post(&token_url)
661            .basic_auth(client_id, Some(client_secret))
662            .form(&[("grant_type", "client_credentials")])
663            .send()
664            .await?
665    } else {
666        client
667            .post(&token_url)
668            .form(&[
669                ("grant_type", "client_credentials"),
670                ("client_id", client_id),
671                ("client_secret", client_secret),
672            ])
673            .send()
674            .await?
675    };
676
677    if !response.status().is_success() {
678        let status = response.status().as_u16();
679        let body = response.text().await.unwrap_or_default();
680        return Err(HttpError::Oauth2Error(format!(
681            "token exchange failed ({status}): {body}"
682        )));
683    }
684
685    let body: Value = response
686        .json()
687        .await
688        .map_err(|e| HttpError::Oauth2Error(format!("failed to parse token response: {e}")))?;
689
690    let access_token = body
691        .get("access_token")
692        .and_then(|v| v.as_str())
693        .ok_or_else(|| HttpError::Oauth2Error("no access_token in response".into()))?
694        .to_string();
695
696    let expires_in = body
697        .get("expires_in")
698        .and_then(|v| v.as_u64())
699        .unwrap_or(1799);
700
701    let expiry = Instant::now() + Duration::from_secs(expires_in);
702
703    // Cache the token
704    {
705        let mut cache = OAUTH2_CACHE.lock().unwrap();
706        cache.insert(cache_key, (access_token.clone(), expiry));
707    }
708
709    Ok(access_token)
710}
711
712#[cfg(test)]
713mod tests {
714    use super::*;
715
716    #[test]
717    fn test_substitute_path_params_normal() {
718        let mut args = HashMap::new();
719        args.insert("petId".to_string(), "123".to_string());
720        let result = substitute_path_params("/pet/{petId}", &args).unwrap();
721        assert_eq!(result, "/pet/123");
722    }
723
724    #[test]
725    fn test_substitute_path_params_rejects_dotdot() {
726        let mut args = HashMap::new();
727        args.insert("id".to_string(), "../admin".to_string());
728        assert!(substitute_path_params("/resource/{id}", &args).is_err());
729    }
730
731    #[test]
732    fn test_substitute_path_params_allows_slash() {
733        let mut args = HashMap::new();
734        args.insert("id".to_string(), "fal-ai/flux/dev".to_string());
735        let result = substitute_path_params("/resource/{id}", &args).unwrap();
736        assert_eq!(result, "/resource/fal-ai/flux/dev");
737    }
738
739    #[test]
740    fn test_substitute_path_params_rejects_backslash() {
741        let mut args = HashMap::new();
742        args.insert("id".to_string(), "foo\\bar".to_string());
743        assert!(substitute_path_params("/resource/{id}", &args).is_err());
744    }
745
746    #[test]
747    fn test_substitute_path_params_rejects_question() {
748        let mut args = HashMap::new();
749        args.insert("id".to_string(), "foo?bar=1".to_string());
750        assert!(substitute_path_params("/resource/{id}", &args).is_err());
751    }
752
753    #[test]
754    fn test_substitute_path_params_rejects_hash() {
755        let mut args = HashMap::new();
756        args.insert("id".to_string(), "foo#bar".to_string());
757        assert!(substitute_path_params("/resource/{id}", &args).is_err());
758    }
759
760    #[test]
761    fn test_substitute_path_params_rejects_null_byte() {
762        let mut args = HashMap::new();
763        args.insert("id".to_string(), "foo\0bar".to_string());
764        assert!(substitute_path_params("/resource/{id}", &args).is_err());
765    }
766
767    #[test]
768    fn test_substitute_path_params_encodes_special() {
769        let mut args = HashMap::new();
770        args.insert("name".to_string(), "hello world".to_string());
771        let result = substitute_path_params("/users/{name}", &args).unwrap();
772        assert_eq!(result, "/users/hello%20world");
773    }
774
775    #[test]
776    fn test_substitute_path_params_preserves_unreserved() {
777        let mut args = HashMap::new();
778        args.insert("id".to_string(), "abc-123_test.v2~draft".to_string());
779        let result = substitute_path_params("/items/{id}", &args).unwrap();
780        assert_eq!(result, "/items/abc-123_test.v2~draft");
781    }
782
783    #[test]
784    fn test_substitute_path_params_encodes_at_sign() {
785        let mut args = HashMap::new();
786        args.insert("user".to_string(), "user@domain".to_string());
787        let result = substitute_path_params("/profile/{user}", &args).unwrap();
788        assert_eq!(result, "/profile/user%40domain");
789    }
790
791    #[test]
792    fn test_percent_encode_path_segment_empty() {
793        assert_eq!(percent_encode_path_segment(""), "");
794    }
795
796    #[test]
797    fn test_percent_encode_path_segment_ascii_only() {
798        assert_eq!(percent_encode_path_segment("abc123"), "abc123");
799    }
800
801    #[test]
802    fn test_substitute_path_params_multiple() {
803        let mut args = HashMap::new();
804        args.insert("owner".to_string(), "acme".to_string());
805        args.insert("repo".to_string(), "widgets".to_string());
806        let result = substitute_path_params("/repos/{owner}/{repo}/issues", &args).unwrap();
807        assert_eq!(result, "/repos/acme/widgets/issues");
808    }
809
810    #[test]
811    fn test_substitute_path_params_no_placeholders() {
812        let args = HashMap::new();
813        let result = substitute_path_params("/health", &args).unwrap();
814        assert_eq!(result, "/health");
815    }
816}