1use std::time::Duration;
4
5use acdp_primitives::error::AcdpError;
6use acdp_primitives::limits::{
7 CONNECT_TIMEOUT, MAX_CONTEXT_BYTES, MAX_METADATA_BYTES, MAX_REDIRECTS, REQUEST_TIMEOUT,
8};
9use acdp_safe_http::SsrfPolicy;
10use acdp_types::{
11 body::FullContext,
12 capabilities::CapabilitiesDocument,
13 primitives::{CtxId, LineageId},
14 publish::{PublishRequest, PublishResponse, WireError},
15 search::{SearchParams, SearchResponse},
16};
17use chrono::{DateTime, Utc};
18use reqwest::{redirect, Client};
19
20#[derive(Clone)]
27pub struct RegistryClient {
28 base: String,
29 http: Client,
30}
31
32#[derive(Debug, Clone, Default)]
38pub struct RetrievalMetadata {
39 pub etag: Option<String>,
41 pub cache_control: Option<String>,
43 pub last_modified: Option<DateTime<Utc>>,
45}
46
47impl RegistryClient {
48 pub fn authority(&self) -> Option<String> {
52 url::Url::parse(&self.base).ok().and_then(|u| {
53 let host = u.host_str()?.to_string();
54 Some(match u.port() {
55 Some(p) => format!("{host}:{p}"),
56 None => host,
57 })
58 })
59 }
60
61 pub fn new(base_url: &str) -> Result<Self, AcdpError> {
67 Self::build(base_url, None, None, SsrfPolicy::default())
68 }
69
70 pub fn with_root_cert_pem(base_url: &str, pem: &[u8]) -> Result<Self, AcdpError> {
78 Self::build(base_url, Some(pem), None, SsrfPolicy::allow_test_loopback())
82 }
83
84 #[doc(hidden)]
93 pub fn with_test_transport(base_url: &str) -> Result<Self, AcdpError> {
94 let policy = SsrfPolicy {
95 reject_ip_literals: false,
96 allow_http: true,
97 allow_loopback_resolved: true,
98 };
99 Self::build(base_url, None, None, policy)
100 }
101
102 #[doc(hidden)]
114 pub fn with_test_endpoint(
115 base_url: &str,
116 target: std::net::SocketAddr,
117 pem: &[u8],
118 ) -> Result<Self, AcdpError> {
119 Self::build(
122 base_url,
123 Some(pem),
124 Some(target),
125 SsrfPolicy::allow_test_loopback(),
126 )
127 }
128
129 fn build(
130 base_url: &str,
131 extra_root_pem: Option<&[u8]>,
132 resolve_target: Option<std::net::SocketAddr>,
133 policy_ssrf: SsrfPolicy,
134 ) -> Result<Self, AcdpError> {
135 let base = base_url.trim_end_matches('/').to_string();
136 policy_ssrf.check_url(&base)?;
143 let original_authority = url::Url::parse(&base)
144 .ok()
145 .and_then(|u| u.host_str().map(str::to_string));
146
147 let policy = redirect::Policy::custom(move |attempt| {
148 if attempt.previous().len() >= MAX_REDIRECTS {
149 return attempt.error(format!(
150 "exceeded {MAX_REDIRECTS} redirects per RFC-ACDP-0006 §7.5"
151 ));
152 }
153 let cross = attempt
156 .previous()
157 .first()
158 .filter(|orig| !acdp_safe_http::same_fetch_authority(orig, attempt.url()))
159 .map(|orig| (orig.to_string(), attempt.url().to_string()));
160 if let Some((from, to)) = cross {
161 return attempt.error(format!(
162 "cross-authority redirect rejected ({from} -> {to})"
163 ));
164 }
165 attempt.follow()
166 });
167
168 let mut builder = Client::builder()
169 .use_rustls_tls()
170 .connect_timeout(CONNECT_TIMEOUT)
171 .timeout(REQUEST_TIMEOUT)
172 .redirect(policy)
173 .dns_resolver(acdp_safe_http::SafeDnsResolver::arc(policy_ssrf));
177
178 if let Some(pem) = extra_root_pem {
179 let cert = reqwest::Certificate::from_pem(pem)
180 .map_err(|e| AcdpError::Http(format!("invalid root cert PEM: {e}")))?;
181 builder = builder.add_root_certificate(cert);
182 }
183
184 if let (Some(target), Some(host)) = (resolve_target, original_authority) {
185 builder = builder.resolve(&host, target);
186 }
187
188 let http = builder
189 .build()
190 .map_err(|e| AcdpError::Http(e.to_string()))?;
191
192 Ok(Self { base, http })
193 }
194
195 pub async fn new_pinned(base_url: &str, policy: &SsrfPolicy) -> Result<Self, AcdpError> {
209 let base = base_url.trim_end_matches('/').to_string();
210 let parsed = url::Url::parse(&base)
211 .map_err(|e| AcdpError::SchemaViolation(format!("invalid base URL: {e}")))?;
212 policy.check_url(&base)?;
214 let host = parsed
215 .host_str()
216 .ok_or_else(|| AcdpError::SchemaViolation(format!("base URL has no host: {base}")))?
217 .to_string();
218 let port = parsed
219 .port_or_known_default()
220 .unwrap_or(if parsed.scheme() == "http" { 80 } else { 443 });
221
222 let pinned = policy.pin_resolved_ip(&host, port).await?;
223
224 let policy_redirect = redirect::Policy::custom(move |attempt| {
225 if attempt.previous().len() >= MAX_REDIRECTS {
226 return attempt.error(format!(
227 "exceeded {MAX_REDIRECTS} redirects per RFC-ACDP-0006 §7.5"
228 ));
229 }
230 let cross = attempt
233 .previous()
234 .first()
235 .filter(|orig| !acdp_safe_http::same_fetch_authority(orig, attempt.url()))
236 .map(|orig| (orig.to_string(), attempt.url().to_string()));
237 if let Some((from, to)) = cross {
238 return attempt.error(format!(
239 "cross-authority redirect rejected ({from} -> {to})"
240 ));
241 }
242 attempt.follow()
243 });
244
245 let http = Client::builder()
246 .use_rustls_tls()
247 .connect_timeout(CONNECT_TIMEOUT)
248 .timeout(REQUEST_TIMEOUT)
249 .redirect(policy_redirect)
250 .resolve(&host, pinned)
251 .build()
252 .map_err(|e| AcdpError::Http(e.to_string()))?;
253
254 Ok(Self { base, http })
255 }
256
257 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
265 pub async fn capabilities(&self) -> Result<CapabilitiesDocument, AcdpError> {
266 Ok(self.capabilities_with_ttl().await?.0)
267 }
268
269 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
278 pub async fn capabilities_with_ttl(
279 &self,
280 ) -> Result<(CapabilitiesDocument, std::time::Duration), AcdpError> {
281 let url = format!("{}/.well-known/acdp.json", self.base);
282 let resp = self.http.get(&url).send().await?;
283 let ttl = cache_ttl_from_response(&resp);
284 let caps: CapabilitiesDocument = self.parse_success(resp, MAX_METADATA_BYTES).await?;
285 acdp_validation::validate_capabilities(&caps)?;
286 Ok((caps, ttl))
287 }
288
289 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, req)))]
293 pub async fn publish(&self, req: &PublishRequest) -> Result<PublishResponse, AcdpError> {
294 let url = format!("{}/contexts", self.base);
295 let resp = self
296 .http
297 .post(&url)
298 .header("Content-Type", "application/acdp+json")
299 .json(req)
300 .send()
301 .await?;
302 self.parse_success(resp, MAX_METADATA_BYTES).await
303 }
304
305 pub async fn publish_idempotent(
307 &self,
308 req: &PublishRequest,
309 idempotency_key: &str,
310 ) -> Result<PublishResponse, AcdpError> {
311 let url = format!("{}/contexts", self.base);
312 let resp = self
313 .http
314 .post(&url)
315 .header("Content-Type", "application/acdp+json")
316 .header("Idempotency-Key", idempotency_key)
317 .json(req)
318 .send()
319 .await?;
320 self.parse_success(resp, MAX_METADATA_BYTES).await
321 }
322
323 pub async fn publish_with_retry(
330 &self,
331 req: &PublishRequest,
332 idempotency_key: &str,
333 max_attempts: u32,
334 ) -> Result<PublishResponse, AcdpError> {
335 let attempts = max_attempts.max(1);
336 let mut last_err: Option<AcdpError> = None;
337 for attempt in 0..attempts {
338 match self.publish_idempotent(req, idempotency_key).await {
339 Ok(resp) => return Ok(resp),
340 Err(e) if e.is_transient() && attempt + 1 < attempts => {
341 let backoff_ms = 250u64 * (1 << attempt.min(3));
342 last_err = Some(e);
343 #[cfg(feature = "tracing")]
344 tracing::debug!(
345 attempt = attempt + 1,
346 backoff_ms,
347 "publish transient failure; retrying"
348 );
349 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
350 }
351 Err(e) => return Err(e),
352 }
353 }
354 Err(last_err
355 .unwrap_or_else(|| AcdpError::Http("publish_with_retry exhausted attempts".into())))
356 }
357
358 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), fields(ctx_id = %ctx_id)))]
364 pub async fn retrieve(&self, ctx_id: &CtxId) -> Result<FullContext, AcdpError> {
365 let encoded = urlencoding::encode(ctx_id.as_str());
366 let url = format!("{}/contexts/{}", self.base, encoded);
367 let resp = self.http.get(&url).send().await?;
368 self.parse_success(resp, MAX_CONTEXT_BYTES).await
369 }
370
371 pub async fn retrieve_with_metadata(
373 &self,
374 ctx_id: &CtxId,
375 ) -> Result<(FullContext, RetrievalMetadata), AcdpError> {
376 let encoded = urlencoding::encode(ctx_id.as_str());
377 let url = format!("{}/contexts/{}", self.base, encoded);
378 let resp = self.http.get(&url).send().await?;
379 let metadata = parse_retrieval_metadata(&resp);
380 let body = self.parse_success(resp, MAX_CONTEXT_BYTES).await?;
381 Ok((body, metadata))
382 }
383
384 pub async fn retrieve_if_none_match(
389 &self,
390 ctx_id: &CtxId,
391 etag: &str,
392 ) -> Result<Option<(FullContext, RetrievalMetadata)>, AcdpError> {
393 let encoded = urlencoding::encode(ctx_id.as_str());
394 let url = format!("{}/contexts/{}", self.base, encoded);
395 let resp = self
396 .http
397 .get(&url)
398 .header("If-None-Match", etag)
399 .send()
400 .await?;
401 if resp.status() == reqwest::StatusCode::NOT_MODIFIED {
402 return Ok(None);
403 }
404 let metadata = parse_retrieval_metadata(&resp);
405 let body = self.parse_success(resp, MAX_CONTEXT_BYTES).await?;
406 Ok(Some((body, metadata)))
407 }
408
409 pub async fn retrieve_body(&self, ctx_id: &CtxId) -> Result<acdp_types::body::Body, AcdpError> {
411 let encoded = urlencoding::encode(ctx_id.as_str());
412 let url = format!("{}/contexts/{}/body", self.base, encoded);
413 let resp = self.http.get(&url).send().await?;
414 self.parse_success(resp, MAX_CONTEXT_BYTES).await
415 }
416
417 pub async fn lineage(&self, lineage_id: &LineageId) -> Result<Vec<FullContext>, AcdpError> {
421 let encoded = urlencoding::encode(lineage_id.as_str());
422 let url = format!("{}/lineages/{}", self.base, encoded);
423 let resp = self.http.get(&url).send().await?;
424 self.parse_success::<serde_json::Value>(resp, MAX_CONTEXT_BYTES)
425 .await
426 .and_then(|v| {
427 serde_json::from_value(v).map_err(|e| AcdpError::Serialization(e.to_string()))
428 })
429 }
430
431 pub async fn current(&self, lineage_id: &LineageId) -> Result<FullContext, AcdpError> {
433 let encoded = urlencoding::encode(lineage_id.as_str());
434 let url = format!("{}/lineages/{}/current", self.base, encoded);
435 let resp = self.http.get(&url).send().await?;
436 self.parse_success(resp, MAX_CONTEXT_BYTES).await
437 }
438
439 pub async fn search(&self, params: &SearchParams) -> Result<SearchResponse, AcdpError> {
446 let url = format!("{}/contexts/search", self.base);
447 let resp = self.http.get(&url).query(params).send().await?;
448 self.parse_success(resp, MAX_METADATA_BYTES).await
449 }
450
451 pub fn search_builder(&self) -> RegistrySearch<'_> {
467 RegistrySearch::new(self)
468 }
469}
470
471pub struct RegistrySearch<'a> {
474 client: &'a RegistryClient,
475 inner: acdp_types::search::SearchParamsBuilder,
476}
477
478impl<'a> RegistrySearch<'a> {
479 fn new(client: &'a RegistryClient) -> Self {
480 Self {
481 client,
482 inner: acdp_types::search::SearchParamsBuilder::new(),
483 }
484 }
485
486 pub async fn send(self) -> Result<SearchResponse, AcdpError> {
488 let params = self.inner.build();
489 self.client.search(¶ms).await
490 }
491 pub fn q(mut self, q: impl Into<String>) -> Self {
493 self.inner = self.inner.q(q);
494 self
495 }
496 pub fn context_type(mut self, t: impl Into<String>) -> Self {
498 self.inner = self.inner.context_type(t);
499 self
500 }
501 pub fn domain(mut self, d: impl Into<String>) -> Self {
503 self.inner = self.inner.domain(d);
504 self
505 }
506 pub fn tag(mut self, t: impl Into<String>) -> Self {
508 self.inner = self.inner.tag(t);
509 self
510 }
511 pub fn agent_id(mut self, a: impl Into<String>) -> Self {
513 self.inner = self.inner.agent_id(a);
514 self
515 }
516 pub fn derived_from(mut self, c: &acdp_types::CtxId) -> Self {
518 self.inner = self.inner.derived_from_ctx_id(c);
519 self
520 }
521 pub fn created_after(mut self, dt: chrono::DateTime<chrono::Utc>) -> Self {
523 self.inner = self.inner.created_after(dt);
524 self
525 }
526 pub fn created_before(mut self, dt: chrono::DateTime<chrono::Utc>) -> Self {
528 self.inner = self.inner.created_before(dt);
529 self
530 }
531 pub fn status(mut self, s: impl Into<String>) -> Self {
533 self.inner = self.inner.status(s);
534 self
535 }
536 pub fn limit(mut self, l: u32) -> Self {
538 self.inner = self.inner.limit(l);
539 self
540 }
541 pub fn cursor(mut self, c: impl Into<String>) -> Self {
543 self.inner = self.inner.cursor(c);
544 self
545 }
546}
547
548impl RegistryClient {
551 async fn parse_success<T: serde::de::DeserializeOwned>(
552 &self,
553 resp: reqwest::Response,
554 max_bytes: usize,
555 ) -> Result<T, AcdpError> {
556 if resp.status().is_success() {
557 let bytes = read_body_capped(resp, max_bytes).await?;
558 serde_json::from_slice(&bytes).map_err(|e| AcdpError::Serialization(e.to_string()))
559 } else {
560 let bytes = match read_body_capped(resp, MAX_METADATA_BYTES).await {
563 Ok(b) => b,
564 Err(_) => {
565 return Err(AcdpError::from_wire_error(WireError {
566 error: acdp_types::publish::WireErrorBody {
567 code: "unknown".into(),
568 message: "could not read registry error response".into(),
569 details: None,
570 },
571 }));
572 }
573 };
574 let wire: WireError = serde_json::from_slice(&bytes).unwrap_or_else(|_| WireError {
575 error: acdp_types::publish::WireErrorBody {
576 code: "unknown".into(),
577 message: "could not parse registry error response".into(),
578 details: None,
579 },
580 });
581 Err(AcdpError::from_wire_error(wire))
582 }
583 }
584}
585
586fn cache_ttl_from_response(resp: &reqwest::Response) -> std::time::Duration {
594 const MAX_CAPS_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(3600);
595 const DEFAULT_CAPS_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(300);
596
597 let Some(cc) = resp
598 .headers()
599 .get(reqwest::header::CACHE_CONTROL)
600 .and_then(|v| v.to_str().ok())
601 else {
602 return DEFAULT_CAPS_CACHE_TTL;
603 };
604 for directive in cc.split(',') {
605 let directive = directive.trim();
606 if let Some(value) = directive
607 .strip_prefix("max-age=")
608 .or_else(|| directive.strip_prefix("s-maxage="))
609 {
610 if let Ok(secs) = value.parse::<u64>() {
611 return std::time::Duration::from_secs(secs).min(MAX_CAPS_CACHE_TTL);
612 }
613 }
614 }
615 DEFAULT_CAPS_CACHE_TTL
616}
617
618fn parse_retrieval_metadata(resp: &reqwest::Response) -> RetrievalMetadata {
619 let headers = resp.headers();
620 let etag = headers
621 .get(reqwest::header::ETAG)
622 .and_then(|v| v.to_str().ok())
623 .map(|s| s.to_string());
624 let cache_control = headers
625 .get(reqwest::header::CACHE_CONTROL)
626 .and_then(|v| v.to_str().ok())
627 .map(|s| s.to_string());
628 let last_modified = headers
629 .get(reqwest::header::LAST_MODIFIED)
630 .and_then(|v| v.to_str().ok())
631 .and_then(|s| {
632 DateTime::parse_from_rfc2822(s)
633 .ok()
634 .map(|dt| dt.with_timezone(&Utc))
635 });
636 RetrievalMetadata {
637 etag,
638 cache_control,
639 last_modified,
640 }
641}
642
643async fn read_body_capped(
646 mut resp: reqwest::Response,
647 max_bytes: usize,
648) -> Result<Vec<u8>, AcdpError> {
649 if let Some(len) = resp.content_length() {
650 if len as usize > max_bytes {
651 return Err(AcdpError::PayloadTooLarge(format!(
652 "response Content-Length {len} exceeds cap {max_bytes}"
653 )));
654 }
655 }
656 let mut buf = Vec::with_capacity(8 * 1024);
657 while let Some(chunk) = resp
658 .chunk()
659 .await
660 .map_err(|e| AcdpError::Http(e.to_string()))?
661 {
662 if buf.len() + chunk.len() > max_bytes {
663 return Err(AcdpError::PayloadTooLarge(format!(
664 "response body exceeded {max_bytes} bytes"
665 )));
666 }
667 buf.extend_from_slice(&chunk);
668 }
669 Ok(buf)
670}