1use crate::bot::Bot;
10use crate::error::Error;
11use crate::models::{parse_game_state, Action};
12use crate::retry::RetryPolicy;
13use async_trait::async_trait;
14use futures_util::{SinkExt, StreamExt};
15use serde_json::{json, Value};
16use std::panic::AssertUnwindSafe;
17use tokio_tungstenite::{
18 connect_async,
19 tungstenite::{client::IntoClientRequest, http::header::USER_AGENT, Error as WsError, Message},
20};
21
22pub const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &["1.0"];
24
25const DEFAULT_CLIENT_NAME: &str = "chipzen-sdk-rust";
26const DEFAULT_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
27
28pub fn default_user_agent() -> String {
32 format!("chipzen-sdk-rust/{DEFAULT_CLIENT_VERSION}")
33}
34
35#[derive(Debug, Clone)]
38pub struct RunBotOptions {
39 pub token: Option<String>,
42 pub ticket: Option<String>,
45 pub match_id: Option<String>,
47 pub client_name: Option<String>,
49 pub client_version: Option<String>,
52 pub retry_policy: RetryPolicy,
54 pub safe_mode: bool,
59 pub user_agent: Option<String>,
62}
63
64impl Default for RunBotOptions {
65 fn default() -> Self {
66 Self {
67 token: None,
68 ticket: None,
69 match_id: None,
70 client_name: None,
71 client_version: None,
72 retry_policy: RetryPolicy::default(),
73 safe_mode: true,
74 user_agent: None,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
82pub struct SessionContext {
83 pub match_id: String,
84 pub token: Option<String>,
85 pub ticket: Option<String>,
86 pub client_name: String,
87 pub client_version: String,
88 pub safe_mode: bool,
91}
92
93impl SessionContext {
94 pub fn new(
98 match_id: String,
99 token: Option<String>,
100 ticket: Option<String>,
101 client_name: String,
102 client_version: String,
103 ) -> Self {
104 Self {
105 match_id,
106 token,
107 ticket,
108 client_name,
109 client_version,
110 safe_mode: true,
111 }
112 }
113}
114
115pub async fn run_bot<B: Bot>(
124 url: &str,
125 mut bot: B,
126 options: RunBotOptions,
127) -> Result<Option<Value>, Error> {
128 let match_id = options
129 .match_id
130 .clone()
131 .unwrap_or_else(|| _extract_match_id(url));
132 let client_version = options
133 .client_version
134 .clone()
135 .unwrap_or_else(|| DEFAULT_CLIENT_VERSION.to_string());
136 let user_agent = options
137 .user_agent
138 .clone()
139 .unwrap_or_else(default_user_agent);
140 let ctx = SessionContext {
141 match_id,
142 token: options.token.clone(),
143 ticket: options.ticket.clone(),
144 client_name: options
145 .client_name
146 .clone()
147 .unwrap_or_else(|| DEFAULT_CLIENT_NAME.to_string()),
148 client_version,
149 safe_mode: options.safe_mode,
150 };
151
152 let policy = options.retry_policy;
153 let max_attempts = policy.max_reconnect_attempts;
154
155 let mut retries: u32 = 0;
156 loop {
157 let result: Result<Option<Value>, Error> = async {
158 let request = build_handshake_request(url, &user_agent)?;
159 let (ws_stream, _) = connect_async(request).await?;
160 let (mut write_half, mut read_half) = ws_stream.split();
161 let mut reader = WsReader {
162 inner: &mut read_half,
163 };
164 let mut writer = WsWriter {
165 inner: &mut write_half,
166 };
167 _run_session(&mut reader, &mut writer, &mut bot, &ctx).await
168 }
169 .await;
170
171 match result {
172 Ok(end) => return Ok(end),
173 Err(e @ Error::BotDecision(_)) => return Err(e),
177 Err(err) => {
178 retries += 1;
179 if retries > max_attempts {
180 return Err(Error::RetriesExhausted {
181 attempts: retries,
182 last_error: err.to_string(),
183 });
184 }
185 let backoff_ms = policy.backoff_ms(retries);
186 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
187 }
188 }
189 }
190}
191
192fn build_handshake_request(
196 url: &str,
197 user_agent: &str,
198) -> Result<tokio_tungstenite::tungstenite::handshake::client::Request, Error> {
199 let mut request = url.into_client_request().map_err(Error::from)?;
200 if let Ok(value) = user_agent.parse() {
201 request.headers_mut().insert(USER_AGENT, value);
202 }
203 Ok(request)
204}
205
206#[async_trait]
214pub trait MessageReader: Send {
215 async fn next(&mut self) -> Result<Option<String>, Error>;
220}
221
222#[async_trait]
224pub trait MessageWriter: Send {
225 async fn send(&mut self, payload: String) -> Result<(), Error>;
226}
227
228#[async_trait]
232impl MessageReader for Box<dyn MessageReader> {
233 async fn next(&mut self) -> Result<Option<String>, Error> {
234 (**self).next().await
235 }
236}
237
238#[async_trait]
239impl MessageWriter for Box<dyn MessageWriter> {
240 async fn send(&mut self, payload: String) -> Result<(), Error> {
241 (**self).send(payload).await
242 }
243}
244
245pub async fn _run_session<R, W, B>(
254 reader: &mut R,
255 writer: &mut W,
256 bot: &mut B,
257 ctx: &SessionContext,
258) -> Result<Option<Value>, Error>
259where
260 R: MessageReader,
261 W: MessageWriter,
262 B: Bot,
263{
264 let mut auth = json!({
266 "type": "authenticate",
267 "match_id": ctx.match_id,
268 "client_name": ctx.client_name,
269 "client_version": ctx.client_version,
270 });
271 if let Some(t) = ctx.token.as_deref().filter(|s| !s.is_empty()) {
272 auth["token"] = Value::String(t.to_string());
273 } else if let Some(t) = ctx.ticket.as_deref().filter(|s| !s.is_empty()) {
274 auth["ticket"] = Value::String(t.to_string());
275 } else {
276 auth["token"] = Value::String(String::new());
277 }
278 writer.send(auth.to_string()).await?;
279
280 let hello_raw = reader.next().await?.ok_or(Error::ConnectionClosed {
281 context: "server hello",
282 })?;
283 let hello: Value = serde_json::from_str(&hello_raw)?;
284 if hello.get("type").and_then(|v| v.as_str()) != Some("hello") {
285 return Err(Error::Protocol(format!(
286 "expected server hello, got {:?}",
287 hello.get("type")
288 )));
289 }
290
291 let client_hello = json!({
292 "type": "hello",
293 "match_id": ctx.match_id,
294 "supported_versions": SUPPORTED_PROTOCOL_VERSIONS,
295 "client_name": ctx.client_name,
296 "client_version": ctx.client_version,
297 });
298 writer.send(client_hello.to_string()).await?;
299
300 let mut last_seq: i64 = 0;
302 while let Some(raw) = reader.next().await? {
303 let msg: Value = match serde_json::from_str(&raw) {
304 Ok(v) => v,
305 Err(_) => continue,
309 };
310
311 if let Some(seq) = msg.get("seq").and_then(Value::as_i64) {
312 if seq <= last_seq {
313 continue; }
315 last_seq = seq;
316 }
317
318 let mtype = msg.get("type").and_then(|v| v.as_str()).unwrap_or("");
319 match mtype {
320 "ping" => {
321 let pong = json!({ "type": "pong", "match_id": ctx.match_id });
322 writer.send(pong.to_string()).await?;
323 }
324 "match_start" => bot.on_match_start(&msg),
325 "round_start" => bot.on_round_start(&msg),
326 "phase_change" => bot.on_phase_change(&msg),
327 "turn_result" => bot.on_turn_result(&msg),
328 "round_result" => bot.on_round_result(&msg),
329 "turn_request" => {
330 let request_id = msg
331 .get("request_id")
332 .and_then(|v| v.as_str())
333 .unwrap_or("")
334 .to_string();
335 let state = parse_game_state(&msg);
336 let (action, latency_ms) = decide_timed(bot, &state, &msg, ctx.safe_mode)?;
337 send_turn_action(writer, &ctx.match_id, &request_id, action).await?;
338 bot.on_decision_latency(latency_ms);
339 }
340 "action_rejected" => {
341 let request_id = msg
342 .get("request_id")
343 .and_then(|v| v.as_str())
344 .unwrap_or("")
345 .to_string();
346 let valid_actions: Vec<String> = msg
347 .get("valid_actions")
348 .and_then(|v| v.as_array())
349 .map(|arr| {
350 arr.iter()
351 .filter_map(|v| v.as_str().map(String::from))
352 .collect()
353 })
354 .unwrap_or_else(|| vec!["fold".to_string()]);
355 let fallback = _safe_fallback_action(&valid_actions);
356 send_turn_action(writer, &ctx.match_id, &request_id, fallback).await?;
357 }
358 "reconnected" => {
359 if let Some(pending) = msg.get("pending_request") {
362 if pending.get("type").and_then(|v| v.as_str()) == Some("turn_request") {
363 let request_id = pending
364 .get("request_id")
365 .and_then(|v| v.as_str())
366 .unwrap_or("")
367 .to_string();
368 let state = parse_game_state(pending);
369 let (action, latency_ms) =
370 decide_timed(bot, &state, pending, ctx.safe_mode)?;
371 send_turn_action(writer, &ctx.match_id, &request_id, action).await?;
372 bot.on_decision_latency(latency_ms);
373 }
374 }
375 }
376 "match_end" => {
377 let results = msg.get("results").cloned().unwrap_or_else(|| msg.clone());
378 bot.on_match_end(&results);
379 return Ok(Some(msg));
380 }
381 "error" => {
382 }
385 _ => {
386 }
388 }
389 }
390
391 Ok(None)
393}
394
395async fn send_turn_action<W: MessageWriter>(
397 writer: &mut W,
398 match_id: &str,
399 request_id: &str,
400 action: Action,
401) -> Result<(), Error> {
402 let (action_str, params) = action.to_wire();
403 let payload = json!({
404 "type": "turn_action",
405 "match_id": match_id,
406 "request_id": request_id,
407 "action": action_str,
408 "params": params,
409 });
410 writer.send(payload.to_string()).await
411}
412
413fn decide_timed<B: Bot>(
424 bot: &mut B,
425 state: &crate::models::GameState,
426 msg: &Value,
427 safe_mode: bool,
428) -> Result<(Action, f64), Error> {
429 let start = std::time::Instant::now();
430 let outcome = std::panic::catch_unwind(AssertUnwindSafe(|| bot.decide(state)));
431 let latency_ms = start.elapsed().as_secs_f64() * 1000.0;
432
433 let action = match outcome {
434 Ok(action) => action,
435 Err(payload) => {
436 let detail = panic_message(payload.as_ref());
437 if !safe_mode {
438 return Err(Error::BotDecision(detail));
439 }
440 _safe_fallback_action(&state.valid_actions)
443 }
444 };
445
446 if action_is_legal(&action, &state.valid_actions) {
447 Ok((action, latency_ms))
448 } else {
449 let valid = msg
450 .get("valid_actions")
451 .and_then(|v| v.as_array())
452 .map(|arr| {
453 arr.iter()
454 .filter_map(|v| v.as_str().map(String::from))
455 .collect::<Vec<_>>()
456 })
457 .unwrap_or_else(|| state.valid_actions.clone());
458 Ok((_safe_fallback_action(&valid), latency_ms))
459 }
460}
461
462fn panic_message(payload: &(dyn std::any::Any + Send)) -> String {
465 if let Some(s) = payload.downcast_ref::<&str>() {
466 (*s).to_string()
467 } else if let Some(s) = payload.downcast_ref::<String>() {
468 s.clone()
469 } else {
470 "decide() panicked".to_string()
471 }
472}
473
474fn action_is_legal(action: &Action, valid: &[String]) -> bool {
475 let needed = action.kind().as_str();
476 valid.iter().any(|v| v == needed)
477}
478
479pub fn _safe_fallback_action(valid_actions: &[String]) -> Action {
483 if valid_actions.iter().any(|a| a == "check") {
484 Action::Check
485 } else {
486 Action::Fold
487 }
488}
489
490pub fn _extract_match_id(url: &str) -> String {
496 let needle = "/ws/match/";
497 let Some(start) = url.find(needle) else {
498 return String::new();
499 };
500 let after = &url[start + needle.len()..];
501 let end = after.find(['/', '?', '#']).unwrap_or(after.len());
502 after[..end].to_string()
503}
504
505struct WsReader<'a, S>
510where
511 S: StreamExt<Item = Result<Message, WsError>> + Unpin,
512{
513 inner: &'a mut S,
514}
515
516#[async_trait]
517impl<'a, S> MessageReader for WsReader<'a, S>
518where
519 S: StreamExt<Item = Result<Message, WsError>> + Unpin + Send,
520{
521 async fn next(&mut self) -> Result<Option<String>, Error> {
522 loop {
523 match self.inner.next().await {
524 Some(Ok(Message::Text(t))) => return Ok(Some(t.to_string())),
525 Some(Ok(Message::Ping(_))) => {
526 continue;
531 }
532 Some(Ok(Message::Close(_))) | None => return Ok(None),
533 Some(Ok(_)) => continue, Some(Err(e)) => return Err(Error::from(e)),
535 }
536 }
537 }
538}
539
540struct WsWriter<'a, S>
541where
542 S: SinkExt<Message, Error = WsError> + Unpin,
543{
544 inner: &'a mut S,
545}
546
547#[async_trait]
548impl<'a, S> MessageWriter for WsWriter<'a, S>
549where
550 S: SinkExt<Message, Error = WsError> + Unpin + Send,
551{
552 async fn send(&mut self, payload: String) -> Result<(), Error> {
553 self.inner
554 .send(Message::Text(payload))
555 .await
556 .map_err(Error::from)
557 }
558}