1use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7use url::Url;
8use uuid::Uuid;
9
10use crate::error::TlogError;
14use ans_types::Badge;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
25#[serde(rename_all = "camelCase")]
26#[non_exhaustive]
27pub struct AuditResponse {
28 pub records: Vec<Badge>,
31}
32
33#[async_trait]
35pub trait TransparencyLogClient: Send + Sync {
36 async fn fetch_badge(&self, url: &str) -> Result<Badge, TlogError>;
38
39 async fn fetch_badge_by_id(&self, agent_id: Uuid) -> Result<Badge, TlogError>;
41
42 async fn fetch_audit(
47 &self,
48 agent_id: Uuid,
49 limit: Option<u32>,
50 offset: Option<u32>,
51 ) -> Result<AuditResponse, TlogError>;
52}
53
54#[derive(Debug)]
56pub struct HttpTransparencyLogClient {
57 client: Client,
58 base_url: Option<Url>,
59 timeout: Duration,
60 extra_headers: Vec<(String, String)>,
61}
62
63impl HttpTransparencyLogClient {
64 pub fn new() -> Self {
66 Self {
67 client: Client::new(),
68 base_url: None,
69 timeout: Duration::from_secs(30),
70 extra_headers: Vec::new(),
71 }
72 }
73
74 pub fn with_base_url(base_url: impl AsRef<str>) -> Result<Self, TlogError> {
80 let parsed =
81 Url::parse(base_url.as_ref()).map_err(|e| TlogError::InvalidUrl(e.to_string()))?;
82 Ok(Self {
83 client: Client::new(),
84 base_url: Some(parsed),
85 timeout: Duration::from_secs(30),
86 extra_headers: Vec::new(),
87 })
88 }
89
90 pub fn with_timeout(mut self, timeout: Duration) -> Self {
92 self.timeout = timeout;
93 self
94 }
95
96 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
100 self.extra_headers.push((name.into(), value.into()));
101 self
102 }
103
104 pub fn with_headers(
108 mut self,
109 headers: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
110 ) -> Self {
111 self.extra_headers
112 .extend(headers.into_iter().map(|(n, v)| (n.into(), v.into())));
113 self
114 }
115
116 fn build_headers(&self) -> Result<reqwest::header::HeaderMap, TlogError> {
118 let mut map = reqwest::header::HeaderMap::new();
119 for (name, value) in &self.extra_headers {
120 let header_name =
121 reqwest::header::HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
122 TlogError::InvalidHeader(format!("invalid header name '{name}': {e}"))
123 })?;
124 let header_value = reqwest::header::HeaderValue::from_str(value).map_err(|e| {
125 TlogError::InvalidHeader(format!("invalid header value for '{name}': {e}"))
126 })?;
127 map.insert(header_name, header_value);
128 }
129 Ok(map)
130 }
131}
132
133impl Default for HttpTransparencyLogClient {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139#[async_trait]
140impl TransparencyLogClient for HttpTransparencyLogClient {
141 async fn fetch_badge(&self, url: &str) -> Result<Badge, TlogError> {
142 tracing::debug!(url = %url, "Fetching badge from transparency log");
143
144 let headers = self.build_headers()?;
145 let mut req = self.client.get(url).header(
146 reqwest::header::USER_AGENT,
147 format!("ans-verify/{}", env!("CARGO_PKG_VERSION")),
148 );
149 for (name, value) in &headers {
150 req = req.header(name, value);
151 }
152 let response = req
153 .timeout(self.timeout)
154 .send()
155 .await
156 .map_err(crate::error::HttpError::from)?;
157
158 let status = response.status();
159
160 if status == reqwest::StatusCode::NOT_FOUND {
161 return Err(TlogError::NotFound {
162 url: url.to_string(),
163 });
164 }
165
166 if status.is_server_error() {
167 return Err(TlogError::ServiceUnavailable);
168 }
169
170 if !status.is_success() {
171 return Err(TlogError::InvalidResponse(format!(
172 "Unexpected status code: {status}"
173 )));
174 }
175
176 let badge: Badge = response
177 .json()
178 .await
179 .map_err(|e| TlogError::InvalidResponse(format!("Failed to parse badge JSON: {e}")))?;
180
181 tracing::debug!(
182 agent_id = %badge.agent_id(),
183 status = ?badge.status,
184 "Successfully fetched badge"
185 );
186
187 Ok(badge)
188 }
189
190 async fn fetch_badge_by_id(&self, agent_id: Uuid) -> Result<Badge, TlogError> {
191 let base_url = self.base_url.as_ref().ok_or_else(|| {
192 TlogError::InvalidUrl("No base URL configured for agent ID lookups".to_string())
193 })?;
194
195 let url = base_url
196 .join(&format!("v1/agents/{agent_id}"))
197 .map_err(|e| TlogError::InvalidUrl(e.to_string()))?;
198
199 self.fetch_badge(url.as_str()).await
200 }
201
202 async fn fetch_audit(
203 &self,
204 agent_id: Uuid,
205 limit: Option<u32>,
206 offset: Option<u32>,
207 ) -> Result<AuditResponse, TlogError> {
208 let base_url = self.base_url.as_ref().ok_or_else(|| {
209 TlogError::InvalidUrl("No base URL configured for audit lookups".to_string())
210 })?;
211
212 let mut url = base_url
213 .join(&format!("v1/agents/{agent_id}/audit"))
214 .map_err(|e| TlogError::InvalidUrl(e.to_string()))?;
215
216 {
218 let mut query = url.query_pairs_mut();
219 if let Some(l) = limit {
220 query.append_pair("limit", &l.to_string());
221 }
222 if let Some(o) = offset {
223 query.append_pair("offset", &o.to_string());
224 }
225 }
226
227 tracing::debug!(url = %url, "Fetching audit trail from transparency log");
228
229 let headers = self.build_headers()?;
230 let mut req = self.client.get(url.as_str()).header(
231 reqwest::header::USER_AGENT,
232 format!("ans-verify/{}", env!("CARGO_PKG_VERSION")),
233 );
234 for (name, value) in &headers {
235 req = req.header(name, value);
236 }
237 let response = req
238 .timeout(self.timeout)
239 .send()
240 .await
241 .map_err(crate::error::HttpError::from)?;
242
243 let status = response.status();
244
245 if status == reqwest::StatusCode::NOT_FOUND {
246 return Err(TlogError::NotFound {
247 url: url.to_string(),
248 });
249 }
250
251 if status.is_server_error() {
252 return Err(TlogError::ServiceUnavailable);
253 }
254
255 if !status.is_success() {
256 return Err(TlogError::InvalidResponse(format!(
257 "Unexpected status code: {status}"
258 )));
259 }
260
261 let audit: AuditResponse = response.json().await.map_err(|e| {
262 TlogError::InvalidResponse(format!("Failed to parse audit response JSON: {e}"))
263 })?;
264
265 tracing::debug!(
266 agent_id = %agent_id,
267 record_count = audit.records.len(),
268 "Successfully fetched audit trail"
269 );
270
271 Ok(audit)
272 }
273}
274
275#[cfg(any(test, feature = "test-support"))]
277#[derive(Debug, Default)]
278pub struct MockTransparencyLogClient {
279 badges: std::collections::HashMap<String, Badge>,
280 errors: std::collections::HashMap<String, TlogError>,
281}
282
283#[cfg(any(test, feature = "test-support"))]
284impl MockTransparencyLogClient {
285 pub fn new() -> Self {
287 Self::default()
288 }
289
290 pub fn with_badge(mut self, url: &str, badge: Badge) -> Self {
292 self.badges.insert(url.to_string(), badge);
293 self
294 }
295
296 pub fn with_error(mut self, url: &str, error: TlogError) -> Self {
298 self.errors.insert(url.to_string(), error);
299 self
300 }
301}
302
303#[cfg(any(test, feature = "test-support"))]
304#[async_trait]
305impl TransparencyLogClient for MockTransparencyLogClient {
306 async fn fetch_badge(&self, url: &str) -> Result<Badge, TlogError> {
307 let url_str = url.to_string();
308
309 if let Some(error) = self.errors.get(&url_str) {
311 return Err(match error {
312 TlogError::NotFound { url } => TlogError::NotFound { url: url.clone() },
313 TlogError::ServiceUnavailable => TlogError::ServiceUnavailable,
314 TlogError::InvalidResponse(msg) => TlogError::InvalidResponse(msg.clone()),
315 TlogError::InvalidUrl(msg) => TlogError::InvalidUrl(msg.clone()),
316 TlogError::HttpError(e) => {
317 TlogError::InvalidResponse(format!("HTTP error: {e}"))
320 }
321 TlogError::InvalidHeader(msg) => TlogError::InvalidHeader(msg.clone()),
322 TlogError::UntrustedDomain { domain, trusted } => TlogError::UntrustedDomain {
323 domain: domain.clone(),
324 trusted: trusted.clone(),
325 },
326 });
327 }
328
329 self.badges
331 .get(&url_str)
332 .cloned()
333 .ok_or_else(|| TlogError::NotFound { url: url_str })
334 }
335
336 async fn fetch_badge_by_id(&self, _agent_id: Uuid) -> Result<Badge, TlogError> {
337 Err(TlogError::InvalidUrl(
338 "Mock client does not support fetch_badge_by_id".to_string(),
339 ))
340 }
341
342 async fn fetch_audit(
343 &self,
344 _agent_id: Uuid,
345 _limit: Option<u32>,
346 _offset: Option<u32>,
347 ) -> Result<AuditResponse, TlogError> {
348 Err(TlogError::InvalidUrl(
349 "Mock client does not support fetch_audit".to_string(),
350 ))
351 }
352}
353
354#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use ans_types::*;
359 use chrono::Utc;
360
361 fn create_test_badge() -> Badge {
362 serde_json::from_value(serde_json::json!({
363 "status": "ACTIVE",
364 "schemaVersion": "V1",
365 "payload": {
366 "logId": Uuid::new_v4().to_string(),
367 "producer": {
368 "event": {
369 "ansId": Uuid::new_v4().to_string(),
370 "ansName": "ans://v1.0.0.test.example.com",
371 "eventType": "AGENT_REGISTERED",
372 "agent": { "host": "test.example.com", "name": "Test Agent", "version": "v1.0.0" },
373 "attestations": {
374 "domainValidation": "ACME-DNS-01",
375 "identityCert": { "fingerprint": "SHA256:aebdc9da0c20d6d5e4999a773839095ed050a9d7252bf212056fddc0c38f3496", "type": "X509-OV-CLIENT" },
376 "serverCert": { "fingerprint": "SHA256:e7b64d16f42055d6faf382a43dc35b98be76aba0db145a904b590a034b33b904", "type": "X509-DV-SERVER" }
377 },
378 "expiresAt": (Utc::now() + chrono::Duration::days(365)).to_rfc3339(),
379 "issuedAt": Utc::now().to_rfc3339(),
380 "raId": "test-ra",
381 "timestamp": Utc::now().to_rfc3339()
382 },
383 "keyId": "test-key",
384 "signature": "test-sig"
385 }
386 }
387 })).expect("test badge JSON should be valid")
388 }
389
390 #[tokio::test]
391 async fn test_mock_client_fetch_badge() {
392 let badge = create_test_badge();
393 let url = "https://example.com/v1/agents/test-id";
394
395 let client = MockTransparencyLogClient::new().with_badge(url, badge.clone());
396
397 let result = client.fetch_badge(url).await.unwrap();
398
399 assert_eq!(result.status, BadgeStatus::Active);
400 assert_eq!(result.agent_host(), "test.example.com");
401 }
402
403 #[tokio::test]
404 async fn test_mock_client_not_found() {
405 let client = MockTransparencyLogClient::new();
406
407 let result = client.fetch_badge("https://example.com/not-found").await;
408
409 assert!(matches!(result, Err(TlogError::NotFound { .. })));
410 }
411
412 #[tokio::test]
413 async fn test_mock_client_error() {
414 let client = MockTransparencyLogClient::new()
415 .with_error("https://example.com/error", TlogError::ServiceUnavailable);
416
417 let result = client.fetch_badge("https://example.com/error").await;
418
419 assert!(matches!(result, Err(TlogError::ServiceUnavailable)));
420 }
421
422 #[tokio::test]
423 async fn test_mock_client_error_not_found() {
424 let client = MockTransparencyLogClient::new().with_error(
425 "https://example.com/error",
426 TlogError::NotFound {
427 url: "https://example.com/error".to_string(),
428 },
429 );
430
431 let result = client.fetch_badge("https://example.com/error").await;
432
433 assert!(matches!(result, Err(TlogError::NotFound { .. })));
434 }
435
436 #[tokio::test]
437 async fn test_mock_client_error_invalid_response() {
438 let client = MockTransparencyLogClient::new().with_error(
439 "https://example.com/error",
440 TlogError::InvalidResponse("Bad JSON".to_string()),
441 );
442
443 let result = client.fetch_badge("https://example.com/error").await;
444
445 assert!(matches!(result, Err(TlogError::InvalidResponse(_))));
446 }
447
448 #[tokio::test]
449 async fn test_mock_client_fetch_badge_by_id_not_supported() {
450 let client = MockTransparencyLogClient::new();
451
452 let result = client.fetch_badge_by_id(Uuid::new_v4()).await;
453
454 assert!(matches!(result, Err(TlogError::InvalidUrl(_))));
455 }
456
457 #[tokio::test]
458 async fn test_mock_client_fetch_audit_not_supported() {
459 let client = MockTransparencyLogClient::new();
460
461 let result = client.fetch_audit(Uuid::new_v4(), None, None).await;
462
463 assert!(matches!(result, Err(TlogError::InvalidUrl(_))));
464 }
465
466 #[test]
467 fn test_http_client_new() {
468 let client = HttpTransparencyLogClient::new();
469 assert!(client.base_url.is_none());
470 }
471
472 #[test]
473 fn test_http_client_default() {
474 let client: HttpTransparencyLogClient = Default::default();
475 assert!(client.base_url.is_none());
476 }
477
478 #[test]
479 fn test_http_client_with_base_url() {
480 let client =
481 HttpTransparencyLogClient::with_base_url("https://transparency.example.com/").unwrap();
482 assert!(client.base_url.is_some());
483 assert_eq!(
484 client.base_url.unwrap().as_str(),
485 "https://transparency.example.com/"
486 );
487 }
488
489 #[test]
490 fn test_http_client_with_timeout() {
491 let client =
492 HttpTransparencyLogClient::new().with_timeout(std::time::Duration::from_secs(60));
493 assert_eq!(client.timeout, std::time::Duration::from_secs(60));
494 }
495
496 #[test]
497 fn test_audit_response_serialization() {
498 let response = AuditResponse {
499 records: vec![create_test_badge()],
500 };
501
502 let json = serde_json::to_string(&response).unwrap();
503 assert!(json.contains("records"));
504
505 let deserialized: AuditResponse = serde_json::from_str(&json).unwrap();
506 assert_eq!(deserialized.records.len(), 1);
507 }
508
509 #[test]
512 fn test_with_header() {
513 let client = HttpTransparencyLogClient::new().with_header("X-Custom", "value1");
514 assert_eq!(client.extra_headers.len(), 1);
515 assert_eq!(
516 client.extra_headers[0],
517 ("X-Custom".to_string(), "value1".to_string())
518 );
519 }
520
521 #[test]
522 fn test_with_headers() {
523 let client =
524 HttpTransparencyLogClient::new().with_headers([("X-One", "1"), ("X-Two", "2")]);
525 assert_eq!(client.extra_headers.len(), 2);
526 }
527
528 #[test]
531 fn test_build_headers_valid() {
532 let client = HttpTransparencyLogClient::new()
533 .with_header("X-Api-Key", "abc123")
534 .with_header("Authorization", "Bearer token");
535 let headers = client.build_headers().unwrap();
536 assert_eq!(headers.len(), 2);
537 assert_eq!(headers.get("X-Api-Key").unwrap(), "abc123");
538 }
539
540 #[test]
541 fn test_build_headers_invalid_name() {
542 let client = HttpTransparencyLogClient::new().with_header("invalid header\nname", "value");
543 let result = client.build_headers();
544 assert!(result.is_err());
545 assert!(matches!(result.unwrap_err(), TlogError::InvalidHeader(_)));
546 }
547
548 #[test]
549 fn test_build_headers_invalid_value() {
550 let client = HttpTransparencyLogClient::new().with_header("X-Custom", "val\x00ue");
551 let result = client.build_headers();
552 assert!(result.is_err());
553 assert!(matches!(result.unwrap_err(), TlogError::InvalidHeader(_)));
554 }
555
556 #[test]
557 fn test_build_headers_empty() {
558 let client = HttpTransparencyLogClient::new();
559 let headers = client.build_headers().unwrap();
560 assert!(headers.is_empty());
561 }
562
563 #[test]
566 fn test_with_base_url_invalid() {
567 let result = HttpTransparencyLogClient::with_base_url("not a url ://");
568 assert!(result.is_err());
569 assert!(matches!(result.unwrap_err(), TlogError::InvalidUrl(_)));
570 }
571
572 #[test]
573 fn test_debug_format() {
574 let client = HttpTransparencyLogClient::new();
575 let dbg = format!("{client:?}");
576 assert!(dbg.contains("HttpTransparencyLogClient"));
577 }
578}