use async_trait::async_trait;
use prost::Message;
use steam_enums::EMsg;
use crate::error::SteamError;
#[derive(Debug, Clone, Copy, Default)]
pub struct SessionInfo {
pub session_id: i32,
pub steam_id: u64,
}
impl SessionInfo {
pub fn new(session_id: i32, steam_id: u64) -> Self {
Self { session_id, steam_id }
}
}
#[derive(Debug, Clone)]
pub struct SentMessage {
pub msg_type: EMsg,
pub body: Vec<u8>,
pub job_id: Option<u64>,
pub service_method: Option<String>,
}
#[async_trait]
pub trait MessageSender: Send {
fn is_logged_in(&self) -> bool;
fn session_info(&self) -> SessionInfo;
async fn send_message<T: Message + Send + Sync>(&mut self, msg_type: EMsg, body: &T) -> Result<(), SteamError>;
async fn send_service_method<T: Message + Send + Sync>(&mut self, method: &str, body: &T) -> Result<(), SteamError>;
}
#[derive(Debug, Default)]
pub struct MockMessageSender {
pub logged_in: bool,
pub session_info: SessionInfo,
pub sent_messages: Vec<SentMessage>,
pub next_error: Option<SteamError>,
pub current_job_id: u64,
}
impl MockMessageSender {
pub fn new() -> Self {
Self::default()
}
pub fn new_logged_in(session_id: i32, steam_id: u64) -> Self {
Self {
logged_in: true,
session_info: SessionInfo::new(session_id, steam_id),
sent_messages: Vec::new(),
next_error: None,
current_job_id: 0,
}
}
pub fn set_logged_in(&mut self, logged_in: bool) {
self.logged_in = logged_in;
}
pub fn set_next_error(&mut self, error: SteamError) {
self.next_error = Some(error);
}
pub fn clear(&mut self) {
self.sent_messages.clear();
}
pub fn last_sent(&self) -> Option<&SentMessage> {
self.sent_messages.last()
}
pub fn decode_last_message<T: Message + Default>(&self) -> Result<T, SteamError> {
let msg = self.last_sent().ok_or_else(|| SteamError::Other("No messages sent".into()))?;
T::decode(&msg.body[..]).map_err(|e| SteamError::ProtocolError(format!("Failed to decode: {}", e)))
}
pub fn messages_of_type(&self, msg_type: EMsg) -> Vec<&SentMessage> {
self.sent_messages.iter().filter(|m| m.msg_type == msg_type).collect()
}
pub fn service_calls(&self) -> Vec<&SentMessage> {
self.sent_messages.iter().filter(|m| m.service_method.is_some()).collect()
}
}
#[async_trait]
impl MessageSender for MockMessageSender {
fn is_logged_in(&self) -> bool {
self.logged_in
}
fn session_info(&self) -> SessionInfo {
self.session_info
}
async fn send_message<T: Message + Send + Sync>(&mut self, msg_type: EMsg, body: &T) -> Result<(), SteamError> {
if let Some(error) = self.next_error.take() {
return Err(error);
}
self.current_job_id += 1;
self.sent_messages.push(SentMessage { msg_type, body: body.encode_to_vec(), job_id: Some(self.current_job_id), service_method: None });
Ok(())
}
async fn send_service_method<T: Message + Send + Sync>(&mut self, method: &str, body: &T) -> Result<(), SteamError> {
if let Some(error) = self.next_error.take() {
return Err(error);
}
self.current_job_id += 1;
self.sent_messages.push(SentMessage {
msg_type: EMsg::ServiceMethodCallFromClient,
body: body.encode_to_vec(),
job_id: Some(self.current_job_id),
service_method: Some(method.to_string()),
});
Ok(())
}
}
#[cfg(test)]
mod tests {
use steam_protos::CMsgClientChangeStatus;
use super::*;
#[tokio::test]
async fn test_mock_sender_records_messages() {
let mut mock = MockMessageSender::new_logged_in(12345, 76561198012345678);
let body = CMsgClientChangeStatus { persona_state: Some(1), ..Default::default() };
mock.send_message(EMsg::ClientChangeStatus, &body).await.expect("test should not fail");
assert_eq!(mock.sent_messages.len(), 1);
assert_eq!(mock.sent_messages[0].msg_type, EMsg::ClientChangeStatus);
assert!(mock.sent_messages[0].service_method.is_none());
}
#[tokio::test]
async fn test_mock_sender_decode_last_message() {
let mut mock = MockMessageSender::new_logged_in(12345, 76561198012345678);
let body = CMsgClientChangeStatus { persona_state: Some(3), player_name: Some("TestPlayer".to_string()), ..Default::default() };
mock.send_message(EMsg::ClientChangeStatus, &body).await.expect("test should not fail");
let decoded: CMsgClientChangeStatus = mock.decode_last_message().expect("test should not fail");
assert_eq!(decoded.persona_state, Some(3));
assert_eq!(decoded.player_name, Some("TestPlayer".to_string()));
}
#[tokio::test]
async fn test_mock_sender_service_method() {
let mut mock = MockMessageSender::new_logged_in(12345, 76561198012345678);
let body = steam_protos::CPlayerGetNicknameListRequest {};
mock.send_service_method("Player.GetNicknameList#1", &body).await.expect("test should not fail");
assert_eq!(mock.sent_messages.len(), 1);
assert_eq!(mock.sent_messages[0].service_method, Some("Player.GetNicknameList#1".to_string()));
assert_eq!(mock.sent_messages[0].msg_type, EMsg::ServiceMethodCallFromClient);
}
#[tokio::test]
async fn test_mock_sender_error_injection() {
let mut mock = MockMessageSender::new_logged_in(12345, 76561198012345678);
mock.set_next_error(SteamError::NotConnected);
let body = CMsgClientChangeStatus::default();
let result = mock.send_message(EMsg::ClientChangeStatus, &body).await;
assert!(result.is_err());
assert!(mock.sent_messages.is_empty()); }
#[tokio::test]
async fn test_mock_sender_is_logged_in() {
let mock_out = MockMessageSender::new();
assert!(!mock_out.is_logged_in());
let mock_in = MockMessageSender::new_logged_in(123, 456);
assert!(mock_in.is_logged_in());
}
#[tokio::test]
async fn test_mock_sender_session_info() {
let mock = MockMessageSender::new_logged_in(12345, 76561198012345678);
let info = mock.session_info();
assert_eq!(info.session_id, 12345);
assert_eq!(info.steam_id, 76561198012345678);
}
#[tokio::test]
async fn test_mock_sender_messages_of_type() {
let mut mock = MockMessageSender::new_logged_in(12345, 76561198012345678);
mock.send_message(EMsg::ClientChangeStatus, &CMsgClientChangeStatus::default()).await.expect("test should not fail");
mock.send_message(EMsg::ClientHeartBeat, &steam_protos::CMsgClientHeartBeat::default()).await.expect("test should not fail");
mock.send_message(EMsg::ClientChangeStatus, &CMsgClientChangeStatus::default()).await.expect("test should not fail");
let status_msgs = mock.messages_of_type(EMsg::ClientChangeStatus);
assert_eq!(status_msgs.len(), 2);
let heartbeat_msgs = mock.messages_of_type(EMsg::ClientHeartBeat);
assert_eq!(heartbeat_msgs.len(), 1);
}
#[tokio::test]
async fn test_mock_sender_clear() {
let mut mock = MockMessageSender::new_logged_in(12345, 76561198012345678);
mock.send_message(EMsg::ClientChangeStatus, &CMsgClientChangeStatus::default()).await.expect("test should not fail");
assert_eq!(mock.sent_messages.len(), 1);
mock.clear();
assert!(mock.sent_messages.is_empty());
}
}