Skip to main content

ati/core/
http.rs

1use reqwest::Client;
2use serde_json::Value;
3use std::collections::HashMap;
4use std::net::ToSocketAddrs;
5use std::sync::Mutex;
6use std::time::{Duration, Instant};
7use thiserror::Error;
8
9use crate::core::auth_generator::{self, AuthCache, GenContext};
10use crate::core::keyring::Keyring;
11use crate::core::manifest::{AuthType, HttpMethod, Provider, Tool};
12
13#[derive(Error, Debug)]
14pub enum HttpError {
15    #[error("API key '{0}' not found in keyring")]
16    MissingKey(String),
17    #[error("HTTP request failed: {0}")]
18    Request(#[from] reqwest::Error),
19    #[error("API error ({status}): {body}")]
20    ApiError {
21        status: u16,
22        body: String,
23        /// Parsed from the upstream JSON body when present (e.g. "not_found",
24        /// "invalid_request"). Used for Sentry tagging.
25        error_type: Option<String>,
26        /// Parsed human-readable message from the upstream body. Used for
27        /// Sentry tagging after scrubbing.
28        error_message: Option<String>,
29    },
30    /// Upstream returned a 404 with a body shape that signals "no records
31    /// match" (PDL, Middesk, etc.). The caller should treat this as a legit
32    /// empty result rather than a failure. Carries the parsed message for
33    /// optional logging.
34    #[error("No records found ({status})")]
35    NoRecordsFound { status: u16 },
36    #[error("Failed to parse response as JSON: {0}")]
37    ParseError(String),
38    #[error("OAuth2 token exchange failed: {0}")]
39    Oauth2Error(String),
40    #[error("Invalid path parameter '{key}': value '{value}' contains forbidden characters")]
41    InvalidPathParam { key: String, value: String },
42    #[error("Header '{0}' is not allowed as a user-supplied parameter")]
43    DeniedHeader(String),
44    #[error("SSRF protection: URL '{0}' targets a private/internal network address")]
45    SsrfBlocked(String),
46    #[error("OAuth2 token URL must use HTTPS: '{0}'")]
47    InsecureTokenUrl(String),
48}
49
50/// Cached OAuth2 token: (access_token, expiry_instant)
51static OAUTH2_CACHE: std::sync::LazyLock<Mutex<HashMap<String, (String, Instant)>>> =
52    std::sync::LazyLock::new(|| Mutex::new(HashMap::new()));
53
54const DEFAULT_TIMEOUT_SECS: u64 = 60;
55
56/// Validate that a URL does not target private/internal network addresses (SSRF protection).
57/// Checks the hostname against deny-listed private IP ranges.
58///
59/// Enforcement is controlled by `ATI_SSRF_PROTECTION` env var:
60/// - "1" or "true": block requests to private addresses (default in proxy mode)
61/// - "warn": log a warning but allow the request
62/// - unset/other: allow the request (for local development/testing)
63pub fn validate_url_not_private(url: &str) -> Result<(), HttpError> {
64    let mode = std::env::var("ATI_SSRF_PROTECTION").unwrap_or_default();
65    let enforce = mode == "1" || mode.eq_ignore_ascii_case("true");
66    let warn_only = mode.eq_ignore_ascii_case("warn");
67
68    if !enforce && !warn_only {
69        return Ok(());
70    }
71    let parsed = match reqwest::Url::parse(url) {
72        Ok(parsed) => parsed,
73        Err(_) => return Ok(()),
74    };
75    let host = parsed.host_str().unwrap_or("");
76    let port = parsed.port_or_known_default().unwrap_or(80);
77    let ip_host = host.trim_matches(['[', ']']);
78
79    if host.is_empty() {
80        return Ok(());
81    }
82
83    // Check common internal hostnames
84    let host_lower = host.to_lowercase();
85    let mut is_private = host_lower == "localhost"
86        || host_lower == "metadata.google.internal"
87        || host_lower.ends_with(".internal")
88        || host_lower.ends_with(".local");
89
90    if !is_private {
91        if let Ok(ip) = ip_host.parse::<std::net::IpAddr>() {
92            is_private = is_private_ip(ip);
93        } else if let Ok(addrs) = (ip_host, port).to_socket_addrs() {
94            is_private = addrs.into_iter().any(|addr| is_private_ip(addr.ip()));
95        }
96    }
97
98    if is_private {
99        if warn_only {
100            tracing::warn!(url, "SSRF protection — URL targets private address");
101            return Ok(());
102        }
103        return Err(HttpError::SsrfBlocked(url.to_string()));
104    }
105
106    Ok(())
107}
108
109/// Headers that must never be set by agent-supplied parameters.
110/// Checked case-insensitively.
111const DENIED_HEADERS: &[&str] = &[
112    "authorization",
113    "host",
114    "cookie",
115    "set-cookie",
116    "content-type",
117    "content-length",
118    "transfer-encoding",
119    "connection",
120    "proxy-authorization",
121    "x-forwarded-for",
122    "x-forwarded-host",
123    "x-forwarded-proto",
124    "x-real-ip",
125];
126
127/// Check that classified header parameters don't contain denied headers.
128pub fn validate_headers(
129    headers: &HashMap<String, String>,
130    provider_auth_header: Option<&str>,
131) -> Result<(), HttpError> {
132    for key in headers.keys() {
133        let lower = key.to_lowercase();
134        if DENIED_HEADERS.contains(&lower.as_str()) {
135            return Err(HttpError::DeniedHeader(key.clone()));
136        }
137        if let Some(auth_header) = provider_auth_header {
138            if lower == auth_header.to_lowercase() {
139                return Err(HttpError::DeniedHeader(key.clone()));
140            }
141        }
142    }
143    Ok(())
144}
145
146/// Merge manifest defaults into the args map for any params not provided by caller.
147fn merge_defaults(tool: &Tool, args: &HashMap<String, Value>) -> HashMap<String, Value> {
148    let mut merged = args.clone();
149    if let Some(schema) = &tool.input_schema {
150        if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
151            for (key, prop_def) in props {
152                if !merged.contains_key(key) {
153                    if let Some(default_val) = prop_def.get("default") {
154                        // Skip empty arrays/objects as defaults — they add no value
155                        // and some APIs reject them (e.g. ClinicalTrials `sort=[]`).
156                        let dominated = match default_val {
157                            Value::Array(a) => a.is_empty(),
158                            Value::Object(o) => o.is_empty(),
159                            Value::Null => true,
160                            _ => false,
161                        };
162                        if !dominated {
163                            merged.insert(key.clone(), default_val.clone());
164                        }
165                    }
166                }
167            }
168        }
169    }
170    merged
171}
172
173/// How array query parameters should be serialized.
174#[derive(Debug, Clone, Copy, PartialEq)]
175enum CollectionFormat {
176    /// Repeated key: ?status=a&status=b
177    Multi,
178    /// Comma-separated: ?status=a,b
179    Csv,
180    /// Space-separated: ?status=a%20b
181    Ssv,
182    /// Pipe-separated: ?status=a|b
183    Pipes,
184}
185
186/// How the request body should be encoded.
187#[derive(Debug, Clone, Copy, PartialEq)]
188enum BodyEncoding {
189    Json,
190    Form,
191}
192
193/// Classified parameter maps, split by location.
194struct ClassifiedParams {
195    path: HashMap<String, String>,
196    query: HashMap<String, String>,
197    query_arrays: HashMap<String, (Vec<String>, CollectionFormat)>,
198    header: HashMap<String, String>,
199    body: HashMap<String, Value>,
200    body_encoding: BodyEncoding,
201}
202
203/// Classify parameters by their `x-ati-param-location` metadata in the input schema.
204/// If no location metadata exists (legacy TOML tools), returns None for legacy fallback.
205fn classify_params(tool: &Tool, args: &HashMap<String, Value>) -> Option<ClassifiedParams> {
206    let schema = tool.input_schema.as_ref()?;
207    let props = schema.get("properties")?.as_object()?;
208
209    // Check if any property has x-ati-param-location — if none do, this is a legacy tool
210    let has_locations = props
211        .values()
212        .any(|p| p.get("x-ati-param-location").is_some());
213
214    if !has_locations {
215        return None;
216    }
217
218    // Detect body encoding from schema-level metadata
219    let body_encoding = match schema.get("x-ati-body-encoding").and_then(|v| v.as_str()) {
220        Some("form") => BodyEncoding::Form,
221        _ => BodyEncoding::Json,
222    };
223
224    let mut classified = ClassifiedParams {
225        path: HashMap::new(),
226        query: HashMap::new(),
227        query_arrays: HashMap::new(),
228        header: HashMap::new(),
229        body: HashMap::new(),
230        body_encoding,
231    };
232
233    for (key, value) in args {
234        let prop_def = props.get(key);
235        let location = prop_def
236            .and_then(|p| p.get("x-ati-param-location"))
237            .and_then(|l| l.as_str())
238            .unwrap_or("body"); // default to body if no location specified
239
240        match location {
241            "path" => {
242                classified.path.insert(key.clone(), value_to_string(value));
243            }
244            "query" => {
245                // Check if this is an array value with a collection format
246                if let Value::Array(arr) = value {
247                    let cf_str = prop_def
248                        .and_then(|p| p.get("x-ati-collection-format"))
249                        .and_then(|v| v.as_str());
250                    let cf = match cf_str {
251                        Some("multi") => CollectionFormat::Multi,
252                        Some("csv") => CollectionFormat::Csv,
253                        Some("ssv") => CollectionFormat::Ssv,
254                        Some("pipes") => CollectionFormat::Pipes,
255                        _ => CollectionFormat::Multi, // default for arrays
256                    };
257                    let values: Vec<String> = arr.iter().map(value_to_string).collect();
258                    classified.query_arrays.insert(key.clone(), (values, cf));
259                } else {
260                    classified.query.insert(key.clone(), value_to_string(value));
261                }
262            }
263            "header" => {
264                classified
265                    .header
266                    .insert(key.clone(), value_to_string(value));
267            }
268            _ => {
269                classified.body.insert(key.clone(), value.clone());
270            }
271        }
272    }
273
274    Some(classified)
275}
276
277/// Substitute path parameters like `{petId}` in the endpoint template.
278/// Rejects values containing path traversal or URL-breaking characters,
279/// then percent-encodes the value before substitution.
280fn substitute_path_params(
281    endpoint: &str,
282    path_args: &HashMap<String, String>,
283) -> Result<String, HttpError> {
284    let mut result = endpoint.to_string();
285    for (key, value) in path_args {
286        if value.contains("..")
287            || value.contains('\\')
288            || value.contains('?')
289            || value.contains('#')
290            || value.contains('\0')
291        {
292            return Err(HttpError::InvalidPathParam {
293                key: key.clone(),
294                value: value.clone(),
295            });
296        }
297        let encoded = percent_encode_path_segment(value);
298        result = result.replace(&format!("{{{key}}}"), &encoded);
299    }
300    Ok(result)
301}
302
303/// Percent-encode a path segment value. Encodes everything except unreserved chars
304/// (RFC 3986 section 2.3: ALPHA / DIGIT / "-" / "." / "_" / "~").
305pub(crate) fn percent_encode_path_segment(s: &str) -> String {
306    let mut encoded = String::with_capacity(s.len());
307    for byte in s.bytes() {
308        match byte {
309            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
310                encoded.push(byte as char);
311            }
312            _ => {
313                encoded.push_str(&format!("%{:02X}", byte));
314            }
315        }
316    }
317    encoded
318}
319
320fn is_private_ip(ip: std::net::IpAddr) -> bool {
321    match ip {
322        std::net::IpAddr::V4(ip) => {
323            ip.is_loopback()
324                || ip.is_private()
325                || ip.is_link_local()
326                || ip.is_unspecified()
327                || (ip.octets()[0] == 100 && ip.octets()[1] >= 64 && ip.octets()[1] <= 127)
328        }
329        std::net::IpAddr::V6(ip) => {
330            ip.is_loopback()
331                || ip.is_unspecified()
332                || ip.is_unique_local()
333                || ip.is_unicast_link_local()
334        }
335    }
336}
337
338/// Convert a serde_json::Value to a URL-safe string.
339fn value_to_string(v: &Value) -> String {
340    match v {
341        Value::String(s) => s.clone(),
342        Value::Number(n) => n.to_string(),
343        Value::Bool(b) => b.to_string(),
344        Value::Null => String::new(),
345        other => other.to_string(),
346    }
347}
348
349/// Apply array query parameters to a request builder using the specified collection format.
350fn apply_query_arrays(
351    mut req: reqwest::RequestBuilder,
352    arrays: &HashMap<String, (Vec<String>, CollectionFormat)>,
353) -> reqwest::RequestBuilder {
354    for (key, (values, format)) in arrays {
355        match format {
356            CollectionFormat::Multi => {
357                // Repeated key: ?status=a&status=b
358                for val in values {
359                    req = req.query(&[(key.as_str(), val.as_str())]);
360                }
361            }
362            CollectionFormat::Csv => {
363                let joined = values.join(",");
364                req = req.query(&[(key.as_str(), joined.as_str())]);
365            }
366            CollectionFormat::Ssv => {
367                let joined = values.join(" ");
368                req = req.query(&[(key.as_str(), joined.as_str())]);
369            }
370            CollectionFormat::Pipes => {
371                let joined = values.join("|");
372                req = req.query(&[(key.as_str(), joined.as_str())]);
373            }
374        }
375    }
376    req
377}
378
379/// Execute an HTTP tool call against a provider's API.
380///
381/// Supports two modes:
382/// 1. **Location-aware** (OpenAPI tools): Parameters are classified by `x-ati-param-location`
383///    metadata in the input schema. Path params are substituted into the URL template,
384///    query params go to the query string, header params become request headers,
385///    and body params go to the JSON body.
386/// 2. **Legacy** (hand-written TOML tools): GET → all args as query params, POST/PUT/DELETE → JSON body.
387pub async fn execute_tool(
388    provider: &Provider,
389    tool: &Tool,
390    args: &HashMap<String, Value>,
391    keyring: &Keyring,
392) -> Result<Value, HttpError> {
393    execute_tool_with_gen(provider, tool, args, keyring, None, None).await
394}
395
396/// Execute an HTTP tool call, optionally using a dynamic auth generator.
397pub async fn execute_tool_with_gen(
398    provider: &Provider,
399    tool: &Tool,
400    args: &HashMap<String, Value>,
401    keyring: &Keyring,
402    gen_ctx: Option<&GenContext>,
403    auth_cache: Option<&AuthCache>,
404) -> Result<Value, HttpError> {
405    // SSRF protection: validate base_url is not targeting private networks
406    validate_url_not_private(&provider.base_url)?;
407
408    let client = Client::builder()
409        .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
410        .build()?;
411
412    // Merge manifest defaults into caller-provided args
413    let merged_args = merge_defaults(tool, args);
414
415    // Try location-aware classification (OpenAPI tools have x-ati-param-location)
416    let mut request = if let Some(classified) = classify_params(tool, &merged_args) {
417        // Validate headers against deny-list before injecting
418        validate_headers(&classified.header, provider.auth_header_name.as_deref())?;
419
420        // Location-aware mode: substitute path params, route by location
421        let resolved_endpoint = substitute_path_params(&tool.endpoint, &classified.path)?;
422        let url = format!(
423            "{}{}",
424            provider.base_url.trim_end_matches('/'),
425            resolved_endpoint
426        );
427
428        let mut req = match tool.method {
429            HttpMethod::Get | HttpMethod::Delete => {
430                let base_req = match tool.method {
431                    HttpMethod::Get => client.get(&url),
432                    HttpMethod::Delete => client.delete(&url),
433                    _ => unreachable!(),
434                };
435                // Query params to query string
436                let mut r = base_req;
437                for (k, v) in &classified.query {
438                    r = r.query(&[(k.as_str(), v.as_str())]);
439                }
440                r = apply_query_arrays(r, &classified.query_arrays);
441                r
442            }
443            HttpMethod::Post | HttpMethod::Put => {
444                let base_req = match tool.method {
445                    HttpMethod::Post => client.post(&url),
446                    HttpMethod::Put => client.put(&url),
447                    _ => unreachable!(),
448                };
449                // Body params: encode as JSON or form-urlencoded based on metadata
450                let mut r = if classified.body.is_empty() {
451                    base_req
452                } else {
453                    match classified.body_encoding {
454                        BodyEncoding::Json => base_req.json(&classified.body),
455                        BodyEncoding::Form => {
456                            let pairs: Vec<(String, String)> = classified
457                                .body
458                                .iter()
459                                .map(|(k, v)| (k.clone(), value_to_string(v)))
460                                .collect();
461                            base_req.form(&pairs)
462                        }
463                    }
464                };
465                // Query params still go to query string
466                for (k, v) in &classified.query {
467                    r = r.query(&[(k.as_str(), v.as_str())]);
468                }
469                r = apply_query_arrays(r, &classified.query_arrays);
470                r
471            }
472        };
473
474        // Inject classified header params
475        for (k, v) in &classified.header {
476            req = req.header(k.as_str(), v.as_str());
477        }
478
479        req
480    } else {
481        // Legacy mode: no x-ati-param-location metadata
482        let url = format!(
483            "{}{}",
484            provider.base_url.trim_end_matches('/'),
485            &tool.endpoint
486        );
487
488        match tool.method {
489            HttpMethod::Get => {
490                let mut req = client.get(&url);
491                for (k, v) in &merged_args {
492                    req = req.query(&[(k.as_str(), value_to_string(v))]);
493                }
494                req
495            }
496            HttpMethod::Post => client.post(&url).json(&merged_args),
497            HttpMethod::Put => client.put(&url).json(&merged_args),
498            HttpMethod::Delete => client.delete(&url).json(&merged_args),
499        }
500    };
501
502    // Inject authentication (generator takes priority over static keyring)
503    request = inject_auth(request, provider, keyring, gen_ctx, auth_cache).await?;
504
505    // Inject extra headers from provider config
506    for (header_name, header_value) in &provider.extra_headers {
507        request = request.header(header_name.as_str(), header_value.as_str());
508    }
509
510    // Execute request
511    let response = request.send().await?;
512    let status = response.status();
513
514    if !status.is_success() {
515        let body = response.text().await.unwrap_or_else(|_| "empty".into());
516        let status_u16 = status.as_u16();
517        let (error_type, error_message) = crate::core::sentry_scope::parse_upstream_error(&body);
518        if status_u16 == 404
519            && crate::core::sentry_scope::is_no_records_body(
520                error_type.as_deref(),
521                error_message.as_deref(),
522            )
523        {
524            return Err(HttpError::NoRecordsFound { status: status_u16 });
525        }
526        return Err(HttpError::ApiError {
527            status: status_u16,
528            body,
529            error_type,
530            error_message,
531        });
532    }
533
534    // Parse response
535    let text = response.text().await?;
536    let value: Value = serde_json::from_str(&text).unwrap_or(Value::String(text));
537
538    Ok(value)
539}
540
541/// Inject authentication headers/params based on provider auth_type.
542///
543/// If the provider has an `auth_generator`, the generator is run first to produce
544/// dynamic credentials. Otherwise, static keyring credentials are used.
545async fn inject_auth(
546    request: reqwest::RequestBuilder,
547    provider: &Provider,
548    keyring: &Keyring,
549    gen_ctx: Option<&GenContext>,
550    auth_cache: Option<&AuthCache>,
551) -> Result<reqwest::RequestBuilder, HttpError> {
552    // Dynamic auth generator takes priority
553    if let Some(gen) = &provider.auth_generator {
554        let default_ctx = GenContext::default();
555        let ctx = gen_ctx.unwrap_or(&default_ctx);
556        let default_cache = AuthCache::new();
557        let cache = auth_cache.unwrap_or(&default_cache);
558
559        let cred = auth_generator::generate(provider, gen, ctx, keyring, cache)
560            .await
561            .map_err(|e| HttpError::MissingKey(format!("auth_generator: {e}")))?;
562
563        // Inject primary credential based on auth_type
564        let mut req = match provider.auth_type {
565            AuthType::Bearer => request.bearer_auth(&cred.value),
566            AuthType::Header => {
567                let name = provider.auth_header_name.as_deref().unwrap_or("X-Api-Key");
568                let val = match &provider.auth_value_prefix {
569                    Some(pfx) => format!("{pfx}{}", cred.value),
570                    None => cred.value.clone(),
571                };
572                request.header(name, val)
573            }
574            AuthType::Query => {
575                let name = provider.auth_query_name.as_deref().unwrap_or("api_key");
576                request.query(&[(name, &cred.value)])
577            }
578            _ => request,
579        };
580        // Inject extra headers from JSON inject targets
581        for (name, value) in &cred.extra_headers {
582            req = req.header(name.as_str(), value.as_str());
583        }
584        return Ok(req);
585    }
586
587    match provider.auth_type {
588        AuthType::None => Ok(request),
589        AuthType::Bearer => {
590            let key_name = provider
591                .auth_key_name
592                .as_deref()
593                .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
594            let key_value = keyring
595                .get(key_name)
596                .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
597            Ok(request.bearer_auth(key_value))
598        }
599        AuthType::Header => {
600            let key_name = provider
601                .auth_key_name
602                .as_deref()
603                .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
604            let key_value = keyring
605                .get(key_name)
606                .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
607            let header_name = provider.auth_header_name.as_deref().unwrap_or("X-Api-Key");
608            let final_value = match &provider.auth_value_prefix {
609                Some(prefix) => format!("{}{}", prefix, key_value),
610                None => key_value.to_string(),
611            };
612            Ok(request.header(header_name, final_value))
613        }
614        AuthType::Query => {
615            let key_name = provider
616                .auth_key_name
617                .as_deref()
618                .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
619            let key_value = keyring
620                .get(key_name)
621                .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
622            let query_name = provider.auth_query_name.as_deref().unwrap_or("api_key");
623            Ok(request.query(&[(query_name, key_value)]))
624        }
625        AuthType::Basic => {
626            let key_name = provider
627                .auth_key_name
628                .as_deref()
629                .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
630            let key_value = keyring
631                .get(key_name)
632                .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
633            Ok(request.basic_auth(key_value, None::<&str>))
634        }
635        AuthType::Oauth2 => {
636            let access_token = get_oauth2_token(provider, keyring).await?;
637            Ok(request.bearer_auth(access_token))
638        }
639        AuthType::Url => {
640            // Auth key is already interpolated into the URL via
641            // ${key_name} placeholders resolved at connection time.
642            // No header or query param injection needed.
643            Ok(request)
644        }
645    }
646}
647
648/// Fetch (or return cached) OAuth2 access token via client_credentials grant.
649async fn get_oauth2_token(provider: &Provider, keyring: &Keyring) -> Result<String, HttpError> {
650    let cache_key = provider.name.clone();
651
652    // Check cache
653    {
654        let cache = OAUTH2_CACHE.lock().unwrap();
655        if let Some((token, expiry)) = cache.get(&cache_key) {
656            // Use cached token if it has at least 60s remaining
657            if Instant::now() + Duration::from_secs(60) < *expiry {
658                return Ok(token.clone());
659            }
660        }
661    }
662
663    // Token expired or not cached — exchange credentials
664    let client_id_key = provider
665        .auth_key_name
666        .as_deref()
667        .ok_or_else(|| HttpError::Oauth2Error("auth_key_name not set for OAuth2".into()))?;
668    let client_id = keyring
669        .get(client_id_key)
670        .ok_or_else(|| HttpError::MissingKey(client_id_key.into()))?;
671
672    let client_secret_key = provider
673        .auth_secret_name
674        .as_deref()
675        .ok_or_else(|| HttpError::Oauth2Error("auth_secret_name not set for OAuth2".into()))?;
676    let client_secret = keyring
677        .get(client_secret_key)
678        .ok_or_else(|| HttpError::MissingKey(client_secret_key.into()))?;
679
680    let token_url = match &provider.oauth2_token_url {
681        Some(url) if url.starts_with("http") => url.clone(),
682        Some(path) => format!("{}{}", provider.base_url.trim_end_matches('/'), path),
683        None => return Err(HttpError::Oauth2Error("oauth2_token_url not set".into())),
684    };
685
686    // Enforce HTTPS for OAuth2 token URLs (credentials are sent in plaintext otherwise)
687    if token_url.starts_with("http://") {
688        return Err(HttpError::InsecureTokenUrl(token_url));
689    }
690
691    let client = Client::builder().timeout(Duration::from_secs(15)).build()?;
692
693    // Two OAuth2 client_credentials modes:
694    // 1. Form body: client_id + client_secret in form data (Amadeus)
695    // 2. Basic Auth: base64(client_id:client_secret) in Authorization header (Sovos)
696    let response = if provider.oauth2_basic_auth {
697        client
698            .post(&token_url)
699            .basic_auth(client_id, Some(client_secret))
700            .form(&[("grant_type", "client_credentials")])
701            .send()
702            .await?
703    } else {
704        client
705            .post(&token_url)
706            .form(&[
707                ("grant_type", "client_credentials"),
708                ("client_id", client_id),
709                ("client_secret", client_secret),
710            ])
711            .send()
712            .await?
713    };
714
715    if !response.status().is_success() {
716        let status = response.status().as_u16();
717        let body = response.text().await.unwrap_or_default();
718        return Err(HttpError::Oauth2Error(format!(
719            "token exchange failed ({status}): {body}"
720        )));
721    }
722
723    let body: Value = response
724        .json()
725        .await
726        .map_err(|e| HttpError::Oauth2Error(format!("failed to parse token response: {e}")))?;
727
728    let access_token = body
729        .get("access_token")
730        .and_then(|v| v.as_str())
731        .ok_or_else(|| HttpError::Oauth2Error("no access_token in response".into()))?
732        .to_string();
733
734    let expires_in = body
735        .get("expires_in")
736        .and_then(|v| v.as_u64())
737        .unwrap_or(1799);
738
739    let expiry = Instant::now() + Duration::from_secs(expires_in);
740
741    // Cache the token
742    {
743        let mut cache = OAUTH2_CACHE.lock().unwrap();
744        cache.insert(cache_key, (access_token.clone(), expiry));
745    }
746
747    Ok(access_token)
748}
749
750#[cfg(test)]
751mod tests {
752    use super::*;
753
754    #[test]
755    fn test_substitute_path_params_normal() {
756        let mut args = HashMap::new();
757        args.insert("petId".to_string(), "123".to_string());
758        let result = substitute_path_params("/pet/{petId}", &args).unwrap();
759        assert_eq!(result, "/pet/123");
760    }
761
762    #[test]
763    fn test_substitute_path_params_rejects_dotdot() {
764        let mut args = HashMap::new();
765        args.insert("id".to_string(), "../admin".to_string());
766        assert!(substitute_path_params("/resource/{id}", &args).is_err());
767    }
768
769    #[test]
770    fn test_substitute_path_params_encodes_slash() {
771        let mut args = HashMap::new();
772        args.insert("id".to_string(), "fal-ai/flux/dev".to_string());
773        let result = substitute_path_params("/resource/{id}", &args).unwrap();
774        assert_eq!(result, "/resource/fal-ai%2Fflux%2Fdev");
775    }
776
777    #[test]
778    fn test_substitute_path_params_rejects_backslash() {
779        let mut args = HashMap::new();
780        args.insert("id".to_string(), "foo\\bar".to_string());
781        assert!(substitute_path_params("/resource/{id}", &args).is_err());
782    }
783
784    #[test]
785    fn test_substitute_path_params_rejects_question() {
786        let mut args = HashMap::new();
787        args.insert("id".to_string(), "foo?bar=1".to_string());
788        assert!(substitute_path_params("/resource/{id}", &args).is_err());
789    }
790
791    #[test]
792    fn test_substitute_path_params_rejects_hash() {
793        let mut args = HashMap::new();
794        args.insert("id".to_string(), "foo#bar".to_string());
795        assert!(substitute_path_params("/resource/{id}", &args).is_err());
796    }
797
798    #[test]
799    fn test_substitute_path_params_rejects_null_byte() {
800        let mut args = HashMap::new();
801        args.insert("id".to_string(), "foo\0bar".to_string());
802        assert!(substitute_path_params("/resource/{id}", &args).is_err());
803    }
804
805    #[test]
806    fn test_substitute_path_params_encodes_special() {
807        let mut args = HashMap::new();
808        args.insert("name".to_string(), "hello world".to_string());
809        let result = substitute_path_params("/users/{name}", &args).unwrap();
810        assert_eq!(result, "/users/hello%20world");
811    }
812
813    #[test]
814    fn test_substitute_path_params_preserves_unreserved() {
815        let mut args = HashMap::new();
816        args.insert("id".to_string(), "abc-123_test.v2~draft".to_string());
817        let result = substitute_path_params("/items/{id}", &args).unwrap();
818        assert_eq!(result, "/items/abc-123_test.v2~draft");
819    }
820
821    #[test]
822    fn test_substitute_path_params_encodes_at_sign() {
823        let mut args = HashMap::new();
824        args.insert("user".to_string(), "user@domain".to_string());
825        let result = substitute_path_params("/profile/{user}", &args).unwrap();
826        assert_eq!(result, "/profile/user%40domain");
827    }
828
829    #[test]
830    fn test_percent_encode_path_segment_empty() {
831        assert_eq!(percent_encode_path_segment(""), "");
832    }
833
834    #[test]
835    fn test_percent_encode_path_segment_ascii_only() {
836        assert_eq!(percent_encode_path_segment("abc123"), "abc123");
837    }
838
839    #[test]
840    fn test_substitute_path_params_multiple() {
841        let mut args = HashMap::new();
842        args.insert("owner".to_string(), "acme".to_string());
843        args.insert("repo".to_string(), "widgets".to_string());
844        let result = substitute_path_params("/repos/{owner}/{repo}/issues", &args).unwrap();
845        assert_eq!(result, "/repos/acme/widgets/issues");
846    }
847
848    #[test]
849    fn test_substitute_path_params_no_placeholders() {
850        let args = HashMap::new();
851        let result = substitute_path_params("/health", &args).unwrap();
852        assert_eq!(result, "/health");
853    }
854}