use std::io::{Read, Write};
use std::time::Instant;
use prost::Message;
use crate::broker::protocol::{
read_frame, validate_frame_envelope, write_frame, Frame, FrameKind, FrameValidationError,
HandoffAck, HandoffOffer, PayloadEncoding, PROTOCOL_VERSION,
};
use crate::broker::server::handoff::handoff_token::HandoffToken;
use crate::broker::server::handoff::orchestrate::{HandoffDelivery, HandoffDeliveryError};
use crate::broker::server::handoff::windows::WindowsHandleValue;
pub use crate::broker::protocol::registry::HANDOFF_PAYLOAD_PROTOCOL;
pub fn handoff_offer_frame(offer: &HandoffOffer) -> Frame {
let mut payload = Vec::with_capacity(64);
offer.encode(&mut payload).expect(
"prost encoding HandoffOffer into Vec cannot fail because Vec writes are infallible",
);
Frame {
envelope_version: PROTOCOL_VERSION,
kind: FrameKind::Request as i32,
payload_protocol: HANDOFF_PAYLOAD_PROTOCOL,
payload,
request_id: offer.correlation_id,
payload_encoding: PayloadEncoding::None as i32,
deadline_unix_ms: 0,
traceparent: String::new(),
tracestate: String::new(),
}
}
pub fn handoff_ack_frame(ack: &HandoffAck) -> Frame {
let mut payload = Vec::with_capacity(64);
ack.encode(&mut payload)
.expect("prost encoding HandoffAck into Vec cannot fail because Vec writes are infallible");
Frame {
envelope_version: PROTOCOL_VERSION,
kind: FrameKind::Response as i32,
payload_protocol: HANDOFF_PAYLOAD_PROTOCOL,
payload,
request_id: ack.correlation_id,
payload_encoding: PayloadEncoding::None as i32,
deadline_unix_ms: 0,
traceparent: String::new(),
tracestate: String::new(),
}
}
pub fn handoff_ready_frame(ack: &HandoffAck) -> Frame {
let mut payload = Vec::with_capacity(64);
ack.encode(&mut payload)
.expect("prost encoding HandoffAck into Vec cannot fail because Vec writes are infallible");
Frame {
envelope_version: PROTOCOL_VERSION,
kind: FrameKind::Event as i32,
payload_protocol: HANDOFF_PAYLOAD_PROTOCOL,
payload,
request_id: ack.correlation_id,
payload_encoding: PayloadEncoding::None as i32,
deadline_unix_ms: 0,
traceparent: String::new(),
tracestate: String::new(),
}
}
pub fn validate_handoff_frame(frame: &Frame, expected_kind: FrameKind) -> Result<(), &'static str> {
validate_frame_envelope(frame, expected_kind, HANDOFF_PAYLOAD_PROTOCOL).map_err(|error| {
match error {
FrameValidationError::EnvelopeVersion { .. } => "envelope_version is not v1",
FrameValidationError::Kind { .. } => match expected_kind {
FrameKind::Request => "kind is not REQUEST",
FrameKind::Event => "kind is not EVENT",
_ => "kind is not RESPONSE",
},
FrameValidationError::PayloadProtocol { .. } => "payload_protocol is not handoff",
FrameValidationError::PayloadEncoding { .. } => "payload is compressed",
}
})
}
#[derive(Debug)]
pub struct WireHandoffDelivery<S> {
stream: S,
service_name: String,
correlation_id: u64,
}
impl<S> WireHandoffDelivery<S> {
pub fn new(stream: S, service_name: impl Into<String>, correlation_id: u64) -> Self {
Self {
stream,
service_name: service_name.into(),
correlation_id,
}
}
pub fn correlation_id(&self) -> u64 {
self.correlation_id
}
pub fn stream(&self) -> &S {
&self.stream
}
pub fn into_stream(self) -> S {
self.stream
}
}
impl<S: Read + Write> HandoffDelivery for WireHandoffDelivery<S> {
fn deliver(
&mut self,
handle: WindowsHandleValue,
token: &HandoffToken,
) -> Result<(), HandoffDeliveryError> {
let offer = HandoffOffer {
handle_value: handle.get() as u64,
token: token.as_bytes().to_vec(),
service_name: self.service_name.clone(),
correlation_id: self.correlation_id,
};
let frame = handoff_offer_frame(&offer);
let mut bytes = Vec::with_capacity(64);
frame
.encode(&mut bytes)
.expect("prost encoding Frame into Vec cannot fail because Vec writes are infallible");
write_frame(&mut self.stream, &bytes).map_err(|error| {
HandoffDeliveryError::DeliveryFailed {
detail: format!("failed to write HandoffOffer frame: {error}"),
}
})?;
Ok(())
}
fn await_backend_ack(
&mut self,
token: &HandoffToken,
deadline: Instant,
) -> Result<Instant, HandoffDeliveryError> {
let bytes = read_frame(&mut self.stream).map_err(|error| {
ack_not_observed(format!("failed to read HandoffAck frame: {error}"))
})?;
let observed_at = Instant::now();
let frame = Frame::decode(bytes.as_slice()).map_err(|error| {
ack_not_observed(format!("failed to decode HandoffAck Frame: {error}"))
})?;
validate_handoff_frame(&frame, FrameKind::Response)
.map_err(|detail| ack_not_observed(format!("unexpected HandoffAck frame: {detail}")))?;
if frame.request_id != self.correlation_id {
return Err(ack_not_observed(format!(
"HandoffAck frame request_id {} does not match correlation id {}",
frame.request_id, self.correlation_id
)));
}
let ack = HandoffAck::decode(frame.payload.as_slice()).map_err(|error| {
ack_not_observed(format!("failed to decode HandoffAck payload: {error}"))
})?;
if ack.correlation_id != self.correlation_id {
return Err(ack_not_observed(format!(
"HandoffAck correlation id {} does not match offer correlation id {}",
ack.correlation_id, self.correlation_id
)));
}
if ack.token != token.as_bytes() {
return Err(ack_not_observed(
"HandoffAck token echo does not match the offered token".to_string(),
));
}
if !ack.accepted {
return Err(ack_not_observed(format!(
"backend refused the handoff: {}",
if ack.error_detail.is_empty() {
"no detail provided"
} else {
ack.error_detail.as_str()
}
)));
}
if observed_at > deadline {
return Err(ack_not_observed(
"backend HandoffAck arrived after the ACK deadline".to_string(),
));
}
Ok(observed_at)
}
}
fn ack_not_observed(detail: String) -> HandoffDeliveryError {
HandoffDeliveryError::AckNotObserved { detail }
}