1use crate::types::*;
2use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::time::Duration;
6
7pub const VERSION: &str = env!("CARGO_PKG_VERSION");
8
9impl RuntimeConfig {
10 pub fn new(download_dir: String) -> Self {
11 let mut headers_for_any_hosts = HashMap::new();
12 headers_for_any_hosts.insert(
13 "User-Agent".to_string(),
14 Value::String(format!("afhttp/{VERSION}")),
15 );
16 RuntimeConfig {
17 response_save_dir: download_dir,
18 response_save_above_bytes: 10_485_760, request_concurrency_limit: 0, timeout_connect_s: 10,
21 pool_idle_timeout_s: 90,
22 retry_base_delay_ms: 100,
23 proxy: None,
24 tls: TlsConfig {
25 insecure: false,
26 cacert_pem: None,
27 cacert_file: None,
28 cert_pem: None,
29 cert_file: None,
30 key_pem_secret: None,
31 key_file: None,
32 },
33 log: vec![],
34 defaults: RequestDefaults {
35 headers_for_any_hosts,
36 timeout_idle_s: 30,
37 retry: 0,
38 response_redirect: 10,
39 response_parse_json: true,
40 response_decompress: true,
41 response_save_resume: false,
42 retry_on_status: vec![],
43 },
44 host_defaults: HashMap::new(),
45 }
46 }
47
48 pub fn apply_update(&mut self, patch: ConfigPatch) -> bool {
50 let mut needs_rebuild = false;
51
52 if let Some(v) = patch.response_save_dir {
53 self.response_save_dir = v;
54 }
55 if let Some(v) = patch.response_save_above_bytes {
56 self.response_save_above_bytes = v;
57 }
58 if let Some(v) = patch.request_concurrency_limit {
59 self.request_concurrency_limit = v;
60 }
61 if let Some(v) = patch.timeout_connect_s {
62 if v != self.timeout_connect_s {
63 needs_rebuild = true;
64 }
65 self.timeout_connect_s = v;
66 }
67 if let Some(v) = patch.pool_idle_timeout_s {
68 if v != self.pool_idle_timeout_s {
69 needs_rebuild = true;
70 }
71 self.pool_idle_timeout_s = v;
72 }
73 if let Some(v) = patch.retry_base_delay_ms {
74 self.retry_base_delay_ms = v;
75 }
76 if let Some(v) = patch.proxy {
77 if Some(&v) != self.proxy.as_ref() {
78 needs_rebuild = true;
79 }
80 self.proxy = Some(v);
81 }
82
83 if let Some(tls_patch) = patch.tls {
84 if let Some(v) = tls_patch.insecure {
85 if v != self.tls.insecure {
86 needs_rebuild = true;
87 }
88 self.tls.insecure = v;
89 }
90 if let Some(v) = tls_patch.cacert_pem {
93 needs_rebuild = true;
94 self.tls.cacert_pem = Some(v);
95 self.tls.cacert_file = None;
96 } else if let Some(v) = tls_patch.cacert_file {
97 needs_rebuild = true;
98 self.tls.cacert_file = Some(v);
99 self.tls.cacert_pem = None;
100 }
101 if let Some(v) = tls_patch.cert_pem {
102 needs_rebuild = true;
103 self.tls.cert_pem = Some(v);
104 self.tls.cert_file = None;
105 } else if let Some(v) = tls_patch.cert_file {
106 needs_rebuild = true;
107 self.tls.cert_file = Some(v);
108 self.tls.cert_pem = None;
109 }
110 if let Some(v) = tls_patch.key_pem_secret {
111 needs_rebuild = true;
112 self.tls.key_pem_secret = Some(v);
113 self.tls.key_file = None;
114 } else if let Some(v) = tls_patch.key_file {
115 needs_rebuild = true;
116 self.tls.key_file = Some(v);
117 self.tls.key_pem_secret = None;
118 }
119 }
120
121 if let Some(v) = patch.log {
122 self.log = v;
123 }
124
125 if let Some(d) = patch.defaults {
126 if let Some(new_headers) = d.headers_for_any_hosts {
128 for (k, v) in new_headers {
129 if v.is_null() {
130 self.defaults.headers_for_any_hosts.remove(&k);
131 } else {
132 self.defaults.headers_for_any_hosts.insert(k, v);
133 }
134 }
135 }
136 if let Some(v) = d.timeout_idle_s {
137 self.defaults.timeout_idle_s = v;
138 }
139 if let Some(v) = d.retry {
140 self.defaults.retry = v;
141 }
142 if let Some(v) = d.response_redirect {
143 self.defaults.response_redirect = v;
144 }
145 if let Some(v) = d.response_parse_json {
146 self.defaults.response_parse_json = v;
147 }
148 if let Some(v) = d.response_decompress {
149 self.defaults.response_decompress = v;
150 }
151 if let Some(v) = d.response_save_resume {
152 self.defaults.response_save_resume = v;
153 }
154 if let Some(v) = d.retry_on_status {
155 self.defaults.retry_on_status = v;
156 }
157 }
158
159 if let Some(hd) = patch.host_defaults {
161 for (host, partial) in hd {
162 let entry = self.host_defaults.entry(host).or_default();
163 if let Some(new_headers) = partial.headers {
164 for (k, v) in new_headers {
165 if v.is_null() {
166 entry.headers.remove(&k);
167 } else {
168 entry.headers.insert(k, v);
169 }
170 }
171 }
172 }
173 }
174
175 needs_rebuild
176 }
177
178 pub fn build_client(&self) -> Result<reqwest::Client, String> {
180 build_client_inner(self, None)
181 }
182
183 pub fn build_client_for_request(
186 &self,
187 tls_override: &TlsConfigPartial,
188 ) -> Result<reqwest::Client, String> {
189 build_client_inner(self, Some(tls_override))
190 }
191
192 pub fn resolve(&self, options: &RequestOptions) -> ResolvedOptions {
194 let chunked_delimiter = if options.chunked {
195 match &options.chunked_delimiter {
196 Value::String(s) => Some(s.clone()),
197 Value::Null => None, _ => Some("\n".to_string()),
199 }
200 } else {
201 None
202 };
203
204 ResolvedOptions {
205 timeout_idle_s: options
206 .timeout_idle_s
207 .unwrap_or(self.defaults.timeout_idle_s),
208 retry: options.retry.unwrap_or(self.defaults.retry),
209 response_redirect: options
210 .response_redirect
211 .unwrap_or(self.defaults.response_redirect),
212 response_parse_json: options
213 .response_parse_json
214 .unwrap_or(self.defaults.response_parse_json),
215 response_decompress: options
216 .response_decompress
217 .unwrap_or(self.defaults.response_decompress),
218 response_save_resume: options
219 .response_save_resume
220 .unwrap_or(self.defaults.response_save_resume),
221 chunked: options.chunked,
222 chunked_delimiter,
223 response_save_file: options.response_save_file.clone(),
224 progress_bytes: options.progress_bytes.unwrap_or(0),
225 progress_ms: options.progress_ms.unwrap_or(10000),
226 response_save_above_bytes: self.response_save_above_bytes,
227 retry_base_delay_ms: self.retry_base_delay_ms,
228 retry_on_status: options
229 .retry_on_status
230 .clone()
231 .unwrap_or_else(|| self.defaults.retry_on_status.clone()),
232 response_max_bytes: options.response_max_bytes,
233 }
234 }
235
236 pub fn merged_headers(
240 &self,
241 request_headers: &HashMap<String, Value>,
242 host: Option<&str>,
243 ) -> Result<HeaderMap, String> {
244 let mut merged: HashMap<String, Value> = self.defaults.headers_for_any_hosts.clone();
245
246 if let Some(host) = host {
248 if let Some(hd) = self.host_defaults.get(host) {
249 for (k, v) in &hd.headers {
250 if v.is_null() {
251 merged.remove(k);
252 } else {
253 merged.insert(k.clone(), v.clone());
254 }
255 }
256 }
257 }
258
259 for (k, v) in request_headers {
261 if v.is_null() {
262 merged.remove(k);
263 } else {
264 merged.insert(k.clone(), v.clone());
265 }
266 }
267
268 let mut header_map = HeaderMap::new();
269 for (k, v) in &merged {
270 let name = HeaderName::from_bytes(k.as_bytes())
271 .map_err(|e| format!("invalid header name '{k}': {e}"))?;
272 let val_str = match v {
273 Value::String(s) => s.clone(),
274 other => other.to_string(),
275 };
276 let value = HeaderValue::from_str(&val_str)
277 .map_err(|e| format!("invalid header value for '{k}': {e}"))?;
278 header_map.insert(name, value);
279 }
280 Ok(header_map)
281 }
282}
283
284fn load_pem(
291 inline: Option<&String>,
292 file_path: Option<&String>,
293) -> Result<Option<Vec<u8>>, String> {
294 if let Some(s) = inline {
295 return Ok(Some(s.as_bytes().to_vec()));
296 }
297 if let Some(path) = file_path {
298 let bytes = std::fs::read(path).map_err(|e| format!("read '{path}': {e}"))?;
299 return Ok(Some(bytes));
300 }
301 Ok(None)
302}
303
304fn build_client_inner(
308 cfg: &RuntimeConfig,
309 tls_override: Option<&TlsConfigPartial>,
310) -> Result<reqwest::Client, String> {
311 let mut builder = reqwest::Client::builder()
312 .connect_timeout(Duration::from_secs(cfg.timeout_connect_s))
313 .pool_idle_timeout(Duration::from_secs(cfg.pool_idle_timeout_s))
314 .pool_max_idle_per_host(10)
315 .redirect(reqwest::redirect::Policy::none());
317
318 let insecure = tls_override
320 .and_then(|o| o.insecure)
321 .unwrap_or(cfg.tls.insecure);
322 if insecure {
323 builder = builder.danger_accept_invalid_certs(true);
324 }
325
326 let ca_pem = if let Some(ov) = tls_override {
329 if ov.cacert_pem.is_some() || ov.cacert_file.is_some() {
330 load_pem(ov.cacert_pem.as_ref(), ov.cacert_file.as_ref())?
331 } else {
332 load_pem(cfg.tls.cacert_pem.as_ref(), cfg.tls.cacert_file.as_ref())?
333 }
334 } else {
335 load_pem(cfg.tls.cacert_pem.as_ref(), cfg.tls.cacert_file.as_ref())?
336 };
337 if let Some(pem) = ca_pem {
338 let cert =
339 reqwest::Certificate::from_pem(&pem).map_err(|e| format!("parse cacert: {e}"))?;
340 builder = builder.add_root_certificate(cert);
341 }
342
343 let cert_pem = if let Some(ov) = tls_override {
345 if ov.cert_pem.is_some() || ov.cert_file.is_some() {
346 load_pem(ov.cert_pem.as_ref(), ov.cert_file.as_ref())?
347 } else {
348 load_pem(cfg.tls.cert_pem.as_ref(), cfg.tls.cert_file.as_ref())?
349 }
350 } else {
351 load_pem(cfg.tls.cert_pem.as_ref(), cfg.tls.cert_file.as_ref())?
352 };
353 let key_pem_secret = if let Some(ov) = tls_override {
354 if ov.key_pem_secret.is_some() || ov.key_file.is_some() {
355 load_pem(ov.key_pem_secret.as_ref(), ov.key_file.as_ref())?
356 } else {
357 load_pem(cfg.tls.key_pem_secret.as_ref(), cfg.tls.key_file.as_ref())?
358 }
359 } else {
360 load_pem(cfg.tls.key_pem_secret.as_ref(), cfg.tls.key_file.as_ref())?
361 };
362
363 if let Some(cert_bytes) = cert_pem {
364 let mut bundle = cert_bytes.clone();
366 bundle.push(b'\n');
367 if let Some(key_bytes) = key_pem_secret {
368 bundle.extend_from_slice(&key_bytes);
369 } else {
370 bundle.extend_from_slice(&cert_bytes);
372 }
373 let identity = reqwest::Identity::from_pem(&bundle)
374 .map_err(|e| format!("parse client identity: {e}"))?;
375 builder = builder.identity(identity);
376 }
377
378 if let Some(ref proxy_url) = cfg.proxy {
380 let proxy = reqwest::Proxy::all(proxy_url).map_err(|e| format!("invalid proxy: {e}"))?;
381 builder = builder.proxy(proxy);
382 }
383
384 builder.build().map_err(|e| format!("build client: {e}"))
385}
386
387pub fn response_headers_to_map(
396 headers: &reqwest::header::HeaderMap,
397) -> Result<HashMap<String, Value>, String> {
398 let mut map: HashMap<String, Vec<String>> = HashMap::new();
399 for (name, value) in headers.iter() {
400 let key = name.as_str().to_string();
401 let val = value
402 .to_str()
403 .map_err(|_| format!("server sent non-ASCII bytes in header '{key}'"))?;
404 map.entry(key).or_default().push(val.to_string());
405 }
406 Ok(map
407 .into_iter()
408 .map(|(k, mut v)| {
409 if v.len() == 1 {
410 (k, Value::String(v.swap_remove(0)))
411 } else {
412 (k, Value::Array(v.into_iter().map(Value::String).collect()))
413 }
414 })
415 .collect())
416}
417
418pub fn parse_content_length(headers: &HashMap<String, Value>) -> Option<u64> {
420 headers
421 .get("content-length")
422 .and_then(|v| v.as_str())
423 .and_then(|s| s.parse::<u64>().ok())
424}
425
426#[cfg(test)]
427#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
428mod tests {
429 use super::*;
430 use reqwest::header::{HeaderValue, CONTENT_LENGTH, SET_COOKIE};
431
432 fn tmp_file_path(name: &str) -> String {
433 let nanos = std::time::SystemTime::now()
434 .duration_since(std::time::UNIX_EPOCH)
435 .map(|d| d.as_nanos())
436 .unwrap_or(0);
437 std::env::temp_dir()
438 .join(format!("afhttp-{name}-{nanos}.tmp"))
439 .to_string_lossy()
440 .into_owned()
441 }
442
443 #[test]
444 fn runtime_config_new_has_defaults() {
445 let cfg = RuntimeConfig::new("/tmp/afhttp-test".to_string());
446 assert_eq!(cfg.response_save_dir, "/tmp/afhttp-test");
447 assert_eq!(
448 cfg.defaults.headers_for_any_hosts.get("User-Agent"),
449 Some(&Value::String(format!("afhttp/{VERSION}")))
450 );
451 assert_eq!(cfg.defaults.timeout_idle_s, 30);
452 assert!(cfg.host_defaults.is_empty());
453 }
454
455 #[test]
456 fn apply_update_merges_and_marks_rebuild() {
457 let mut cfg = RuntimeConfig::new("/tmp/afhttp-test".to_string());
458 let mut defaults_headers = HashMap::new();
459 defaults_headers.insert("X-One".to_string(), Value::String("1".to_string()));
460 defaults_headers.insert("User-Agent".to_string(), Value::Null);
461 let mut host_defaults = HashMap::new();
462 host_defaults.insert(
463 "example.com".to_string(),
464 HostDefaultsPartial {
465 headers: Some(
466 [("X-Host".to_string(), Value::String("yes".to_string()))]
467 .into_iter()
468 .collect(),
469 ),
470 },
471 );
472
473 let patch = ConfigPatch {
474 timeout_connect_s: Some(11),
475 pool_idle_timeout_s: Some(22),
476 proxy: Some("http://127.0.0.1:8080".to_string()),
477 defaults: Some(RequestDefaultsPartial {
478 headers_for_any_hosts: Some(defaults_headers),
479 timeout_idle_s: Some(9),
480 retry_on_status: Some(vec![429, 503]),
481 ..RequestDefaultsPartial::default()
482 }),
483 host_defaults: Some(host_defaults),
484 tls: Some(TlsConfigPartial {
485 insecure: Some(true),
486 cacert_file: Some("/tmp/ca.pem".to_string()),
487 cert_file: Some("/tmp/cert.pem".to_string()),
488 key_file: Some("/tmp/key.pem".to_string()),
489 ..TlsConfigPartial::default()
490 }),
491 ..ConfigPatch::default()
492 };
493 let needs_rebuild = cfg.apply_update(patch);
494 assert!(needs_rebuild);
495 assert_eq!(cfg.timeout_connect_s, 11);
496 assert_eq!(cfg.pool_idle_timeout_s, 22);
497 assert_eq!(cfg.proxy.as_deref(), Some("http://127.0.0.1:8080"));
498 assert_eq!(cfg.defaults.timeout_idle_s, 9);
499 assert_eq!(cfg.defaults.retry_on_status, vec![429, 503]);
500 assert_eq!(
501 cfg.defaults.headers_for_any_hosts.get("X-One"),
502 Some(&Value::String("1".into()))
503 );
504 assert!(!cfg
505 .defaults
506 .headers_for_any_hosts
507 .contains_key("User-Agent"));
508 assert_eq!(
509 cfg.host_defaults
510 .get("example.com")
511 .and_then(|h| h.headers.get("X-Host")),
512 Some(&Value::String("yes".into()))
513 );
514 assert!(cfg.tls.insecure);
515 assert_eq!(cfg.tls.cacert_file.as_deref(), Some("/tmp/ca.pem"));
516 assert_eq!(cfg.tls.cert_file.as_deref(), Some("/tmp/cert.pem"));
517 assert_eq!(cfg.tls.key_file.as_deref(), Some("/tmp/key.pem"));
518 }
519
520 #[test]
521 fn apply_update_inline_tls_clears_file_variants() {
522 let mut cfg = RuntimeConfig::new("/tmp/afhttp-test".to_string());
523 cfg.tls.cacert_file = Some("a".to_string());
524 cfg.tls.cert_file = Some("b".to_string());
525 cfg.tls.key_file = Some("c".to_string());
526
527 let _ = cfg.apply_update(ConfigPatch {
528 tls: Some(TlsConfigPartial {
529 cacert_pem: Some("CA".to_string()),
530 cert_pem: Some("CERT".to_string()),
531 key_pem_secret: Some("KEY".to_string()),
532 ..TlsConfigPartial::default()
533 }),
534 ..ConfigPatch::default()
535 });
536 assert_eq!(cfg.tls.cacert_pem.as_deref(), Some("CA"));
537 assert!(cfg.tls.cacert_file.is_none());
538 assert_eq!(cfg.tls.cert_pem.as_deref(), Some("CERT"));
539 assert!(cfg.tls.cert_file.is_none());
540 assert_eq!(cfg.tls.key_pem_secret.as_deref(), Some("KEY"));
541 assert!(cfg.tls.key_file.is_none());
542 }
543
544 #[test]
545 fn resolve_merges_defaults_and_request_options() {
546 let mut cfg = RuntimeConfig::new("/tmp/afhttp-test".to_string());
547 cfg.defaults.timeout_idle_s = 31;
548 cfg.defaults.retry = 2;
549 cfg.defaults.response_redirect = 7;
550 cfg.defaults.response_parse_json = false;
551 cfg.defaults.response_decompress = false;
552 cfg.defaults.response_save_resume = true;
553 cfg.defaults.retry_on_status = vec![500];
554 cfg.response_save_above_bytes = 123;
555 cfg.retry_base_delay_ms = 456;
556
557 let opts = RequestOptions {
558 chunked: true,
559 chunked_delimiter: Value::Null,
560 progress_bytes: Some(5),
561 progress_ms: Some(6),
562 response_max_bytes: Some(7),
563 ..RequestOptions::default()
564 };
565 let resolved = cfg.resolve(&opts);
566 assert_eq!(resolved.timeout_idle_s, 31);
567 assert_eq!(resolved.retry, 2);
568 assert_eq!(resolved.response_redirect, 7);
569 assert!(!resolved.response_parse_json);
570 assert!(!resolved.response_decompress);
571 assert!(resolved.response_save_resume);
572 assert!(resolved.chunked);
573 assert!(resolved.chunked_delimiter.is_none());
574 assert_eq!(resolved.progress_bytes, 5);
575 assert_eq!(resolved.progress_ms, 6);
576 assert_eq!(resolved.response_save_above_bytes, 123);
577 assert_eq!(resolved.retry_base_delay_ms, 456);
578 assert_eq!(resolved.retry_on_status, vec![500]);
579 assert_eq!(resolved.response_max_bytes, Some(7));
580 }
581
582 #[test]
583 fn merged_headers_applies_layers_and_null_removal() {
584 let mut cfg = RuntimeConfig::new("/tmp/afhttp-test".to_string());
585 cfg.defaults.headers_for_any_hosts.insert(
586 "X-Default".to_string(),
587 Value::String("default".to_string()),
588 );
589 cfg.host_defaults.insert(
590 "api.example.com".to_string(),
591 HostDefaults {
592 headers: [
593 ("X-Host".to_string(), Value::String("host".to_string())),
594 ("X-Default".to_string(), Value::Null),
595 ]
596 .into_iter()
597 .collect(),
598 },
599 );
600 let req_headers: HashMap<String, Value> = [
601 ("X-Req".to_string(), Value::String("req".to_string())),
602 ("X-Host".to_string(), Value::Null),
603 ]
604 .into_iter()
605 .collect();
606 let merged = cfg
607 .merged_headers(&req_headers, Some("api.example.com"))
608 .expect("merged headers");
609 assert_eq!(
610 merged.get("x-req").and_then(|v| v.to_str().ok()),
611 Some("req")
612 );
613 assert!(merged.get("x-host").is_none());
614 assert!(merged.get("x-default").is_none());
615 }
616
617 #[test]
618 fn merged_headers_rejects_invalid_names_or_values() {
619 let cfg = RuntimeConfig::new("/tmp/afhttp-test".to_string());
620 let bad_name: HashMap<String, Value> =
621 [("bad name".to_string(), Value::String("x".into()))]
622 .into_iter()
623 .collect();
624 assert!(cfg.merged_headers(&bad_name, None).is_err());
625
626 let bad_value: HashMap<String, Value> =
627 [("X".to_string(), Value::String("bad\nvalue".into()))]
628 .into_iter()
629 .collect();
630 assert!(cfg.merged_headers(&bad_value, None).is_err());
631 }
632
633 #[test]
634 fn load_pem_prefers_inline_then_file() {
635 let file = tmp_file_path("pem");
636 std::fs::write(&file, b"FILE").expect("write");
637 let inline = "INLINE".to_string();
638 let from_inline = load_pem(Some(&inline), Some(&file)).expect("inline pem");
639 assert_eq!(from_inline, Some(b"INLINE".to_vec()));
640 let from_file = load_pem(None, Some(&file)).expect("file pem");
641 assert_eq!(from_file, Some(b"FILE".to_vec()));
642 let none = load_pem(None, None).expect("none");
643 assert_eq!(none, None);
644 let _ = std::fs::remove_file(file);
645 }
646
647 #[test]
648 fn build_client_basics_and_bad_cert_error() {
649 let mut cfg = RuntimeConfig::new("/tmp/afhttp-test".to_string());
650 assert!(cfg.build_client().is_ok());
651
652 cfg.proxy = Some("not a valid proxy".to_string());
653 let err = cfg
654 .build_client()
655 .expect_err("should fail on invalid proxy");
656 assert!(err.contains("invalid proxy"));
657 }
658
659 #[test]
660 fn response_headers_map_and_content_length() {
661 let mut headers = reqwest::header::HeaderMap::new();
662 headers.insert(CONTENT_LENGTH, HeaderValue::from_static("42"));
663 headers.append(SET_COOKIE, HeaderValue::from_static("a=1"));
664 headers.append(SET_COOKIE, HeaderValue::from_static("b=2"));
665 let map = response_headers_to_map(&headers).expect("headers");
666 assert_eq!(parse_content_length(&map), Some(42));
667 assert_eq!(
668 map.get("set-cookie"),
669 Some(&Value::Array(vec![
670 Value::String("a=1".to_string()),
671 Value::String("b=2".to_string())
672 ]))
673 );
674 }
675
676 #[test]
677 fn response_headers_map_rejects_non_ascii() {
678 let mut headers = reqwest::header::HeaderMap::new();
679 let bad = HeaderValue::from_bytes(&[0xFF]).expect("header bytes");
680 headers.insert("x-bad", bad);
681 assert!(response_headers_to_map(&headers).is_err());
682 }
683}