use super::active_read::ActiveRead;
use super::range_reader::RangeReader;
use crate::model::Object;
use crate::model_ext::{ReadRange, RequestedRange};
use crate::read_object::ReadObjectResponse;
use crate::stub::ObjectDescriptor;
use crate::{Error, Result};
use http::HeaderMap;
use std::sync::Arc;
use tokio::sync::mpsc::Sender;
#[derive(Debug)]
pub struct ObjectDescriptorTransport {
object: Arc<Object>,
headers: HeaderMap,
tx: Sender<ActiveRead>,
}
impl ObjectDescriptorTransport {
pub async fn new<T>(
mut connector: super::connector::Connector<T>,
ranges: Vec<ReadRange>,
) -> Result<(Self, Vec<ReadObjectResponse>)>
where
T: super::Client + Clone + Sync,
<T as super::Client>::Stream: super::TonicStreaming + Send + Sync,
{
use gaxi::prost::FromProto;
let (tx, rx) = tokio::sync::mpsc::channel(100);
let requested_ranges = ranges.into_iter().map(|r| r.0).collect::<Vec<_>>();
let proto_ranges = requested_ranges
.iter()
.enumerate()
.map(|(id, r)| r.as_proto(id as i64))
.collect::<Vec<_>>();
let (mut initial, headers, connection) = connector.connect(proto_ranges).await?;
let object = FromProto::cnv(initial.metadata.take().ok_or_else(|| {
Error::deser("initial response in bidi read must contain object metadata")
})?)
.expect("transforming from proto Object never fails");
let object = Arc::new(object);
let (active, readers) = Self::map_ranges(requested_ranges, &tx, &object);
let mut worker = super::worker::Worker::new(connector, active);
worker
.handle_response_success(initial)
.await
.map_err(Error::io)?;
let _handle = tokio::spawn(worker.run(connection, rx));
Ok((
Self {
object,
headers,
tx,
},
readers,
))
}
fn map_ranges(
ranges: Vec<RequestedRange>,
requests: &Sender<ActiveRead>,
object: &Arc<Object>,
) -> (Vec<ActiveRead>, Vec<ReadObjectResponse>) {
ranges
.into_iter()
.map(|r| {
let (tx, rx) = tokio::sync::mpsc::channel(100);
let active = ActiveRead::new(tx, r);
let reader = RangeReader::new(rx, object.clone(), requests.clone());
(active, ReadObjectResponse::new(Box::new(reader)))
})
.unzip()
}
}
impl ObjectDescriptor for ObjectDescriptorTransport {
fn object(&self) -> Object {
self.object.as_ref().clone()
}
async fn read_range(&self, range: ReadRange) -> ReadObjectResponse {
let (tx, rx) = tokio::sync::mpsc::channel(100);
let range = ActiveRead::new(tx, range.0);
self.tx
.send(range)
.await
.expect("worker never exits while ObjectDescriptor is live");
ReadObjectResponse::new(Box::new(RangeReader::new(
rx,
self.object.clone(),
self.tx.clone(),
)))
}
fn headers(&self) -> HeaderMap {
self.headers.clone()
}
}
#[cfg(test)]
mod tests {
use super::super::mocks::{MockTestClient, mock_connector};
use super::*;
use crate::error::ReadError;
use crate::google::storage::v2::{
BidiReadHandle, BidiReadObjectResponse, ChecksummedData, Object as ProtoObject,
ObjectRangeData, ReadRange as ProtoRange,
};
use crate::storage::bidi::tests::{permanent_error, proto_range};
use gaxi::grpc::tonic::{Response as TonicResponse, Result as TonicResult, Status};
#[tokio::test]
async fn success() -> anyhow::Result<()> {
const LEN: i64 = 42;
let (connect_tx, connect_rx) =
tokio::sync::mpsc::channel::<TonicResult<BidiReadObjectResponse>>(8);
let initial = BidiReadObjectResponse {
metadata: Some(ProtoObject {
bucket: "projects/_/buckets/test-bucket".into(),
name: "test-object".into(),
generation: 123456,
..ProtoObject::default()
}),
read_handle: Some(BidiReadHandle {
handle: bytes::Bytes::from_static(b"test-read-handle"),
}),
..BidiReadObjectResponse::default()
};
connect_tx.send(Ok(initial)).await?;
let connect_stream = TonicResponse::from(connect_rx);
let receivers = Arc::new(std::sync::Mutex::new(Vec::new()));
let save = receivers.clone();
let mut mock = MockTestClient::new();
mock.expect_start().return_once(move |_, _, rx, _, _, _| {
save.lock().expect("never poisoned").push(rx);
Ok(Ok(connect_stream))
});
let connector = mock_connector(mock);
let (transport, _) = ObjectDescriptorTransport::new(connector, Vec::new()).await?;
let want = Object::new()
.set_bucket("projects/_/buckets/test-bucket")
.set_name("test-object")
.set_generation(123456);
assert_eq!(transport.object(), want, "{transport:?}");
let mut connect_rx = {
let mut guard = receivers.lock().expect("never poisoned");
let rx = guard.pop().expect("at least one receiver");
assert!(guard.is_empty(), "{receivers:?}");
rx
};
let request = connect_rx.recv().await.expect("the initial request");
assert!(request.read_object_spec.is_some(), "{request:?}");
let mut reader = transport.read_range(ReadRange::segment(100, 200)).await;
let request = connect_rx.recv().await.expect("the read request");
let range_request = request.read_ranges.first();
assert_eq!(range_request, Some(&proto_range(100, 200)), "{request:?}");
let content = bytes::Bytes::from_owner(String::from_iter((0..LEN).map(|_| 'x')));
let response = BidiReadObjectResponse {
object_data_ranges: vec![ObjectRangeData {
checksummed_data: Some(ChecksummedData {
content: content.clone(),
..ChecksummedData::default()
}),
read_range: Some(ProtoRange {
read_offset: 100,
read_length: LEN,
read_id: 0,
}),
range_end: true,
}],
..BidiReadObjectResponse::default()
};
connect_tx.send(Ok(response)).await?;
let got = reader.next().await.transpose()?;
assert_eq!(got, Some(content));
let got = reader.next().await.transpose()?;
assert!(got.is_none(), "{got:?}");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn read_range_error() -> anyhow::Result<()> {
use std::error::Error as _;
let (connect_tx, connect_rx) =
tokio::sync::mpsc::channel::<TonicResult<BidiReadObjectResponse>>(8);
let initial = BidiReadObjectResponse {
metadata: Some(ProtoObject {
bucket: "projects/_/buckets/test-bucket".into(),
name: "test-object".into(),
generation: 123456,
..ProtoObject::default()
}),
read_handle: Some(BidiReadHandle {
handle: bytes::Bytes::from_static(b"test-read-handle"),
}),
..BidiReadObjectResponse::default()
};
connect_tx.send(Ok(initial)).await?;
let connect_stream = TonicResponse::from(connect_rx);
let mut mock = MockTestClient::new();
let mut seq = mockall::Sequence::new();
mock.expect_start()
.times(1)
.in_sequence(&mut seq)
.return_once(move |_, _, _, _, _, _| Ok(Ok(connect_stream)));
let connector = mock_connector(mock);
let (transport, _) = ObjectDescriptorTransport::new(connector, Vec::new()).await?;
let want = Object::new()
.set_bucket("projects/_/buckets/test-bucket")
.set_name("test-object")
.set_generation(123456);
assert_eq!(transport.object(), want, "{transport:?}");
let mut existing = transport.read_range(ReadRange::segment(100, 200)).await;
connect_tx
.send(Err(Status::permission_denied("uh-oh")))
.await?;
let err = existing.next().await.transpose().unwrap_err();
let source = err.source().and_then(|e| e.downcast_ref::<ReadError>());
assert!(
matches!(
source,
Some(ReadError::UnrecoverableBidiReadInterrupt(e)) if e.status().is_some()
),
"{err:?}"
);
let got = existing.next().await;
assert!(got.is_none(), "{got:?}");
drop(connect_tx);
let mut reader = transport.read_range(ReadRange::segment(100, 200)).await;
let err = reader.next().await.transpose().unwrap_err();
let source = err.source().and_then(|e| e.downcast_ref::<ReadError>());
assert!(
matches!(
source,
Some(ReadError::UnrecoverableBidiReadInterrupt(e)) if e.status().is_some()
),
"{err:?}"
);
let got = reader.next().await;
assert!(got.is_none(), "{got:?}");
Ok(())
}
#[tokio::test]
async fn connect_error() -> anyhow::Result<()> {
let mut mock = MockTestClient::new();
mock.expect_start()
.return_once(move |_, _, _, _, _, _| Err(permanent_error()));
let connector = mock_connector(mock);
let err = ObjectDescriptorTransport::new(connector, Vec::new())
.await
.unwrap_err();
assert_eq!(err.status(), permanent_error().status(), "{err:?}");
Ok(())
}
#[tokio::test]
async fn deser_error() -> anyhow::Result<()> {
let (connect_tx, connect_rx) =
tokio::sync::mpsc::channel::<TonicResult<BidiReadObjectResponse>>(8);
let initial = BidiReadObjectResponse {
metadata: None,
read_handle: Some(BidiReadHandle {
handle: bytes::Bytes::from_static(b"test-read-handle"),
}),
..BidiReadObjectResponse::default()
};
connect_tx.send(Ok(initial)).await?;
let connect_stream = TonicResponse::from(connect_rx);
let mut mock = MockTestClient::new();
mock.expect_start()
.return_once(move |_, _, _, _, _, _| Ok(Ok(connect_stream)));
let connector = mock_connector(mock);
let err = ObjectDescriptorTransport::new(connector, Vec::new())
.await
.unwrap_err();
assert!(err.is_deserialization(), "{err:?}");
Ok(())
}
}