#[cfg(test)]
mod tests {
use anyhow::Result;
use google_cloud_auth::credentials::mds::Builder;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
#[tokio::test]
async fn hyper_incomplete_message() -> Result<()> {
let endpoint = spawn_incomplete_message_server().await;
let creds = Builder::default()
.with_endpoint(endpoint)
.build_access_token_credentials()?;
let token = creds.access_token().await;
let err = token.expect_err("token request should fail");
assert!(err.is_transient());
let hyper_err = as_inner::<hyper::Error, _>(&err).expect("should contain a hyper error");
assert!(hyper_err.is_incomplete_message(), "{hyper_err:?}");
Ok(())
}
#[tokio::test]
async fn io_connection_reset() -> Result<()> {
let endpoint = spawn_connection_reset_server().await;
let creds = Builder::default()
.with_endpoint(endpoint)
.build_access_token_credentials()?;
let token = creds.access_token().await;
let err = token.expect_err("token request should fail");
assert!(err.is_transient());
let io_err = as_inner::<std::io::Error, _>(&err).expect("should contain an io error");
assert!(
matches!(io_err.kind(), std::io::ErrorKind::ConnectionReset),
"{io_err:?}"
);
Ok(())
}
async fn spawn_incomplete_message_server() -> String {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("Failed to bind TCP listener");
let addr = listener.local_addr().expect("Failed to get local address");
tokio::spawn(async move {
loop {
if let Ok((mut socket, _)) = listener.accept().await {
tokio::spawn(async move {
let mut buf = [0; 1024];
let _ = socket.read(&mut buf).await;
let _ = socket.write_all(b"HTTP/1.1 200 O").await;
let _ = socket.flush().await;
});
}
}
});
format!("http://{}", addr)
}
async fn spawn_connection_reset_server() -> String {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("Failed to bind TCP listener");
let addr = listener.local_addr().expect("Failed to get local address");
tokio::spawn(async move {
loop {
if let Ok((mut socket, _)) = listener.accept().await {
tokio::spawn(async move {
let mut buf = [0; 1024];
let _ = socket.read(&mut buf).await;
let _ = socket.set_zero_linger();
});
}
}
});
format!("http://{}", addr)
}
fn as_inner<T, E>(error: &E) -> Option<&T>
where
T: std::error::Error + 'static,
E: std::error::Error,
{
let mut e = error.source()?;
for _ in 0..32 {
if let Some(value) = e.downcast_ref::<T>() {
return Some(value);
}
e = e.source()?;
}
None
}
}