1use reqwest::Client;
2use serde_json::Value;
3use std::collections::HashMap;
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6use thiserror::Error;
7
8use crate::core::auth_generator::{self, AuthCache, GenContext};
9use crate::core::keyring::Keyring;
10use crate::core::manifest::{AuthType, HttpMethod, Provider, Tool};
11
12#[derive(Error, Debug)]
13pub enum HttpError {
14 #[error("API key '{0}' not found in keyring")]
15 MissingKey(String),
16 #[error("HTTP request failed: {0}")]
17 Request(#[from] reqwest::Error),
18 #[error("API error ({status}): {body}")]
19 ApiError { status: u16, body: String },
20 #[error("Failed to parse response as JSON: {0}")]
21 ParseError(String),
22 #[error("OAuth2 token exchange failed: {0}")]
23 Oauth2Error(String),
24 #[error("Invalid path parameter '{key}': value '{value}' contains forbidden characters")]
25 InvalidPathParam { key: String, value: String },
26 #[error("Header '{0}' is not allowed as a user-supplied parameter")]
27 DeniedHeader(String),
28 #[error("SSRF protection: URL '{0}' targets a private/internal network address")]
29 SsrfBlocked(String),
30 #[error("OAuth2 token URL must use HTTPS: '{0}'")]
31 InsecureTokenUrl(String),
32}
33
34static OAUTH2_CACHE: std::sync::LazyLock<Mutex<HashMap<String, (String, Instant)>>> =
36 std::sync::LazyLock::new(|| Mutex::new(HashMap::new()));
37
38const DEFAULT_TIMEOUT_SECS: u64 = 60;
39
40pub fn validate_url_not_private(url: &str) -> Result<(), HttpError> {
48 let mode = std::env::var("ATI_SSRF_PROTECTION").unwrap_or_default();
49 let enforce = mode == "1" || mode.eq_ignore_ascii_case("true");
50 let warn_only = mode.eq_ignore_ascii_case("warn");
51
52 if !enforce && !warn_only {
53 return Ok(());
54 }
55 let host = url
57 .strip_prefix("http://")
58 .or_else(|| url.strip_prefix("https://"))
59 .unwrap_or(url)
60 .split('/')
61 .next()
62 .unwrap_or("")
63 .split(':')
64 .next()
65 .unwrap_or("");
66
67 if host.is_empty() {
68 return Ok(());
69 }
70
71 let mut is_private = false;
72
73 if let Ok(ip) = host.parse::<std::net::Ipv4Addr>() {
75 if ip.is_loopback() || ip.is_private() || ip.is_link_local() || ip.is_unspecified() || (ip.octets()[0] == 100 && ip.octets()[1] >= 64 && ip.octets()[1] <= 127)
80 {
82 is_private = true;
83 }
84 }
85
86 let host_lower = host.to_lowercase();
88 if host_lower == "localhost"
89 || host_lower == "metadata.google.internal"
90 || host_lower.ends_with(".internal")
91 || host_lower.ends_with(".local")
92 {
93 is_private = true;
94 }
95
96 if is_private {
97 if warn_only {
98 eprintln!("Warning: SSRF protection — URL targets private address: {url}");
99 return Ok(());
100 }
101 return Err(HttpError::SsrfBlocked(url.to_string()));
102 }
103
104 Ok(())
105}
106
107const DENIED_HEADERS: &[&str] = &[
110 "authorization",
111 "host",
112 "cookie",
113 "set-cookie",
114 "content-type",
115 "content-length",
116 "transfer-encoding",
117 "connection",
118 "proxy-authorization",
119 "x-forwarded-for",
120 "x-forwarded-host",
121 "x-forwarded-proto",
122 "x-real-ip",
123];
124
125pub fn validate_headers(
127 headers: &HashMap<String, String>,
128 provider_auth_header: Option<&str>,
129) -> Result<(), HttpError> {
130 for key in headers.keys() {
131 let lower = key.to_lowercase();
132 if DENIED_HEADERS.contains(&lower.as_str()) {
133 return Err(HttpError::DeniedHeader(key.clone()));
134 }
135 if let Some(auth_header) = provider_auth_header {
136 if lower == auth_header.to_lowercase() {
137 return Err(HttpError::DeniedHeader(key.clone()));
138 }
139 }
140 }
141 Ok(())
142}
143
144fn merge_defaults(tool: &Tool, args: &HashMap<String, Value>) -> HashMap<String, Value> {
146 let mut merged = args.clone();
147 if let Some(schema) = &tool.input_schema {
148 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
149 for (key, prop_def) in props {
150 if !merged.contains_key(key) {
151 if let Some(default_val) = prop_def.get("default") {
152 let dominated = match default_val {
155 Value::Array(a) => a.is_empty(),
156 Value::Object(o) => o.is_empty(),
157 Value::Null => true,
158 _ => false,
159 };
160 if !dominated {
161 merged.insert(key.clone(), default_val.clone());
162 }
163 }
164 }
165 }
166 }
167 }
168 merged
169}
170
171#[derive(Debug, Clone, Copy, PartialEq)]
173enum CollectionFormat {
174 Multi,
176 Csv,
178 Ssv,
180 Pipes,
182}
183
184#[derive(Debug, Clone, Copy, PartialEq)]
186enum BodyEncoding {
187 Json,
188 Form,
189}
190
191struct ClassifiedParams {
193 path: HashMap<String, String>,
194 query: HashMap<String, String>,
195 query_arrays: HashMap<String, (Vec<String>, CollectionFormat)>,
196 header: HashMap<String, String>,
197 body: HashMap<String, Value>,
198 body_encoding: BodyEncoding,
199}
200
201fn classify_params(tool: &Tool, args: &HashMap<String, Value>) -> Option<ClassifiedParams> {
204 let schema = tool.input_schema.as_ref()?;
205 let props = schema.get("properties")?.as_object()?;
206
207 let has_locations = props
209 .values()
210 .any(|p| p.get("x-ati-param-location").is_some());
211
212 if !has_locations {
213 return None;
214 }
215
216 let body_encoding = match schema.get("x-ati-body-encoding").and_then(|v| v.as_str()) {
218 Some("form") => BodyEncoding::Form,
219 _ => BodyEncoding::Json,
220 };
221
222 let mut classified = ClassifiedParams {
223 path: HashMap::new(),
224 query: HashMap::new(),
225 query_arrays: HashMap::new(),
226 header: HashMap::new(),
227 body: HashMap::new(),
228 body_encoding,
229 };
230
231 for (key, value) in args {
232 let prop_def = props.get(key);
233 let location = prop_def
234 .and_then(|p| p.get("x-ati-param-location"))
235 .and_then(|l| l.as_str())
236 .unwrap_or("body"); match location {
239 "path" => {
240 classified.path.insert(key.clone(), value_to_string(value));
241 }
242 "query" => {
243 if let Value::Array(arr) = value {
245 let cf_str = prop_def
246 .and_then(|p| p.get("x-ati-collection-format"))
247 .and_then(|v| v.as_str());
248 let cf = match cf_str {
249 Some("multi") => CollectionFormat::Multi,
250 Some("csv") => CollectionFormat::Csv,
251 Some("ssv") => CollectionFormat::Ssv,
252 Some("pipes") => CollectionFormat::Pipes,
253 _ => CollectionFormat::Multi, };
255 let values: Vec<String> = arr.iter().map(|v| value_to_string(v)).collect();
256 classified.query_arrays.insert(key.clone(), (values, cf));
257 } else {
258 classified.query.insert(key.clone(), value_to_string(value));
259 }
260 }
261 "header" => {
262 classified
263 .header
264 .insert(key.clone(), value_to_string(value));
265 }
266 _ => {
267 classified.body.insert(key.clone(), value.clone());
268 }
269 }
270 }
271
272 Some(classified)
273}
274
275fn substitute_path_params(
279 endpoint: &str,
280 path_args: &HashMap<String, String>,
281) -> Result<String, HttpError> {
282 let mut result = endpoint.to_string();
283 for (key, value) in path_args {
284 if value.contains("..")
285 || value.contains('\\')
286 || value.contains('?')
287 || value.contains('#')
288 || value.contains('\0')
289 {
290 return Err(HttpError::InvalidPathParam {
291 key: key.clone(),
292 value: value.clone(),
293 });
294 }
295 let encoded = percent_encode_path_segment(value);
296 result = result.replace(&format!("{{{key}}}"), &encoded);
297 }
298 Ok(result)
299}
300
301fn percent_encode_path_segment(s: &str) -> String {
304 let mut encoded = String::with_capacity(s.len());
305 for byte in s.bytes() {
306 match byte {
307 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' | b'/' => {
308 encoded.push(byte as char);
309 }
310 _ => {
311 encoded.push_str(&format!("%{:02X}", byte));
312 }
313 }
314 }
315 encoded
316}
317
318fn value_to_string(v: &Value) -> String {
320 match v {
321 Value::String(s) => s.clone(),
322 Value::Number(n) => n.to_string(),
323 Value::Bool(b) => b.to_string(),
324 Value::Null => String::new(),
325 other => other.to_string(),
326 }
327}
328
329fn apply_query_arrays(
331 mut req: reqwest::RequestBuilder,
332 arrays: &HashMap<String, (Vec<String>, CollectionFormat)>,
333) -> reqwest::RequestBuilder {
334 for (key, (values, format)) in arrays {
335 match format {
336 CollectionFormat::Multi => {
337 for val in values {
339 req = req.query(&[(key.as_str(), val.as_str())]);
340 }
341 }
342 CollectionFormat::Csv => {
343 let joined = values.join(",");
344 req = req.query(&[(key.as_str(), joined.as_str())]);
345 }
346 CollectionFormat::Ssv => {
347 let joined = values.join(" ");
348 req = req.query(&[(key.as_str(), joined.as_str())]);
349 }
350 CollectionFormat::Pipes => {
351 let joined = values.join("|");
352 req = req.query(&[(key.as_str(), joined.as_str())]);
353 }
354 }
355 }
356 req
357}
358
359pub async fn execute_tool(
368 provider: &Provider,
369 tool: &Tool,
370 args: &HashMap<String, Value>,
371 keyring: &Keyring,
372) -> Result<Value, HttpError> {
373 execute_tool_with_gen(provider, tool, args, keyring, None, None).await
374}
375
376pub async fn execute_tool_with_gen(
378 provider: &Provider,
379 tool: &Tool,
380 args: &HashMap<String, Value>,
381 keyring: &Keyring,
382 gen_ctx: Option<&GenContext>,
383 auth_cache: Option<&AuthCache>,
384) -> Result<Value, HttpError> {
385 validate_url_not_private(&provider.base_url)?;
387
388 let client = Client::builder()
389 .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
390 .build()?;
391
392 let merged_args = merge_defaults(tool, args);
394
395 let mut request = if let Some(classified) = classify_params(tool, &merged_args) {
397 validate_headers(&classified.header, provider.auth_header_name.as_deref())?;
399
400 let resolved_endpoint = substitute_path_params(&tool.endpoint, &classified.path)?;
402 let url = format!(
403 "{}{}",
404 provider.base_url.trim_end_matches('/'),
405 resolved_endpoint
406 );
407
408 let mut req = match tool.method {
409 HttpMethod::Get | HttpMethod::Delete => {
410 let base_req = match tool.method {
411 HttpMethod::Get => client.get(&url),
412 HttpMethod::Delete => client.delete(&url),
413 _ => unreachable!(),
414 };
415 let mut r = base_req;
417 for (k, v) in &classified.query {
418 r = r.query(&[(k.as_str(), v.as_str())]);
419 }
420 r = apply_query_arrays(r, &classified.query_arrays);
421 r
422 }
423 HttpMethod::Post | HttpMethod::Put => {
424 let base_req = match tool.method {
425 HttpMethod::Post => client.post(&url),
426 HttpMethod::Put => client.put(&url),
427 _ => unreachable!(),
428 };
429 let mut r = if classified.body.is_empty() {
431 base_req
432 } else {
433 match classified.body_encoding {
434 BodyEncoding::Json => base_req.json(&classified.body),
435 BodyEncoding::Form => {
436 let pairs: Vec<(String, String)> = classified
437 .body
438 .iter()
439 .map(|(k, v)| (k.clone(), value_to_string(v)))
440 .collect();
441 base_req.form(&pairs)
442 }
443 }
444 };
445 for (k, v) in &classified.query {
447 r = r.query(&[(k.as_str(), v.as_str())]);
448 }
449 r = apply_query_arrays(r, &classified.query_arrays);
450 r
451 }
452 };
453
454 for (k, v) in &classified.header {
456 req = req.header(k.as_str(), v.as_str());
457 }
458
459 req
460 } else {
461 let url = format!(
463 "{}{}",
464 provider.base_url.trim_end_matches('/'),
465 &tool.endpoint
466 );
467
468 match tool.method {
469 HttpMethod::Get => {
470 let mut req = client.get(&url);
471 for (k, v) in &merged_args {
472 req = req.query(&[(k.as_str(), value_to_string(v))]);
473 }
474 req
475 }
476 HttpMethod::Post => client.post(&url).json(&merged_args),
477 HttpMethod::Put => client.put(&url).json(&merged_args),
478 HttpMethod::Delete => client.delete(&url).json(&merged_args),
479 }
480 };
481
482 request = inject_auth(request, provider, keyring, gen_ctx, auth_cache).await?;
484
485 for (header_name, header_value) in &provider.extra_headers {
487 request = request.header(header_name.as_str(), header_value.as_str());
488 }
489
490 let response = request.send().await?;
492 let status = response.status();
493
494 if !status.is_success() {
495 let body = response.text().await.unwrap_or_else(|_| "empty".into());
496 return Err(HttpError::ApiError {
497 status: status.as_u16(),
498 body,
499 });
500 }
501
502 let text = response.text().await?;
504 let value: Value = serde_json::from_str(&text).unwrap_or_else(|_| Value::String(text));
505
506 Ok(value)
507}
508
509async fn inject_auth(
514 request: reqwest::RequestBuilder,
515 provider: &Provider,
516 keyring: &Keyring,
517 gen_ctx: Option<&GenContext>,
518 auth_cache: Option<&AuthCache>,
519) -> Result<reqwest::RequestBuilder, HttpError> {
520 if let Some(gen) = &provider.auth_generator {
522 let default_ctx = GenContext::default();
523 let ctx = gen_ctx.unwrap_or(&default_ctx);
524 let default_cache = AuthCache::new();
525 let cache = auth_cache.unwrap_or(&default_cache);
526
527 let cred = auth_generator::generate(provider, gen, ctx, keyring, cache)
528 .await
529 .map_err(|e| HttpError::MissingKey(format!("auth_generator: {e}")))?;
530
531 let mut req = match provider.auth_type {
533 AuthType::Bearer => request.bearer_auth(&cred.value),
534 AuthType::Header => {
535 let name = provider.auth_header_name.as_deref().unwrap_or("X-Api-Key");
536 let val = match &provider.auth_value_prefix {
537 Some(pfx) => format!("{pfx}{}", cred.value),
538 None => cred.value.clone(),
539 };
540 request.header(name, val)
541 }
542 AuthType::Query => {
543 let name = provider.auth_query_name.as_deref().unwrap_or("api_key");
544 request.query(&[(name, &cred.value)])
545 }
546 _ => request,
547 };
548 for (name, value) in &cred.extra_headers {
550 req = req.header(name.as_str(), value.as_str());
551 }
552 return Ok(req);
553 }
554
555 match provider.auth_type {
556 AuthType::None => Ok(request),
557 AuthType::Bearer => {
558 let key_name = provider
559 .auth_key_name
560 .as_deref()
561 .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
562 let key_value = keyring
563 .get(key_name)
564 .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
565 Ok(request.bearer_auth(key_value))
566 }
567 AuthType::Header => {
568 let key_name = provider
569 .auth_key_name
570 .as_deref()
571 .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
572 let key_value = keyring
573 .get(key_name)
574 .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
575 let header_name = provider.auth_header_name.as_deref().unwrap_or("X-Api-Key");
576 let final_value = match &provider.auth_value_prefix {
577 Some(prefix) => format!("{}{}", prefix, key_value),
578 None => key_value.to_string(),
579 };
580 Ok(request.header(header_name, final_value))
581 }
582 AuthType::Query => {
583 let key_name = provider
584 .auth_key_name
585 .as_deref()
586 .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
587 let key_value = keyring
588 .get(key_name)
589 .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
590 let query_name = provider.auth_query_name.as_deref().unwrap_or("api_key");
591 Ok(request.query(&[(query_name, key_value)]))
592 }
593 AuthType::Basic => {
594 let key_name = provider
595 .auth_key_name
596 .as_deref()
597 .ok_or_else(|| HttpError::MissingKey("auth_key_name not set".into()))?;
598 let key_value = keyring
599 .get(key_name)
600 .ok_or_else(|| HttpError::MissingKey(key_name.into()))?;
601 Ok(request.basic_auth(key_value, None::<&str>))
602 }
603 AuthType::Oauth2 => {
604 let access_token = get_oauth2_token(provider, keyring).await?;
605 Ok(request.bearer_auth(access_token))
606 }
607 }
608}
609
610async fn get_oauth2_token(provider: &Provider, keyring: &Keyring) -> Result<String, HttpError> {
612 let cache_key = provider.name.clone();
613
614 {
616 let cache = OAUTH2_CACHE.lock().unwrap();
617 if let Some((token, expiry)) = cache.get(&cache_key) {
618 if Instant::now() + Duration::from_secs(60) < *expiry {
620 return Ok(token.clone());
621 }
622 }
623 }
624
625 let client_id_key = provider
627 .auth_key_name
628 .as_deref()
629 .ok_or_else(|| HttpError::Oauth2Error("auth_key_name not set for OAuth2".into()))?;
630 let client_id = keyring
631 .get(client_id_key)
632 .ok_or_else(|| HttpError::MissingKey(client_id_key.into()))?;
633
634 let client_secret_key = provider
635 .auth_secret_name
636 .as_deref()
637 .ok_or_else(|| HttpError::Oauth2Error("auth_secret_name not set for OAuth2".into()))?;
638 let client_secret = keyring
639 .get(client_secret_key)
640 .ok_or_else(|| HttpError::MissingKey(client_secret_key.into()))?;
641
642 let token_url = match &provider.oauth2_token_url {
643 Some(url) if url.starts_with("http") => url.clone(),
644 Some(path) => format!("{}{}", provider.base_url.trim_end_matches('/'), path),
645 None => return Err(HttpError::Oauth2Error("oauth2_token_url not set".into())),
646 };
647
648 if token_url.starts_with("http://") {
650 return Err(HttpError::InsecureTokenUrl(token_url));
651 }
652
653 let client = Client::builder().timeout(Duration::from_secs(15)).build()?;
654
655 let response = if provider.oauth2_basic_auth {
659 client
660 .post(&token_url)
661 .basic_auth(client_id, Some(client_secret))
662 .form(&[("grant_type", "client_credentials")])
663 .send()
664 .await?
665 } else {
666 client
667 .post(&token_url)
668 .form(&[
669 ("grant_type", "client_credentials"),
670 ("client_id", client_id),
671 ("client_secret", client_secret),
672 ])
673 .send()
674 .await?
675 };
676
677 if !response.status().is_success() {
678 let status = response.status().as_u16();
679 let body = response.text().await.unwrap_or_default();
680 return Err(HttpError::Oauth2Error(format!(
681 "token exchange failed ({status}): {body}"
682 )));
683 }
684
685 let body: Value = response
686 .json()
687 .await
688 .map_err(|e| HttpError::Oauth2Error(format!("failed to parse token response: {e}")))?;
689
690 let access_token = body
691 .get("access_token")
692 .and_then(|v| v.as_str())
693 .ok_or_else(|| HttpError::Oauth2Error("no access_token in response".into()))?
694 .to_string();
695
696 let expires_in = body
697 .get("expires_in")
698 .and_then(|v| v.as_u64())
699 .unwrap_or(1799);
700
701 let expiry = Instant::now() + Duration::from_secs(expires_in);
702
703 {
705 let mut cache = OAUTH2_CACHE.lock().unwrap();
706 cache.insert(cache_key, (access_token.clone(), expiry));
707 }
708
709 Ok(access_token)
710}
711
712#[cfg(test)]
713mod tests {
714 use super::*;
715
716 #[test]
717 fn test_substitute_path_params_normal() {
718 let mut args = HashMap::new();
719 args.insert("petId".to_string(), "123".to_string());
720 let result = substitute_path_params("/pet/{petId}", &args).unwrap();
721 assert_eq!(result, "/pet/123");
722 }
723
724 #[test]
725 fn test_substitute_path_params_rejects_dotdot() {
726 let mut args = HashMap::new();
727 args.insert("id".to_string(), "../admin".to_string());
728 assert!(substitute_path_params("/resource/{id}", &args).is_err());
729 }
730
731 #[test]
732 fn test_substitute_path_params_allows_slash() {
733 let mut args = HashMap::new();
734 args.insert("id".to_string(), "fal-ai/flux/dev".to_string());
735 let result = substitute_path_params("/resource/{id}", &args).unwrap();
736 assert_eq!(result, "/resource/fal-ai/flux/dev");
737 }
738
739 #[test]
740 fn test_substitute_path_params_rejects_backslash() {
741 let mut args = HashMap::new();
742 args.insert("id".to_string(), "foo\\bar".to_string());
743 assert!(substitute_path_params("/resource/{id}", &args).is_err());
744 }
745
746 #[test]
747 fn test_substitute_path_params_rejects_question() {
748 let mut args = HashMap::new();
749 args.insert("id".to_string(), "foo?bar=1".to_string());
750 assert!(substitute_path_params("/resource/{id}", &args).is_err());
751 }
752
753 #[test]
754 fn test_substitute_path_params_rejects_hash() {
755 let mut args = HashMap::new();
756 args.insert("id".to_string(), "foo#bar".to_string());
757 assert!(substitute_path_params("/resource/{id}", &args).is_err());
758 }
759
760 #[test]
761 fn test_substitute_path_params_rejects_null_byte() {
762 let mut args = HashMap::new();
763 args.insert("id".to_string(), "foo\0bar".to_string());
764 assert!(substitute_path_params("/resource/{id}", &args).is_err());
765 }
766
767 #[test]
768 fn test_substitute_path_params_encodes_special() {
769 let mut args = HashMap::new();
770 args.insert("name".to_string(), "hello world".to_string());
771 let result = substitute_path_params("/users/{name}", &args).unwrap();
772 assert_eq!(result, "/users/hello%20world");
773 }
774
775 #[test]
776 fn test_substitute_path_params_preserves_unreserved() {
777 let mut args = HashMap::new();
778 args.insert("id".to_string(), "abc-123_test.v2~draft".to_string());
779 let result = substitute_path_params("/items/{id}", &args).unwrap();
780 assert_eq!(result, "/items/abc-123_test.v2~draft");
781 }
782
783 #[test]
784 fn test_substitute_path_params_encodes_at_sign() {
785 let mut args = HashMap::new();
786 args.insert("user".to_string(), "user@domain".to_string());
787 let result = substitute_path_params("/profile/{user}", &args).unwrap();
788 assert_eq!(result, "/profile/user%40domain");
789 }
790
791 #[test]
792 fn test_percent_encode_path_segment_empty() {
793 assert_eq!(percent_encode_path_segment(""), "");
794 }
795
796 #[test]
797 fn test_percent_encode_path_segment_ascii_only() {
798 assert_eq!(percent_encode_path_segment("abc123"), "abc123");
799 }
800
801 #[test]
802 fn test_substitute_path_params_multiple() {
803 let mut args = HashMap::new();
804 args.insert("owner".to_string(), "acme".to_string());
805 args.insert("repo".to_string(), "widgets".to_string());
806 let result = substitute_path_params("/repos/{owner}/{repo}/issues", &args).unwrap();
807 assert_eq!(result, "/repos/acme/widgets/issues");
808 }
809
810 #[test]
811 fn test_substitute_path_params_no_placeholders() {
812 let args = HashMap::new();
813 let result = substitute_path_params("/health", &args).unwrap();
814 assert_eq!(result, "/health");
815 }
816}