use crate::bot::Bot;
use crate::client::{
_run_session, default_user_agent, MessageReader, MessageWriter, SessionContext,
};
use crate::config::{load_chipzen_config, resolve_token, ChipzenConfig};
use crate::connect::{connect_to_chipzen, EnvName};
use crate::error::Error;
use crate::retry::RetryPolicy;
use async_trait::async_trait;
use serde_json::{json, Value};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::task::JoinSet;
pub const BOT_TOKEN_SUBPROTOCOL: &str = "chipzen-bot-token";
static LOBBY_RECV_TIMEOUT_MS: AtomicU64 = AtomicU64::new(2_000);
fn lobby_recv_timeout() -> Duration {
Duration::from_millis(LOBBY_RECV_TIMEOUT_MS.load(Ordering::Relaxed))
}
#[doc(hidden)]
pub fn _set_lobby_recv_timeout_ms(ms: u64) {
LOBBY_RECV_TIMEOUT_MS.store(ms, Ordering::Relaxed);
}
const MATCH_DRAIN_GRACE: Duration = Duration::from_secs(5);
const DEFAULT_CLIENT_NAME: &str = "chipzen-sdk-rust";
pub fn bot_token_subprotocols(token: &str) -> Vec<String> {
vec![BOT_TOKEN_SUBPROTOCOL.to_string(), token.to_string()]
}
fn normalise_base(url: &str) -> String {
let (scheme, rest) = match url.split_once("://") {
Some((s, r)) => (s, r),
None => ("wss", url),
};
let end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
let authority = &rest[..end];
format!("{scheme}://{authority}")
}
fn split_origin(url: &str) -> (&str, &str) {
let (scheme, rest) = match url.split_once("://") {
Some((s, r)) => (s, r),
None => ("wss", url),
};
let end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
(scheme, &rest[..end])
}
pub fn resolve_gateway_url(lobby_url: &str, gateway_ws_path: &str) -> Result<String, Error> {
if gateway_ws_path.starts_with("ws://") || gateway_ws_path.starts_with("wss://") {
let lobby_base = normalise_base(lobby_url);
let (lobby_scheme, lobby_authority) = split_origin(&lobby_base);
let (gw_scheme, gw_authority) = split_origin(gateway_ws_path);
let downgrade = lobby_scheme == "wss" && gw_scheme != "wss";
if gw_authority != lobby_authority || downgrade {
return Err(Error::UntrustedGateway(format!(
"{gateway_ws_path:?}: cross-origin or insecure relative to lobby \
{lobby_scheme}://{lobby_authority} (the bot token must not be sent to a \
different host or in cleartext)"
)));
}
return Ok(gateway_ws_path.to_string());
}
Ok(format!("{}{}", normalise_base(lobby_url), gateway_ws_path))
}
fn loads(raw: &str) -> Value {
match serde_json::from_str::<Value>(raw) {
Ok(v @ Value::Object(_)) => v,
_ => json!({}),
}
}
#[derive(Debug, Clone)]
pub struct MatchResult {
pub match_id: Option<String>,
pub end: Option<Value>,
}
#[derive(Debug, Clone, Default)]
pub struct RunExternalOptions {
pub bot_id: Option<String>,
pub env: Option<EnvName>,
pub url: Option<String>,
pub token: Option<String>,
pub config: Option<ChipzenConfig>,
pub retry_policy: Option<RetryPolicy>,
pub client_name: Option<String>,
pub client_version: Option<String>,
pub safe_mode: Option<bool>,
pub max_matches: Option<u64>,
pub user_agent: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct RunExternalArgs {
pub env: Option<EnvName>,
pub token: Option<String>,
pub bot_id: Option<String>,
pub max_matches: Option<u64>,
pub safe_mode: bool,
}
impl RunExternalArgs {
pub fn new() -> Self {
Self {
env: None,
token: None,
bot_id: None,
max_matches: None,
safe_mode: true,
}
}
}
pub async fn run_external_cli<B, F>(
factory: F,
args: RunExternalArgs,
) -> Result<Vec<MatchResult>, Error>
where
B: Bot,
F: Fn() -> B + Send + Sync + 'static,
{
let options = RunExternalOptions {
bot_id: args.bot_id,
env: args.env,
token: args.token,
max_matches: args.max_matches,
safe_mode: Some(args.safe_mode),
..Default::default()
};
run_external_bot(factory, options).await
}
#[async_trait]
pub trait LobbyTransport: Send + Sync + 'static {
async fn connect_lobby(
&self,
url: &str,
user_agent: &str,
) -> Result<(Box<dyn MessageReader>, Box<dyn MessageWriter>), Error>;
async fn connect_gateway(
&self,
url: &str,
token: &str,
user_agent: &str,
) -> Result<(Box<dyn MessageReader>, Box<dyn MessageWriter>), Error>;
}
struct WsTransport;
#[async_trait]
impl LobbyTransport for WsTransport {
async fn connect_lobby(
&self,
url: &str,
user_agent: &str,
) -> Result<(Box<dyn MessageReader>, Box<dyn MessageWriter>), Error> {
ws_transport::connect(url, user_agent, None).await
}
async fn connect_gateway(
&self,
url: &str,
token: &str,
user_agent: &str,
) -> Result<(Box<dyn MessageReader>, Box<dyn MessageWriter>), Error> {
ws_transport::connect(url, user_agent, Some(token)).await
}
}
pub async fn run_external_bot<B, F>(
factory: F,
options: RunExternalOptions,
) -> Result<Vec<MatchResult>, Error>
where
B: Bot,
F: Fn() -> B + Send + Sync + 'static,
{
run_external_with_transport(factory, options, Arc::new(WsTransport)).await
}
pub async fn run_external_with_transport<B, F, T>(
factory: F,
options: RunExternalOptions,
transport: Arc<T>,
) -> Result<Vec<MatchResult>, Error>
where
B: Bot,
F: Fn() -> B + Send + Sync + 'static,
T: LobbyTransport,
{
let config = match options.config {
Some(c) => Some(c),
None => load_chipzen_config(None)?,
};
let (lobby_url, policy, resolved_token) = if let Some(url) = options.url {
let policy = options.retry_policy.unwrap_or_default();
let token = resolve_token(options.token.as_deref(), config.as_ref());
(url, policy, token)
} else {
let bot_id = options
.bot_id
.clone()
.or_else(|| config.as_ref().and_then(|c| c.bot_id.clone()));
let Some(bot_id) = bot_id.filter(|s| !s.is_empty()) else {
return Err(Error::Protocol(
"run_external_bot() needs a lobby URL. Set url, or bot_id (or \
[external_api].bot_id / url in chipzen.toml)."
.to_string(),
));
};
let conn = connect_to_chipzen(&bot_id, options.env, options.retry_policy, config.clone())?;
let token = match options.token.as_deref() {
Some(t) => Some(t.to_string()),
None => conn.token.clone(),
};
(conn.url, conn.retry_policy, token)
};
let Some(resolved_token) = resolved_token.filter(|t| !t.is_empty()) else {
return Err(Error::Protocol(
"run_external_bot() requires an external-API token (cz_extbot_...). Pass \
token, or set [external_api].token in chipzen.toml."
.to_string(),
));
};
let client_version = options
.client_version
.clone()
.unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
let user_agent = options
.user_agent
.clone()
.unwrap_or_else(default_user_agent);
let client_name = options
.client_name
.clone()
.unwrap_or_else(|| DEFAULT_CLIENT_NAME.to_string());
let safe_mode = options.safe_mode.unwrap_or(true);
let session = Arc::new(SessionParams {
token: resolved_token,
client_name,
client_version,
safe_mode,
user_agent,
policy,
max_matches: options.max_matches,
});
let results: Arc<Mutex<Vec<MatchResult>>> = Arc::new(Mutex::new(Vec::new()));
let completed = Arc::new(AtomicU64::new(0));
let stop = Arc::new(AtomicBool::new(false));
let fatal: Arc<Mutex<Option<Error>>> = Arc::new(Mutex::new(None));
let mut match_tasks: JoinSet<()> = JoinSet::new();
let mut consecutive_failures: u32 = 0;
let mut ever_connected = false;
let mut giveup: Option<Error> = None;
while !stop.load(Ordering::SeqCst) {
let run = run_lobby_once(
&transport,
&lobby_url,
&factory,
&session,
&results,
&completed,
&stop,
&fatal,
&mut match_tasks,
)
.await;
match run {
Ok(status) => {
ever_connected = true;
consecutive_failures = 0;
if matches!(status, LobbyStatus::Stopped | LobbyStatus::Evicted)
|| fatal.lock().unwrap().is_some()
{
break;
}
consecutive_failures += 1;
if consecutive_failures > session.policy.max_reconnect_attempts {
break;
}
let delay = session.policy.backoff_ms(consecutive_failures);
tokio::time::sleep(Duration::from_millis(delay)).await;
}
Err(exc) => {
consecutive_failures += 1;
if consecutive_failures > session.policy.max_reconnect_attempts {
if !ever_connected {
giveup = Some(exc);
}
break;
}
let delay = session.policy.backoff_ms(consecutive_failures);
tokio::time::sleep(Duration::from_millis(delay)).await;
}
}
}
drain_and_cancel(&mut match_tasks).await;
if let Some(err) = fatal.lock().unwrap().take() {
return Err(err);
}
if let Some(err) = giveup {
return Err(err);
}
let out = std::mem::take(&mut *results.lock().unwrap());
Ok(out)
}
struct SessionParams {
token: String,
client_name: String,
client_version: String,
safe_mode: bool,
user_agent: String,
policy: RetryPolicy,
max_matches: Option<u64>,
}
enum LobbyStatus {
Stopped,
Evicted,
Closed,
}
#[allow(clippy::too_many_arguments)]
async fn run_lobby_once<B, F, T>(
transport: &Arc<T>,
lobby_url: &str,
factory: &F,
session: &Arc<SessionParams>,
results: &Arc<Mutex<Vec<MatchResult>>>,
completed: &Arc<AtomicU64>,
stop: &Arc<AtomicBool>,
fatal: &Arc<Mutex<Option<Error>>>,
match_tasks: &mut JoinSet<()>,
) -> Result<LobbyStatus, Error>
where
B: Bot,
F: Fn() -> B + Send + Sync + 'static,
T: LobbyTransport,
{
let (mut reader, mut writer) = transport
.connect_lobby(lobby_url, &session.user_agent)
.await?;
writer
.send(json!({ "type": "authenticate", "token": session.token }).to_string())
.await?;
while !stop.load(Ordering::SeqCst) {
let recv = tokio::time::timeout(lobby_recv_timeout(), reader.next()).await;
let raw = match recv {
Err(_elapsed) => continue,
Ok(Ok(Some(raw))) => raw,
Ok(Ok(None)) => return Ok(LobbyStatus::Closed),
Ok(Err(_)) => return Ok(LobbyStatus::Closed),
};
let msg = loads(&raw);
match msg.get("type").and_then(|v| v.as_str()) {
Some("ping") => {
writer.send(json!({ "type": "pong" }).to_string()).await?;
}
Some("hello") => {
}
Some("matched") => {
let gateway_path = msg
.get("gateway_ws_url")
.and_then(|v| v.as_str())
.unwrap_or("");
let match_id = msg
.get("match_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let gateway_url = match resolve_gateway_url(lobby_url, gateway_path) {
Ok(url) => url,
Err(_) => continue,
};
let bot = factory();
let transport = Arc::clone(transport);
let session = Arc::clone(session);
let results = Arc::clone(results);
let completed = Arc::clone(completed);
let stop = Arc::clone(stop);
let fatal = Arc::clone(fatal);
match_tasks.spawn(async move {
let outcome =
play_one_match(&*transport, &gateway_url, &match_id, bot, &session).await;
record_match_outcome(
outcome, &match_id, &session, &results, &completed, &stop, &fatal,
);
});
}
Some("evict") => return Ok(LobbyStatus::Evicted),
_ => {
}
}
}
if stop.load(Ordering::SeqCst) {
Ok(LobbyStatus::Stopped)
} else {
Ok(LobbyStatus::Closed)
}
}
fn record_match_outcome(
outcome: Result<Option<Value>, Error>,
match_id: &str,
session: &SessionParams,
results: &Mutex<Vec<MatchResult>>,
completed: &AtomicU64,
stop: &AtomicBool,
fatal: &Mutex<Option<Error>>,
) {
match outcome {
Err(err @ Error::BotDecision(_)) => {
let mut slot = fatal.lock().unwrap();
if slot.is_none() {
*slot = Some(err);
}
stop.store(true, Ordering::SeqCst);
return;
}
Err(_other) => {
results.lock().unwrap().push(MatchResult {
match_id: Some(match_id.to_string()),
end: None,
});
}
Ok(end) => {
let match_id = end
.as_ref()
.and_then(|e| e.get("match_id"))
.and_then(|v| v.as_str())
.map(String::from)
.or_else(|| Some(match_id.to_string()));
completed.fetch_add(1, Ordering::SeqCst);
results.lock().unwrap().push(MatchResult { match_id, end });
}
}
if let Some(max) = session.max_matches {
if completed.load(Ordering::SeqCst) >= max {
stop.store(true, Ordering::SeqCst);
}
}
}
async fn play_one_match<B, T>(
transport: &T,
gateway_url: &str,
match_id: &str,
mut bot: B,
session: &SessionParams,
) -> Result<Option<Value>, Error>
where
B: Bot,
T: LobbyTransport,
{
let ctx = SessionContext {
match_id: match_id.to_string(),
token: None,
ticket: None,
client_name: session.client_name.clone(),
client_version: session.client_version.clone(),
safe_mode: session.safe_mode,
};
let mut attempt: u32 = 0;
loop {
let connect = transport
.connect_gateway(gateway_url, &session.token, &session.user_agent)
.await;
match connect {
Ok((mut reader, mut writer)) => {
match _run_session(&mut reader, &mut writer, &mut bot, &ctx).await {
Ok(Some(end)) => return Ok(Some(end)),
Ok(None) => {}
Err(e @ Error::BotDecision(_)) => return Err(e),
Err(_) => {}
}
}
Err(Error::BotDecision(_)) => unreachable!("connect cannot raise BotDecision"),
Err(_) => {}
}
attempt += 1;
if attempt > session.policy.max_reconnect_attempts {
return Ok(None); }
let delay = session.policy.backoff_ms(attempt);
tokio::time::sleep(Duration::from_millis(delay)).await;
}
}
async fn drain_and_cancel(match_tasks: &mut JoinSet<()>) {
if match_tasks.is_empty() {
return;
}
let drain = async { while match_tasks.join_next().await.is_some() {} };
if tokio::time::timeout(MATCH_DRAIN_GRACE, drain)
.await
.is_err()
{
match_tasks.abort_all();
while match_tasks.join_next().await.is_some() {}
}
}
mod ws_transport {
use super::{bot_token_subprotocols, MessageReader, MessageWriter};
use crate::error::Error;
use async_trait::async_trait;
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio_tungstenite::{
connect_async,
tungstenite::client::IntoClientRequest,
tungstenite::handshake::client::generate_key,
tungstenite::http::header::{
CONNECTION, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE,
USER_AGENT,
},
tungstenite::Message,
MaybeTlsStream, WebSocketStream,
};
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub async fn connect(
url: &str,
user_agent: &str,
token: Option<&str>,
) -> Result<(Box<dyn MessageReader>, Box<dyn MessageWriter>), Error> {
let mut request = url.into_client_request().map_err(Error::from)?;
let headers = request.headers_mut();
if let Ok(value) = user_agent.parse() {
headers.insert(USER_AGENT, value);
}
if let Some(token) = token {
let offer = bot_token_subprotocols(token).join(", ");
if let Ok(value) = offer.parse() {
headers.insert(SEC_WEBSOCKET_PROTOCOL, value);
}
headers
.entry(SEC_WEBSOCKET_VERSION)
.or_insert_with(|| "13".parse().expect("static header"));
headers
.entry(SEC_WEBSOCKET_KEY)
.or_insert_with(|| generate_key().parse().expect("generated key"));
headers
.entry(CONNECTION)
.or_insert_with(|| "Upgrade".parse().expect("static header"));
headers
.entry(UPGRADE)
.or_insert_with(|| "websocket".parse().expect("static header"));
}
let (ws_stream, _) = connect_async(request).await?;
let (write_half, read_half) = ws_stream.split();
let reader: Box<dyn MessageReader> = Box::new(OwnedWsReader { inner: read_half });
let writer: Box<dyn MessageWriter> = Box::new(OwnedWsWriter { inner: write_half });
Ok((reader, writer))
}
struct OwnedWsReader {
inner: SplitStream<WsStream>,
}
#[async_trait]
impl MessageReader for OwnedWsReader {
async fn next(&mut self) -> Result<Option<String>, Error> {
loop {
match self.inner.next().await {
Some(Ok(Message::Text(t))) => return Ok(Some(t.to_string())),
Some(Ok(Message::Ping(_))) => continue,
Some(Ok(Message::Close(_))) | None => return Ok(None),
Some(Ok(_)) => continue,
Some(Err(e)) => return Err(Error::from(e)),
}
}
}
}
struct OwnedWsWriter {
inner: SplitSink<WsStream, Message>,
}
#[async_trait]
impl MessageWriter for OwnedWsWriter {
async fn send(&mut self, payload: String) -> Result<(), Error> {
self.inner
.send(Message::Text(payload))
.await
.map_err(Error::from)
}
}
}