Skip to main content

ati/core/
http.rs

1use reqwest::Client;
2use serde_json::Value;
3use std::collections::HashMap;
4use std::net::ToSocketAddrs;
5use std::sync::Mutex;
6use std::time::{Duration, Instant};
7use thiserror::Error;
8
9use crate::core::auth_generator::{self, AuthCache, GenContext};
10use crate::core::keyring::Keyring;
11use crate::core::manifest::{AuthType, HttpMethod, Provider, Tool};
12
13#[derive(Error, Debug)]
14pub enum HttpError {
15    #[error("API key '{0}' not found in keyring")]
16    MissingKey(String),
17    #[error("HTTP request failed: {0}")]
18    Request(#[from] reqwest::Error),
19    #[error("API error ({status}): {body}")]
20    ApiError { status: u16, body: String },
21    #[error("Failed to parse response as JSON: {0}")]
22    ParseError(String),
23    #[error("OAuth2 token exchange failed: {0}")]
24    Oauth2Error(String),
25    #[error("Invalid path parameter '{key}': value '{value}' contains forbidden characters")]
26    InvalidPathParam { key: String, value: String },
27    #[error("Header '{0}' is not allowed as a user-supplied parameter")]
28    DeniedHeader(String),
29    #[error("SSRF protection: URL '{0}' targets a private/internal network address")]
30    SsrfBlocked(String),
31    #[error("OAuth2 token URL must use HTTPS: '{0}'")]
32    InsecureTokenUrl(String),
33}
34
35/// Cached OAuth2 token: (access_token, expiry_instant)
36static OAUTH2_CACHE: std::sync::LazyLock<Mutex<HashMap<String, (String, Instant)>>> =
37    std::sync::LazyLock::new(|| Mutex::new(HashMap::new()));
38
39const DEFAULT_TIMEOUT_SECS: u64 = 60;
40
41/// Validate that a URL does not target private/internal network addresses (SSRF protection).
42/// Checks the hostname against deny-listed private IP ranges.
43///
44/// Enforcement is controlled by `ATI_SSRF_PROTECTION` env var:
45/// - "1" or "true": block requests to private addresses (default in proxy mode)
46/// - "warn": log a warning but allow the request
47/// - unset/other: allow the request (for local development/testing)
48pub fn validate_url_not_private(url: &str) -> Result<(), HttpError> {
49    let mode = std::env::var("ATI_SSRF_PROTECTION").unwrap_or_default();
50    let enforce = mode == "1" || mode.eq_ignore_ascii_case("true");
51    let warn_only = mode.eq_ignore_ascii_case("warn");
52
53    if !enforce && !warn_only {
54        return Ok(());
55    }
56    let parsed = match reqwest::Url::parse(url) {
57        Ok(parsed) => parsed,
58        Err(_) => return Ok(()),
59    };
60    let host = parsed.host_str().unwrap_or("");
61    let port = parsed.port_or_known_default().unwrap_or(80);
62    let ip_host = host.trim_matches(['[', ']']);
63
64    if host.is_empty() {
65        return Ok(());
66    }
67
68    // Check common internal hostnames
69    let host_lower = host.to_lowercase();
70    let mut is_private = host_lower == "localhost"
71        || host_lower == "metadata.google.internal"
72        || host_lower.ends_with(".internal")
73        || host_lower.ends_with(".local");
74
75    if !is_private {
76        if let Ok(ip) = ip_host.parse::<std::net::IpAddr>() {
77            is_private = is_private_ip(ip);
78        } else if let Ok(addrs) = (ip_host, port).to_socket_addrs() {
79            is_private = addrs.into_iter().any(|addr| is_private_ip(addr.ip()));
80        }
81    }
82
83    if is_private {
84        if warn_only {
85            tracing::warn!(url, "SSRF protection — URL targets private address");
86            return Ok(());
87        }
88        return Err(HttpError::SsrfBlocked(url.to_string()));
89    }
90
91    Ok(())
92}
93
94/// Headers that must never be set by agent-supplied parameters.
95/// Checked case-insensitively.
96const DENIED_HEADERS: &[&str] = &[
97    "authorization",
98    "host",
99    "cookie",
100    "set-cookie",
101    "content-type",
102    "content-length",
103    "transfer-encoding",
104    "connection",
105    "proxy-authorization",
106    "x-forwarded-for",
107    "x-forwarded-host",
108    "x-forwarded-proto",
109    "x-real-ip",
110];
111
112/// Check that classified header parameters don't contain denied headers.
113pub fn validate_headers(
114    headers: &HashMap<String, String>,
115    provider_auth_header: Option<&str>,
116) -> Result<(), HttpError> {
117    for key in headers.keys() {
118        let lower = key.to_lowercase();
119        if DENIED_HEADERS.contains(&lower.as_str()) {
120            return Err(HttpError::DeniedHeader(key.clone()));
121        }
122        if let Some(auth_header) = provider_auth_header {
123            if lower == auth_header.to_lowercase() {
124                return Err(HttpError::DeniedHeader(key.clone()));
125            }
126        }
127    }
128    Ok(())
129}
130
131/// Merge manifest defaults into the args map for any params not provided by caller.
132fn merge_defaults(tool: &Tool, args: &HashMap<String, Value>) -> HashMap<String, Value> {
133    let mut merged = args.clone();
134    if let Some(schema) = &tool.input_schema {
135        if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
136            for (key, prop_def) in props {
137                if !merged.contains_key(key) {
138                    if let Some(default_val) = prop_def.get("default") {
139                        // Skip empty arrays/objects as defaults — they add no value
140                        // and some APIs reject them (e.g. ClinicalTrials `sort=[]`).
141                        let dominated = match default_val {
142                            Value::Array(a) => a.is_empty(),
143                            Value::Object(o) => o.is_empty(),
144                            Value::Null => true,
145                            _ => false,
146                        };
147                        if !dominated {
148                            merged.insert(key.clone(), default_val.clone());
149                        }
150                    }
151                }
152            }
153        }
154    }
155    merged
156}
157
158/// How array query parameters should be serialized.
159#[derive(Debug, Clone, Copy, PartialEq)]
160enum CollectionFormat {
161    /// Repeated key: ?status=a&status=b
162    Multi,
163    /// Comma-separated: ?status=a,b
164    Csv,
165    /// Space-separated: ?status=a%20b
166    Ssv,
167    /// Pipe-separated: ?status=a|b
168    Pipes,
169}
170
171/// How the request body should be encoded.
172#[derive(Debug, Clone, Copy, PartialEq)]
173enum BodyEncoding {
174    Json,
175    Form,
176}
177
178/// Classified parameter maps, split by location.
179struct ClassifiedParams {
180    path: HashMap<String, String>,
181    query: HashMap<String, String>,
182    query_arrays: HashMap<String, (Vec<String>, CollectionFormat)>,
183    header: HashMap<String, String>,
184    body: HashMap<String, Value>,
185    body_encoding: BodyEncoding,
186}
187
188/// Classify parameters by their `x-ati-param-location` metadata in the input schema.
189/// If no location metadata exists (legacy TOML tools), returns None for legacy fallback.
190fn classify_params(tool: &Tool, args: &HashMap<String, Value>) -> Option<ClassifiedParams> {
191    let schema = tool.input_schema.as_ref()?;
192    let props = schema.get("properties")?.as_object()?;
193
194    // Check if any property has x-ati-param-location — if none do, this is a legacy tool
195    let has_locations = props
196        .values()
197        .any(|p| p.get("x-ati-param-location").is_some());
198
199    if !has_locations {
200        return None;
201    }
202
203    // Detect body encoding from schema-level metadata
204    let body_encoding = match schema.get("x-ati-body-encoding").and_then(|v| v.as_str()) {
205        Some("form") => BodyEncoding::Form,
206        _ => BodyEncoding::Json,
207    };
208
209    let mut classified = ClassifiedParams {
210        path: HashMap::new(),
211        query: HashMap::new(),
212        query_arrays: HashMap::new(),
213        header: HashMap::new(),
214        body: HashMap::new(),
215        body_encoding,
216    };
217
218    for (key, value) in args {
219        let prop_def = props.get(key);
220        let location = prop_def
221            .and_then(|p| p.get("x-ati-param-location"))
222            .and_then(|l| l.as_str())
223            .unwrap_or("body"); // default to body if no location specified
224
225        match location {
226            "path" => {
227                classified.path.insert(key.clone(), value_to_string(value));
228            }
229            "query" => {
230                // Check if this is an array value with a collection format
231                if let Value::Array(arr) = value {
232                    let cf_str = prop_def
233                        .and_then(|p| p.get("x-ati-collection-format"))
234                        .and_then(|v| v.as_str());
235                    let cf = match cf_str {
236                        Some("multi") => CollectionFormat::Multi,
237                        Some("csv") => CollectionFormat::Csv,
238                        Some("ssv") => CollectionFormat::Ssv,
239                        Some("pipes") => CollectionFormat::Pipes,
240                        _ => CollectionFormat::Multi, // default for arrays
241                    };
242                    let values: Vec<String> = arr.iter().map(value_to_string).collect();
243                    classified.query_arrays.insert(key.clone(), (values, cf));
244                } else {
245                    classified.query.insert(key.clone(), value_to_string(value));
246                }
247            }
248            "header" => {
249                classified
250                    .header
251                    .insert(key.clone(), value_to_string(value));
252            }
253            _ => {
254                classified.body.insert(key.clone(), value.clone());
255            }
256        }
257    }
258
259    Some(classified)
260}
261
262/// Substitute path parameters like `{petId}` in the endpoint template.
263/// Rejects values containing path traversal or URL-breaking characters,
264/// then percent-encodes the value before substitution.
265fn substitute_path_params(
266    endpoint: &str,
267    path_args: &HashMap<String, String>,
268) -> Result<String, HttpError> {
269    let mut result = endpoint.to_string();
270    for (key, value) in path_args {
271        if value.contains("..")
272            || value.contains('\\')
273            || value.contains('?')
274            || value.contains('#')
275            || value.contains('\0')
276        {
277            return Err(HttpError::InvalidPathParam {
278                key: key.clone(),
279                value: value.clone(),
280            });
281        }
282        let encoded = percent_encode_path_segment(value);
283        result = result.replace(&format!("{{{key}}}"), &encoded);
284    }
285    Ok(result)
286}
287
288/// Percent-encode a path segment value. Encodes everything except unreserved chars
289/// (RFC 3986 section 2.3: ALPHA / DIGIT / "-" / "." / "_" / "~").
290fn percent_encode_path_segment(s: &str) -> String {
291    let mut encoded = String::with_capacity(s.len());
292    for byte in s.bytes() {
293        match byte {
294            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
295                encoded.push(byte as char);
296            }
297            _ => {
298                encoded.push_str(&format!("%{:02X}", byte));
299            }
300        }
301    }
302    encoded
303}
304
305fn is_private_ip(ip: std::net::IpAddr) -> bool {
306    match ip {
307        std::net::IpAddr::V4(ip) => {
308            ip.is_loopback()
309                || ip.is_private()
310                || ip.is_link_local()
311                || ip.is_unspecified()
312                || (ip.octets()[0] == 100 && ip.octets()[1] >= 64 && ip.octets()[1] <= 127)
313        }
314        std::net::IpAddr::V6(ip) => {
315            ip.is_loopback()
316                || ip.is_unspecified()
317                || ip.is_unique_local()
318                || ip.is_unicast_link_local()
319        }
320    }
321}
322
323/// Convert a serde_json::Value to a URL-safe string.
324fn value_to_string(v: &Value) -> String {
325    match v {
326        Value::String(s) => s.clone(),
327        Value::Number(n) => n.to_string(),
328        Value::Bool(b) => b.to_string(),
329        Value::Null => String::new(),
330        other => other.to_string(),
331    }
332}
333
334/// Apply array query parameters to a request builder using the specified collection format.
335fn apply_query_arrays(
336    mut req: reqwest::RequestBuilder,
337    arrays: &HashMap<String, (Vec<String>, CollectionFormat)>,
338) -> reqwest::RequestBuilder {
339    for (key, (values, format)) in arrays {
340        match format {
341            CollectionFormat::Multi => {
342                // Repeated key: ?status=a&status=b
343                for val in values {
344                    req = req.query(&[(key.as_str(), val.as_str())]);
345                }
346            }
347            CollectionFormat::Csv => {
348                let joined = values.join(",");
349                req = req.query(&[(key.as_str(), joined.as_str())]);
350            }
351            CollectionFormat::Ssv => {
352                let joined = values.join(" ");
353                req = req.query(&[(key.as_str(), joined.as_str())]);
354            }
355            CollectionFormat::Pipes => {
356                let joined = values.join("|");
357                req = req.query(&[(key.as_str(), joined.as_str())]);
358            }
359        }
360    }
361    req
362}
363
364/// Execute an HTTP tool call against a provider's API.
365///
366/// Supports two modes:
367/// 1. **Location-aware** (OpenAPI tools): Parameters are classified by `x-ati-param-location`
368///    metadata in the input schema. Path params are substituted into the URL template,
369///    query params go to the query string, header params become request headers,
370///    and body params go to the JSON body.
371/// 2. **Legacy** (hand-written TOML tools): GET → all args as query params, POST/PUT/DELETE → JSON body.
372pub async fn execute_tool(
373    provider: &Provider,
374    tool: &Tool,
375    args: &HashMap<String, Value>,
376    keyring: &Keyring,
377) -> Result<Value, HttpError> {
378    execute_tool_with_gen(provider, tool, args, keyring, None, None).await
379}
380
381/// Execute an HTTP tool call, optionally using a dynamic auth generator.
382pub async fn execute_tool_with_gen(
383    provider: &Provider,
384    tool: &Tool,
385    args: &HashMap<String, Value>,
386    keyring: &Keyring,
387    gen_ctx: Option<&GenContext>,
388    auth_cache: Option<&AuthCache>,
389) -> Result<Value, HttpError> {
390    // SSRF protection: validate base_url is not targeting private networks
391    validate_url_not_private(&provider.base_url)?;
392
393    let client = Client::builder()
394        .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
395        .build()?;
396
397    // Merge manifest defaults into caller-provided args
398    let merged_args = merge_defaults(tool, args);
399
400    // Try location-aware classification (OpenAPI tools have x-ati-param-location)
401    let mut request = if let Some(classified) = classify_params(tool, &merged_args) {
402        // Validate headers against deny-list before injecting
403        validate_headers(&classified.header, provider.auth_header_name.as_deref())?;
404
405        // Location-aware mode: substitute path params, route by location
406        let resolved_endpoint = substitute_path_params(&tool.endpoint, &classified.path)?;
407        let url = format!(
408            "{}{}",
409            provider.base_url.trim_end_matches('/'),
410            resolved_endpoint
411        );
412
413        let mut req = match tool.method {
414            HttpMethod::Get | HttpMethod::Delete => {
415                let base_req = match tool.method {
416                    HttpMethod::Get => client.get(&url),
417                    HttpMethod::Delete => client.delete(&url),
418                    _ => unreachable!(),
419                };
420                // Query params to query string
421                let mut r = base_req;
422                for (k, v) in &classified.query {
423                    r = r.query(&[(k.as_str(), v.as_str())]);
424                }
425                r = apply_query_arrays(r, &classified.query_arrays);
426                r
427            }
428            HttpMethod::Post | HttpMethod::Put => {
429                let base_req = match tool.method {
430                    HttpMethod::Post => client.post(&url),
431                    HttpMethod::Put => client.put(&url),
432                    _ => unreachable!(),
433                };
434                // Body params: encode as JSON or form-urlencoded based on metadata
435                let mut r = if classified.body.is_empty() {
436                    base_req
437                } else {
438                    match classified.body_encoding {
439                        BodyEncoding::Json => base_req.json(&classified.body),
440                        BodyEncoding::Form => {
441                            let pairs: Vec<(String, String)> = classified
442                                .body
443                                .iter()
444                                .map(|(k, v)| (k.clone(), value_to_string(v)))
445                                .collect();
446                            base_req.form(&pairs)
447                        }
448                    }
449                };
450                // Query params still go to query string
451                for (k, v) in &classified.query {
452                    r = r.query(&[(k.as_str(), v.as_str())]);
453                }
454                r = apply_query_arrays(r, &classified.query_arrays);
455                r
456            }
457        };
458
459        // Inject classified header params
460        for (k, v) in &classified.header {
461            req = req.header(k.as_str(), v.as_str());
462        }
463
464        req
465    } else {
466        // Legacy mode: no x-ati-param-location metadata
467        let url = format!(
468            "{}{}",
469            provider.base_url.trim_end_matches('/'),
470            &tool.endpoint
471        );
472
473        match tool.method {
474            HttpMethod::Get => {
475                let mut req = client.get(&url);
476                for (k, v) in &merged_args {
477                    req = req.query(&[(k.as_str(), value_to_string(v))]);
478                }
479                req
480            }
481            HttpMethod::Post => client.post(&url).json(&merged_args),
482            HttpMethod::Put => client.put(&url).json(&merged_args),
483            HttpMethod::Delete => client.delete(&url).json(&merged_args),
484        }
485    };
486
487    // Inject authentication (generator takes priority over static keyring)
488    request = inject_auth(request, provider, keyring, gen_ctx, auth_cache).await?;
489
490    // Inject extra headers from provider config
491    for (header_name, header_value) in &provider.extra_headers {
492        request = request.header(header_name.as_str(), header_value.as_str());
493    }
494
495    // Execute request
496    let response = request.send().await?;
497    let status = response.status();
498
499    if !status.is_success() {
500        let body = response.text().await.unwrap_or_else(|_| "empty".into());
501        return Err(HttpError::ApiError {
502            status: status.as_u16(),
503            body,
504        });
505    }
506
507    // Parse response
508    let text = response.text().await?;
509    let value: Value = serde_json::from_str(&text).unwrap_or(Value::String(text));
510
511    Ok(value)
512}
513
514/// Inject authentication headers/params based on provider auth_type.
515///
516/// If the provider has an `auth_generator`, the generator is run first to produce
517/// dynamic credentials. Otherwise, static keyring credentials are used.
518async fn inject_auth(
519    request: reqwest::RequestBuilder,
520    provider: &Provider,
521    keyring: &Keyring,
522    gen_ctx: Option<&GenContext>,
523    auth_cache: Option<&AuthCache>,
524) -> Result<reqwest::RequestBuilder, HttpError> {
525    // Dynamic auth generator takes priority
526    if let Some(gen) = &provider.auth_generator {
527        let default_ctx = GenContext::default();
528        let ctx = gen_ctx.unwrap_or(&default_ctx);
529        let default_cache = AuthCache::new();
530        let cache = auth_cache.unwrap_or(&default_cache);
531
532        let cred = auth_generator::generate(provider, gen, ctx, keyring, cache)
533            .await
534            .map_err(|e| HttpError::MissingKey(format!("auth_generator: {e}")))?;
535
536        // Inject primary credential based on auth_type
537        let mut req = match provider.auth_type {
538            AuthType::Bearer => request.bearer_auth(&cred.value),
539            AuthType::Header => {
540                let name = provider.auth_header_name.as_deref().unwrap_or("X-Api-Key");
541                let val = match &provider.auth_value_prefix {
542                    Some(pfx) => format!("{pfx}{}", cred.value),
543                    None => cred.value.clone(),
544                };
545                request.header(name, val)
546            }
547            AuthType::Query => {
548                let name = provider.auth_query_name.as_deref().unwrap_or("api_key");
549                request.query(&[(name, &cred.value)])
550            }
551            _ => request,
552        };
553        // Inject extra headers from JSON inject targets
554        for (name, value) in &cred.extra_headers {
555            req = req.header(name.as_str(), value.as_str());
556        }
557        return Ok(req);
558    }
559
560    match provider.auth_type {
561        AuthType::None => Ok(request),
562        AuthType::Bearer => {
563            let key_name = provider
564                .auth_key_name
565                .as_deref()
566                .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
567            let key_value = keyring
568                .get(key_name)
569                .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
570            Ok(request.bearer_auth(key_value))
571        }
572        AuthType::Header => {
573            let key_name = provider
574                .auth_key_name
575                .as_deref()
576                .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
577            let key_value = keyring
578                .get(key_name)
579                .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
580            let header_name = provider.auth_header_name.as_deref().unwrap_or("X-Api-Key");
581            let final_value = match &provider.auth_value_prefix {
582                Some(prefix) => format!("{}{}", prefix, key_value),
583                None => key_value.to_string(),
584            };
585            Ok(request.header(header_name, final_value))
586        }
587        AuthType::Query => {
588            let key_name = provider
589                .auth_key_name
590                .as_deref()
591                .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
592            let key_value = keyring
593                .get(key_name)
594                .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
595            let query_name = provider.auth_query_name.as_deref().unwrap_or("api_key");
596            Ok(request.query(&[(query_name, key_value)]))
597        }
598        AuthType::Basic => {
599            let key_name = provider
600                .auth_key_name
601                .as_deref()
602                .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
603            let key_value = keyring
604                .get(key_name)
605                .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
606            Ok(request.basic_auth(key_value, None::<&str>))
607        }
608        AuthType::Oauth2 => {
609            let access_token = get_oauth2_token(provider, keyring).await?;
610            Ok(request.bearer_auth(access_token))
611        }
612    }
613}
614
615/// Fetch (or return cached) OAuth2 access token via client_credentials grant.
616async fn get_oauth2_token(provider: &Provider, keyring: &Keyring) -> Result<String, HttpError> {
617    let cache_key = provider.name.clone();
618
619    // Check cache
620    {
621        let cache = OAUTH2_CACHE.lock().unwrap();
622        if let Some((token, expiry)) = cache.get(&cache_key) {
623            // Use cached token if it has at least 60s remaining
624            if Instant::now() + Duration::from_secs(60) < *expiry {
625                return Ok(token.clone());
626            }
627        }
628    }
629
630    // Token expired or not cached — exchange credentials
631    let client_id_key = provider
632        .auth_key_name
633        .as_deref()
634        .ok_or_else(|| HttpError::Oauth2Error("auth_key_name not set for OAuth2".into()))?;
635    let client_id = keyring
636        .get(client_id_key)
637        .ok_or_else(|| HttpError::MissingKey(client_id_key.into()))?;
638
639    let client_secret_key = provider
640        .auth_secret_name
641        .as_deref()
642        .ok_or_else(|| HttpError::Oauth2Error("auth_secret_name not set for OAuth2".into()))?;
643    let client_secret = keyring
644        .get(client_secret_key)
645        .ok_or_else(|| HttpError::MissingKey(client_secret_key.into()))?;
646
647    let token_url = match &provider.oauth2_token_url {
648        Some(url) if url.starts_with("http") => url.clone(),
649        Some(path) => format!("{}{}", provider.base_url.trim_end_matches('/'), path),
650        None => return Err(HttpError::Oauth2Error("oauth2_token_url not set".into())),
651    };
652
653    // Enforce HTTPS for OAuth2 token URLs (credentials are sent in plaintext otherwise)
654    if token_url.starts_with("http://") {
655        return Err(HttpError::InsecureTokenUrl(token_url));
656    }
657
658    let client = Client::builder().timeout(Duration::from_secs(15)).build()?;
659
660    // Two OAuth2 client_credentials modes:
661    // 1. Form body: client_id + client_secret in form data (Amadeus)
662    // 2. Basic Auth: base64(client_id:client_secret) in Authorization header (Sovos)
663    let response = if provider.oauth2_basic_auth {
664        client
665            .post(&token_url)
666            .basic_auth(client_id, Some(client_secret))
667            .form(&[("grant_type", "client_credentials")])
668            .send()
669            .await?
670    } else {
671        client
672            .post(&token_url)
673            .form(&[
674                ("grant_type", "client_credentials"),
675                ("client_id", client_id),
676                ("client_secret", client_secret),
677            ])
678            .send()
679            .await?
680    };
681
682    if !response.status().is_success() {
683        let status = response.status().as_u16();
684        let body = response.text().await.unwrap_or_default();
685        return Err(HttpError::Oauth2Error(format!(
686            "token exchange failed ({status}): {body}"
687        )));
688    }
689
690    let body: Value = response
691        .json()
692        .await
693        .map_err(|e| HttpError::Oauth2Error(format!("failed to parse token response: {e}")))?;
694
695    let access_token = body
696        .get("access_token")
697        .and_then(|v| v.as_str())
698        .ok_or_else(|| HttpError::Oauth2Error("no access_token in response".into()))?
699        .to_string();
700
701    let expires_in = body
702        .get("expires_in")
703        .and_then(|v| v.as_u64())
704        .unwrap_or(1799);
705
706    let expiry = Instant::now() + Duration::from_secs(expires_in);
707
708    // Cache the token
709    {
710        let mut cache = OAUTH2_CACHE.lock().unwrap();
711        cache.insert(cache_key, (access_token.clone(), expiry));
712    }
713
714    Ok(access_token)
715}
716
717#[cfg(test)]
718mod tests {
719    use super::*;
720
721    #[test]
722    fn test_substitute_path_params_normal() {
723        let mut args = HashMap::new();
724        args.insert("petId".to_string(), "123".to_string());
725        let result = substitute_path_params("/pet/{petId}", &args).unwrap();
726        assert_eq!(result, "/pet/123");
727    }
728
729    #[test]
730    fn test_substitute_path_params_rejects_dotdot() {
731        let mut args = HashMap::new();
732        args.insert("id".to_string(), "../admin".to_string());
733        assert!(substitute_path_params("/resource/{id}", &args).is_err());
734    }
735
736    #[test]
737    fn test_substitute_path_params_encodes_slash() {
738        let mut args = HashMap::new();
739        args.insert("id".to_string(), "fal-ai/flux/dev".to_string());
740        let result = substitute_path_params("/resource/{id}", &args).unwrap();
741        assert_eq!(result, "/resource/fal-ai%2Fflux%2Fdev");
742    }
743
744    #[test]
745    fn test_substitute_path_params_rejects_backslash() {
746        let mut args = HashMap::new();
747        args.insert("id".to_string(), "foo\\bar".to_string());
748        assert!(substitute_path_params("/resource/{id}", &args).is_err());
749    }
750
751    #[test]
752    fn test_substitute_path_params_rejects_question() {
753        let mut args = HashMap::new();
754        args.insert("id".to_string(), "foo?bar=1".to_string());
755        assert!(substitute_path_params("/resource/{id}", &args).is_err());
756    }
757
758    #[test]
759    fn test_substitute_path_params_rejects_hash() {
760        let mut args = HashMap::new();
761        args.insert("id".to_string(), "foo#bar".to_string());
762        assert!(substitute_path_params("/resource/{id}", &args).is_err());
763    }
764
765    #[test]
766    fn test_substitute_path_params_rejects_null_byte() {
767        let mut args = HashMap::new();
768        args.insert("id".to_string(), "foo\0bar".to_string());
769        assert!(substitute_path_params("/resource/{id}", &args).is_err());
770    }
771
772    #[test]
773    fn test_substitute_path_params_encodes_special() {
774        let mut args = HashMap::new();
775        args.insert("name".to_string(), "hello world".to_string());
776        let result = substitute_path_params("/users/{name}", &args).unwrap();
777        assert_eq!(result, "/users/hello%20world");
778    }
779
780    #[test]
781    fn test_substitute_path_params_preserves_unreserved() {
782        let mut args = HashMap::new();
783        args.insert("id".to_string(), "abc-123_test.v2~draft".to_string());
784        let result = substitute_path_params("/items/{id}", &args).unwrap();
785        assert_eq!(result, "/items/abc-123_test.v2~draft");
786    }
787
788    #[test]
789    fn test_substitute_path_params_encodes_at_sign() {
790        let mut args = HashMap::new();
791        args.insert("user".to_string(), "user@domain".to_string());
792        let result = substitute_path_params("/profile/{user}", &args).unwrap();
793        assert_eq!(result, "/profile/user%40domain");
794    }
795
796    #[test]
797    fn test_percent_encode_path_segment_empty() {
798        assert_eq!(percent_encode_path_segment(""), "");
799    }
800
801    #[test]
802    fn test_percent_encode_path_segment_ascii_only() {
803        assert_eq!(percent_encode_path_segment("abc123"), "abc123");
804    }
805
806    #[test]
807    fn test_substitute_path_params_multiple() {
808        let mut args = HashMap::new();
809        args.insert("owner".to_string(), "acme".to_string());
810        args.insert("repo".to_string(), "widgets".to_string());
811        let result = substitute_path_params("/repos/{owner}/{repo}/issues", &args).unwrap();
812        assert_eq!(result, "/repos/acme/widgets/issues");
813    }
814
815    #[test]
816    fn test_substitute_path_params_no_placeholders() {
817        let args = HashMap::new();
818        let result = substitute_path_params("/health", &args).unwrap();
819        assert_eq!(result, "/health");
820    }
821}