#![allow(dead_code)]
use super::{AtpSession, TransferId, TransferProgress};
use crate::channel::mpsc;
use crate::cx::Cx;
use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::net::atp::protocol::{AtpError, AtpOutcome, PlatformError, ProtocolError};
macro_rules! try_atp {
($expr:expr, $error_mapper:expr) => {
match $expr {
Ok(v) => v,
Err(e) => return AtpOutcome::Err($error_mapper(e)),
}
};
}
use crate::obligation::graded::{GradedObligation, Resolution};
use futures_lite::Stream;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use std::task::{Context, Poll};
#[derive(Debug)]
pub struct AtpWriter {
transfer_id: TransferId,
data_tx: mpsc::Sender<StreamChunk>,
progress_rx: mpsc::Receiver<TransferProgress>,
cancel_tx: Option<mpsc::Sender<()>>,
obligation: Option<GradedObligation>,
config: StreamConfig,
state: WriterState,
}
#[derive(Debug)]
pub struct AtpReader {
transfer_id: TransferId,
data_rx: mpsc::Receiver<StreamChunk>,
progress_rx: mpsc::Receiver<TransferProgress>,
cancel_tx: Option<mpsc::Sender<()>>,
obligation: Option<GradedObligation>,
config: StreamConfig,
state: ReaderState,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct StreamConfig {
pub buffer_size: usize,
pub chunk_size: usize,
pub enable_compression: bool,
pub enable_repair: bool,
pub backpressure_threshold: usize,
pub chunk_timeout_ms: u64,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
buffer_size: 64 * 1024, chunk_size: 8 * 1024, enable_compression: true,
enable_repair: false,
backpressure_threshold: 256 * 1024, chunk_timeout_ms: 5000, }
}
}
#[derive(Debug, Clone)]
pub struct StreamChunk {
pub data: Vec<u8>,
pub sequence: u64,
pub is_final: bool,
pub checksum: u32,
}
impl StreamChunk {
#[must_use]
pub fn new(data: Vec<u8>, sequence: u64, is_final: bool) -> Self {
let checksum = crc32fast::hash(&data);
Self {
data,
sequence,
is_final,
checksum,
}
}
#[must_use]
pub fn verify(&self) -> bool {
crc32fast::hash(&self.data) == self.checksum
}
#[must_use]
pub fn size(&self) -> usize {
self.data.len()
}
}
#[derive(Debug, Clone)]
pub enum WriterState {
Ready,
Writing,
Flushing,
Closed,
Error(String),
}
#[derive(Debug, Clone)]
pub enum ReaderState {
Ready,
Reading,
Buffering(Vec<u8>), Closed,
Error(String),
}
impl AtpSession {
pub async fn create_writer(&self, cx: &Cx, config: StreamConfig) -> AtpOutcome<AtpWriter> {
try_atp!(cx.checkpoint(), |_| AtpError::Platform(
PlatformError::OperatingSystemError
));
let _ = config;
AtpOutcome::Err(AtpError::Protocol(ProtocolError::SessionStateMismatch))
}
pub async fn create_reader(&self, cx: &Cx, config: StreamConfig) -> AtpOutcome<AtpReader> {
try_atp!(cx.checkpoint(), |_| AtpError::Platform(
PlatformError::OperatingSystemError
));
let _ = config;
AtpOutcome::Err(AtpError::Protocol(ProtocolError::SessionStateMismatch))
}
}
impl AtpWriter {
#[must_use]
pub const fn transfer_id(&self) -> &TransferId {
&self.transfer_id
}
#[must_use]
pub const fn state(&self) -> &WriterState {
&self.state
}
pub async fn next_progress(&mut self, cx: &Cx) -> Option<TransferProgress> {
self.progress_rx.recv(cx).await.ok()
}
pub async fn close(&mut self) -> AtpOutcome<()> {
if matches!(self.state, WriterState::Closed) {
return AtpOutcome::ok(());
}
self.state = WriterState::Flushing;
let final_chunk = StreamChunk::new(Vec::new(), 0, true);
try_atp!(self.data_tx.try_send(final_chunk), |_| AtpError::Platform(
PlatformError::OperatingSystemError
));
if let Some(cancel_tx) = self.cancel_tx.take() {
let _ = cancel_tx.try_send(()); }
if let Some(obligation) = self.obligation.take() {
let _ = obligation.resolve(Resolution::Commit);
}
self.state = WriterState::Closed;
AtpOutcome::ok(())
}
pub async fn write_chunk(&mut self, data: Vec<u8>) -> AtpOutcome<()> {
if !matches!(self.state, WriterState::Ready | WriterState::Writing) {
return AtpOutcome::Err(AtpError::Platform(PlatformError::OperatingSystemError));
}
self.state = WriterState::Writing;
let chunk = StreamChunk::new(data, 0, false); try_atp!(self.data_tx.try_send(chunk), |_| AtpError::Platform(
PlatformError::OperatingSystemError
));
self.state = WriterState::Ready;
AtpOutcome::ok(())
}
}
impl AtpReader {
#[must_use]
pub const fn transfer_id(&self) -> &TransferId {
&self.transfer_id
}
#[must_use]
pub const fn state(&self) -> &ReaderState {
&self.state
}
pub async fn next_progress(&mut self) -> Option<TransferProgress> {
self.progress_rx.try_recv().ok()
}
pub async fn read_chunk(&mut self) -> AtpOutcome<Option<StreamChunk>> {
if matches!(self.state, ReaderState::Closed | ReaderState::Error(_)) {
return AtpOutcome::ok(None);
}
self.state = ReaderState::Reading;
match self.data_rx.try_recv() {
Ok(chunk) => {
if chunk.is_final {
self.state = ReaderState::Closed;
} else {
self.state = ReaderState::Ready;
}
AtpOutcome::ok(Some(chunk))
}
Err(mpsc::RecvError::Empty) => {
self.state = ReaderState::Ready;
AtpOutcome::ok(None)
}
Err(mpsc::RecvError::Disconnected | mpsc::RecvError::Cancelled) => {
self.state = ReaderState::Closed;
AtpOutcome::ok(None)
}
}
}
pub async fn read_buffer(&mut self, buf: &mut [u8]) -> AtpOutcome<usize> {
let mut bytes_read = 0;
while bytes_read < buf.len() {
if let ReaderState::Buffering(buffered_data) = &mut self.state {
let to_copy = std::cmp::min(buffered_data.len(), buf.len() - bytes_read);
buf[bytes_read..bytes_read + to_copy].copy_from_slice(&buffered_data[..to_copy]);
buffered_data.drain(..to_copy);
bytes_read += to_copy;
if buffered_data.is_empty() {
self.state = ReaderState::Ready;
}
if bytes_read == buf.len() {
break;
}
}
let chunk_outcome = self.read_chunk().await;
let chunk_option = match chunk_outcome {
AtpOutcome::Ok(v) => v,
AtpOutcome::Err(e) => return AtpOutcome::Err(e),
AtpOutcome::Cancelled(r) => return AtpOutcome::Cancelled(r),
AtpOutcome::Panicked(p) => return AtpOutcome::Panicked(p),
};
match chunk_option {
Some(chunk) => {
let to_copy = std::cmp::min(chunk.data.len(), buf.len() - bytes_read);
buf[bytes_read..bytes_read + to_copy].copy_from_slice(&chunk.data[..to_copy]);
bytes_read += to_copy;
if to_copy < chunk.data.len() {
self.state = ReaderState::Buffering(chunk.data[to_copy..].to_vec());
}
}
None => break, }
}
AtpOutcome::ok(bytes_read)
}
pub async fn close(&mut self) -> AtpOutcome<()> {
if matches!(self.state, ReaderState::Closed) {
return AtpOutcome::ok(());
}
if let Some(cancel_tx) = self.cancel_tx.take() {
let _ = cancel_tx.try_send(()); }
if let Some(obligation) = self.obligation.take() {
let _ = obligation.resolve(Resolution::Commit);
}
self.state = ReaderState::Closed;
AtpOutcome::ok(())
}
}
impl Drop for AtpWriter {
fn drop(&mut self) {
if let Some(cancel_tx) = self.cancel_tx.take() {
let _ = cancel_tx.try_send(());
}
if let Some(obligation) = self.obligation.take() {
let _ = obligation.resolve(Resolution::Abort);
}
}
}
impl Drop for AtpReader {
fn drop(&mut self) {
if let Some(cancel_tx) = self.cancel_tx.take() {
let _ = cancel_tx.try_send(());
}
if let Some(obligation) = self.obligation.take() {
let _ = obligation.resolve(Resolution::Abort);
}
}
}
impl AsyncWrite for AtpWriter {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
if !matches!(self.state, WriterState::Ready | WriterState::Writing) {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Writer is not ready",
)));
}
let chunk_size = std::cmp::min(buf.len(), self.config.chunk_size);
let data = buf[..chunk_size].to_vec();
match self.data_tx.try_send(StreamChunk::new(data, 0, false)) {
Ok(()) => Poll::Ready(Ok(chunk_size)),
Err(mpsc::SendError::Full(_)) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(mpsc::SendError::Disconnected(_)) => Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Channel closed",
))),
Err(mpsc::SendError::Cancelled(_)) => Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"Channel cancelled",
))),
}
}
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.state = WriterState::Ready;
Poll::Ready(Ok(()))
}
fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.state = WriterState::Closed;
Poll::Ready(Ok(()))
}
}
impl AsyncRead for AtpReader {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if matches!(self.state, ReaderState::Closed | ReaderState::Error(_)) {
return Poll::Ready(Ok(()));
}
if let ReaderState::Buffering(buffered_data) = &mut self.state {
let to_copy = std::cmp::min(buffered_data.len(), buf.remaining());
buf.put_slice(&buffered_data[..to_copy]);
buffered_data.drain(..to_copy);
if buffered_data.is_empty() {
self.state = ReaderState::Ready;
}
return Poll::Ready(Ok(()));
}
match self.data_rx.try_recv() {
Ok(chunk) => {
let to_copy = std::cmp::min(chunk.data.len(), buf.remaining());
buf.put_slice(&chunk.data[..to_copy]);
if to_copy < chunk.data.len() {
self.state = ReaderState::Buffering(chunk.data[to_copy..].to_vec());
} else if chunk.is_final {
self.state = ReaderState::Closed;
}
Poll::Ready(Ok(()))
}
Err(mpsc::RecvError::Empty) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(mpsc::RecvError::Disconnected | mpsc::RecvError::Cancelled) => {
self.state = ReaderState::Closed;
Poll::Ready(Ok(()))
}
}
}
}
impl Stream for AtpWriter {
type Item = TransferProgress;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.progress_rx.try_recv() {
Ok(p) => Poll::Ready(Some(p)),
Err(mpsc::RecvError::Empty) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(mpsc::RecvError::Disconnected | mpsc::RecvError::Cancelled) => Poll::Ready(None),
}
}
}
impl Stream for AtpReader {
type Item = TransferProgress;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.progress_rx.try_recv() {
Ok(p) => Poll::Ready(Some(p)),
Err(mpsc::RecvError::Empty) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(mpsc::RecvError::Disconnected | mpsc::RecvError::Cancelled) => Poll::Ready(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::net::atp::protocol::{
CapabilityAction, CapabilityGrant, CapabilityGrantId, CapabilityScope, PeerId,
SessionContextKind,
};
use crate::net::atp::sdk::{AtpSdk, SessionConfig, SessionOptions};
use futures_lite::future::block_on;
fn granted_direct_options(config: &SessionConfig, peer: PeerId, label: &str) -> SessionOptions {
SessionOptions::direct(peer).with_grants(vec![CapabilityGrant::new(
CapabilityGrantId::from_label(label),
peer,
config.local_peer,
[CapabilityAction::Read, CapabilityAction::Write],
CapabilityScope::for_context(SessionContextKind::Direct),
)])
}
fn assert_missing_stream_transport<T: std::fmt::Debug>(outcome: AtpOutcome<T>) {
match outcome {
AtpOutcome::Err(AtpError::Protocol(ProtocolError::SessionStateMismatch)) => {}
other => panic!("stream setup must fail closed without transport: {other:?}"), }
}
#[test]
fn stream_chunk_creation() {
let data = b"test data".to_vec();
let chunk = StreamChunk::new(data.clone(), 42, false);
assert_eq!(chunk.data, data);
assert_eq!(chunk.sequence, 42);
assert!(!chunk.is_final);
assert!(chunk.verify());
let mut bad_chunk = chunk.clone();
bad_chunk.data[0] = 0xFF; assert!(!bad_chunk.verify());
}
#[test]
fn atp_writer_creation() {
crate::test_utils::init_test_logging();
let cx = crate::cx::Cx::for_testing();
block_on(async {
let config = SessionConfig::default();
let sdk = AtpSdk::new_in_process(config);
let peer = PeerId::from_label("test_peer");
let session_options =
granted_direct_options(&SessionConfig::default(), peer, "writer-open");
let session = sdk.open_session(&cx, session_options).await.unwrap();
let stream_config = StreamConfig::default();
assert_missing_stream_transport(session.create_writer(&cx, stream_config).await);
});
crate::test_complete!("atp_writer_creation");
}
#[test]
fn atp_reader_creation() {
crate::test_utils::init_test_logging();
let cx = crate::cx::Cx::for_testing();
block_on(async {
let config = SessionConfig::default();
let sdk = AtpSdk::new_in_process(config);
let peer = PeerId::from_label("test_peer");
let session_options =
granted_direct_options(&SessionConfig::default(), peer, "reader-open");
let session = sdk.open_session(&cx, session_options).await.unwrap();
let stream_config = StreamConfig::default();
assert_missing_stream_transport(session.create_reader(&cx, stream_config).await);
});
crate::test_complete!("atp_reader_creation");
}
#[test]
fn writer_chunk_operations() {
crate::test_utils::init_test_logging();
let cx = crate::cx::Cx::for_testing();
block_on(async {
let config = SessionConfig::default();
let sdk = AtpSdk::new_in_process(config);
let peer = PeerId::from_label("test_peer");
let session_options =
granted_direct_options(&SessionConfig::default(), peer, "writer-chunk");
let session = sdk.open_session(&cx, session_options).await.unwrap();
let stream_config = StreamConfig::default();
assert_missing_stream_transport(session.create_writer(&cx, stream_config).await);
});
crate::test_complete!("writer_chunk_operations");
}
#[test]
fn reader_chunk_operations() {
crate::test_utils::init_test_logging();
let cx = crate::cx::Cx::for_testing();
block_on(async {
let config = SessionConfig::default();
let sdk = AtpSdk::new_in_process(config);
let peer = PeerId::from_label("test_peer");
let session_options =
granted_direct_options(&SessionConfig::default(), peer, "reader-chunk");
let session = sdk.open_session(&cx, session_options).await.unwrap();
let stream_config = StreamConfig::default();
assert_missing_stream_transport(session.create_reader(&cx, stream_config).await);
});
crate::test_complete!("reader_chunk_operations");
}
#[test]
fn async_write_interface() {
crate::test_utils::init_test_logging();
let cx = crate::cx::Cx::for_testing();
block_on(async {
let config = SessionConfig::default();
let sdk = AtpSdk::new_in_process(config);
let peer = PeerId::from_label("test_peer");
let session_options =
granted_direct_options(&SessionConfig::default(), peer, "async-write");
let session = sdk.open_session(&cx, session_options).await.unwrap();
let stream_config = StreamConfig::default();
assert_missing_stream_transport(session.create_writer(&cx, stream_config).await);
});
crate::test_complete!("async_write_interface");
}
#[test]
fn async_read_interface() {
crate::test_utils::init_test_logging();
let cx = crate::cx::Cx::for_testing();
block_on(async {
let config = SessionConfig::default();
let sdk = AtpSdk::new_in_process(config);
let peer = PeerId::from_label("test_peer");
let session_options =
granted_direct_options(&SessionConfig::default(), peer, "async-read");
let session = sdk.open_session(&cx, session_options).await.unwrap();
let stream_config = StreamConfig::default();
assert_missing_stream_transport(session.create_reader(&cx, stream_config).await);
});
crate::test_complete!("async_read_interface");
}
}