use crate::ddp::{DdpHandle, DdpSocket};
use std::collections::HashMap;
use std::io;
use tailtalk_packets::{
atp::{AtpFunction, AtpPacket},
ddp::{DdpPacket, DdpProtocolType},
};
use tokio::sync::{mpsc, oneshot};
pub const ATP_MAX_DATA_PER_PACKET: usize = 578;
type AtpResponseChannel = oneshot::Sender<Result<(Vec<u8>, [u8; 4]), io::Error>>;
pub struct PendingRequestState {
pub chan: AtpResponseChannel,
pub xo: bool,
pub received_packets: std::collections::BTreeMap<u8, Vec<u8>>,
pub user_bytes: Option<[u8; 4]>,
pub eom_seq: Option<u8>,
pub raw_packet: Vec<u8>,
pub destination: AtpAddress,
pub retry_count: u8,
}
type AtpTransactionMap = HashMap<u16, PendingRequestState>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct AtpAddress {
pub network_number: u16,
pub node_number: u8,
pub socket_number: u8,
}
#[derive(Debug)]
pub struct AtpSendRequest {
pub address: AtpAddress,
pub user_bytes: [u8; 4],
pub data: Vec<u8>,
pub chan: AtpResponseChannel,
}
#[derive(Debug)]
pub struct AtpResponse {
pub data: Vec<u8>,
pub user_bytes: [u8; 4],
}
#[derive(Debug)]
pub struct AtpSendResponse {
pub destination: AtpAddress,
pub tid: u16,
pub packets: Vec<AtpResponse>,
}
#[derive(Debug)]
pub struct AtpSendRelease {
pub destination: AtpAddress,
pub tid: u16,
}
#[derive(Debug)]
pub struct AtpSendAlo {
pub address: AtpAddress,
pub user_bytes: [u8; 4],
}
pub enum AtpCommand {
SendRequest(AtpSendRequest),
SendResponse(AtpSendResponse),
SendRelease(AtpSendRelease),
SendAlo(AtpSendAlo),
}
pub struct AtpReceivedRequest {
pub transaction_id: u16,
pub source: AtpAddress,
pub user_bytes: [u8; 4],
pub data: Vec<u8>,
pub response_sender: mpsc::Sender<AtpCommand>,
pub release_rx: Option<oneshot::Receiver<()>>,
pub bitmap: u8,
}
impl AtpReceivedRequest {
pub fn max_response_bytes(&self) -> usize {
let effective_bitmap = if self.bitmap == 0x00 { 0xFF } else { self.bitmap };
let max_packets = (effective_bitmap.count_ones() as usize).clamp(1, 8);
max_packets * ATP_MAX_DATA_PER_PACKET
}
pub async fn send_response(&self, data: Vec<u8>, user_bytes: [u8; 4]) -> Result<(), io::Error> {
let effective_bitmap = if self.bitmap == 0x00 {
0xFF
} else {
self.bitmap
};
let max_packets = (effective_bitmap.count_ones() as usize).clamp(1, 8);
let max_data = max_packets * ATP_MAX_DATA_PER_PACKET;
if data.len() > max_data {
tracing::warn!(
"ATP response truncated: {} bytes requested but client bitmap 0x{:02x} only allows {} bytes ({} packets)",
data.len(),
self.bitmap,
max_data,
max_packets
);
}
let mut packets: Vec<AtpResponse> = data[..data.len().min(max_data)]
.chunks(ATP_MAX_DATA_PER_PACKET)
.map(|chunk| AtpResponse {
data: chunk.to_vec(),
user_bytes,
})
.collect();
if packets.is_empty() {
packets.push(AtpResponse { data: vec![], user_bytes });
}
self.send_response_internal(packets).await
}
pub async fn send_response_chunked(
&self,
data: Vec<u8>,
user_bytes: [u8; 4],
chunk_size: usize,
) -> Result<(), io::Error> {
assert!(chunk_size > 0, "chunk_size must be positive");
let effective_bitmap = if self.bitmap == 0x00 { 0xFF } else { self.bitmap };
let max_packets = (effective_bitmap.count_ones() as usize).clamp(1, 8);
let max_data = max_packets * chunk_size;
let mut packets: Vec<AtpResponse> = data[..data.len().min(max_data)]
.chunks(chunk_size)
.map(|chunk| AtpResponse { data: chunk.to_vec(), user_bytes })
.collect();
if packets.is_empty() {
packets.push(AtpResponse { data: vec![], user_bytes });
}
self.send_response_internal(packets).await
}
async fn send_response_internal(&self, packets: Vec<AtpResponse>) -> Result<(), io::Error> {
if packets.len() > 8 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"cannot send more than 8 response packets",
));
}
let cmd = AtpCommand::SendResponse(AtpSendResponse {
destination: self.source,
tid: self.transaction_id,
packets,
});
self.response_sender
.send(cmd)
.await
.map_err(io::Error::other)
}
}
#[derive(Clone, Debug)]
pub struct AtpRequestor {
pub cmd_tx: mpsc::Sender<AtpCommand>,
pub socket_number: u8,
}
impl AtpRequestor {
pub async fn send_alo(
&self,
address: AtpAddress,
user_bytes: [u8; 4],
) -> Result<(), io::Error> {
let cmd = AtpCommand::SendAlo(AtpSendAlo { address, user_bytes });
self.cmd_tx.send(cmd).await.map_err(io::Error::other)
}
pub async fn send_request(
&self,
address: AtpAddress,
user_bytes: [u8; 4],
data: Vec<u8>,
) -> Result<(Vec<u8>, [u8; 4]), io::Error> {
let (tx, rx) = oneshot::channel();
let cmd = AtpCommand::SendRequest(AtpSendRequest {
address,
user_bytes,
data,
chan: tx,
});
self.cmd_tx.send(cmd).await.map_err(io::Error::other)?;
rx.await.map_err(io::Error::other)?
}
}
#[derive(Debug)]
pub struct AtpResponder {
pub incoming_rx: mpsc::Receiver<AtpReceivedRequest>,
}
impl AtpResponder {
pub async fn next(&mut self) -> Option<AtpReceivedRequest> {
self.incoming_rx.recv().await
}
}
pub struct Atp {
sock: DdpSocket,
request_recv: mpsc::Receiver<AtpCommand>,
incoming_req_tx: mpsc::Sender<AtpReceivedRequest>,
cmd_tx: mpsc::Sender<AtpCommand>,
pending_transactions: AtpTransactionMap,
pending_releases: HashMap<(AtpAddress, u16), oneshot::Sender<()>>,
next_tid: u16,
}
impl Atp {
pub async fn spawn(
ddp: &DdpHandle,
socket_number: Option<u8>,
) -> (u8, AtpRequestor, AtpResponder) {
let sock = ddp
.new_sock(DdpProtocolType::Atp, socket_number) .await
.expect("failed to create ATP sock");
let actual_socket = sock.socket_num();
let (request_send, request_recv) = mpsc::channel(100);
let (incoming_req_tx, incoming_req_rx) = mpsc::channel(32);
let atp = Atp {
sock,
request_recv,
incoming_req_tx,
cmd_tx: request_send.clone(),
pending_transactions: HashMap::new(),
pending_releases: HashMap::new(),
next_tid: 1, };
tokio::spawn(async move {
tracing::debug!("ATP actor starting");
atp.run().await;
tracing::debug!("ATP actor stopped");
});
(
actual_socket,
AtpRequestor {
cmd_tx: request_send,
socket_number: actual_socket,
},
AtpResponder {
incoming_rx: incoming_req_rx,
},
)
}
async fn run(mut self) {
let mut retry_interval = tokio::time::interval(tokio::time::Duration::from_secs(2));
retry_interval.tick().await;
loop {
tokio::select! {
sock_recv = self.sock.recv() => {
match sock_recv {
Ok(mut pkt) => {
self.handle_packet(pkt.headers, &mut pkt.payload).await;
},
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
tracing::debug!("ATP socket closed, shutting down");
break;
},
Err(e) => {
tracing::error!("ATP socket error: {}", e);
break;
},
}
},
req = self.request_recv.recv() => {
if let Some(command) = req {
match command {
AtpCommand::SendRequest(req) => self.handle_send_request(req).await,
AtpCommand::SendResponse(resp) => self.handle_send_response(resp).await,
AtpCommand::SendRelease(rel) => self.handle_send_release(rel).await,
AtpCommand::SendAlo(alo) => self.handle_send_alo(alo).await,
}
} else {
tracing::info!("ATP command channel closed");
break;
}
}
_ = retry_interval.tick() => {
self.retransmit_pending().await;
}
}
}
}
async fn retransmit_pending(&mut self) {
if self.pending_transactions.is_empty() {
return;
}
const MAX_RETRIES: u8 = 8;
let mut to_evict: Vec<u16> = Vec::new();
let mut to_resend: Vec<(u16, Vec<u8>, AtpAddress)> = Vec::new();
for (tid, state) in &mut self.pending_transactions {
state.retry_count += 1;
if state.retry_count > MAX_RETRIES {
to_evict.push(*tid);
} else {
to_resend.push((*tid, state.raw_packet.clone(), state.destination));
}
}
for tid in to_evict {
if let Some(state) = self.pending_transactions.remove(&tid) {
tracing::warn!("ATP: TID {} got no response after {} retransmits, giving up", tid, MAX_RETRIES);
let _ = state.chan.send(Err(io::Error::new(
io::ErrorKind::TimedOut,
"ATP: no response after maximum retransmits",
)));
}
}
for (tid, packet, dest_addr) in to_resend {
let dest = crate::ddp::DdpAddress::new(
tailtalk_packets::aarp::AppleTalkAddress {
network_number: dest_addr.network_number,
node_number: dest_addr.node_number,
},
dest_addr.socket_number,
);
if let Err(e) = self.sock.send_to(&packet, dest).await {
tracing::warn!("ATP retransmit failed for TID {}: {}", tid, e);
} else {
tracing::debug!("ATP retransmitting TID {}", tid);
}
}
}
async fn handle_send_request(&mut self, req: AtpSendRequest) {
let tid = {
let start = self.next_tid;
loop {
let candidate = self.next_tid;
self.next_tid = self.next_tid.wrapping_add(1);
if !self.pending_transactions.contains_key(&candidate) {
break candidate;
}
if self.next_tid == start {
let _ = req.chan.send(Err(io::Error::new(
io::ErrorKind::WouldBlock,
"ATP: all transaction IDs in use",
)));
return;
}
}
};
let packet = AtpPacket {
function: AtpFunction::Request,
xo: true, eom: false, sts: false,
bitmap_seq_num: 0xff, tid,
user_bytes: req.user_bytes,
};
let mut buf = [0u8; 600];
let header_len = packet
.to_bytes(&mut buf)
.expect("failed to serialize ATP header");
let total_len = header_len + req.data.len();
if req.data.len() > ATP_MAX_DATA_PER_PACKET {
tracing::error!(
"ATP request data too large: {} (max {})",
req.data.len(),
ATP_MAX_DATA_PER_PACKET
);
let _ = req.chan.send(Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("data too large (max {})", ATP_MAX_DATA_PER_PACKET),
)));
return;
}
buf[header_len..total_len].copy_from_slice(&req.data);
let dest = crate::ddp::DdpAddress::new(
tailtalk_packets::aarp::AppleTalkAddress {
network_number: req.address.network_number,
node_number: req.address.node_number,
},
req.address.socket_number,
);
let raw_packet = buf[..total_len].to_vec();
if let Err(e) = self.sock.send_to(&buf[..total_len], dest).await {
let _ = req.chan.send(Err(io::Error::other(e)));
} else {
self.pending_transactions.insert(
tid,
PendingRequestState {
chan: req.chan,
xo: true,
received_packets: std::collections::BTreeMap::new(),
user_bytes: None,
eom_seq: None,
raw_packet,
destination: req.address,
retry_count: 0,
},
);
}
}
async fn handle_send_response(&mut self, resp: AtpSendResponse) {
for (i, node) in resp.packets.iter().enumerate() {
let packet = AtpPacket {
function: AtpFunction::Response,
xo: false, eom: i == resp.packets.len() - 1, sts: false,
bitmap_seq_num: i as u8,
tid: resp.tid,
user_bytes: node.user_bytes,
};
let mut buf = [0u8; 600];
let header_len = packet
.to_bytes(&mut buf)
.expect("failed to serialize ATP response header");
let total_len = header_len + node.data.len();
if total_len > buf.len() {
tracing::error!("Response chunk too large: {}", node.data.len());
continue;
}
buf[header_len..total_len].copy_from_slice(&node.data);
let dest = crate::ddp::DdpAddress::new(
tailtalk_packets::aarp::AppleTalkAddress {
network_number: resp.destination.network_number,
node_number: resp.destination.node_number,
},
resp.destination.socket_number,
);
if let Err(e) = self.sock.send_to(&buf[..total_len], dest).await {
tracing::error!("Failed to send ATP response packet {}: {}", i, e);
}
}
}
async fn handle_send_alo(&mut self, alo: AtpSendAlo) {
let tid = self.next_tid;
self.next_tid = self.next_tid.wrapping_add(1);
let packet = AtpPacket {
function: AtpFunction::Request,
xo: false, eom: false,
sts: false,
bitmap_seq_num: 0xff,
tid,
user_bytes: alo.user_bytes,
};
let mut buf = [0u8; 600];
let header_len = packet
.to_bytes(&mut buf)
.expect("failed to serialize ATP ALO header");
let dest = crate::ddp::DdpAddress::new(
tailtalk_packets::aarp::AppleTalkAddress {
network_number: alo.address.network_number,
node_number: alo.address.node_number,
},
alo.address.socket_number,
);
if let Err(e) = self.sock.send_to(&buf[..header_len], dest).await {
tracing::warn!("Failed to send ATP ALO packet: {}", e);
}
}
async fn handle_send_release(&mut self, rel: AtpSendRelease) {
let packet = AtpPacket {
function: AtpFunction::Release,
xo: false,
eom: false,
sts: false,
bitmap_seq_num: 0,
tid: rel.tid,
user_bytes: [0; 4],
};
tracing::debug!(
"ATP Sending Release to {:?} tid={}",
rel.destination,
rel.tid
);
let mut buf = [0u8; 600];
let header_len = packet
.to_bytes(&mut buf)
.expect("failed to serialize ATP release header");
let dest = crate::ddp::DdpAddress::new(
tailtalk_packets::aarp::AppleTalkAddress {
network_number: rel.destination.network_number,
node_number: rel.destination.node_number,
},
rel.destination.socket_number,
);
if let Err(e) = self.sock.send_to(&buf[..header_len], dest).await {
tracing::error!("Failed to send ATP Release: {}", e);
}
}
async fn handle_packet(&mut self, ddp: DdpPacket, payload: &mut [u8]) {
let packet = match AtpPacket::parse(payload) {
Ok(p) => p,
Err(e) => {
tracing::warn!("Failed to parse ATP packet: {}", e);
return;
}
};
match packet.function {
AtpFunction::Request => {
let request_data = if payload.len() > AtpPacket::HEADER_LEN {
payload[AtpPacket::HEADER_LEN..].to_vec()
} else {
Vec::new()
};
let from = AtpAddress {
network_number: ddp.src_network_num,
node_number: ddp.src_node_id,
socket_number: ddp.src_sock_num,
};
let release_rx = if packet.xo {
let (tx, rx) = oneshot::channel();
self.pending_releases.insert((from, packet.tid), tx);
Some(rx)
} else {
None
};
let req = AtpReceivedRequest {
transaction_id: packet.tid,
source: from,
user_bytes: packet.user_bytes,
data: request_data,
response_sender: self.cmd_tx.clone(),
release_rx,
bitmap: packet.bitmap_seq_num,
};
if let Err(e) = self.incoming_req_tx.try_send(req) {
tracing::warn!("Dropping incoming ATP request (queue full): {}", e);
}
}
AtpFunction::Response => {
if let std::collections::hash_map::Entry::Occupied(mut entry) =
self.pending_transactions.entry(packet.tid)
{
if payload.len() >= AtpPacket::HEADER_LEN {
let data = payload[AtpPacket::HEADER_LEN..].to_vec();
let state = entry.get_mut();
state.received_packets.insert(packet.bitmap_seq_num, data);
if state.user_bytes.is_none() {
state.user_bytes = Some(packet.user_bytes);
}
if packet.eom {
state.eom_seq = Some(packet.bitmap_seq_num);
}
let mut is_complete = false;
if let Some(eom) = state.eom_seq {
if (0..=eom).all(|i| state.received_packets.contains_key(&i)) {
is_complete = true;
}
} else if state.received_packets.len() == 8 {
is_complete = true;
}
if is_complete {
let (_, mut state) = entry.remove_entry();
let mut full_data = Vec::new();
let expected_count = state.eom_seq.map(|e| e + 1).unwrap_or(8);
for i in 0..expected_count {
if let Some(p) = state.received_packets.remove(&i) {
full_data.extend_from_slice(&p);
}
}
let user_bytes = state.user_bytes.unwrap_or([0; 4]);
let _ = state.chan.send(Ok((full_data, user_bytes)));
if state.xo {
let rel = AtpSendRelease {
destination: AtpAddress {
network_number: ddp.src_network_num,
node_number: ddp.src_node_id,
socket_number: ddp.src_sock_num,
},
tid: packet.tid,
};
self.handle_send_release(rel).await;
}
}
} else {
tracing::warn!("ATP Response payload too short");
}
}
}
AtpFunction::Release => {
let from = AtpAddress {
network_number: ddp.src_network_num,
node_number: ddp.src_node_id,
socket_number: ddp.src_sock_num,
};
tracing::debug!(
"Received ATP Release packet from {:?} tid={}",
from,
packet.tid
);
if let Some(chan) = self.pending_releases.remove(&(from, packet.tid)) {
let _ = chan.send(());
}
}
}
}
}