use crate::bot::Bot;
use crate::error::Error;
use crate::models::{parse_game_state, Action};
use crate::retry::RetryPolicy;
use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};
use serde_json::{json, Value};
use std::panic::AssertUnwindSafe;
use tokio_tungstenite::{
connect_async,
tungstenite::{client::IntoClientRequest, http::header::USER_AGENT, Error as WsError, Message},
};
pub const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &["1.0"];
const DEFAULT_CLIENT_NAME: &str = "chipzen-sdk-rust";
const DEFAULT_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
pub fn default_user_agent() -> String {
format!("chipzen-sdk-rust/{DEFAULT_CLIENT_VERSION}")
}
#[derive(Debug, Clone)]
pub struct RunBotOptions {
pub token: Option<String>,
pub ticket: Option<String>,
pub match_id: Option<String>,
pub client_name: Option<String>,
pub client_version: Option<String>,
pub retry_policy: RetryPolicy,
pub safe_mode: bool,
pub user_agent: Option<String>,
}
impl Default for RunBotOptions {
fn default() -> Self {
Self {
token: None,
ticket: None,
match_id: None,
client_name: None,
client_version: None,
retry_policy: RetryPolicy::default(),
safe_mode: true,
user_agent: None,
}
}
}
#[derive(Debug, Clone)]
pub struct SessionContext {
pub match_id: String,
pub token: Option<String>,
pub ticket: Option<String>,
pub client_name: String,
pub client_version: String,
pub safe_mode: bool,
}
impl SessionContext {
pub fn new(
match_id: String,
token: Option<String>,
ticket: Option<String>,
client_name: String,
client_version: String,
) -> Self {
Self {
match_id,
token,
ticket,
client_name,
client_version,
safe_mode: true,
}
}
}
pub async fn run_bot<B: Bot>(
url: &str,
mut bot: B,
options: RunBotOptions,
) -> Result<Option<Value>, Error> {
let match_id = options
.match_id
.clone()
.unwrap_or_else(|| _extract_match_id(url));
let client_version = options
.client_version
.clone()
.unwrap_or_else(|| DEFAULT_CLIENT_VERSION.to_string());
let user_agent = options
.user_agent
.clone()
.unwrap_or_else(default_user_agent);
let ctx = SessionContext {
match_id,
token: options.token.clone(),
ticket: options.ticket.clone(),
client_name: options
.client_name
.clone()
.unwrap_or_else(|| DEFAULT_CLIENT_NAME.to_string()),
client_version,
safe_mode: options.safe_mode,
};
let policy = options.retry_policy;
let max_attempts = policy.max_reconnect_attempts;
let mut retries: u32 = 0;
loop {
let result: Result<Option<Value>, Error> = async {
let request = build_handshake_request(url, &user_agent)?;
let (ws_stream, _) = connect_async(request).await?;
let (mut write_half, mut read_half) = ws_stream.split();
let mut reader = WsReader {
inner: &mut read_half,
};
let mut writer = WsWriter {
inner: &mut write_half,
};
_run_session(&mut reader, &mut writer, &mut bot, &ctx).await
}
.await;
match result {
Ok(end) => return Ok(end),
Err(e @ Error::BotDecision(_)) => return Err(e),
Err(err) => {
retries += 1;
if retries > max_attempts {
return Err(Error::RetriesExhausted {
attempts: retries,
last_error: err.to_string(),
});
}
let backoff_ms = policy.backoff_ms(retries);
tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
}
}
}
}
fn build_handshake_request(
url: &str,
user_agent: &str,
) -> Result<tokio_tungstenite::tungstenite::handshake::client::Request, Error> {
let mut request = url.into_client_request().map_err(Error::from)?;
if let Ok(value) = user_agent.parse() {
request.headers_mut().insert(USER_AGENT, value);
}
Ok(request)
}
#[async_trait]
pub trait MessageReader: Send {
async fn next(&mut self) -> Result<Option<String>, Error>;
}
#[async_trait]
pub trait MessageWriter: Send {
async fn send(&mut self, payload: String) -> Result<(), Error>;
}
#[async_trait]
impl MessageReader for Box<dyn MessageReader> {
async fn next(&mut self) -> Result<Option<String>, Error> {
(**self).next().await
}
}
#[async_trait]
impl MessageWriter for Box<dyn MessageWriter> {
async fn send(&mut self, payload: String) -> Result<(), Error> {
(**self).send(payload).await
}
}
pub async fn _run_session<R, W, B>(
reader: &mut R,
writer: &mut W,
bot: &mut B,
ctx: &SessionContext,
) -> Result<Option<Value>, Error>
where
R: MessageReader,
W: MessageWriter,
B: Bot,
{
let mut auth = json!({
"type": "authenticate",
"match_id": ctx.match_id,
"client_name": ctx.client_name,
"client_version": ctx.client_version,
});
if let Some(t) = ctx.token.as_deref().filter(|s| !s.is_empty()) {
auth["token"] = Value::String(t.to_string());
} else if let Some(t) = ctx.ticket.as_deref().filter(|s| !s.is_empty()) {
auth["ticket"] = Value::String(t.to_string());
} else {
auth["token"] = Value::String(String::new());
}
writer.send(auth.to_string()).await?;
let hello_raw = reader.next().await?.ok_or(Error::ConnectionClosed {
context: "server hello",
})?;
let hello: Value = serde_json::from_str(&hello_raw)?;
if hello.get("type").and_then(|v| v.as_str()) != Some("hello") {
return Err(Error::Protocol(format!(
"expected server hello, got {:?}",
hello.get("type")
)));
}
let client_hello = json!({
"type": "hello",
"match_id": ctx.match_id,
"supported_versions": SUPPORTED_PROTOCOL_VERSIONS,
"client_name": ctx.client_name,
"client_version": ctx.client_version,
});
writer.send(client_hello.to_string()).await?;
let mut last_seq: i64 = 0;
while let Some(raw) = reader.next().await? {
let msg: Value = match serde_json::from_str(&raw) {
Ok(v) => v,
Err(_) => continue,
};
if let Some(seq) = msg.get("seq").and_then(Value::as_i64) {
if seq <= last_seq {
continue; }
last_seq = seq;
}
let mtype = msg.get("type").and_then(|v| v.as_str()).unwrap_or("");
match mtype {
"ping" => {
let pong = json!({ "type": "pong", "match_id": ctx.match_id });
writer.send(pong.to_string()).await?;
}
"match_start" => bot.on_match_start(&msg),
"round_start" => bot.on_round_start(&msg),
"phase_change" => bot.on_phase_change(&msg),
"turn_result" => bot.on_turn_result(&msg),
"round_result" => bot.on_round_result(&msg),
"turn_request" => {
let request_id = msg
.get("request_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let state = parse_game_state(&msg);
let (action, latency_ms) = decide_timed(bot, &state, &msg, ctx.safe_mode)?;
send_turn_action(writer, &ctx.match_id, &request_id, action).await?;
bot.on_decision_latency(latency_ms);
}
"action_rejected" => {
let request_id = msg
.get("request_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let valid_actions: Vec<String> = msg
.get("valid_actions")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_else(|| vec!["fold".to_string()]);
let fallback = _safe_fallback_action(&valid_actions);
send_turn_action(writer, &ctx.match_id, &request_id, fallback).await?;
}
"reconnected" => {
if let Some(pending) = msg.get("pending_request") {
if pending.get("type").and_then(|v| v.as_str()) == Some("turn_request") {
let request_id = pending
.get("request_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let state = parse_game_state(pending);
let (action, latency_ms) =
decide_timed(bot, &state, pending, ctx.safe_mode)?;
send_turn_action(writer, &ctx.match_id, &request_id, action).await?;
bot.on_decision_latency(latency_ms);
}
}
}
"match_end" => {
let results = msg.get("results").cloned().unwrap_or_else(|| msg.clone());
bot.on_match_end(&results);
return Ok(Some(msg));
}
"error" => {
}
_ => {
}
}
}
Ok(None)
}
async fn send_turn_action<W: MessageWriter>(
writer: &mut W,
match_id: &str,
request_id: &str,
action: Action,
) -> Result<(), Error> {
let (action_str, params) = action.to_wire();
let payload = json!({
"type": "turn_action",
"match_id": match_id,
"request_id": request_id,
"action": action_str,
"params": params,
});
writer.send(payload.to_string()).await
}
fn decide_timed<B: Bot>(
bot: &mut B,
state: &crate::models::GameState,
msg: &Value,
safe_mode: bool,
) -> Result<(Action, f64), Error> {
let start = std::time::Instant::now();
let outcome = std::panic::catch_unwind(AssertUnwindSafe(|| bot.decide(state)));
let latency_ms = start.elapsed().as_secs_f64() * 1000.0;
let action = match outcome {
Ok(action) => action,
Err(payload) => {
let detail = panic_message(payload.as_ref());
if !safe_mode {
return Err(Error::BotDecision(detail));
}
_safe_fallback_action(&state.valid_actions)
}
};
if action_is_legal(&action, &state.valid_actions) {
Ok((action, latency_ms))
} else {
let valid = msg
.get("valid_actions")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect::<Vec<_>>()
})
.unwrap_or_else(|| state.valid_actions.clone());
Ok((_safe_fallback_action(&valid), latency_ms))
}
}
fn panic_message(payload: &(dyn std::any::Any + Send)) -> String {
if let Some(s) = payload.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"decide() panicked".to_string()
}
}
fn action_is_legal(action: &Action, valid: &[String]) -> bool {
let needed = action.kind().as_str();
valid.iter().any(|v| v == needed)
}
pub fn _safe_fallback_action(valid_actions: &[String]) -> Action {
if valid_actions.iter().any(|a| a == "check") {
Action::Check
} else {
Action::Fold
}
}
pub fn _extract_match_id(url: &str) -> String {
let needle = "/ws/match/";
let Some(start) = url.find(needle) else {
return String::new();
};
let after = &url[start + needle.len()..];
let end = after.find(['/', '?', '#']).unwrap_or(after.len());
after[..end].to_string()
}
struct WsReader<'a, S>
where
S: StreamExt<Item = Result<Message, WsError>> + Unpin,
{
inner: &'a mut S,
}
#[async_trait]
impl<'a, S> MessageReader for WsReader<'a, S>
where
S: StreamExt<Item = Result<Message, WsError>> + Unpin + Send,
{
async fn next(&mut self) -> Result<Option<String>, Error> {
loop {
match self.inner.next().await {
Some(Ok(Message::Text(t))) => return Ok(Some(t.to_string())),
Some(Ok(Message::Ping(_))) => {
continue;
}
Some(Ok(Message::Close(_))) | None => return Ok(None),
Some(Ok(_)) => continue, Some(Err(e)) => return Err(Error::from(e)),
}
}
}
}
struct WsWriter<'a, S>
where
S: SinkExt<Message, Error = WsError> + Unpin,
{
inner: &'a mut S,
}
#[async_trait]
impl<'a, S> MessageWriter for WsWriter<'a, S>
where
S: SinkExt<Message, Error = WsError> + Unpin + Send,
{
async fn send(&mut self, payload: String) -> Result<(), Error> {
self.inner
.send(Message::Text(payload))
.await
.map_err(Error::from)
}
}