use crate::{
common::{ControlType, EnergyManagementRole, Handshake, Message, ReceptionStatus, ReceptionStatusValues, ResourceManagerDetails},
transport::S2Transport,
};
use semver::VersionReq;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum ConnectionError<T: std::error::Error> {
#[error("an error occurred in the underlying transport: {0}")]
TransportError(#[source] T),
#[error("a situation occurred that is in violation of the S2 specification: {0}")]
ProtocolError(#[from] ProtocolError),
}
#[derive(Error, Debug)]
pub enum ProtocolError {
#[error("error parsing the requested S2 version into a valid semver version: {0}")]
S2VersionParseError(#[from] semver::Error),
#[error("the CEM requested an incompatible S2 version: {requested:?} was requested, {supported:?} is supported")]
IncompatibleS2Version {
supported: semver::Version,
requested: VersionReq,
},
#[error(
"received an unexpected message, or an expected message at an unexpected moment, during the S2 handshaking process: {message:?} ({} Handshake, {} HandshakeResponse)",
if *handshake_received { "already received" } else { "not yet received" },
if *handshake_response_received { "already received" } else { "not yet received" },
)]
InvalidHandshakeOrder {
message: Message,
handshake_received: bool,
handshake_response_received: bool,
},
}
#[derive(Clone, Debug)]
pub struct S2Connection<T: S2Transport> {
transport: T,
}
impl<T: S2Transport> S2Connection<T> {
pub fn new(transport: T) -> S2Connection<T> {
Self { transport }
}
pub async fn initialize_as_rm(
&mut self,
rm_details: ResourceManagerDetails,
) -> Result<ControlType, ConnectionError<T::TransportError>> {
let handshake = Handshake::builder()
.role(EnergyManagementRole::Rm)
.supported_protocol_versions(vec![crate::s2_schema_version().to_string()])
.build();
self.send_message(handshake).await?;
let mut need_handshake = true;
let mut need_handshake_response = true;
loop {
let message = self.receive_message().await?;
match message.get_message() {
Message::Handshake(..) if need_handshake => {
need_handshake = false;
}
Message::HandshakeResponse(handshake_response) if need_handshake_response && !need_handshake => {
need_handshake_response = false;
let requested_version =
VersionReq::parse(&handshake_response.selected_protocol_version).map_err(ProtocolError::S2VersionParseError)?;
if !requested_version.matches(&crate::s2_schema_version()) {
let error_msg = format!(
"CEM requested an incompatible version of S2: requested {}, which is not compatible with {}",
requested_version,
crate::s2_schema_version()
);
tracing::warn!("{error_msg:?}");
message.error(ReceptionStatusValues::InvalidContent, &error_msg).await?;
return Err(ProtocolError::IncompatibleS2Version {
supported: crate::s2_schema_version(),
requested: requested_version.clone(),
}
.into());
}
message.confirm().await?;
self.send_message(rm_details.clone()).await?;
continue;
}
Message::SelectControlType(select_control_type) if !need_handshake && !need_handshake_response => {
tracing::info!("Control type selected by CEM: {:?}", select_control_type.control_type);
let control_type = select_control_type.control_type;
message.confirm().await?;
return Ok(control_type);
}
other_message => {
let diagnostic = format!("Did not expect message at this point in the handshake process: {:?}", other_message);
let message = message.error(ReceptionStatusValues::InvalidContent, &diagnostic).await?;
return Err(ProtocolError::InvalidHandshakeOrder {
message,
handshake_received: !need_handshake,
handshake_response_received: !need_handshake_response,
}
.into());
}
}
message.confirm().await?;
}
}
pub async fn send_message(&mut self, message: impl Into<Message>) -> Result<(), ConnectionError<T::TransportError>> {
self.transport.send(message.into()).await.map_err(ConnectionError::TransportError)?;
Ok(())
}
pub async fn receive_message<'connection>(
&'connection mut self,
) -> Result<UnconfirmedMessage<'connection, T>, ConnectionError<T::TransportError>> {
let message = self.transport.receive().await.map_err(ConnectionError::TransportError)?;
tracing::trace!("Received S2 message: {message:?}");
Ok(UnconfirmedMessage::new(message, self))
}
pub async fn receive_and_confirm(&mut self) -> Result<Message, ConnectionError<T::TransportError>> {
self.receive_message().await?.confirm().await
}
pub async fn disconnect(self) {
self.transport.disconnect().await
}
}
pub struct UnconfirmedMessage<'conn, T: S2Transport> {
message: Option<Message>,
connection: &'conn mut S2Connection<T>,
}
impl<'conn, T: S2Transport> UnconfirmedMessage<'conn, T> {
fn new(message: Message, connection: &'conn mut S2Connection<T>) -> UnconfirmedMessage<'conn, T> {
Self {
message: Some(message),
connection,
}
}
pub async fn confirm(mut self) -> Result<Message, ConnectionError<T::TransportError>> {
let message = self
.message
.take()
.expect("No message contained in UnconfirmedMessage; this is a bug in s2energy and should be reported");
if matches!(message, Message::ReceptionStatus(..)) {
return Ok(message);
}
let Some(message_id) = message.id() else { return Ok(message) };
self.connection
.send_message(
ReceptionStatus::builder()
.status(ReceptionStatusValues::Ok)
.subject_message_id(message_id)
.build(),
)
.await?;
Ok(message)
}
pub async fn error(
mut self,
status: ReceptionStatusValues,
diagnostic_message: &str,
) -> Result<Message, ConnectionError<T::TransportError>> {
let message = self
.message
.take()
.expect("No message contained in UnconfirmedMessage; this is a bug in s2energy and should be reported");
if matches!(message, Message::ReceptionStatus(..)) {
return Ok(message);
}
let Some(message_id) = message.id() else { return Ok(message) };
tracing::warn!("Sending reception status {status:?} in response to message {message_id:?}");
self.connection
.send_message(
ReceptionStatus::builder()
.diagnostic_label(diagnostic_message.to_string())
.status(status)
.subject_message_id(message_id)
.build(),
)
.await?;
Ok(message)
}
pub fn get_message(&self) -> &Message {
self.message
.as_ref()
.expect("No message contained in UnconfirmedMessage; this is a bug in s2energy and should be reported")
}
pub fn into_inner(mut self) -> Message {
self.message
.take()
.expect("No message contained in UnconfirmedMessage; this is a bug in s2energy and should be reported")
}
}
impl<'conn, T: S2Transport> Drop for UnconfirmedMessage<'conn, T> {
fn drop(&mut self) {
if !std::thread::panicking() && self.message.is_some() {
panic!(
"Dropped an `UnconfirmedMessage` without calling `confirm`, `bad_status` or `into_inner`. Please refer to the `UnconfirmedMessage` documentation for proper usage."
);
}
}
}