1use std::io::{ErrorKind, Read};
16use std::path::Path;
17use std::sync::Arc;
18use std::time::Duration;
19
20use bytes::Bytes;
21use futures_util::StreamExt;
22use reqwest::header::{ACCEPT, CONTENT_LENGTH, CONTENT_TYPE};
23use reqwest::{Method, RequestBuilder};
24use secrecy::{ExposeSecret, SecretString};
25use serde::de::DeserializeOwned;
26use serde::Serialize;
27use tokio::fs::File;
28use tokio_stream::wrappers::ReceiverStream;
29use tokio_util::io::ReaderStream;
30use url::Url;
31
32use crate::endpoint::Endpoint;
33use crate::error::ZenodoError;
34use crate::ids::{BucketUrl, DepositionFileId, DepositionId};
35use crate::metadata::DepositMetadataUpdate;
36use crate::model::{BucketObject, Deposition, DepositionFile};
37use crate::poll::PollOptions;
38use crate::progress::TransferProgress;
39
40#[derive(Clone)]
42pub struct Auth {
43 pub token: SecretString,
45}
46
47impl Auth {
48 pub const TOKEN_ENV_VAR: &'static str = "ZENODO_TOKEN";
50
51 pub const SANDBOX_TOKEN_ENV_VAR: &'static str = "ZENODO_SANDBOX_TOKEN";
53
54 #[must_use]
66 pub fn new(token: impl Into<String>) -> Self {
67 Self {
68 token: SecretString::from(token.into()),
69 }
70 }
71
72 pub fn from_env() -> Result<Self, ZenodoError> {
78 Self::from_env_var(Self::TOKEN_ENV_VAR)
79 }
80
81 pub fn from_sandbox_env() -> Result<Self, ZenodoError> {
87 Self::from_env_var(Self::SANDBOX_TOKEN_ENV_VAR)
88 }
89
90 pub fn from_env_var(name: &str) -> Result<Self, ZenodoError> {
106 let token = std::env::var(name).map_err(|source| ZenodoError::EnvVar {
107 name: name.to_owned(),
108 source,
109 })?;
110 Ok(Self::new(token))
111 }
112}
113
114impl From<SecretString> for Auth {
115 fn from(token: SecretString) -> Self {
116 Self { token }
117 }
118}
119
120impl std::fmt::Debug for Auth {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("Auth")
123 .field("token", &"<redacted>")
124 .finish()
125 }
126}
127
128#[derive(Clone, Debug)]
130pub struct ZenodoClientBuilder {
131 auth: Auth,
132 endpoint: Endpoint,
133 poll: PollOptions,
134 user_agent: Option<String>,
135 request_timeout: Option<Duration>,
136 connect_timeout: Option<Duration>,
137}
138
139impl ZenodoClientBuilder {
140 #[must_use]
155 pub fn endpoint(mut self, endpoint: Endpoint) -> Self {
156 self.endpoint = endpoint;
157 self
158 }
159
160 #[must_use]
175 pub fn sandbox(mut self) -> Self {
176 self.endpoint = Endpoint::Sandbox;
177 self
178 }
179
180 #[must_use]
194 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
195 self.user_agent = Some(user_agent.into());
196 self
197 }
198
199 #[must_use]
215 pub fn request_timeout(mut self, timeout: Duration) -> Self {
216 self.request_timeout = Some(timeout);
217 self
218 }
219
220 #[must_use]
236 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
237 self.connect_timeout = Some(timeout);
238 self
239 }
240
241 #[must_use]
261 pub fn poll_options(mut self, poll: PollOptions) -> Self {
262 self.poll = poll;
263 self
264 }
265
266 pub fn build(self) -> Result<ZenodoClient, ZenodoError> {
272 ensure_rustls_provider();
273
274 let user_agent = self
275 .user_agent
276 .unwrap_or_else(|| format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")));
277
278 let mut inner = reqwest::Client::builder().user_agent(&user_agent);
279 if let Some(timeout) = self.request_timeout {
280 inner = inner.timeout(timeout);
281 }
282 if let Some(timeout) = self.connect_timeout {
283 inner = inner.connect_timeout(timeout);
284 }
285 let inner = inner.build()?;
286
287 Ok(ZenodoClient {
288 inner,
289 auth: self.auth,
290 endpoint: self.endpoint,
291 poll: self.poll,
292 request_timeout: self.request_timeout,
293 connect_timeout: self.connect_timeout,
294 })
295 }
296}
297
298#[cfg(feature = "rustls-ring-tls")]
299pub(crate) fn ensure_rustls_provider() {
300 static INSTALL: std::sync::OnceLock<()> = std::sync::OnceLock::new();
301 INSTALL.get_or_init(|| {
302 let _ = rustls::crypto::ring::default_provider().install_default();
303 });
304}
305
306#[cfg(not(feature = "rustls-ring-tls"))]
307pub(crate) fn ensure_rustls_provider() {}
308
309#[derive(Clone, Debug)]
311pub struct ZenodoClient {
312 pub(crate) inner: reqwest::Client,
313 pub(crate) auth: Auth,
314 pub(crate) endpoint: Endpoint,
315 pub(crate) poll: PollOptions,
316 pub(crate) request_timeout: Option<Duration>,
317 pub(crate) connect_timeout: Option<Duration>,
318}
319
320impl ZenodoClient {
321 #[must_use]
336 pub fn builder(auth: Auth) -> ZenodoClientBuilder {
337 ZenodoClientBuilder {
338 auth,
339 endpoint: Endpoint::default(),
340 poll: PollOptions::default(),
341 user_agent: None,
342 request_timeout: None,
343 connect_timeout: None,
344 }
345 }
346
347 pub fn new(auth: Auth) -> Result<Self, ZenodoError> {
353 Self::builder(auth).build()
354 }
355
356 pub fn with_token(token: impl Into<String>) -> Result<Self, ZenodoError> {
372 Self::new(Auth::new(token))
373 }
374
375 pub fn from_env() -> Result<Self, ZenodoError> {
382 Self::new(Auth::from_env()?)
383 }
384
385 pub fn from_sandbox_env() -> Result<Self, ZenodoError> {
392 Self::builder(Auth::from_sandbox_env()?).sandbox().build()
393 }
394
395 #[must_use]
397 pub fn endpoint(&self) -> &Endpoint {
398 &self.endpoint
399 }
400
401 #[must_use]
403 pub fn poll_options(&self) -> &PollOptions {
404 &self.poll
405 }
406
407 #[must_use]
409 pub fn request_timeout(&self) -> Option<Duration> {
410 self.request_timeout
411 }
412
413 #[must_use]
415 pub fn connect_timeout(&self) -> Option<Duration> {
416 self.connect_timeout
417 }
418
419 pub(crate) fn request(
420 &self,
421 method: Method,
422 path: &str,
423 ) -> Result<RequestBuilder, ZenodoError> {
424 let url = self.endpoint.base_url()?.join(path)?;
425 self.request_url(method, url)
426 }
427
428 pub(crate) fn request_url(
429 &self,
430 method: Method,
431 url: Url,
432 ) -> Result<RequestBuilder, ZenodoError> {
433 if !self.is_trusted_url(&url)? {
434 return Err(ZenodoError::InvalidState(format!(
435 "refusing authenticated API request to different origin: {url}"
436 )));
437 }
438
439 Ok(self
440 .inner
441 .request(method, url)
442 .bearer_auth(self.auth.token.expose_secret())
443 .header(ACCEPT, "application/json"))
444 }
445
446 pub(crate) fn download_request_url(
447 &self,
448 method: Method,
449 url: Url,
450 ) -> Result<RequestBuilder, ZenodoError> {
451 let trusted = self.is_trusted_url(&url)?;
452 let mut request = self.inner.request(method, url);
453 if trusted {
454 request = request.bearer_auth(self.auth.token.expose_secret());
455 }
456
457 Ok(request)
458 }
459
460 fn is_trusted_url(&self, url: &Url) -> Result<bool, ZenodoError> {
461 Ok(self.endpoint.base_url()?.origin() == url.origin())
462 }
463
464 pub(crate) async fn execute_json<T>(&self, request: RequestBuilder) -> Result<T, ZenodoError>
465 where
466 T: DeserializeOwned,
467 {
468 let response = request.send().await?;
469 if !response.status().is_success() {
470 return Err(ZenodoError::from_response(response).await);
471 }
472
473 let bytes = response.bytes().await?;
474 Ok(serde_json::from_slice(&bytes)?)
475 }
476
477 pub(crate) async fn execute_json_or_else<T, F, Fut>(
478 &self,
479 request: RequestBuilder,
480 on_empty: F,
481 ) -> Result<T, ZenodoError>
482 where
483 T: DeserializeOwned,
484 F: FnOnce() -> Fut,
485 Fut: std::future::Future<Output = Result<T, ZenodoError>>,
486 {
487 let response = request.send().await?;
488 if !response.status().is_success() {
489 return Err(ZenodoError::from_response(response).await);
490 }
491
492 let bytes = response.bytes().await?;
493 if bytes.is_empty() {
494 return on_empty().await;
495 }
496
497 Ok(serde_json::from_slice(&bytes)?)
498 }
499
500 pub(crate) async fn execute_unit(&self, request: RequestBuilder) -> Result<(), ZenodoError> {
501 let response = request.send().await?;
502 if !response.status().is_success() {
503 return Err(ZenodoError::from_response(response).await);
504 }
505
506 Ok(())
507 }
508
509 pub(crate) async fn execute_response(
510 &self,
511 request: RequestBuilder,
512 ) -> Result<reqwest::Response, ZenodoError> {
513 let response = request.send().await?;
514 if !response.status().is_success() {
515 return Err(ZenodoError::from_response(response).await);
516 }
517
518 Ok(response)
519 }
520
521 pub(crate) async fn get_deposition_by_url(&self, url: &Url) -> Result<Deposition, ZenodoError> {
522 self.execute_json(self.request_url(Method::GET, url.clone())?)
523 .await
524 }
525
526 pub(crate) async fn get_record_by_url(
527 &self,
528 url: &Url,
529 ) -> Result<crate::model::Record, ZenodoError> {
530 self.execute_json(self.request_url(Method::GET, url.clone())?)
531 .await
532 }
533
534 pub async fn create_deposition(&self) -> Result<Deposition, ZenodoError> {
541 self.execute_json(
542 self.request(Method::POST, "deposit/depositions")?
543 .json(&serde_json::json!({})),
544 )
545 .await
546 }
547
548 pub async fn get_deposition(&self, id: DepositionId) -> Result<Deposition, ZenodoError> {
555 self.execute_json(self.request(Method::GET, &format!("deposit/depositions/{id}"))?)
556 .await
557 }
558
559 pub async fn update_metadata(
565 &self,
566 id: DepositionId,
567 metadata: &DepositMetadataUpdate,
568 ) -> Result<Deposition, ZenodoError> {
569 #[derive(Serialize)]
570 struct Payload<'a> {
571 metadata: &'a DepositMetadataUpdate,
572 }
573
574 self.execute_json(
575 self.request(Method::PUT, &format!("deposit/depositions/{id}"))?
576 .json(&Payload { metadata }),
577 )
578 .await
579 }
580
581 pub async fn list_files(&self, id: DepositionId) -> Result<Vec<DepositionFile>, ZenodoError> {
588 self.execute_json(self.request(Method::GET, &format!("deposit/depositions/{id}/files"))?)
589 .await
590 }
591
592 pub async fn delete_file(
598 &self,
599 id: DepositionId,
600 file_id: DepositionFileId,
601 ) -> Result<(), ZenodoError> {
602 self.execute_unit(self.request(
603 Method::DELETE,
604 &format!("deposit/depositions/{id}/files/{file_id}"),
605 )?)
606 .await
607 }
608
609 pub async fn upload_path(
616 &self,
617 bucket: &BucketUrl,
618 filename: &str,
619 path: &Path,
620 ) -> Result<BucketObject, ZenodoError> {
621 self.upload_path_with_progress(bucket, filename, path, ())
622 .await
623 }
624
625 pub async fn upload_path_with_progress<P>(
635 &self,
636 bucket: &BucketUrl,
637 filename: &str,
638 path: &Path,
639 progress: P,
640 ) -> Result<BucketObject, ZenodoError>
641 where
642 P: TransferProgress + 'static,
643 {
644 self.upload_path_with_content_type_and_progress(
645 bucket,
646 filename,
647 path,
648 mime::APPLICATION_OCTET_STREAM,
649 progress,
650 )
651 .await
652 }
653
654 pub(crate) async fn upload_path_with_content_type_and_progress<P>(
655 &self,
656 bucket: &BucketUrl,
657 filename: &str,
658 path: &Path,
659 content_type: mime::Mime,
660 progress: P,
661 ) -> Result<BucketObject, ZenodoError>
662 where
663 P: TransferProgress + 'static,
664 {
665 let file = File::open(path).await?;
666 let length = file.metadata().await?.len();
667 let progress = Arc::new(progress);
668 progress.begin(Some(length));
669 let body_progress = Arc::clone(&progress);
670 let body = reqwest::Body::wrap_stream(ReaderStream::new(file).map(move |item| {
671 if let Ok(bytes) = &item {
672 body_progress.advance(bytes.len() as u64);
673 }
674 item
675 }));
676
677 let uploaded = self
678 .execute_json(
679 self.request_url(Method::PUT, bucket_upload_url(bucket, filename)?)?
680 .header(CONTENT_LENGTH, length)
681 .header(CONTENT_TYPE, content_type.as_ref())
682 .body(body),
683 )
684 .await?;
685 progress.finish();
686 Ok(uploaded)
687 }
688
689 pub async fn upload_reader<R>(
698 &self,
699 bucket: &BucketUrl,
700 filename: &str,
701 reader: R,
702 content_length: u64,
703 content_type: mime::Mime,
704 ) -> Result<BucketObject, ZenodoError>
705 where
706 R: Read + Send + 'static,
707 {
708 self.upload_reader_with_progress(bucket, filename, reader, content_length, content_type, ())
709 .await
710 }
711
712 pub async fn upload_reader_with_progress<R, P>(
723 &self,
724 bucket: &BucketUrl,
725 filename: &str,
726 reader: R,
727 content_length: u64,
728 content_type: mime::Mime,
729 progress: P,
730 ) -> Result<BucketObject, ZenodoError>
731 where
732 R: Read + Send + 'static,
733 P: TransferProgress + 'static,
734 {
735 let progress = Arc::new(progress);
736 progress.begin(Some(content_length));
737 let body = sized_body_from_reader(reader, content_length, Arc::clone(&progress));
738
739 let uploaded = self
740 .execute_json(
741 self.request_url(Method::PUT, bucket_upload_url(bucket, filename)?)?
742 .header(CONTENT_LENGTH, content_length)
743 .header(CONTENT_TYPE, content_type.as_ref())
744 .body(body),
745 )
746 .await?;
747 progress.finish();
748 Ok(uploaded)
749 }
750
751 pub async fn publish(&self, id: DepositionId) -> Result<Deposition, ZenodoError> {
758 self.execute_json_or_else(
759 self.request(
760 Method::POST,
761 &format!("deposit/depositions/{id}/actions/publish"),
762 )?,
763 || async move { self.get_deposition(id).await },
764 )
765 .await
766 }
767
768 pub async fn edit(&self, id: DepositionId) -> Result<Deposition, ZenodoError> {
775 self.execute_json_or_else(
776 self.request(
777 Method::POST,
778 &format!("deposit/depositions/{id}/actions/edit"),
779 )?,
780 || async move { self.get_deposition(id).await },
781 )
782 .await
783 }
784
785 pub async fn discard(&self, id: DepositionId) -> Result<Deposition, ZenodoError> {
792 self.execute_json_or_else(
793 self.request(
794 Method::POST,
795 &format!("deposit/depositions/{id}/actions/discard"),
796 )?,
797 || async move { self.get_deposition(id).await },
798 )
799 .await
800 }
801
802 pub async fn new_version(&self, id: DepositionId) -> Result<Deposition, ZenodoError> {
809 self.execute_json_or_else(
810 self.request(
811 Method::POST,
812 &format!("deposit/depositions/{id}/actions/newversion"),
813 )?,
814 || async move { self.get_deposition(id).await },
815 )
816 .await
817 }
818}
819
820fn bucket_upload_url(bucket: &BucketUrl, filename: &str) -> Result<Url, ZenodoError> {
821 let mut url = bucket.0.clone();
822 let mut segments = url.path_segments_mut().map_err(|()| {
823 ZenodoError::InvalidState("bucket URL cannot accept filename segments".to_owned())
824 })?;
825 segments.pop_if_empty();
826 segments.push(filename);
827 drop(segments);
828 Ok(url)
829}
830
831fn sized_body_from_reader<R, P>(reader: R, content_length: u64, progress: Arc<P>) -> reqwest::Body
832where
833 R: Read + Send + 'static,
834 P: TransferProgress + 'static,
835{
836 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, std::io::Error>>(8);
837
838 tokio::task::spawn_blocking(move || {
839 let mut reader = reader;
840 let mut remaining = content_length;
841
842 while remaining > 0 {
843 let mut buf = vec![0_u8; remaining.min(64 * 1024) as usize];
844 match reader.read(&mut buf) {
845 Ok(0) => {
846 let _ = tx.blocking_send(Err(std::io::Error::new(
847 ErrorKind::UnexpectedEof,
848 "reader ended before declared content_length bytes were produced",
849 )));
850 return;
851 }
852 Ok(read) => {
853 buf.truncate(read);
854 remaining -= read as u64;
855 if tx.blocking_send(Ok(Bytes::from(buf))).is_err() {
856 return;
857 }
858 progress.advance(read as u64);
859 }
860 Err(error) => {
861 let _ = tx.blocking_send(Err(error));
862 return;
863 }
864 }
865 }
866 });
867
868 reqwest::Body::wrap_stream(ReceiverStream::new(rx))
869}
870
871#[cfg(test)]
872mod tests {
873 use std::env::VarError;
874 use std::io::{self, Cursor, Read};
875 use std::sync::Arc;
876 use std::sync::Mutex;
877 use std::time::Duration;
878
879 use super::{bucket_upload_url, Auth, ZenodoClient};
880 use crate::ids::BucketUrl;
881 use crate::{Endpoint, PollOptions, RecordId, ZenodoError};
882 use axum::extract::State;
883 use axum::http::StatusCode;
884 use axum::routing::get;
885 use axum::{Json, Router};
886 use http_body_util::BodyExt;
887 use reqwest::Method;
888 use secrecy::{ExposeSecret, SecretString};
889 use serde_json::json;
890 use tokio::net::TcpListener;
891 use url::Url;
892
893 static ENV_LOCK: Mutex<()> = Mutex::new(());
894
895 struct EnvVarGuard {
896 name: &'static str,
897 previous: Option<String>,
898 }
899
900 impl EnvVarGuard {
901 fn set(name: &'static str, value: Option<&str>) -> Self {
902 let previous = std::env::var(name).ok();
903 match value {
904 Some(value) => std::env::set_var(name, value),
905 None => std::env::remove_var(name),
906 }
907 Self { name, previous }
908 }
909 }
910
911 impl Drop for EnvVarGuard {
912 fn drop(&mut self) {
913 match &self.previous {
914 Some(value) => std::env::set_var(self.name, value),
915 None => std::env::remove_var(self.name),
916 }
917 }
918 }
919
920 #[test]
921 fn bucket_upload_preserves_path_and_encodes_filename() {
922 let bucket = BucketUrl(Url::parse("https://zenodo.org/api/files/bucket-id").unwrap());
923 let url = bucket_upload_url(&bucket, "artifact v1.tar.gz").unwrap();
924 assert_eq!(
925 url.as_str(),
926 "https://zenodo.org/api/files/bucket-id/artifact%20v1.tar.gz"
927 );
928 }
929
930 #[test]
931 fn auth_debug_redacts_tokens_and_builders_preserve_configuration() {
932 let auth = Auth::from(SecretString::from("secret"));
933 assert!(format!("{auth:?}").contains("<redacted>"));
934 assert_eq!(auth.token.expose_secret(), "secret");
935
936 let poll = PollOptions {
937 max_wait: Duration::from_secs(3),
938 initial_delay: Duration::from_millis(2),
939 max_delay: Duration::from_millis(4),
940 };
941 let endpoint = Endpoint::Custom(Url::parse("http://localhost:9999/api/").unwrap());
942 let client = ZenodoClient::builder(Auth::new("token"))
943 .endpoint(endpoint.clone())
944 .user_agent("custom-agent/1.0")
945 .request_timeout(Duration::from_secs(7))
946 .connect_timeout(Duration::from_secs(2))
947 .poll_options(poll.clone())
948 .build()
949 .unwrap();
950
951 assert_eq!(client.endpoint(), &endpoint);
952 assert_eq!(client.poll_options(), &poll);
953 assert_eq!(client.request_timeout(), Some(Duration::from_secs(7)));
954 assert_eq!(client.connect_timeout(), Some(Duration::from_secs(2)));
955 assert!(matches!(
956 ZenodoClient::builder(Auth::new("token"))
957 .sandbox()
958 .build()
959 .unwrap()
960 .endpoint(),
961 Endpoint::Sandbox
962 ));
963 assert!(ZenodoClient::new(Auth::new("token")).is_ok());
964 assert!(ZenodoClient::with_token("token").is_ok());
965 }
966
967 #[test]
968 fn env_helpers_read_expected_token_variables() {
969 let _lock = ENV_LOCK.lock().unwrap();
970 let _prod_guard = EnvVarGuard::set(Auth::TOKEN_ENV_VAR, Some("prod-token"));
971 let _sandbox_guard = EnvVarGuard::set(Auth::SANDBOX_TOKEN_ENV_VAR, Some("sandbox-token"));
972 let _custom_guard = EnvVarGuard::set("CUSTOM_ZENODO_TOKEN", Some("custom-token"));
973
974 assert_eq!(
975 Auth::from_env().unwrap().token.expose_secret(),
976 "prod-token"
977 );
978 assert_eq!(
979 Auth::from_sandbox_env().unwrap().token.expose_secret(),
980 "sandbox-token"
981 );
982 assert_eq!(
983 Auth::from_env_var("CUSTOM_ZENODO_TOKEN")
984 .unwrap()
985 .token
986 .expose_secret(),
987 "custom-token"
988 );
989 assert!(matches!(
990 ZenodoClient::from_sandbox_env().unwrap().endpoint(),
991 Endpoint::Sandbox
992 ));
993 assert!(matches!(
994 ZenodoClient::from_env().unwrap().endpoint(),
995 Endpoint::Production
996 ));
997 }
998
999 #[test]
1000 fn env_helpers_report_missing_variables() {
1001 let _lock = ENV_LOCK.lock().unwrap();
1002 let _prod_guard = EnvVarGuard::set(Auth::TOKEN_ENV_VAR, None);
1003 let _sandbox_guard = EnvVarGuard::set(Auth::SANDBOX_TOKEN_ENV_VAR, None);
1004
1005 match Auth::from_env().unwrap_err() {
1006 ZenodoError::EnvVar { name, source } => {
1007 assert_eq!(name, Auth::TOKEN_ENV_VAR);
1008 assert!(matches!(source, VarError::NotPresent));
1009 }
1010 other => panic!("unexpected error: {other:?}"),
1011 }
1012
1013 match ZenodoClient::from_sandbox_env().unwrap_err() {
1014 ZenodoError::EnvVar { name, source } => {
1015 assert_eq!(name, Auth::SANDBOX_TOKEN_ENV_VAR);
1016 assert!(matches!(source, VarError::NotPresent));
1017 }
1018 other => panic!("unexpected error: {other:?}"),
1019 }
1020 }
1021
1022 #[test]
1023 fn bucket_upload_rejects_urls_without_path_segments() {
1024 let bucket = BucketUrl(Url::parse("mailto:test@example.com").unwrap());
1025 let error = bucket_upload_url(&bucket, "artifact.bin").unwrap_err();
1026 assert!(matches!(error, crate::ZenodoError::InvalidState(_)));
1027 }
1028
1029 #[test]
1030 fn request_url_rejects_cross_origin_api_requests() {
1031 let client = ZenodoClient::builder(Auth::new("token"))
1032 .endpoint(Endpoint::Custom(
1033 Url::parse("http://localhost:1234/api/").unwrap(),
1034 ))
1035 .build()
1036 .unwrap();
1037
1038 let error = client
1039 .request_url(
1040 Method::GET,
1041 Url::parse("http://example.com/api/records/1").unwrap(),
1042 )
1043 .unwrap_err();
1044 assert!(matches!(error, ZenodoError::InvalidState(_)));
1045 }
1046
1047 #[tokio::test]
1048 async fn sized_body_from_reader_reports_short_reads() {
1049 let body =
1050 super::sized_body_from_reader(Cursor::new(b"ab".to_vec()), 5, std::sync::Arc::new(()));
1051 let error = body.collect().await.unwrap_err();
1052 assert!(error.is_body());
1053 }
1054
1055 struct BrokenReader;
1056
1057 impl Read for BrokenReader {
1058 fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
1059 Err(io::Error::other("boom"))
1060 }
1061 }
1062
1063 #[tokio::test]
1064 async fn sized_body_from_reader_reports_reader_errors() {
1065 let body = super::sized_body_from_reader(BrokenReader, 5, std::sync::Arc::new(()));
1066 let error = body.collect().await.unwrap_err();
1067 assert!(error.is_body());
1068 }
1069
1070 #[tokio::test]
1071 async fn sized_body_from_reader_tolerates_dropped_receiver() {
1072 let body =
1073 super::sized_body_from_reader(Cursor::new(b"abc".to_vec()), 3, std::sync::Arc::new(()));
1074 drop(body);
1075 tokio::time::sleep(Duration::from_millis(10)).await;
1076 }
1077
1078 #[tokio::test]
1079 async fn request_timeout_is_enforced_for_http_calls() {
1080 #[derive(Clone)]
1081 struct DelayState {
1082 delay: Duration,
1083 }
1084
1085 async fn delayed_record(
1086 State(state): State<Arc<DelayState>>,
1087 ) -> (StatusCode, Json<serde_json::Value>) {
1088 tokio::time::sleep(state.delay).await;
1089 (
1090 StatusCode::OK,
1091 Json(json!({
1092 "id": 1,
1093 "recid": 1,
1094 "metadata": { "title": "slow" },
1095 "files": [],
1096 "links": {}
1097 })),
1098 )
1099 }
1100
1101 let state = Arc::new(DelayState {
1102 delay: Duration::from_millis(50),
1103 });
1104 let app = Router::new()
1105 .route("/api/records/1", get(delayed_record))
1106 .with_state(state);
1107 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1108 let addr = listener.local_addr().unwrap();
1109 let server = tokio::spawn(async move {
1110 axum::serve(listener, app).await.unwrap();
1111 });
1112
1113 let client = ZenodoClient::builder(Auth::new("token"))
1114 .endpoint(Endpoint::Custom(
1115 Url::parse(&format!("http://{addr}/api/")).unwrap(),
1116 ))
1117 .request_timeout(Duration::from_millis(10))
1118 .build()
1119 .unwrap();
1120
1121 let error = client.get_record(RecordId(1)).await.unwrap_err();
1122 match error {
1123 ZenodoError::Transport(source) => assert!(source.is_timeout()),
1124 other => panic!("unexpected error: {other:?}"),
1125 }
1126
1127 server.abort();
1128 }
1129}