rivetkit-core 2.3.0-rc.12

Core runtime primitives for RivetKit actor hosts
use std::collections::HashMap;
use std::time::Duration;

use super::{
	FrameworkHttpRoute, HttpResponseEncoding, RegistryHttpRoute,
	authorization_bearer_token, authorization_bearer_token_map, framework_action_error_response,
	framework_anyhow_error_response_with_actor, inspector_error_status,
	message_boundary_error_response, message_boundary_error_response_with_actor,
	normalize_actor_request_path, request_encoding, with_action_dispatch_timeout,
	with_framework_action_timeout, workflow_dispatch_result,
};
use crate::actor::action::ActionDispatchError;
use crate::error::ActorLifecycle as ActorLifecycleError;
use http::StatusCode;
use rivet_error::{ActorSpecifier, MacroMarker, RivetError, RivetErrorSchema};
use serde_json::json;
use vbare::OwnedVersionedData;

#[derive(RivetError)]
#[error("message", "incoming_too_long", "Incoming message too long")]
struct IncomingMessageTooLong;

#[derive(RivetError)]
#[error("message", "outgoing_too_long", "Outgoing message too long")]
struct OutgoingMessageTooLong;

#[test]
fn request_prefix_detection_matches_normalization() {
	assert_eq!(normalize_actor_request_path("/request"), "/");
	assert_eq!(normalize_actor_request_path("/request/"), "/");
	assert_eq!(normalize_actor_request_path("/request/users"), "/users");
	assert_eq!(normalize_actor_request_path("/request?foo=bar"), "?foo=bar");
	assert_eq!(normalize_actor_request_path("/requestfoo"), "/requestfoo");

	assert!(matches!(
		RegistryHttpRoute::from_paths("/request", "/").expect("route should decode"),
		RegistryHttpRoute::UserRawRequest
	));
	assert!(matches!(
		RegistryHttpRoute::from_paths("/request/users", "/users").expect("route should decode"),
		RegistryHttpRoute::UserRawRequest
	));
	assert!(matches!(
		RegistryHttpRoute::from_paths("/request?foo=bar", "?foo=bar").expect("route should decode"),
		RegistryHttpRoute::UserRawRequest
	));
	assert!(matches!(
		RegistryHttpRoute::from_paths("/requestfoo", "/requestfoo").expect("route should decode"),
		RegistryHttpRoute::UserRawRequest
	));
}

#[test]
fn classifier_keeps_framework_and_user_routing_separate() {
	let route = RegistryHttpRoute::from_paths("/action/increment", "/action/increment")
		.expect("route should decode");
	assert!(matches!(
		route,
		RegistryHttpRoute::Framework(FrameworkHttpRoute::Action(name)) if name == "increment"
	));

	let route = RegistryHttpRoute::from_paths("/request/action/increment", "/action/increment")
		.expect("route should decode");
	assert!(matches!(route, RegistryHttpRoute::UserRawRequest));

	let route = RegistryHttpRoute::from_paths("/custom", "/custom").expect("route should decode");
	assert!(matches!(route, RegistryHttpRoute::UserRawRequest));
}

#[test]
fn workflow_dispatch_result_marks_handled_workflow_as_enabled() {
	assert_eq!(
		workflow_dispatch_result(Ok(Some(vec![1, 2, 3])))
			.expect("workflow dispatch should succeed"),
		(true, Some(vec![1, 2, 3])),
	);
	assert_eq!(
		workflow_dispatch_result(Ok(None)).expect("workflow dispatch should succeed"),
		(true, None),
	);
}

#[test]
fn workflow_dispatch_result_treats_dropped_reply_as_disabled() {
	assert_eq!(
		workflow_dispatch_result(Err(ActorLifecycleError::DroppedReply.build()))
			.expect("dropped reply should map to workflow disabled"),
		(false, None),
	);
}

#[test]
fn workflow_dispatch_result_preserves_non_dropped_reply_errors() {
	let error = workflow_dispatch_result(Err(ActorLifecycleError::Destroying.build()))
		.expect_err("non-dropped reply errors should be preserved");
	let error = rivet_error::RivetError::extract(&error);
	assert_eq!(error.group(), "actor");
	assert_eq!(error.code(), "destroying");
}

#[test]
fn inspector_error_status_maps_action_timeout_to_408() {
	assert_eq!(
		inspector_error_status("actor", "action_timed_out"),
		StatusCode::REQUEST_TIMEOUT,
	);
}

#[test]
fn authorization_bearer_token_accepts_case_insensitive_scheme_and_whitespace() {
	let mut headers = http::HeaderMap::new();
	headers.insert(
		http::header::AUTHORIZATION,
		"bearer   test-token".parse().unwrap(),
	);

	assert_eq!(authorization_bearer_token(&headers), Some("test-token"));

	let map = HashMap::from([(
		http::header::AUTHORIZATION.as_str().to_owned(),
		"BEARER\ttest-token".to_owned(),
	)]);
	assert_eq!(authorization_bearer_token_map(&map), Some("test-token"));
}

#[tokio::test]
async fn action_dispatch_timeout_returns_structured_error() {
	let error = with_action_dispatch_timeout(Duration::from_millis(1), async {
		tokio::time::sleep(Duration::from_secs(60)).await;
		Ok::<Vec<u8>, ActionDispatchError>(Vec::new())
	})
	.await
	.expect_err("timeout should return an action dispatch error");

	assert_eq!(error.group, "actor");
	assert_eq!(error.code, "action_timed_out");
	assert_eq!(error.message, "Action timed out");
}

