1use std::time::Duration;
2
3use atd_protocol::AtdError;
4#[cfg(test)]
5use tokio::io::{AsyncRead, AsyncWrite};
6use tokio::net::UnixStream;
7use tokio::sync::Mutex;
8
9use crate::ConnectOptions;
10use crate::endpoint::Endpoint;
11use atd_protocol::wire::{read_frame, write_frame};
12use atd_protocol::{Request, Response};
13
14pub struct AtdClient {
19 inner: Mutex<Pipe>,
20}
21
22enum Pipe {
23 Unix {
24 read: tokio::net::unix::OwnedReadHalf,
25 write: tokio::net::unix::OwnedWriteHalf,
26 },
27 #[cfg(test)]
29 Duplex {
30 read: Box<dyn AsyncRead + Send + Unpin>,
31 write: Box<dyn AsyncWrite + Send + Unpin>,
32 },
33}
34
35impl AtdClient {
36 pub async fn connect(endpoint: Endpoint) -> Result<Self, AtdError> {
40 Self::connect_with_options(endpoint, ConnectOptions::default()).await
41 }
42
43 pub async fn connect_with_options(
54 endpoint: Endpoint,
55 opts: ConnectOptions,
56 ) -> Result<Self, AtdError> {
57 let mut delay_ms = opts.backoff_base_ms;
58 let mut last_err: Option<AtdError> = None;
59 for attempt in 0..opts.max_attempts {
60 let attempt_fut = Self::connect_once(&endpoint);
61 let result =
62 tokio::time::timeout(Duration::from_millis(opts.connect_timeout_ms), attempt_fut)
63 .await
64 .unwrap_or_else(|_| {
65 Err(AtdError::ServerUnreachable(std::io::Error::new(
66 std::io::ErrorKind::TimedOut,
67 format!(
68 "connect attempt timed out after {}ms",
69 opts.connect_timeout_ms
70 ),
71 )))
72 });
73 match result {
74 Ok(client) => return Ok(client),
75 Err(e) if is_fatal_connect_error(&e) => return Err(e),
76 Err(e) => {
77 last_err = Some(e);
78 if attempt + 1 < opts.max_attempts {
79 let jitter_pct = jitter_factor(); let wait_ms = (delay_ms as f64 * (1.0 + jitter_pct)).max(1.0) as u64;
81 tokio::time::sleep(Duration::from_millis(wait_ms)).await;
82 delay_ms = (delay_ms.saturating_mul(2)).min(opts.backoff_cap_ms);
83 }
84 }
85 }
86 }
87 Err(last_err.expect("loop runs at least once"))
88 }
89
90 async fn connect_once(endpoint: &Endpoint) -> Result<Self, AtdError> {
91 match endpoint {
92 Endpoint::UnixSocket(path) => {
93 let stream = UnixStream::connect(path).await?;
94 let (read, write) = stream.into_split();
95 let client = AtdClient {
96 inner: Mutex::new(Pipe::Unix { read, write }),
97 };
98 client.ping().await?;
99 Ok(client)
100 }
101 }
102 }
103
104 #[cfg(test)]
105 pub(crate) fn from_duplex<R, W>(read: R, write: W) -> Self
106 where
107 R: AsyncRead + Send + Unpin + 'static,
108 W: AsyncWrite + Send + Unpin + 'static,
109 {
110 AtdClient {
111 inner: Mutex::new(Pipe::Duplex {
112 read: Box::new(read),
113 write: Box::new(write),
114 }),
115 }
116 }
117
118 pub async fn ping(&self) -> Result<(), AtdError> {
119 match self.request(&Request::Ping).await? {
120 Response::Pong => Ok(()),
121 other => Err(AtdError::ProtocolError {
122 expected: "pong".into(),
123 got: format!("{other:?}"),
124 }),
125 }
126 }
127
128 pub async fn hello(
138 &self,
139 client_id: Option<&str>,
140 requested: Vec<String>,
141 ) -> Result<Vec<String>, AtdError> {
142 self.hello_with_ucan_tokens(client_id, requested, Vec::new())
143 .await
144 }
145
146 pub async fn hello_with_ucan_tokens(
161 &self,
162 client_id: Option<&str>,
163 requested: Vec<String>,
164 ucan_tokens: Vec<String>,
165 ) -> Result<Vec<String>, AtdError> {
166 let presenting_ucan = !ucan_tokens.is_empty();
167 let req = Request::Hello {
168 client_id: client_id.map(|s| s.to_string()),
169 requested_capabilities: requested,
170 ucan_tokens,
171 };
172 match self.request(&req).await {
173 Ok(Response::HelloAck {
174 granted_capabilities,
175 ..
176 }) => Ok(granted_capabilities),
177 Ok(Response::Error { message, code, .. }) if presenting_ucan => {
178 Err(AtdError::ProtocolError {
185 expected: "hello_ack with verified UCAN".into(),
186 got: format!("server error code={code:?} message={message}"),
187 })
188 }
189 Ok(Response::Error { .. }) => Ok(vec![]),
194 Err(AtdError::ProtocolError { .. }) if !presenting_ucan => Ok(vec![]),
195 Err(AtdError::ProtocolError { .. }) => Err(AtdError::ProtocolError {
196 expected: "hello_ack with verified UCAN".into(),
197 got: "protocol error".into(),
198 }),
199 Ok(other) => Err(AtdError::ProtocolError {
200 expected: "hello_ack".into(),
201 got: format!("{other:?}"),
202 }),
203 Err(e) => Err(e),
204 }
205 }
206
207 pub(crate) async fn request(&self, req: &Request) -> Result<Response, AtdError> {
208 let mut guard = self.inner.lock().await;
209 match &mut *guard {
210 Pipe::Unix { read, write } => {
211 write_frame(write, req).await?;
212 let resp: Response = read_frame(read).await?;
213 Ok(resp)
214 }
215 #[cfg(test)]
216 Pipe::Duplex { read, write } => {
217 write_frame(write, req).await?;
218 let resp: Response = read_frame(read).await?;
219 Ok(resp)
220 }
221 }
222 }
223
224 pub async fn discover(
225 &self,
226 query: Option<&str>,
227 filter: crate::options::DiscoverFilter,
228 ) -> Result<Vec<atd_protocol::ToolSummary>, AtdError> {
229 let resp = self.request(&Request::ToolList).await?;
230 let raw = match resp {
231 Response::ToolListResponse { tools } => tools,
232 Response::Error { message, .. } => {
233 return Err(AtdError::ProtocolError {
234 expected: "tool_list".into(),
235 got: format!("error: {message}"),
236 });
237 }
238 other => {
239 return Err(AtdError::ProtocolError {
240 expected: "tool_list".into(),
241 got: format!("{other:?}"),
242 });
243 }
244 };
245
246 let arr = raw.as_array().ok_or_else(|| AtdError::ProtocolError {
247 expected: "array of tool summaries".into(),
248 got: format!("{raw}"),
249 })?;
250
251 let mut out: Vec<atd_protocol::ToolSummary> = Vec::with_capacity(arr.len());
252 for v in arr {
253 match serde_json::from_value::<atd_protocol::ToolSummary>(v.clone()) {
254 Ok(s) => out.push(s),
255 Err(_) => {
256 if let Ok(def) =
258 serde_json::from_value::<atd_protocol::ToolDefinition>(v.clone())
259 {
260 out.push(atd_protocol::ToolSummary::from(&def));
261 }
262 }
263 }
264 }
265
266 for s in &mut out {
268 if s.name.is_empty() {
269 s.name = derive_name(s);
270 }
271 if s.domain.is_empty() {
272 s.domain = derive_domain(&s.id);
273 }
274 }
275
276 if let Some(q) = query {
277 let q_lower = q.to_lowercase();
278 out.retain(|s| {
279 s.name.to_lowercase().contains(&q_lower)
280 || s.description.to_lowercase().contains(&q_lower)
281 || s.id.to_lowercase().contains(&q_lower)
282 });
283 }
284 if let Some(d) = filter.domain.as_deref() {
285 out.retain(|s| s.domain == d);
286 }
287 if let Some(v) = filter.visibility {
288 out.retain(|s| s.visibility == v);
289 }
290 if let Some(t) = filter.tier {
291 out.retain(|s| s.tier == t);
292 }
293 if let Some(n) = filter.limit {
294 out.truncate(n);
295 }
296
297 Ok(out)
298 }
299
300 pub async fn describe(&self, tool_id: &str) -> Result<atd_protocol::ToolDefinition, AtdError> {
301 let resp = self
302 .request(&Request::ToolSchema {
303 tool_id: tool_id.to_string(),
304 })
305 .await?;
306
307 match resp {
308 Response::ToolSchemaResponse { schema } => {
309 serde_json::from_value(schema).map_err(|e| AtdError::ProtocolError {
310 expected: "ToolDefinition".into(),
311 got: format!("deserialize error: {e}"),
312 })
313 }
314 Response::Error { message, .. } if message.to_lowercase().contains("not found") => {
315 Err(AtdError::ToolNotFound {
316 tool_id: tool_id.to_string(),
317 suggestions: vec![],
318 })
319 }
320 Response::Error { message, .. } => Err(AtdError::ProtocolError {
321 expected: "tool_schema".into(),
322 got: format!("error: {message}"),
323 }),
324 other => Err(AtdError::ProtocolError {
325 expected: "tool_schema".into(),
326 got: format!("{other:?}"),
327 }),
328 }
329 }
330
331 pub async fn call_page(
341 &self,
342 tool_id: &str,
343 args: serde_json::Value,
344 cursor: Option<&str>,
345 opts: crate::options::CallOptions,
346 ) -> Result<crate::options::PaginatedSdkResult, AtdError> {
347 let req = match cursor {
348 None => Request::RunTool {
349 tool_id: tool_id.to_string(),
350 args,
351 dry_run: opts.dry_run,
352 },
353 Some(c) => Request::RunToolContinue {
354 tool_id: tool_id.to_string(),
355 cursor: c.to_string(),
356 },
357 };
358 let resp = self.request(&req).await?;
359 match resp {
360 Response::ToolResultResponse {
361 result,
362 success,
363 next_cursor,
364 ..
365 } => {
366 if success {
367 Ok(crate::options::PaginatedSdkResult {
368 value: result,
369 next_cursor,
370 })
371 } else {
372 let (code, message, retryable) = extract_error(&result);
373 Err(AtdError::ToolExecutionFailed {
374 tool_id: tool_id.to_string(),
375 inner: Box::new(std::io::Error::other(format!(
376 "{code} {message} (retryable={retryable})"
377 ))),
378 })
379 }
380 }
381 Response::Error {
382 message,
383 code,
384 retryable,
385 ..
386 } => Err(AtdError::ToolExecutionFailed {
387 tool_id: tool_id.to_string(),
388 inner: Box::new(std::io::Error::other(format!(
389 "server error code={code:?} retryable={retryable:?}: {message}"
390 ))),
391 }),
392 other => Err(AtdError::ProtocolError {
393 expected: "tool_result".into(),
394 got: format!("{other:?}"),
395 }),
396 }
397 }
398
399 pub async fn call_all(
403 &self,
404 tool_id: &str,
405 args: serde_json::Value,
406 opts: crate::options::CallAllOptions,
407 ) -> Result<serde_json::Value, AtdError> {
408 let mut accumulated: Option<serde_json::Value> = None;
409 let mut bytes_total: usize = 0;
410 let mut cursor: Option<String> = None;
411 for page_idx in 0..opts.max_pages {
412 let page_args = if page_idx == 0 {
413 args.clone()
414 } else {
415 serde_json::Value::Null
416 };
417 let page = self
418 .call_page(
419 tool_id,
420 page_args,
421 cursor.as_deref(),
422 crate::options::CallOptions::default(),
423 )
424 .await?;
425 let page_bytes = serde_json::to_vec(&page.value)
426 .map(|v| v.len())
427 .unwrap_or(0);
428 bytes_total += page_bytes;
429 if bytes_total > opts.max_total_bytes {
430 return Err(AtdError::PaginationLimitExceeded {
431 pages_fetched: page_idx + 1,
432 bytes_fetched: bytes_total,
433 });
434 }
435 accumulated = Some(merge_pages(accumulated, page.value, &opts.merge_policy)?);
436 match page.next_cursor {
437 Some(c) => cursor = Some(c),
438 None => return Ok(accumulated.unwrap_or(serde_json::Value::Null)),
439 }
440 }
441 Err(AtdError::PaginationLimitExceeded {
442 pages_fetched: opts.max_pages,
443 bytes_fetched: bytes_total,
444 })
445 }
446
447 pub async fn call(
448 &self,
449 tool_id: &str,
450 args: serde_json::Value,
451 opts: crate::options::CallOptions,
452 ) -> Result<atd_protocol::ToolResult, AtdError> {
453 let resp = self
454 .request(&Request::RunTool {
455 tool_id: tool_id.to_string(),
456 args,
457 dry_run: opts.dry_run,
458 })
459 .await?;
460
461 match resp {
462 Response::ToolResultResponse {
463 tool_id: resp_tool_id,
464 result,
465 success,
466 dry_run: _,
467 next_cursor: _,
468 } => {
469 if success {
470 Ok(atd_protocol::ToolResult::Success {
476 data: result,
477 metadata: atd_protocol::ToolResultMetadata::for_tool(resp_tool_id),
478 })
479 } else {
480 let (code, message, retryable) = extract_error(&result);
481 let reason = serde_json::to_string(&result).ok();
487 Ok(atd_protocol::ToolResult::Error {
488 code,
489 message,
490 reason,
491 retryable,
492 })
493 }
494 }
495 Response::Error {
500 message: _,
501 code: Some(code),
502 details,
503 ..
504 } if code == atd_protocol::ERR_CAPABILITY_DENIED => {
505 let (required, granted) = extract_cap_denied_sets(details.as_ref());
506 Err(AtdError::CapabilityDenied {
507 tool_id: tool_id.to_string(),
508 required,
509 granted,
510 })
511 }
512 Response::Error {
513 message, retryable, ..
514 } => Err(AtdError::ToolExecutionFailed {
515 tool_id: tool_id.to_string(),
516 inner: Box::new(std::io::Error::other(format!(
517 "{message} (retryable={})",
518 retryable.unwrap_or(false)
519 ))),
520 }),
521 other => Err(AtdError::ProtocolError {
522 expected: "tool_result".into(),
523 got: format!("{other:?}"),
524 }),
525 }
526 }
527}
528
529fn derive_name(s: &atd_protocol::ToolSummary) -> String {
532 if !s.name.is_empty() {
533 s.name.clone()
534 } else if !s.description.is_empty() {
535 s.description.clone()
536 } else {
537 s.id.clone()
538 }
539}
540
541fn derive_domain(id: &str) -> String {
544 match id.split_once(':') {
545 Some((_ns, rest)) => rest.split('.').next().unwrap_or("").to_string(),
546 None => String::new(),
547 }
548}
549
550fn extract_cap_denied_sets(details: Option<&serde_json::Value>) -> (Vec<String>, Vec<String>) {
555 let Some(d) = details else {
556 return (vec![], vec![]);
557 };
558 let to_vec = |v: &serde_json::Value| -> Vec<String> {
559 v.as_array()
560 .map(|arr| {
561 arr.iter()
562 .filter_map(|x| x.as_str().map(str::to_string))
563 .collect()
564 })
565 .unwrap_or_default()
566 };
567 let required = d.get("required").map(to_vec).unwrap_or_default();
568 let granted = d.get("granted").map(to_vec).unwrap_or_default();
569 (required, granted)
570}
571
572fn extract_error(value: &serde_json::Value) -> (String, String, bool) {
573 let code = value
574 .get("code")
575 .and_then(|v| v.as_str())
576 .unwrap_or("UNKNOWN")
577 .to_string();
578 let message = value
579 .get("message")
580 .and_then(|v| v.as_str())
581 .unwrap_or("tool call failed")
582 .to_string();
583 let retryable = value
584 .get("retryable")
585 .and_then(|v| v.as_bool())
586 .unwrap_or(false);
587 (code, message, retryable)
588}
589
590fn merge_pages(
593 accumulated: Option<serde_json::Value>,
594 page: serde_json::Value,
595 policy: &crate::options::MergePolicy,
596) -> Result<serde_json::Value, AtdError> {
597 use crate::options::MergePolicy;
598 match (accumulated, policy) {
599 (None, _) => Ok(page),
601 (Some(acc), MergePolicy::FirstPageOnly) => {
602 let _ = page;
607 Ok(acc)
608 }
609 (Some(acc), MergePolicy::ConcatArray) => match (acc, page) {
610 (serde_json::Value::Array(mut a), serde_json::Value::Array(b)) => {
611 a.extend(b);
612 Ok(serde_json::Value::Array(a))
613 }
614 _ => Err(AtdError::MergeFailed {
615 reason: "ConcatArray requires every page to be a JSON array".into(),
616 }),
617 },
618 (Some(acc), MergePolicy::ConcatField(field)) => {
619 let acc_obj = match acc {
620 serde_json::Value::Object(m) => m,
621 _ => {
622 return Err(AtdError::MergeFailed {
623 reason: format!(
624 "ConcatField({field}) requires every page to be a JSON object"
625 ),
626 });
627 }
628 };
629 let mut page_obj = match page {
630 serde_json::Value::Object(m) => m,
631 _ => {
632 return Err(AtdError::MergeFailed {
633 reason: format!("ConcatField({field}) page is not a JSON object"),
634 });
635 }
636 };
637 let acc_arr =
638 acc_obj
639 .get(field.as_str())
640 .cloned()
641 .ok_or_else(|| AtdError::MergeFailed {
642 reason: format!("ConcatField({field}): field missing in accumulator"),
643 })?;
644 let page_arr =
645 page_obj
646 .get(field.as_str())
647 .cloned()
648 .ok_or_else(|| AtdError::MergeFailed {
649 reason: format!("ConcatField({field}): field missing in page"),
650 })?;
651 let combined = match (acc_arr, page_arr) {
652 (serde_json::Value::Array(mut a), serde_json::Value::Array(b)) => {
653 a.extend(b);
654 serde_json::Value::Array(a)
655 }
656 _ => {
657 return Err(AtdError::MergeFailed {
658 reason: format!("ConcatField({field}) is not an array"),
659 });
660 }
661 };
662 page_obj.insert(field.clone(), combined);
665 Ok(serde_json::Value::Object(page_obj))
666 }
667 }
668}
669
670fn is_fatal_connect_error(err: &AtdError) -> bool {
674 matches!(
675 err,
676 AtdError::ServerUnreachable(io) if matches!(
677 io.kind(),
678 std::io::ErrorKind::NotFound | std::io::ErrorKind::PermissionDenied
679 )
680 )
681}
682
683fn jitter_factor() -> f64 {
688 let nanos = std::time::SystemTime::now()
689 .duration_since(std::time::UNIX_EPOCH)
690 .map(|d| d.subsec_nanos())
691 .unwrap_or(0);
692 ((nanos % 1000) as f64 / 1000.0 - 0.5) * 0.4
693}
694
695#[cfg(test)]
696mod tests {
697 use super::*;
698 use tokio::io::duplex;
699
700 async fn spin_server<F>(server_end: tokio::io::DuplexStream, mut handler: F)
703 where
704 F: FnMut(Request) -> Response + Send + 'static,
705 {
706 let (mut read, mut write) = tokio::io::split(server_end);
707 tokio::spawn(async move {
708 while let Ok(req) = read_frame::<_, Request>(&mut read).await {
709 let resp = handler(req);
710 if write_frame(&mut write, &resp).await.is_err() {
711 break;
712 }
713 }
714 });
715 }
716
717 #[tokio::test]
718 async fn ping_returns_ok_when_server_sends_pong() {
719 let (client_end, server_end) = duplex(4096);
720 spin_server(server_end, |req| match req {
721 Request::Ping => Response::Pong,
722 _ => Response::Error {
723 message: "unexpected".into(),
724 code: None,
725 retryable: None,
726 details: None,
727 },
728 })
729 .await;
730
731 let (cr, cw) = tokio::io::split(client_end);
732 let client = AtdClient::from_duplex(cr, cw);
733 client.ping().await.unwrap();
734 }
735
736 #[tokio::test]
737 async fn ping_errors_when_server_sends_wrong_response() {
738 let (client_end, server_end) = duplex(4096);
739 spin_server(server_end, |_| Response::ToolListResponse {
740 tools: serde_json::json!([]),
741 })
742 .await;
743
744 let (cr, cw) = tokio::io::split(client_end);
745 let client = AtdClient::from_duplex(cr, cw);
746 let err = client.ping().await.unwrap_err();
747 assert!(matches!(err, AtdError::ProtocolError { .. }));
748 }
749
750 #[tokio::test]
751 async fn discover_projects_tool_definitions_to_summaries() {
752 let (client_end, server_end) = duplex(16_384);
753 spin_server(server_end, |req| match req {
754 Request::ToolList => Response::ToolListResponse {
755 tools: serde_json::json!([
756 {
757 "id": "anos:fs.read",
758 "name": "Read",
759 "description": "read a file",
760 "version": "0.1.0",
761 "capability": {
762 "domain": "fs",
763 "actions": ["read"],
764 "tags": ["filesystem"],
765 "intent_examples": []
766 },
767 "input_schema": {},
768 "output_schema": {},
769 "bindings": [{"protocol": "Cli", "config": {}}],
770 "safety": {"level": "Read", "dry_run": false, "side_effects": [], "data_sensitivity": null},
771 "resources": {"timeout_ms": 1000, "max_concurrent": 1, "rate_limit_per_min": null, "estimated_tokens": null},
772 "trust": {"publisher": "anos", "trust_level": "L2Tested", "signature": null},
773 "visibility": "read"
774 }
775 ]),
776 },
777 _ => unreachable!(),
778 })
779 .await;
780
781 let (cr, cw) = tokio::io::split(client_end);
782 let client = AtdClient::from_duplex(cr, cw);
783 let summaries = client
784 .discover(None, crate::options::DiscoverFilter::default())
785 .await
786 .unwrap();
787 assert_eq!(summaries.len(), 1);
788 assert_eq!(summaries[0].id, "anos:fs.read");
789 assert_eq!(summaries[0].domain, "fs");
790 }
791
792 #[tokio::test]
793 async fn discover_applies_query_and_limit_client_side() {
794 let (client_end, server_end) = duplex(16_384);
795 spin_server(server_end, |_| Response::ToolListResponse {
796 tools: serde_json::json!([
797 {"id": "anos:fs.read", "name": "Read", "description": "read a file", "domain": "fs", "tags": []},
798 {"id": "anos:fs.write", "name": "Write", "description": "write a file", "domain": "fs", "tags": []},
799 {"id": "anos:web.fetch", "name": "Fetch", "description": "download a url", "domain": "web", "tags": []}
800 ]),
801 })
802 .await;
803
804 let (cr, cw) = tokio::io::split(client_end);
805 let client = AtdClient::from_duplex(cr, cw);
806
807 let only_fs = client
808 .discover(
809 Some("fs"),
810 crate::options::DiscoverFilter {
811 limit: Some(1),
812 ..Default::default()
813 },
814 )
815 .await
816 .unwrap();
817 assert_eq!(only_fs.len(), 1);
818 assert!(only_fs[0].id.starts_with("anos:fs"));
819 }
820
821 fn tool_def_json() -> serde_json::Value {
822 serde_json::json!({
823 "id": "anos:fs.read",
824 "name": "Read",
825 "description": "read a file",
826 "version": "0.1.0",
827 "capability": {
828 "domain": "fs", "actions": ["read"], "tags": [], "intent_examples": []
829 },
830 "input_schema": {"type": "object"},
831 "output_schema": {"type": "string"},
832 "bindings": [{"protocol": "Cli", "config": {}}],
833 "safety": {"level": "Read", "dry_run": false, "side_effects": [], "data_sensitivity": null},
834 "resources": {"timeout_ms": 1000, "max_concurrent": 1, "rate_limit_per_min": null, "estimated_tokens": null},
835 "trust": {"publisher": "anos", "trust_level": "L2Tested", "signature": null},
836 "visibility": "read"
837 })
838 }
839
840 #[tokio::test]
841 async fn describe_returns_full_tool_definition() {
842 let (client_end, server_end) = duplex(16_384);
843 spin_server(server_end, |req| match req {
844 Request::ToolSchema { tool_id } => {
845 assert_eq!(tool_id, "anos:fs.read");
846 Response::ToolSchemaResponse {
847 schema: tool_def_json(),
848 }
849 }
850 _ => unreachable!(),
851 })
852 .await;
853
854 let (cr, cw) = tokio::io::split(client_end);
855 let client = AtdClient::from_duplex(cr, cw);
856 let def = client.describe("anos:fs.read").await.unwrap();
857 assert_eq!(def.id, "anos:fs.read");
858 assert_eq!(def.capability.domain, "fs");
859 }
860
861 #[tokio::test]
862 async fn describe_maps_not_found_error_to_tool_not_found() {
863 let (client_end, server_end) = duplex(4096);
864 spin_server(server_end, |_| Response::Error {
865 message: "tool not found: anos:nope".into(),
866 code: None,
867 retryable: None,
868 details: None,
869 })
870 .await;
871
872 let (cr, cw) = tokio::io::split(client_end);
873 let client = AtdClient::from_duplex(cr, cw);
874 let err = client.describe("anos:nope").await.unwrap_err();
875 assert!(matches!(err, AtdError::ToolNotFound { .. }));
876 }
877
878 #[tokio::test]
879 async fn call_success_returns_tool_result_success() {
880 let (client_end, server_end) = duplex(16_384);
881 spin_server(server_end, |req| match req {
882 Request::RunTool {
883 tool_id,
884 args,
885 dry_run,
886 } => {
887 assert_eq!(tool_id, "anos:fs.read");
888 assert_eq!(args["path"], "/tmp/x");
889 assert!(!dry_run);
890 Response::ToolResultResponse {
891 tool_id,
892 result: serde_json::json!({"content": "ok"}),
893 success: true,
894 dry_run: false,
895 next_cursor: None,
896 }
897 }
898 _ => unreachable!(),
899 })
900 .await;
901
902 let (cr, cw) = tokio::io::split(client_end);
903 let client = AtdClient::from_duplex(cr, cw);
904 let r = client
905 .call(
906 "anos:fs.read",
907 serde_json::json!({"path": "/tmp/x"}),
908 crate::options::CallOptions::default(),
909 )
910 .await
911 .unwrap();
912 assert!(r.is_success());
913 assert_eq!(r.data().unwrap()["content"], "ok");
914 }
915
916 #[tokio::test]
917 async fn call_failure_returns_tool_result_error() {
918 let (client_end, server_end) = duplex(4096);
919 spin_server(server_end, |_| Response::ToolResultResponse {
920 tool_id: "anos:fs.read".into(),
921 result: serde_json::json!({"code": "EPERM", "message": "no", "retryable": false}),
922 success: false,
923 dry_run: false,
924 next_cursor: None,
925 })
926 .await;
927
928 let (cr, cw) = tokio::io::split(client_end);
929 let client = AtdClient::from_duplex(cr, cw);
930 let r = client
931 .call(
932 "anos:fs.read",
933 serde_json::json!({}),
934 crate::options::CallOptions::default(),
935 )
936 .await
937 .unwrap();
938 match r {
939 atd_protocol::ToolResult::Error { code, .. } => assert_eq!(code, "EPERM"),
940 _ => panic!("expected error variant"),
941 }
942 }
943
944 #[tokio::test]
945 async fn call_failure_preserves_raw_payload_in_reason() {
946 let (client_end, server_end) = duplex(4096);
947 spin_server(server_end, |_| Response::ToolResultResponse {
948 tool_id: "anos:fs.read".into(),
949 result: serde_json::json!({"unexpected": {"nested": [1, 2, 3]}, "hint": "quota exceeded"}),
952 success: false,
953 dry_run: false,
954 next_cursor: None,
955 })
956 .await;
957
958 let (cr, cw) = tokio::io::split(client_end);
959 let client = AtdClient::from_duplex(cr, cw);
960 let r = client
961 .call(
962 "anos:fs.read",
963 serde_json::json!({}),
964 crate::options::CallOptions::default(),
965 )
966 .await
967 .unwrap();
968 match r {
969 atd_protocol::ToolResult::Error {
970 code,
971 message,
972 reason,
973 retryable,
974 } => {
975 assert_eq!(code, "UNKNOWN"); assert_eq!(message, "tool call failed");
977 assert!(!retryable);
978 let reason = reason.expect("reason must carry the raw payload");
979 assert!(
980 reason.contains("\"quota exceeded\""),
981 "reason should preserve hint, got: {reason}"
982 );
983 assert!(
984 reason.contains("\"unexpected\""),
985 "reason should preserve unknown keys, got: {reason}"
986 );
987 }
988 _ => panic!("expected error variant"),
989 }
990 }
991
992 #[tokio::test]
993 async fn call_forwards_dry_run_flag() {
994 let (client_end, server_end) = duplex(4096);
995 spin_server(server_end, |req| match req {
996 Request::RunTool { dry_run, .. } => {
997 assert!(dry_run);
998 Response::ToolResultResponse {
999 tool_id: "anos:fs.read".into(),
1000 result: serde_json::json!({}),
1001 success: true,
1002 dry_run: true,
1003 next_cursor: None,
1004 }
1005 }
1006 _ => unreachable!(),
1007 })
1008 .await;
1009
1010 let (cr, cw) = tokio::io::split(client_end);
1011 let client = AtdClient::from_duplex(cr, cw);
1012 client
1013 .call(
1014 "anos:fs.read",
1015 serde_json::json!({}),
1016 crate::options::CallOptions {
1017 dry_run: true,
1018 preferred_binding: None,
1019 },
1020 )
1021 .await
1022 .unwrap();
1023 }
1024
1025 #[tokio::test]
1026 async fn discover_fills_name_and_domain_from_id_when_missing() {
1027 let (client_end, server_end) = duplex(16_384);
1028 spin_server(server_end, |_| Response::ToolListResponse {
1029 tools: serde_json::json!([
1030 {"id":"anos:fs.read","description":"File Read","tier":"hot","visibility":"read","lifecycle":"Active"},
1031 {"id":"anos:web.search","description":"Web Search","tier":"hot","visibility":"read"},
1032 {"id":"host:media.convert","description":"","tier":"warm","visibility":"dangerous"}
1033 ]),
1034 })
1035 .await;
1036
1037 let (cr, cw) = tokio::io::split(client_end);
1038 let client = AtdClient::from_duplex(cr, cw);
1039 let summaries = client
1040 .discover(None, crate::options::DiscoverFilter::default())
1041 .await
1042 .unwrap();
1043 assert_eq!(summaries.len(), 3);
1044
1045 assert_eq!(summaries[0].id, "anos:fs.read");
1047 assert_eq!(summaries[0].name, "File Read");
1048 assert_eq!(summaries[0].domain, "fs");
1049
1050 assert_eq!(summaries[1].domain, "web");
1052
1053 assert_eq!(summaries[2].domain, "media");
1055 assert_eq!(summaries[2].name, "host:media.convert");
1056 }
1057
1058 #[tokio::test]
1061 async fn hello_returns_granted_subset_from_server() {
1062 let (client_end, server_end) = duplex(4096);
1063 spin_server(server_end, |req| match req {
1064 Request::Hello {
1065 client_id,
1066 requested_capabilities,
1067 ..
1068 } => {
1069 assert_eq!(client_id.as_deref(), Some("test"));
1070 assert_eq!(requested_capabilities, vec!["exec", "admin"]);
1071 Response::HelloAck {
1072 granted_capabilities: vec!["exec".into()],
1073 server_version: "atd-ref-server 0.2.0".into(),
1074 supported_tiers: vec!["hot".into(), "warm".into(), "cold".into()],
1075 }
1076 }
1077 _ => unreachable!(),
1078 })
1079 .await;
1080 let (cr, cw) = tokio::io::split(client_end);
1081 let client = AtdClient::from_duplex(cr, cw);
1082 let granted = client
1083 .hello(Some("test"), vec!["exec".into(), "admin".into()])
1084 .await
1085 .unwrap();
1086 assert_eq!(granted, vec!["exec"]);
1087 }
1088
1089 #[tokio::test]
1090 async fn hello_degrades_to_empty_caps_on_pre_sp12_server_error() {
1091 let (client_end, server_end) = duplex(4096);
1092 spin_server(server_end, |req| match req {
1093 Request::Hello { .. } => Response::Error {
1094 message: "unknown request".into(),
1095 code: None,
1096 retryable: None,
1097 details: None,
1098 },
1099 _ => unreachable!(),
1100 })
1101 .await;
1102 let (cr, cw) = tokio::io::split(client_end);
1103 let client = AtdClient::from_duplex(cr, cw);
1104 let granted = client.hello(None, vec!["exec".into()]).await.unwrap();
1105 assert!(granted.is_empty(), "pre-SP-12 server → empty grant");
1106 }
1107
1108 #[tokio::test]
1109 async fn call_surfaces_capability_denied_with_both_sets() {
1110 let (client_end, server_end) = duplex(4096);
1111 spin_server(server_end, |req| match req {
1112 Request::RunTool { .. } => Response::Error {
1113 message: "capability denied for ref:x: missing [\"exec\"]".into(),
1114 code: Some(atd_protocol::ERR_CAPABILITY_DENIED),
1115 retryable: Some(false),
1116 details: Some(serde_json::json!({
1117 "required": ["exec"],
1118 "granted": [],
1119 "missing": ["exec"],
1120 })),
1121 },
1122 _ => unreachable!(),
1123 })
1124 .await;
1125 let (cr, cw) = tokio::io::split(client_end);
1126 let client = AtdClient::from_duplex(cr, cw);
1127 let err = client
1128 .call(
1129 "ref:x",
1130 serde_json::json!({}),
1131 crate::options::CallOptions::default(),
1132 )
1133 .await
1134 .unwrap_err();
1135 match err {
1136 AtdError::CapabilityDenied {
1137 tool_id,
1138 required,
1139 granted,
1140 } => {
1141 assert_eq!(tool_id, "ref:x");
1142 assert_eq!(required, vec!["exec"]);
1143 assert!(granted.is_empty());
1144 }
1145 other => panic!("expected CapabilityDenied, got {other:?}"),
1146 }
1147 }
1148
1149 #[tokio::test]
1150 async fn call_non_capability_error_still_maps_to_tool_execution_failed() {
1151 let (client_end, server_end) = duplex(4096);
1154 spin_server(server_end, |_| Response::Error {
1155 message: "something else".into(),
1156 code: Some(500),
1157 retryable: Some(true),
1158 details: None,
1159 })
1160 .await;
1161 let (cr, cw) = tokio::io::split(client_end);
1162 let client = AtdClient::from_duplex(cr, cw);
1163 let err = client
1164 .call(
1165 "ref:x",
1166 serde_json::json!({}),
1167 crate::options::CallOptions::default(),
1168 )
1169 .await
1170 .unwrap_err();
1171 assert!(
1172 matches!(err, AtdError::ToolExecutionFailed { .. }),
1173 "non-1001 errors must still be ToolExecutionFailed, got {err:?}"
1174 );
1175 }
1176
1177 async fn spawn_immediate_close_listener() -> (
1183 std::path::PathBuf,
1184 std::sync::Arc<std::sync::atomic::AtomicU32>,
1185 ) {
1186 use std::sync::atomic::{AtomicU32, Ordering};
1187 let dir = tempfile::tempdir().unwrap();
1188 let path = dir.path().join("close.sock");
1189 let counter = std::sync::Arc::new(AtomicU32::new(0));
1190 let counter_for_task = counter.clone();
1191 let listener = tokio::net::UnixListener::bind(&path).unwrap();
1192 std::mem::forget(dir); let path_ret = path.clone();
1194 tokio::spawn(async move {
1195 while let Ok((stream, _)) = listener.accept().await {
1196 counter_for_task.fetch_add(1, Ordering::Relaxed);
1197 drop(stream); }
1199 });
1200 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1202 (path_ret, counter)
1203 }
1204
1205 #[tokio::test]
1206 async fn connect_retries_on_transient_failure() {
1207 let (path, accepts) = spawn_immediate_close_listener().await;
1208 let opts = ConnectOptions {
1209 max_attempts: 3,
1210 backoff_base_ms: 5,
1211 backoff_cap_ms: 20,
1212 connect_timeout_ms: 500,
1213 };
1214 let result = AtdClient::connect_with_options(Endpoint::unix(path), opts).await;
1215 assert!(
1216 result.is_err(),
1217 "connect should fail when listener closes streams"
1218 );
1219 let n = accepts.load(std::sync::atomic::Ordering::Relaxed);
1221 assert_eq!(n, 3, "expected 3 connect attempts, listener saw {n}");
1222 }
1223
1224 #[tokio::test]
1225 async fn connect_respects_max_attempts() {
1226 let (path, accepts) = spawn_immediate_close_listener().await;
1227 let opts = ConnectOptions {
1228 max_attempts: 5,
1229 backoff_base_ms: 5,
1230 backoff_cap_ms: 20,
1231 connect_timeout_ms: 500,
1232 };
1233 let _ = AtdClient::connect_with_options(Endpoint::unix(path), opts).await;
1234 let n = accepts.load(std::sync::atomic::Ordering::Relaxed);
1235 assert_eq!(
1236 n, 5,
1237 "max_attempts=5 should yield exactly 5 attempts, got {n}"
1238 );
1239 }
1240
1241 #[tokio::test]
1242 async fn connect_short_circuits_on_not_found() {
1243 let opts = ConnectOptions {
1246 max_attempts: 5,
1247 backoff_base_ms: 100, backoff_cap_ms: 100,
1249 connect_timeout_ms: 500,
1250 };
1251 let started = std::time::Instant::now();
1252 let result = AtdClient::connect_with_options(
1253 Endpoint::unix("/tmp/atd-sdk-test-no-such-socket-xy7q"),
1254 opts,
1255 )
1256 .await;
1257 let elapsed = started.elapsed();
1258 match result {
1259 Err(AtdError::ServerUnreachable(_)) => {}
1260 Err(other) => panic!("expected ServerUnreachable, got {other:?}"),
1261 Ok(_) => panic!("connect to nonexistent path should not succeed"),
1262 }
1263 assert!(
1264 elapsed < std::time::Duration::from_millis(80),
1265 "short-circuit should be near-instant, took {elapsed:?}"
1266 );
1267 }
1268
1269 #[test]
1274 fn connect_options_default_reads_env() {
1275 let orig = (
1277 std::env::var("ATD_CONNECT_RETRIES").ok(),
1278 std::env::var("ATD_CONNECT_BACKOFF_BASE_MS").ok(),
1279 );
1280 unsafe {
1285 std::env::set_var("ATD_CONNECT_RETRIES", "2");
1286 std::env::set_var("ATD_CONNECT_BACKOFF_BASE_MS", "123");
1287 }
1288 let opts = ConnectOptions::default();
1289 unsafe {
1291 match &orig.0 {
1292 Some(v) => std::env::set_var("ATD_CONNECT_RETRIES", v),
1293 None => std::env::remove_var("ATD_CONNECT_RETRIES"),
1294 }
1295 match &orig.1 {
1296 Some(v) => std::env::set_var("ATD_CONNECT_BACKOFF_BASE_MS", v),
1297 None => std::env::remove_var("ATD_CONNECT_BACKOFF_BASE_MS"),
1298 }
1299 }
1300 assert_eq!(opts.max_attempts, 2);
1301 assert_eq!(opts.backoff_base_ms, 123);
1302 }
1303
1304 #[test]
1305 fn is_fatal_classifies_not_found_and_permission_denied() {
1306 let nf =
1307 AtdError::ServerUnreachable(std::io::Error::new(std::io::ErrorKind::NotFound, "x"));
1308 let pd = AtdError::ServerUnreachable(std::io::Error::new(
1309 std::io::ErrorKind::PermissionDenied,
1310 "x",
1311 ));
1312 let cr = AtdError::ServerUnreachable(std::io::Error::new(
1313 std::io::ErrorKind::ConnectionRefused,
1314 "x",
1315 ));
1316 assert!(is_fatal_connect_error(&nf));
1317 assert!(is_fatal_connect_error(&pd));
1318 assert!(!is_fatal_connect_error(&cr));
1319 }
1320
1321 #[test]
1322 fn jitter_factor_stays_within_bounds() {
1323 for _ in 0..1000 {
1324 let j = jitter_factor();
1325 assert!((-0.2..=0.2).contains(&j), "jitter {j} out of ±0.2 bound");
1326 }
1327 }
1328
1329 #[tokio::test]
1332 async fn call_page_initial_sends_run_tool() {
1333 let (client_end, server_end) = duplex(4096);
1334 spin_server(server_end, |req| match req {
1335 Request::RunTool { tool_id, .. } => Response::ToolResultResponse {
1336 tool_id,
1337 result: serde_json::json!([1, 2, 3]),
1338 success: true,
1339 dry_run: false,
1340 next_cursor: Some("CURSOR_AFTER_PAGE_1".into()),
1341 },
1342 other => panic!("expected RunTool, got {other:?}"),
1343 })
1344 .await;
1345 let (cr, cw) = tokio::io::split(client_end);
1346 let client = AtdClient::from_duplex(cr, cw);
1347 let page = client
1348 .call_page(
1349 "celia:list_obs",
1350 serde_json::json!({"p": "x"}),
1351 None,
1352 crate::options::CallOptions::default(),
1353 )
1354 .await
1355 .unwrap();
1356 assert_eq!(page.value, serde_json::json!([1, 2, 3]));
1357 assert_eq!(page.next_cursor.as_deref(), Some("CURSOR_AFTER_PAGE_1"));
1358 }
1359
1360 #[tokio::test]
1361 async fn call_page_with_cursor_sends_run_tool_continue() {
1362 let (client_end, server_end) = duplex(4096);
1363 spin_server(server_end, |req| match req {
1364 Request::RunToolContinue { tool_id, cursor } => {
1365 assert_eq!(cursor, "CURSOR_X");
1366 Response::ToolResultResponse {
1367 tool_id,
1368 result: serde_json::json!([4, 5]),
1369 success: true,
1370 dry_run: false,
1371 next_cursor: None,
1372 }
1373 }
1374 other => panic!("expected RunToolContinue, got {other:?}"),
1375 })
1376 .await;
1377 let (cr, cw) = tokio::io::split(client_end);
1378 let client = AtdClient::from_duplex(cr, cw);
1379 let page = client
1380 .call_page(
1381 "celia:list_obs",
1382 serde_json::Value::Null,
1383 Some("CURSOR_X"),
1384 crate::options::CallOptions::default(),
1385 )
1386 .await
1387 .unwrap();
1388 assert_eq!(page.value, serde_json::json!([4, 5]));
1389 assert!(page.next_cursor.is_none());
1390 }
1391
1392 #[tokio::test]
1393 async fn call_all_concats_arrays_until_no_cursor() {
1394 let (client_end, server_end) = duplex(4096);
1395 spin_server(server_end, move |req| match req {
1397 Request::RunTool { tool_id, .. } => Response::ToolResultResponse {
1398 tool_id,
1399 result: serde_json::json!([1, 2]),
1400 success: true,
1401 dry_run: false,
1402 next_cursor: Some("cursor-a".into()),
1403 },
1404 Request::RunToolContinue { tool_id, cursor } => match cursor.as_str() {
1405 "cursor-a" => Response::ToolResultResponse {
1406 tool_id,
1407 result: serde_json::json!([3, 4]),
1408 success: true,
1409 dry_run: false,
1410 next_cursor: Some("cursor-b".into()),
1411 },
1412 "cursor-b" => Response::ToolResultResponse {
1413 tool_id,
1414 result: serde_json::json!([5, 6]),
1415 success: true,
1416 dry_run: false,
1417 next_cursor: None,
1418 },
1419 other => panic!("unexpected cursor: {other}"),
1420 },
1421 other => panic!("unexpected req: {other:?}"),
1422 })
1423 .await;
1424 let (cr, cw) = tokio::io::split(client_end);
1425 let client = AtdClient::from_duplex(cr, cw);
1426 let all = client
1427 .call_all(
1428 "t",
1429 serde_json::json!({}),
1430 crate::options::CallAllOptions::default(),
1431 )
1432 .await
1433 .unwrap();
1434 assert_eq!(all, serde_json::json!([1, 2, 3, 4, 5, 6]));
1435 }
1436
1437 #[tokio::test]
1438 async fn call_all_concat_field_merges_named_array() {
1439 let (client_end, server_end) = duplex(4096);
1440 spin_server(server_end, |req| match req {
1441 Request::RunTool { tool_id, .. } => Response::ToolResultResponse {
1442 tool_id,
1443 result: serde_json::json!({"patient": "p1", "obs": [{"id": 1}], "total": 4}),
1444 success: true,
1445 dry_run: false,
1446 next_cursor: Some("c1".into()),
1447 },
1448 Request::RunToolContinue { tool_id, .. } => Response::ToolResultResponse {
1449 tool_id,
1450 result: serde_json::json!({"patient": "p1", "obs": [{"id": 2}, {"id": 3}, {"id": 4}], "total": 4}),
1451 success: true,
1452 dry_run: false,
1453 next_cursor: None,
1454 },
1455 other => panic!("unexpected: {other:?}"),
1456 })
1457 .await;
1458 let (cr, cw) = tokio::io::split(client_end);
1459 let client = AtdClient::from_duplex(cr, cw);
1460 let opts = crate::options::CallAllOptions {
1461 merge_policy: crate::options::MergePolicy::ConcatField("obs".into()),
1462 ..Default::default()
1463 };
1464 let all = client
1465 .call_all("t", serde_json::json!({}), opts)
1466 .await
1467 .unwrap();
1468 assert_eq!(all["patient"], "p1");
1470 assert_eq!(all["total"], 4);
1471 assert_eq!(
1472 all["obs"],
1473 serde_json::json!([{"id":1},{"id":2},{"id":3},{"id":4}])
1474 );
1475 }
1476
1477 #[tokio::test]
1478 async fn call_all_respects_max_pages() {
1479 let (client_end, server_end) = duplex(4096);
1480 spin_server(server_end, |req| match req {
1482 Request::RunTool { tool_id, .. } => Response::ToolResultResponse {
1483 tool_id,
1484 result: serde_json::json!([0]),
1485 success: true,
1486 dry_run: false,
1487 next_cursor: Some("c".into()),
1488 },
1489 Request::RunToolContinue { tool_id, .. } => Response::ToolResultResponse {
1490 tool_id,
1491 result: serde_json::json!([0]),
1492 success: true,
1493 dry_run: false,
1494 next_cursor: Some("c".into()),
1495 },
1496 other => panic!("unexpected: {other:?}"),
1497 })
1498 .await;
1499 let (cr, cw) = tokio::io::split(client_end);
1500 let client = AtdClient::from_duplex(cr, cw);
1501 let opts = crate::options::CallAllOptions {
1502 max_pages: 3,
1503 ..Default::default()
1504 };
1505 let err = client.call_all("t", serde_json::json!({}), opts).await;
1506 match err {
1507 Err(AtdError::PaginationLimitExceeded { pages_fetched, .. }) => {
1508 assert_eq!(pages_fetched, 3);
1509 }
1510 other => panic!("expected PaginationLimitExceeded, got {other:?}"),
1511 }
1512 }
1513
1514 #[tokio::test]
1515 async fn call_all_respects_max_total_bytes() {
1516 let (client_end, server_end) = duplex(8192);
1517 spin_server(server_end, |req| {
1520 let big = serde_json::Value::Array((0..100).map(serde_json::Value::from).collect());
1521 match req {
1522 Request::RunTool { tool_id, .. } => Response::ToolResultResponse {
1523 tool_id,
1524 result: big,
1525 success: true,
1526 dry_run: false,
1527 next_cursor: Some("c".into()),
1528 },
1529 Request::RunToolContinue { tool_id, .. } => Response::ToolResultResponse {
1530 tool_id,
1531 result: big,
1532 success: true,
1533 dry_run: false,
1534 next_cursor: Some("c".into()),
1535 },
1536 other => panic!("unexpected: {other:?}"),
1537 }
1538 })
1539 .await;
1540 let (cr, cw) = tokio::io::split(client_end);
1541 let client = AtdClient::from_duplex(cr, cw);
1542 let opts = crate::options::CallAllOptions {
1543 max_total_bytes: 400, ..Default::default()
1545 };
1546 let err = client.call_all("t", serde_json::json!({}), opts).await;
1547 match err {
1548 Err(AtdError::PaginationLimitExceeded {
1549 bytes_fetched,
1550 pages_fetched: _,
1551 }) => {
1552 assert!(
1553 bytes_fetched > 400,
1554 "expected byte overflow, got {bytes_fetched}"
1555 );
1556 }
1557 other => panic!("expected PaginationLimitExceeded, got {other:?}"),
1558 }
1559 }
1560
1561 #[tokio::test]
1562 async fn call_all_single_page_returns_value_unchanged() {
1563 let (client_end, server_end) = duplex(4096);
1564 spin_server(server_end, |req| match req {
1565 Request::RunTool { tool_id, .. } => Response::ToolResultResponse {
1566 tool_id,
1567 result: serde_json::json!({"data": [1, 2, 3]}),
1568 success: true,
1569 dry_run: false,
1570 next_cursor: None,
1571 },
1572 other => panic!("unexpected: {other:?}"),
1573 })
1574 .await;
1575 let (cr, cw) = tokio::io::split(client_end);
1576 let client = AtdClient::from_duplex(cr, cw);
1577 let all = client
1578 .call_all(
1579 "t",
1580 serde_json::json!({}),
1581 crate::options::CallAllOptions::default(),
1582 )
1583 .await
1584 .unwrap();
1585 assert_eq!(all, serde_json::json!({"data": [1, 2, 3]}));
1586 }
1587
1588 #[test]
1589 fn merge_pages_concat_array_basic() {
1590 use crate::options::MergePolicy;
1591 let r = merge_pages(
1592 Some(serde_json::json!([1, 2])),
1593 serde_json::json!([3, 4]),
1594 &MergePolicy::ConcatArray,
1595 )
1596 .unwrap();
1597 assert_eq!(r, serde_json::json!([1, 2, 3, 4]));
1598 }
1599
1600 #[test]
1601 fn merge_pages_concat_array_rejects_non_array() {
1602 use crate::options::MergePolicy;
1603 let err = merge_pages(
1604 Some(serde_json::json!([1, 2])),
1605 serde_json::json!({"x": 1}),
1606 &MergePolicy::ConcatArray,
1607 )
1608 .unwrap_err();
1609 assert!(matches!(err, AtdError::MergeFailed { .. }));
1610 }
1611
1612 #[test]
1613 fn merge_pages_first_page_only_drops_subsequent() {
1614 use crate::options::MergePolicy;
1615 let r = merge_pages(
1616 Some(serde_json::json!({"first": true})),
1617 serde_json::json!({"second": true}),
1618 &MergePolicy::FirstPageOnly,
1619 )
1620 .unwrap();
1621 assert_eq!(
1622 r,
1623 serde_json::json!({"first": true}),
1624 "FirstPageOnly: accumulator wins; subsequent pages dropped"
1625 );
1626 }
1627}