use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tracing::{debug, info, instrument, warn};
use crate::binary_protocol::{ClientMessage, PayloadType};
use crate::errors::{Error, ProtocolError, Result};
use crate::protocol::MessageType;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ActionType {
#[serde(rename = "KMSEncryption")]
KmsEncryption,
#[serde(rename = "SessionType")]
SessionType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum ActionStatus {
Success = 1,
Failed = 2,
Unsupported = 3,
}
impl serde::Serialize for ActionStatus {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u32(*self as u32)
}
}
impl<'de> serde::Deserialize<'de> for ActionStatus {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = u32::deserialize(deserializer)?;
match value {
1 => Ok(ActionStatus::Success),
2 => Ok(ActionStatus::Failed),
3 => Ok(ActionStatus::Unsupported),
_ => Err(serde::de::Error::custom(format!(
"invalid ActionStatus: {}",
value
))),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum SessionTypeValue {
#[serde(rename = "Standard_Stream")]
#[default]
StandardStream,
#[serde(rename = "InteractiveCommands")]
InteractiveCommands,
#[serde(rename = "Port")]
Port,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KmsEncryptionRequest {
#[serde(rename = "KMSKeyId")]
pub kms_key_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KmsEncryptionResponse {
#[serde(rename = "KMSCipherTextKey")]
pub kms_cipher_text_key: Vec<u8>,
#[serde(rename = "KMSCipherTextHash", skip_serializing_if = "Option::is_none")]
pub kms_cipher_text_hash: Option<Vec<u8>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionTypeRequest {
#[serde(rename = "SessionType")]
pub session_type: String,
#[serde(rename = "Properties", skip_serializing_if = "Option::is_none")]
pub properties: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestedClientAction {
#[serde(rename = "ActionType")]
pub action_type: ActionType,
#[serde(rename = "ActionParameters")]
pub action_parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandshakeRequest {
#[serde(rename = "AgentVersion")]
pub agent_version: String,
#[serde(rename = "RequestedClientActions")]
pub requested_client_actions: Vec<RequestedClientAction>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProcessedClientAction {
#[serde(rename = "ActionType")]
pub action_type: ActionType,
#[serde(rename = "ActionStatus")]
pub action_status: ActionStatus,
#[serde(rename = "ActionResult", skip_serializing_if = "Option::is_none")]
pub action_result: Option<serde_json::Value>,
#[serde(rename = "Error", skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandshakeResponse {
#[serde(rename = "ClientVersion")]
pub client_version: String,
#[serde(rename = "ProcessedClientActions")]
pub processed_client_actions: Vec<ProcessedClientAction>,
#[serde(rename = "Errors", skip_serializing_if = "Vec::is_empty", default)]
pub errors: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandshakeComplete {
#[serde(rename = "HandshakeTimeToComplete")]
pub handshake_time_to_complete: i64,
#[serde(rename = "CustomerMessage", skip_serializing_if = "Option::is_none")]
pub customer_message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptionChallengeRequest {
#[serde(rename = "Challenge")]
pub challenge: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptionChallengeResponse {
#[serde(rename = "Challenge")]
pub challenge: Vec<u8>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HandshakeState {
AwaitingRequest,
Processing,
AwaitingComplete,
Completed,
Failed,
}
#[derive(Debug, Clone)]
pub struct HandshakeConfig {
pub client_version: String,
pub support_kms: bool,
pub supported_session_types: Vec<SessionTypeValue>,
pub timeout: Duration,
}
impl Default for HandshakeConfig {
fn default() -> Self {
Self {
client_version: env!("CARGO_PKG_VERSION").to_string(),
support_kms: false, supported_session_types: vec![
SessionTypeValue::StandardStream,
SessionTypeValue::InteractiveCommands,
SessionTypeValue::Port,
],
timeout: Duration::from_secs(30),
}
}
}
pub struct HandshakeHandler {
config: HandshakeConfig,
state: HandshakeState,
negotiated_session_type: Option<SessionTypeValue>,
agent_version: Option<String>,
kms_enabled: bool,
}
impl HandshakeHandler {
pub fn new(config: HandshakeConfig) -> Self {
Self {
config,
state: HandshakeState::AwaitingRequest,
negotiated_session_type: None,
agent_version: None,
kms_enabled: false,
}
}
pub fn state(&self) -> HandshakeState {
self.state
}
pub fn session_type(&self) -> Option<&SessionTypeValue> {
self.negotiated_session_type.as_ref()
}
pub fn agent_version(&self) -> Option<&str> {
self.agent_version.as_deref()
}
pub fn is_kms_enabled(&self) -> bool {
self.kms_enabled
}
#[instrument(skip(self, request))]
pub fn process_request(
&mut self,
request: HandshakeRequest,
) -> Result<Option<HandshakeResponse>> {
if self.state == HandshakeState::AwaitingComplete || self.state == HandshakeState::Completed
{
debug!(state = ?self.state, "Ignoring duplicate HandshakeRequest");
return Ok(None);
}
if self.state != HandshakeState::AwaitingRequest {
return Err(Error::Protocol(ProtocolError::InvalidMessage(format!(
"Invalid state for handshake request: {:?}",
self.state
))));
}
self.state = HandshakeState::Processing;
self.agent_version = Some(request.agent_version.clone());
info!(
agent_version = %request.agent_version,
actions = request.requested_client_actions.len(),
"Processing handshake request"
);
let mut processed_actions = Vec::new();
let mut errors = Vec::new();
for action in &request.requested_client_actions {
let processed = self.process_action(action);
if processed.action_status == ActionStatus::Failed {
if let Some(ref err) = processed.error {
errors.push(err.clone());
}
}
processed_actions.push(processed);
}
let response = HandshakeResponse {
client_version: self.config.client_version.clone(),
processed_client_actions: processed_actions,
errors,
};
self.state = HandshakeState::AwaitingComplete;
debug!("Handshake response prepared");
Ok(Some(response))
}
fn process_action(&mut self, action: &RequestedClientAction) -> ProcessedClientAction {
match action.action_type {
ActionType::KmsEncryption => self.process_kms_action(action),
ActionType::SessionType => self.process_session_type_action(action),
}
}
fn process_kms_action(&mut self, action: &RequestedClientAction) -> ProcessedClientAction {
if !self.config.support_kms {
debug!("KMS encryption not supported, marking as unsupported");
return ProcessedClientAction {
action_type: ActionType::KmsEncryption,
action_status: ActionStatus::Unsupported,
action_result: None,
error: Some("KMS encryption not supported by this client".to_string()),
};
}
let kms_request: std::result::Result<KmsEncryptionRequest, _> =
serde_json::from_value(action.action_parameters.clone());
match kms_request {
Ok(req) => {
info!(kms_key_id = %req.kms_key_id, "KMS encryption requested");
warn!("KMS key generation not implemented yet");
ProcessedClientAction {
action_type: ActionType::KmsEncryption,
action_status: ActionStatus::Unsupported,
action_result: None,
error: Some("KMS key generation not implemented".to_string()),
}
}
Err(e) => ProcessedClientAction {
action_type: ActionType::KmsEncryption,
action_status: ActionStatus::Failed,
action_result: None,
error: Some(format!("Failed to parse KMS request: {}", e)),
},
}
}
fn process_session_type_action(
&mut self,
action: &RequestedClientAction,
) -> ProcessedClientAction {
let session_req: std::result::Result<SessionTypeRequest, _> =
serde_json::from_value(action.action_parameters.clone());
match session_req {
Ok(req) => {
info!(session_type = %req.session_type, "Session type requested");
let session_type = match req.session_type.as_str() {
"Standard_Stream" => SessionTypeValue::StandardStream,
"InteractiveCommands" => SessionTypeValue::InteractiveCommands,
"Port" => SessionTypeValue::Port,
other => {
warn!(
session_type = other,
"Unknown session type, defaulting to StandardStream"
);
SessionTypeValue::StandardStream
}
};
self.negotiated_session_type = Some(session_type);
ProcessedClientAction {
action_type: ActionType::SessionType,
action_status: ActionStatus::Success,
action_result: None,
error: None,
}
}
Err(e) => {
warn!(error = %e, "Failed to parse session type request");
ProcessedClientAction {
action_type: ActionType::SessionType,
action_status: ActionStatus::Failed,
action_result: None,
error: Some(format!("Failed to parse session type: {}", e)),
}
}
}
}
#[instrument(skip(self, complete))]
pub fn process_complete(&mut self, complete: HandshakeComplete) -> Result<()> {
if self.state != HandshakeState::AwaitingComplete {
return Err(Error::Protocol(ProtocolError::InvalidMessage(format!(
"Invalid state for handshake complete: {:?}",
self.state
))));
}
let duration = Duration::from_nanos(complete.handshake_time_to_complete as u64);
info!(
duration_ms = duration.as_millis(),
customer_message = ?complete.customer_message,
"Handshake completed"
);
if let Some(msg) = &complete.customer_message {
info!(message = %msg, "Customer message received");
}
self.state = HandshakeState::Completed;
Ok(())
}
pub fn response_to_message(
&self,
response: &HandshakeResponse,
sequence_number: i64,
) -> Result<ClientMessage> {
let payload = serde_json::to_vec(response).map_err(|e| {
Error::Protocol(ProtocolError::Framing(format!(
"Failed to serialize response: {}",
e
)))
})?;
Ok(ClientMessage::new(
MessageType::InputStreamData,
sequence_number,
PayloadType::HandshakeResponse,
Bytes::from(payload),
))
}
pub fn parse_request(message: &ClientMessage) -> Result<HandshakeRequest> {
if message.payload_type != PayloadType::HandshakeRequest {
return Err(Error::Protocol(ProtocolError::InvalidMessage(format!(
"Expected HandshakeRequest, got {:?}",
message.payload_type
))));
}
serde_json::from_slice(&message.payload).map_err(|e| {
Error::Protocol(ProtocolError::Framing(format!(
"Failed to parse HandshakeRequest: {}",
e
)))
})
}
pub fn parse_complete(message: &ClientMessage) -> Result<HandshakeComplete> {
if message.payload_type != PayloadType::HandshakeComplete {
return Err(Error::Protocol(ProtocolError::InvalidMessage(format!(
"Expected HandshakeComplete, got {:?}",
message.payload_type
))));
}
serde_json::from_slice(&message.payload).map_err(|e| {
Error::Protocol(ProtocolError::Framing(format!(
"Failed to parse HandshakeComplete: {}",
e
)))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handshake_request_parsing() {
let json = r#"{
"AgentVersion": "3.0.0",
"RequestedClientActions": [
{
"ActionType": "SessionType",
"ActionParameters": {
"SessionType": "Standard_Stream"
}
}
]
}"#;
let request: HandshakeRequest = serde_json::from_str(json).unwrap();
assert_eq!(request.agent_version, "3.0.0");
assert_eq!(request.requested_client_actions.len(), 1);
assert_eq!(
request.requested_client_actions[0].action_type,
ActionType::SessionType
);
}
#[test]
fn test_handshake_response_serialization() {
let response = HandshakeResponse {
client_version: "0.1.0".to_string(),
processed_client_actions: vec![ProcessedClientAction {
action_type: ActionType::SessionType,
action_status: ActionStatus::Success,
action_result: None,
error: None,
}],
errors: vec![],
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("ClientVersion"));
assert!(json.contains("ProcessedClientActions"));
}
#[test]
fn test_handshake_complete_parsing() {
let json = r#"{
"HandshakeTimeToComplete": 1000000000,
"CustomerMessage": "Welcome to SSM"
}"#;
let complete: HandshakeComplete = serde_json::from_str(json).unwrap();
assert_eq!(complete.handshake_time_to_complete, 1_000_000_000);
assert_eq!(
complete.customer_message,
Some("Welcome to SSM".to_string())
);
}
#[test]
fn test_handshake_state_machine() {
let config = HandshakeConfig::default();
let mut handler = HandshakeHandler::new(config);
assert_eq!(handler.state(), HandshakeState::AwaitingRequest);
let request = HandshakeRequest {
agent_version: "3.0.0".to_string(),
requested_client_actions: vec![RequestedClientAction {
action_type: ActionType::SessionType,
action_parameters: serde_json::json!({
"SessionType": "Standard_Stream"
}),
}],
};
let response = handler
.process_request(request)
.unwrap()
.expect("Should return response");
assert_eq!(handler.state(), HandshakeState::AwaitingComplete);
assert_eq!(response.processed_client_actions.len(), 1);
assert_eq!(
response.processed_client_actions[0].action_status,
ActionStatus::Success
);
let complete = HandshakeComplete {
handshake_time_to_complete: 500_000_000,
customer_message: Some("Session ready".to_string()),
};
handler.process_complete(complete).unwrap();
assert_eq!(handler.state(), HandshakeState::Completed);
assert!(handler.session_type().is_some());
}
#[test]
fn test_kms_unsupported() {
let config = HandshakeConfig {
support_kms: false,
..Default::default()
};
let mut handler = HandshakeHandler::new(config);
let request = HandshakeRequest {
agent_version: "3.0.0".to_string(),
requested_client_actions: vec![RequestedClientAction {
action_type: ActionType::KmsEncryption,
action_parameters: serde_json::json!({
"KMSKeyId": "arn:aws:kms:us-east-1:123456789:key/abc"
}),
}],
};
let response = handler
.process_request(request)
.unwrap()
.expect("Should return response");
assert_eq!(
response.processed_client_actions[0].action_status,
ActionStatus::Unsupported
);
}
#[test]
fn test_duplicate_request_ignored() {
let config = HandshakeConfig::default();
let mut handler = HandshakeHandler::new(config);
let request = HandshakeRequest {
agent_version: "3.0.0".to_string(),
requested_client_actions: vec![RequestedClientAction {
action_type: ActionType::SessionType,
action_parameters: serde_json::json!({
"SessionType": "Standard_Stream"
}),
}],
};
let response = handler.process_request(request.clone()).unwrap();
assert!(response.is_some());
assert_eq!(handler.state(), HandshakeState::AwaitingComplete);
let duplicate_response = handler.process_request(request).unwrap();
assert!(duplicate_response.is_none());
assert_eq!(handler.state(), HandshakeState::AwaitingComplete);
}
}