#[tokio::test]
async fn framework_action_timeout_returns_structured_error() {
	let error = with_framework_action_timeout(Duration::from_millis(1), async {
		tokio::time::sleep(Duration::from_secs(60)).await;
		Ok::<(), anyhow::Error>(())
	})
	.await
	.expect_err("timeout should return a framework error");
	let error = rivet_error::RivetError::extract(&error);

	assert_eq!(error.group(), "actor");
	assert_eq!(error.code(), "action_timed_out");
	assert_eq!(error.message(), "Action timed out");
}

#[test]
fn framework_action_error_response_maps_timeout_to_408() {
	let response = framework_action_error_response(
		HttpResponseEncoding::Json,
		ActionDispatchError {
			group: "actor".to_owned(),
			code: "action_timed_out".to_owned(),
			message: "Action timed out".to_owned(),
			metadata: None,
			actor: None,
		},
		None,
	)
	.expect("timeout error response should serialize");

	assert_eq!(response.status, StatusCode::REQUEST_TIMEOUT.as_u16());
	assert_eq!(
		response.body,
		Some(
			serde_json::to_vec(&json!({
				"group": "actor",
				"code": "action_timed_out",
				"message": "Action timed out",
			}))
			.expect("json body should encode")
		)
	);
}

#[test]
fn framework_action_error_response_sanitizes_internal_message() {
	let response = framework_action_error_response(
		HttpResponseEncoding::Json,
		ActionDispatchError {
			group: "rivetkit".to_owned(),
			code: "internal_error".to_owned(),
			message: "plain failure".to_owned(),
			metadata: Some(json!({ "error": "plain failure" })),
			actor: None,
		},
		None,
	)
	.expect("internal error response should serialize");

	assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR.as_u16());
	assert_eq!(
		response.body,
		Some(
			serde_json::to_vec(&json!({
				"group": "rivetkit",
				"code": "internal_error",
				"message": "An internal error occurred",
			}))
			.expect("json body should encode")
		)
	);
}

#[test]
fn message_boundary_error_response_defaults_to_json() {
	let response = message_boundary_error_response(
		HttpResponseEncoding::Json,
		StatusCode::BAD_REQUEST,
		IncomingMessageTooLong.build(),
	)
	.expect("json response should serialize");

	assert_eq!(response.status, StatusCode::BAD_REQUEST.as_u16());
	assert_eq!(
		response.headers.get(http::header::CONTENT_TYPE.as_str()),
		Some(&"application/json".to_owned())
	);
	assert_eq!(
		response.body,
		Some(
			serde_json::to_vec(&json!({
				"group": "message",
				"code": "incoming_too_long",
				"message": "Incoming message too long",
			}))
			.expect("json body should encode")
		)
	);
}

#[test]
fn framework_anyhow_error_response_sanitizes_internal_message() {
	static TEST_ERROR: RivetErrorSchema = RivetErrorSchema {
		group: "rivetkit",
		code: "internal_error",
		default_message: "plain failure",
		meta_type: None,
		_macro_marker: MacroMarker { _private: () },
	};
	let response = framework_anyhow_error_response_with_actor(
		HttpResponseEncoding::Json,
		TEST_ERROR.build(),
		None,
	)
	.expect("internal error response should serialize");

	assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR.as_u16());
	assert_eq!(
		response.body,
		Some(
			serde_json::to_vec(&json!({
				"group": "rivetkit",
				"code": "internal_error",
				"message": "An internal error occurred",
			}))
			.expect("json body should encode")
		)
	);
}

#[test]
fn request_encoding_reads_cbor_header() {
	let mut headers = http::HeaderMap::new();
	headers.insert("x-rivet-encoding", "cbor".parse().unwrap());

	assert_eq!(request_encoding(&headers), HttpResponseEncoding::Cbor);
}

#[test]
fn message_boundary_error_response_with_actor_sets_body_and_headers() {
	let actor = ActorSpecifier::new("actor-http", 9).with_key("user/1");
	let response = message_boundary_error_response_with_actor(
		HttpResponseEncoding::Json,
		StatusCode::BAD_REQUEST,
		IncomingMessageTooLong.build(),
		Some(&actor),
	)
	.expect("json response should serialize");

	assert_eq!(
		response.headers.get("x-rivet-actor"),
		Some(&"actor-http".to_owned())
	);
	assert_eq!(
		response.headers.get("x-rivet-actor-generation"),
		Some(&"9".to_owned())
	);
	assert_eq!(
		response.headers.get("x-rivet-actor-key"),
		Some(&"user/1".to_owned())
	);
	assert_eq!(
		response.body,
		Some(
			serde_json::to_vec(&json!({
				"group": "message",
				"code": "incoming_too_long",
				"message": "Incoming message too long",
				"actor": {
					"actorId": "actor-http",
					"generation": 9,
					"key": "user/1",
				},
			}))
			.expect("json body should encode")
		)
	);
}

#[test]
fn message_boundary_error_response_serializes_bare_v3() {
	let response = message_boundary_error_response(
		HttpResponseEncoding::Bare,
		StatusCode::BAD_REQUEST,
		OutgoingMessageTooLong.build(),
	)
	.expect("bare response should serialize");

	assert_eq!(
		response.headers.get(http::header::CONTENT_TYPE.as_str()),
		Some(&"application/octet-stream".to_owned())
	);

	let body = response.body.expect("bare response should include body");
	let decoded =
		<rivetkit_client_protocol::versioned::HttpResponseError as OwnedVersionedData>::deserialize_with_embedded_version(&body)
			.expect("bare error should decode");
	assert_eq!(decoded.group, "message");
	assert_eq!(decoded.code, "outgoing_too_long");
	assert_eq!(decoded.message, "Outgoing message too long");
	assert_eq!(decoded.metadata, None);
}