use bytes::{Bytes, BytesMut};
use snap7_client::proto::{
cotp::CotpPdu,
s7::{
header::{PduType, S7Header},
negotiate::{NegotiateRequest, NegotiateResponse},
},
tpkt::TpktFrame,
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::error::{Error, Result};
const MAX_PDU_SIZE: u16 = 480;
const NEGOTIATE_PARAM_LEN: u16 = 8;
pub async fn server_handshake<T>(mut transport: T) -> Result<u16>
where
T: AsyncRead + AsyncWrite + Unpin,
{
let cr = recv_tpkt_cotp(&mut transport).await?;
let src_ref = match cr {
CotpPdu::ConnectRequest { src_ref, .. } => src_ref,
_ => return Err(Error::NegotiationFailed),
};
let cc = CotpPdu::ConnectConfirm {
dst_ref: src_ref,
src_ref: 0x0001,
};
send_tpkt_cotp(&mut transport, &cc).await?;
let mut payload = recv_cotp_data(&mut transport).await?;
let req_header = S7Header::decode(&mut payload)?;
if req_header.pdu_type != PduType::Job {
return Err(Error::NegotiationFailed);
}
let neg_req = NegotiateRequest::decode(&mut payload)?;
let negotiated = neg_req.pdu_length.min(MAX_PDU_SIZE);
let resp_header = S7Header {
pdu_type: PduType::AckData,
reserved: 0,
pdu_ref: req_header.pdu_ref,
param_len: NEGOTIATE_PARAM_LEN,
data_len: 0,
error_class: Some(0),
error_code: Some(0),
};
let neg_resp = NegotiateResponse {
max_amq_calling: neg_req.max_amq_calling,
max_amq_called: neg_req.max_amq_called,
pdu_length: negotiated,
};
let mut s7_buf = BytesMut::new();
resp_header.encode(&mut s7_buf);
neg_resp.encode(&mut s7_buf);
send_cotp_data(&mut transport, s7_buf.freeze()).await?;
Ok(negotiated)
}
pub(crate) async fn recv_tpkt_cotp<T: AsyncRead + Unpin>(transport: &mut T) -> Result<CotpPdu> {
let mut header = [0u8; 4];
transport.read_exact(&mut header).await?;
if header[0] != 0x03 {
return Err(Error::NegotiationFailed);
}
let total = u16::from_be_bytes([header[2], header[3]]) as usize;
if total < 4 {
return Err(Error::NegotiationFailed);
}
let payload_len = total - 4;
let mut payload = vec![0u8; payload_len];
transport.read_exact(&mut payload).await?;
let mut b = Bytes::from(payload);
CotpPdu::decode(&mut b).map_err(Error::Proto)
}
pub(crate) async fn recv_cotp_data<T: AsyncRead + Unpin>(transport: &mut T) -> Result<Bytes> {
let pdu = recv_tpkt_cotp(transport).await?;
match pdu {
CotpPdu::Data { payload, .. } => Ok(payload),
_ => Err(Error::NegotiationFailed),
}
}
pub(crate) async fn send_tpkt_cotp<T: AsyncWrite + Unpin>(
transport: &mut T,
pdu: &CotpPdu,
) -> Result<()> {
let mut cotp_buf = BytesMut::new();
pdu.encode(&mut cotp_buf);
let tpkt = TpktFrame {
payload: cotp_buf.freeze(),
};
let mut buf = BytesMut::new();
tpkt.encode(&mut buf)?;
transport.write_all(&buf).await?;
Ok(())
}
pub(crate) async fn send_cotp_data<T: AsyncWrite + Unpin>(
transport: &mut T,
payload: Bytes,
) -> Result<()> {
let dt = CotpPdu::Data {
tpdu_nr: 0,
last: true,
payload,
};
send_tpkt_cotp(transport, &dt).await
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::BytesMut;
use snap7_client::proto::{
cotp::CotpPdu,
s7::{
header::{PduType, S7Header},
negotiate::NegotiateRequest,
},
tpkt::TpktFrame,
};
use tokio::io::AsyncWriteExt;
async fn write_tpkt_cotp(writer: &mut (impl tokio::io::AsyncWrite + Unpin), cotp: &CotpPdu) {
let mut cotp_buf = BytesMut::new();
cotp.encode(&mut cotp_buf);
let tpkt = TpktFrame {
payload: cotp_buf.freeze(),
};
let mut buf = BytesMut::new();
tpkt.encode(&mut buf).unwrap();
writer.write_all(&buf).await.unwrap();
}
async fn write_negotiate_request(
writer: &mut (impl tokio::io::AsyncWrite + Unpin),
pdu_length: u16,
) {
let header = S7Header {
pdu_type: PduType::Job,
reserved: 0,
pdu_ref: 1,
param_len: 8,
data_len: 0,
error_class: None,
error_code: None,
};
let req = NegotiateRequest {
max_amq_calling: 1,
max_amq_called: 1,
pdu_length,
};
let mut s7_buf = BytesMut::new();
header.encode(&mut s7_buf);
req.encode(&mut s7_buf);
let dt = CotpPdu::Data {
tpdu_nr: 0,
last: true,
payload: s7_buf.freeze(),
};
write_tpkt_cotp(writer, &dt).await;
}
#[tokio::test]
async fn handshake_completes_with_valid_client() {
let (server_io, mut client_io) = tokio::io::duplex(4096);
let client_task = tokio::spawn(async move {
use tokio::io::AsyncReadExt;
let cr = CotpPdu::ConnectRequest {
dst_ref: 0x0000,
src_ref: 0x0001,
rack: 0,
slot: 2,
};
write_tpkt_cotp(&mut client_io, &cr).await;
let mut hdr = [0u8; 4];
client_io.read_exact(&mut hdr).await.unwrap();
let total = u16::from_be_bytes([hdr[2], hdr[3]]) as usize;
let mut body = vec![0u8; total - 4];
client_io.read_exact(&mut body).await.unwrap();
let mut b = Bytes::from(body);
let cc = CotpPdu::decode(&mut b).unwrap();
assert!(
matches!(cc, CotpPdu::ConnectConfirm { .. }),
"expected ConnectConfirm, got {cc:?}"
);
write_negotiate_request(&mut client_io, 480).await;
let mut drain = vec![0u8; 512];
let _ = client_io.read(&mut drain).await;
});
let result = server_handshake(server_io).await;
client_task.await.unwrap();
assert!(
result.is_ok(),
"server_handshake returned error: {result:?}"
);
assert_eq!(result.unwrap(), 480);
}
#[tokio::test]
async fn handshake_fails_on_non_cr() {
let (server_io, mut client_io) = tokio::io::duplex(4096);
tokio::spawn(async move {
let dt = CotpPdu::Data {
tpdu_nr: 0,
last: true,
payload: Bytes::from_static(b"oops"),
};
write_tpkt_cotp(&mut client_io, &dt).await;
});
let result = server_handshake(server_io).await;
assert!(result.is_err(), "expected error, got: {result:?}");
}
}