Skip to main content

chipzen_bot/
client.rs

1//! WebSocket client for the Chipzen two-layer protocol.
2//!
3//! The user-facing surface is [`run_bot`]. Internals (the session
4//! loop, helper extractors, the `MessageReader`/`MessageWriter`
5//! traits) are exported with an underscore prefix or behind
6//! `#[doc(hidden)]` for the conformance harness — they are not part
7//! of the supported API.
8
9use crate::bot::Bot;
10use crate::error::Error;
11use crate::models::{parse_game_state, Action};
12use crate::retry::RetryPolicy;
13use async_trait::async_trait;
14use futures_util::{SinkExt, StreamExt};
15use serde_json::{json, Value};
16use std::panic::AssertUnwindSafe;
17use tokio_tungstenite::{
18    connect_async,
19    tungstenite::{client::IntoClientRequest, http::header::USER_AGENT, Error as WsError, Message},
20};
21
22/// Protocol versions this client claims to support in the handshake.
23pub const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &["1.0"];
24
25const DEFAULT_CLIENT_NAME: &str = "chipzen-sdk-rust";
26const DEFAULT_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
27
28/// Default `User-Agent` header sent on the WebSocket handshake. A
29/// non-default UA also clears the platform's Cloudflare bot-fight rule
30/// (chipzen-ai/chipzen-sdk#46).
31pub fn default_user_agent() -> String {
32    format!("chipzen-sdk-rust/{DEFAULT_CLIENT_VERSION}")
33}
34
35/// Optional knobs for [`run_bot`]. Defaults match the platform's
36/// expectations.
37#[derive(Debug, Clone)]
38pub struct RunBotOptions {
39    /// Bot API token. Required for the `/bot` endpoint; empty is fine
40    /// for local dev.
41    pub token: Option<String>,
42    /// Single-use ticket alternative to `token` (competitive
43    /// endpoints).
44    pub ticket: Option<String>,
45    /// Match UUID. Auto-extracted from the URL if `None`.
46    pub match_id: Option<String>,
47    /// Client software name sent in the `hello` handshake.
48    pub client_name: Option<String>,
49    /// Client software version sent in the `hello` handshake. Defaults to
50    /// the crate version (chipzen-ai/chipzen-sdk#41).
51    pub client_version: Option<String>,
52    /// Reconnect-pacing policy (attempt cap + exponential backoff).
53    pub retry_policy: RetryPolicy,
54    /// When `true` (default), a panic in `decide()` is caught and folded —
55    /// a transient bug won't forfeit a competitive match. Set `false` for
56    /// dev/eval so the first panic propagates as [`Error::BotDecision`] and
57    /// exits non-zero (chipzen-ai/chipzen-sdk#52).
58    pub safe_mode: bool,
59    /// Override the WS `User-Agent` header. Defaults to
60    /// `chipzen-sdk-rust/<version>` (chipzen-ai/chipzen-sdk#46).
61    pub user_agent: Option<String>,
62}
63
64impl Default for RunBotOptions {
65    fn default() -> Self {
66        Self {
67            token: None,
68            ticket: None,
69            match_id: None,
70            client_name: None,
71            client_version: None,
72            retry_policy: RetryPolicy::default(),
73            safe_mode: true,
74            user_agent: None,
75        }
76    }
77}
78
79/// Opaque per-session bag the session loop threads through. Public so
80/// the conformance harness can construct one.
81#[derive(Debug, Clone)]
82pub struct SessionContext {
83    pub match_id: String,
84    pub token: Option<String>,
85    pub ticket: Option<String>,
86    pub client_name: String,
87    pub client_version: String,
88    /// When `false`, a panicking `decide()` surfaces as
89    /// [`Error::BotDecision`] instead of being folded to a safe action.
90    pub safe_mode: bool,
91}
92
93impl SessionContext {
94    /// A context with `safe_mode` on — the common case. Lets existing
95    /// callers (tests, conformance harness) construct a context with the
96    /// historical field set without naming `safe_mode` explicitly.
97    pub fn new(
98        match_id: String,
99        token: Option<String>,
100        ticket: Option<String>,
101        client_name: String,
102        client_version: String,
103    ) -> Self {
104        Self {
105            match_id,
106            token,
107            ticket,
108            client_name,
109            client_version,
110            safe_mode: true,
111        }
112    }
113}
114
115/// Connect a bot to the Chipzen server and play until the match ends.
116///
117/// Returns the `match_end` payload on a clean finish, or `None` if the
118/// connection closed without a clean `match_end` after exhausting the
119/// retry budget. Returns [`Error::RetriesExhausted`] if the connection
120/// cannot be established after `retry_policy.max_reconnect_attempts`
121/// attempts, or [`Error::BotDecision`] if `decide()` panics under
122/// `safe_mode = false`.
123pub async fn run_bot<B: Bot>(
124    url: &str,
125    mut bot: B,
126    options: RunBotOptions,
127) -> Result<Option<Value>, Error> {
128    let match_id = options
129        .match_id
130        .clone()
131        .unwrap_or_else(|| _extract_match_id(url));
132    let client_version = options
133        .client_version
134        .clone()
135        .unwrap_or_else(|| DEFAULT_CLIENT_VERSION.to_string());
136    let user_agent = options
137        .user_agent
138        .clone()
139        .unwrap_or_else(default_user_agent);
140    let ctx = SessionContext {
141        match_id,
142        token: options.token.clone(),
143        ticket: options.ticket.clone(),
144        client_name: options
145            .client_name
146            .clone()
147            .unwrap_or_else(|| DEFAULT_CLIENT_NAME.to_string()),
148        client_version,
149        safe_mode: options.safe_mode,
150    };
151
152    let policy = options.retry_policy;
153    let max_attempts = policy.max_reconnect_attempts;
154
155    let mut retries: u32 = 0;
156    loop {
157        let result: Result<Option<Value>, Error> = async {
158            let request = build_handshake_request(url, &user_agent)?;
159            let (ws_stream, _) = connect_async(request).await?;
160            let (mut write_half, mut read_half) = ws_stream.split();
161            let mut reader = WsReader {
162                inner: &mut read_half,
163            };
164            let mut writer = WsWriter {
165                inner: &mut write_half,
166            };
167            _run_session(&mut reader, &mut writer, &mut bot, &ctx).await
168        }
169        .await;
170
171        match result {
172            Ok(end) => return Ok(end),
173            // A deterministic bot bug (safe_mode off) — terminal, not a
174            // transient disconnect. Do not reconnect-retry; propagate so the
175            // process exits non-zero.
176            Err(e @ Error::BotDecision(_)) => return Err(e),
177            Err(err) => {
178                retries += 1;
179                if retries > max_attempts {
180                    return Err(Error::RetriesExhausted {
181                        attempts: retries,
182                        last_error: err.to_string(),
183                    });
184                }
185                let backoff_ms = policy.backoff_ms(retries);
186                tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
187            }
188        }
189    }
190}
191
192/// Build the tungstenite handshake request for `url`, attaching the
193/// `User-Agent` header. tokio-tungstenite does not set a UA by default;
194/// some hosts (Cloudflare bot-fight) reject the empty UA.
195fn build_handshake_request(
196    url: &str,
197    user_agent: &str,
198) -> Result<tokio_tungstenite::tungstenite::handshake::client::Request, Error> {
199    let mut request = url.into_client_request().map_err(Error::from)?;
200    if let Ok(value) = user_agent.parse() {
201        request.headers_mut().insert(USER_AGENT, value);
202    }
203    Ok(request)
204}
205
206// ---------------------------------------------------------------------------
207// Session loop — internal, exposed for the conformance harness
208// ---------------------------------------------------------------------------
209
210/// Pull-based async iterator over inbound messages. Real impl wraps
211/// `tokio_tungstenite::WebSocketStream`; the conformance harness
212/// provides a scripted impl.
213#[async_trait]
214pub trait MessageReader: Send {
215    /// Returns the next message as a UTF-8 string, or `None` if the
216    /// underlying transport has closed cleanly. Errors should be
217    /// surfaced through [`Error`] rather than this return type so the
218    /// session loop can decide whether to retry.
219    async fn next(&mut self) -> Result<Option<String>, Error>;
220}
221
222/// Push-based async sender for outbound messages.
223#[async_trait]
224pub trait MessageWriter: Send {
225    async fn send(&mut self, payload: String) -> Result<(), Error>;
226}
227
228// Boxed trait objects forward to their inner impl so callers (e.g. the
229// external-API transport, which returns `Box<dyn ...>`) can hand them to
230// the generic [`_run_session`] without re-wrapping.
231#[async_trait]
232impl MessageReader for Box<dyn MessageReader> {
233    async fn next(&mut self) -> Result<Option<String>, Error> {
234        (**self).next().await
235    }
236}
237
238#[async_trait]
239impl MessageWriter for Box<dyn MessageWriter> {
240    async fn send(&mut self, payload: String) -> Result<(), Error> {
241        (**self).send(payload).await
242    }
243}
244
245/// Drive a single connected session: handshake + message loop until
246/// `match_end`. Public-but-hidden so the conformance harness (and the
247/// external-API gateway leg) can reuse it against any transport.
248///
249/// Returns the `match_end` payload (the full envelope) on a clean
250/// finish, or `None` if the socket closed without a `match_end` (a drop
251/// the caller may reconnect through). Returns [`Error::BotDecision`] if
252/// `decide()` panics under `ctx.safe_mode = false`.
253pub async fn _run_session<R, W, B>(
254    reader: &mut R,
255    writer: &mut W,
256    bot: &mut B,
257    ctx: &SessionContext,
258) -> Result<Option<Value>, Error>
259where
260    R: MessageReader,
261    W: MessageWriter,
262    B: Bot,
263{
264    // --- Layer 1 handshake ----------------------------------------------------
265    let mut auth = json!({
266        "type": "authenticate",
267        "match_id": ctx.match_id,
268        "client_name": ctx.client_name,
269        "client_version": ctx.client_version,
270    });
271    if let Some(t) = ctx.token.as_deref().filter(|s| !s.is_empty()) {
272        auth["token"] = Value::String(t.to_string());
273    } else if let Some(t) = ctx.ticket.as_deref().filter(|s| !s.is_empty()) {
274        auth["ticket"] = Value::String(t.to_string());
275    } else {
276        auth["token"] = Value::String(String::new());
277    }
278    writer.send(auth.to_string()).await?;
279
280    let hello_raw = reader.next().await?.ok_or(Error::ConnectionClosed {
281        context: "server hello",
282    })?;
283    let hello: Value = serde_json::from_str(&hello_raw)?;
284    if hello.get("type").and_then(|v| v.as_str()) != Some("hello") {
285        return Err(Error::Protocol(format!(
286            "expected server hello, got {:?}",
287            hello.get("type")
288        )));
289    }
290
291    let client_hello = json!({
292        "type": "hello",
293        "match_id": ctx.match_id,
294        "supported_versions": SUPPORTED_PROTOCOL_VERSIONS,
295        "client_name": ctx.client_name,
296        "client_version": ctx.client_version,
297    });
298    writer.send(client_hello.to_string()).await?;
299
300    // --- Message loop ---------------------------------------------------------
301    let mut last_seq: i64 = 0;
302    while let Some(raw) = reader.next().await? {
303        let msg: Value = match serde_json::from_str(&raw) {
304            Ok(v) => v,
305            // Malformed envelope — log + continue. Real production
306            // deployments never emit invalid JSON; this is for
307            // adversarial-input robustness.
308            Err(_) => continue,
309        };
310
311        if let Some(seq) = msg.get("seq").and_then(Value::as_i64) {
312            if seq <= last_seq {
313                continue; // sequence regression / retransmit
314            }
315            last_seq = seq;
316        }
317
318        let mtype = msg.get("type").and_then(|v| v.as_str()).unwrap_or("");
319        match mtype {
320            "ping" => {
321                let pong = json!({ "type": "pong", "match_id": ctx.match_id });
322                writer.send(pong.to_string()).await?;
323            }
324            "match_start" => bot.on_match_start(&msg),
325            "round_start" => bot.on_round_start(&msg),
326            "phase_change" => bot.on_phase_change(&msg),
327            "turn_result" => bot.on_turn_result(&msg),
328            "round_result" => bot.on_round_result(&msg),
329            "turn_request" => {
330                let request_id = msg
331                    .get("request_id")
332                    .and_then(|v| v.as_str())
333                    .unwrap_or("")
334                    .to_string();
335                let state = parse_game_state(&msg);
336                let (action, latency_ms) = decide_timed(bot, &state, &msg, ctx.safe_mode)?;
337                send_turn_action(writer, &ctx.match_id, &request_id, action).await?;
338                bot.on_decision_latency(latency_ms);
339            }
340            "action_rejected" => {
341                let request_id = msg
342                    .get("request_id")
343                    .and_then(|v| v.as_str())
344                    .unwrap_or("")
345                    .to_string();
346                let valid_actions: Vec<String> = msg
347                    .get("valid_actions")
348                    .and_then(|v| v.as_array())
349                    .map(|arr| {
350                        arr.iter()
351                            .filter_map(|v| v.as_str().map(String::from))
352                            .collect()
353                    })
354                    .unwrap_or_else(|| vec!["fold".to_string()]);
355                let fallback = _safe_fallback_action(&valid_actions);
356                send_turn_action(writer, &ctx.match_id, &request_id, fallback).await?;
357            }
358            "reconnected" => {
359                // Mid-session after a reconnect: replay the pending request as
360                // if it were a fresh turn_request so the bot acts on it.
361                if let Some(pending) = msg.get("pending_request") {
362                    if pending.get("type").and_then(|v| v.as_str()) == Some("turn_request") {
363                        let request_id = pending
364                            .get("request_id")
365                            .and_then(|v| v.as_str())
366                            .unwrap_or("")
367                            .to_string();
368                        let state = parse_game_state(pending);
369                        let (action, latency_ms) =
370                            decide_timed(bot, &state, pending, ctx.safe_mode)?;
371                        send_turn_action(writer, &ctx.match_id, &request_id, action).await?;
372                        bot.on_decision_latency(latency_ms);
373                    }
374                }
375            }
376            "match_end" => {
377                let results = msg.get("results").cloned().unwrap_or_else(|| msg.clone());
378                bot.on_match_end(&results);
379                return Ok(Some(msg));
380            }
381            "error" => {
382                // Non-fatal — production deployments log; the SDK stays
383                // quiet so user code controls the logging surface.
384            }
385            _ => {
386                // Forward-compat: silently ignore unknown message types.
387            }
388        }
389    }
390
391    // Stream closed without match_end. Caller decides whether to retry.
392    Ok(None)
393}
394
395/// Send a `turn_action` envelope echoing `request_id`.
396async fn send_turn_action<W: MessageWriter>(
397    writer: &mut W,
398    match_id: &str,
399    request_id: &str,
400    action: Action,
401) -> Result<(), Error> {
402    let (action_str, params) = action.to_wire();
403    let payload = json!({
404        "type": "turn_action",
405        "match_id": match_id,
406        "request_id": request_id,
407        "action": action_str,
408        "params": params,
409    });
410    writer.send(payload.to_string()).await
411}
412
413/// Run `bot.decide(state)` with timing + safe_mode handling, returning
414/// `(action, latency_ms)`.
415///
416/// A panic in `decide()` is caught. Under `safe_mode` it is folded to a
417/// safe-fallback action (so a transient bug doesn't forfeit a competitive
418/// match); under `safe_mode = false` it surfaces as [`Error::BotDecision`]
419/// so the caller treats it as terminal. If `decide()` returns an action
420/// that isn't legal for the current `valid_actions`, the safe fallback is
421/// substituted regardless of `safe_mode` (mirrors the existing Rust + JS
422/// behavior).
423fn decide_timed<B: Bot>(
424    bot: &mut B,
425    state: &crate::models::GameState,
426    msg: &Value,
427    safe_mode: bool,
428) -> Result<(Action, f64), Error> {
429    let start = std::time::Instant::now();
430    let outcome = std::panic::catch_unwind(AssertUnwindSafe(|| bot.decide(state)));
431    let latency_ms = start.elapsed().as_secs_f64() * 1000.0;
432
433    let action = match outcome {
434        Ok(action) => action,
435        Err(payload) => {
436            let detail = panic_message(payload.as_ref());
437            if !safe_mode {
438                return Err(Error::BotDecision(detail));
439            }
440            // safe_mode: fold the panic into a safe action so the match
441            // continues. valid_actions drives check-vs-fold.
442            _safe_fallback_action(&state.valid_actions)
443        }
444    };
445
446    if action_is_legal(&action, &state.valid_actions) {
447        Ok((action, latency_ms))
448    } else {
449        let valid = msg
450            .get("valid_actions")
451            .and_then(|v| v.as_array())
452            .map(|arr| {
453                arr.iter()
454                    .filter_map(|v| v.as_str().map(String::from))
455                    .collect::<Vec<_>>()
456            })
457            .unwrap_or_else(|| state.valid_actions.clone());
458        Ok((_safe_fallback_action(&valid), latency_ms))
459    }
460}
461
462/// Best-effort extraction of a panic message from the `catch_unwind`
463/// payload (the common `&str` / `String` cases; otherwise a generic note).
464fn panic_message(payload: &(dyn std::any::Any + Send)) -> String {
465    if let Some(s) = payload.downcast_ref::<&str>() {
466        (*s).to_string()
467    } else if let Some(s) = payload.downcast_ref::<String>() {
468        s.clone()
469    } else {
470        "decide() panicked".to_string()
471    }
472}
473
474fn action_is_legal(action: &Action, valid: &[String]) -> bool {
475    let needed = action.kind().as_str();
476    valid.iter().any(|v| v == needed)
477}
478
479/// Pick a safe action from the legal set: prefer `check`, fall back
480/// to `fold`. Public-but-hidden so the conformance harness can use
481/// the same logic when it doesn't have a real bot to drive.
482pub fn _safe_fallback_action(valid_actions: &[String]) -> Action {
483    if valid_actions.iter().any(|a| a == "check") {
484        Action::Check
485    } else {
486        Action::Fold
487    }
488}
489
490/// Pull `match_id` out of a Chipzen WebSocket URL. Path shape is
491/// `.../ws/match/<match_id>/...`. Returns an empty string if the URL
492/// doesn't match the expected pattern. Permissive on the inner shape
493/// — server-side IDs may be UUIDs, shortened hashes, or namespaced
494/// strings like `m_abc_123`.
495pub fn _extract_match_id(url: &str) -> String {
496    let needle = "/ws/match/";
497    let Some(start) = url.find(needle) else {
498        return String::new();
499    };
500    let after = &url[start + needle.len()..];
501    let end = after.find(['/', '?', '#']).unwrap_or(after.len());
502    after[..end].to_string()
503}
504
505// ---------------------------------------------------------------------------
506// Real WebSocket adapters — bridge tokio-tungstenite to the trait surface
507// ---------------------------------------------------------------------------
508
509struct WsReader<'a, S>
510where
511    S: StreamExt<Item = Result<Message, WsError>> + Unpin,
512{
513    inner: &'a mut S,
514}
515
516#[async_trait]
517impl<'a, S> MessageReader for WsReader<'a, S>
518where
519    S: StreamExt<Item = Result<Message, WsError>> + Unpin + Send,
520{
521    async fn next(&mut self) -> Result<Option<String>, Error> {
522        loop {
523            match self.inner.next().await {
524                Some(Ok(Message::Text(t))) => return Ok(Some(t.to_string())),
525                Some(Ok(Message::Ping(_))) => {
526                    // Tungstenite auto-replies to control pings, but
527                    // some peers send Ping as a Text-style heartbeat.
528                    // Surface it so the session loop's `ping` handler
529                    // can respond if needed.
530                    continue;
531                }
532                Some(Ok(Message::Close(_))) | None => return Ok(None),
533                Some(Ok(_)) => continue, // binary, pong, etc.
534                Some(Err(e)) => return Err(Error::from(e)),
535            }
536        }
537    }
538}
539
540struct WsWriter<'a, S>
541where
542    S: SinkExt<Message, Error = WsError> + Unpin,
543{
544    inner: &'a mut S,
545}
546
547#[async_trait]
548impl<'a, S> MessageWriter for WsWriter<'a, S>
549where
550    S: SinkExt<Message, Error = WsError> + Unpin + Send,
551{
552    async fn send(&mut self, payload: String) -> Result<(), Error> {
553        self.inner
554            .send(Message::Text(payload))
555            .await
556            .map_err(Error::from)
557    }
558}