rivetkit-core 2.3.0-rc.12

Core runtime primitives for RivetKit actor hosts
use std::net::IpAddr;

use anyhow::{Context, Result};
use reqwest::{Client, Url};
use serde::Deserialize;
use serde_json::{Map as JsonMap, json};

use super::ServeConfig;

#[derive(Debug, Deserialize)]
struct DatacentersResponse {
	datacenters: Vec<Datacenter>,
}

#[derive(Debug, Deserialize)]
struct Datacenter {
	name: String,
}

pub(super) async fn ensure_local_normal_runner_config(config: &ServeConfig) -> Result<()> {
	if !is_local_engine_endpoint(&config.endpoint) {
		return Ok(());
	}

	let client = Client::builder()
		.build()
		.context("build reqwest client for runner config")?;
	let datacenters = get_datacenters(&client, config).await?;
	let mut runner_datacenters = JsonMap::new();

	for datacenter in datacenters.datacenters {
		runner_datacenters.insert(
			datacenter.name,
			json!({
				"normal": {},
				"drain_on_version_upgrade": true,
			}),
		);
	}

	let url = engine_api_url(
		&config.endpoint,
		&["runner-configs", config.pool_name.as_str()],
		&config.namespace,
	)?;
	let body = json!({
		"datacenters": runner_datacenters,
	});

	let response = apply_auth(client.put(url), config)
		.json(&body)
		.send()
		.await
		.context("upsert local runner config")?;
	let status = response.status();
	if !status.is_success() {
		let response_body = response
			.text()
			.await
			.context("read failed runner config response body")?;
		anyhow::bail!(
			"failed to upsert local runner config `{}`: {} {}",
			config.pool_name,
			status,
			response_body
		);
	}

	tracing::debug!(
		namespace = %config.namespace,
		pool_name = %config.pool_name,
		"ensured local normal runner config"
	);

	Ok(())
}

async fn get_datacenters(client: &Client, config: &ServeConfig) -> Result<DatacentersResponse> {
	let url = engine_api_url(&config.endpoint, &["datacenters"], &config.namespace)?;
	let response = apply_auth(client.get(url), config)
		.send()
		.await
		.context("get local datacenters")?;
	let status = response.status();
	if !status.is_success() {
		let response_body = response
			.text()
			.await
			.context("read failed datacenters response body")?;
		anyhow::bail!(
			"failed to get local datacenters for runner config: {} {}",
			status,
			response_body
		);
	}

	response
		.json::<DatacentersResponse>()
		.await
		.context("decode datacenters response")
}

fn apply_auth(request: reqwest::RequestBuilder, config: &ServeConfig) -> reqwest::RequestBuilder {
	match config.token.as_deref() {
		Some(token) => request.bearer_auth(token),
		None => request,
	}
}

fn engine_api_url(endpoint: &str, path: &[&str], namespace: &str) -> Result<Url> {
	let mut url =
		Url::parse(endpoint).with_context(|| format!("parse engine endpoint `{endpoint}`"))?;
	url.set_path("");
	url.path_segments_mut()
		.map_err(|_| anyhow::anyhow!("engine endpoint cannot be a base URL: {endpoint}"))?
		.extend(path);
	url.query_pairs_mut()
		.clear()
		.append_pair("namespace", namespace);
	Ok(url)
}

fn is_local_engine_endpoint(endpoint: &str) -> bool {
	let Ok(url) = Url::parse(endpoint) else {
		return false;
	};
	let Some(host) = url.host_str() else {
		return false;
	};

	if host == "localhost" || host.ends_with(".localhost") {
		return true;
	}

	host.parse::<IpAddr>()
		.map(|ip| ip.is_loopback() || ip.is_unspecified())
		.unwrap_or(false)
}