use std::collections::HashMap;
use std::time::Duration;
use super::{
FrameworkHttpRoute, HttpResponseEncoding, RegistryHttpRoute,
authorization_bearer_token, authorization_bearer_token_map, framework_action_error_response,
framework_anyhow_error_response_with_actor, inspector_error_status,
message_boundary_error_response, message_boundary_error_response_with_actor,
normalize_actor_request_path, request_encoding, with_action_dispatch_timeout,
with_framework_action_timeout, workflow_dispatch_result,
};
use crate::actor::action::ActionDispatchError;
use crate::error::ActorLifecycle as ActorLifecycleError;
use http::StatusCode;
use rivet_error::{ActorSpecifier, MacroMarker, RivetError, RivetErrorSchema};
use serde_json::json;
use vbare::OwnedVersionedData;
#[derive(RivetError)]
#[error("message", "incoming_too_long", "Incoming message too long")]
struct IncomingMessageTooLong;
#[derive(RivetError)]
#[error("message", "outgoing_too_long", "Outgoing message too long")]
struct OutgoingMessageTooLong;
#[test]
fn request_prefix_detection_matches_normalization() {
assert_eq!(normalize_actor_request_path("/request"), "/");
assert_eq!(normalize_actor_request_path("/request/"), "/");
assert_eq!(normalize_actor_request_path("/request/users"), "/users");
assert_eq!(normalize_actor_request_path("/request?foo=bar"), "?foo=bar");
assert_eq!(normalize_actor_request_path("/requestfoo"), "/requestfoo");
assert!(matches!(
RegistryHttpRoute::from_paths("/request", "/").expect("route should decode"),
RegistryHttpRoute::UserRawRequest
));
assert!(matches!(
RegistryHttpRoute::from_paths("/request/users", "/users").expect("route should decode"),
RegistryHttpRoute::UserRawRequest
));
assert!(matches!(
RegistryHttpRoute::from_paths("/request?foo=bar", "?foo=bar").expect("route should decode"),
RegistryHttpRoute::UserRawRequest
));
assert!(matches!(
RegistryHttpRoute::from_paths("/requestfoo", "/requestfoo").expect("route should decode"),
RegistryHttpRoute::UserRawRequest
));
}
#[test]
fn classifier_keeps_framework_and_user_routing_separate() {
let route = RegistryHttpRoute::from_paths("/action/increment", "/action/increment")
.expect("route should decode");
assert!(matches!(
route,
RegistryHttpRoute::Framework(FrameworkHttpRoute::Action(name)) if name == "increment"
));
let route = RegistryHttpRoute::from_paths("/request/action/increment", "/action/increment")
.expect("route should decode");
assert!(matches!(route, RegistryHttpRoute::UserRawRequest));
let route = RegistryHttpRoute::from_paths("/custom", "/custom").expect("route should decode");
assert!(matches!(route, RegistryHttpRoute::UserRawRequest));
}
#[test]
fn workflow_dispatch_result_marks_handled_workflow_as_enabled() {
assert_eq!(
workflow_dispatch_result(Ok(Some(vec![1, 2, 3])))
.expect("workflow dispatch should succeed"),
(true, Some(vec![1, 2, 3])),
);
assert_eq!(
workflow_dispatch_result(Ok(None)).expect("workflow dispatch should succeed"),
(true, None),
);
}
#[test]
fn workflow_dispatch_result_treats_dropped_reply_as_disabled() {
assert_eq!(
workflow_dispatch_result(Err(ActorLifecycleError::DroppedReply.build()))
.expect("dropped reply should map to workflow disabled"),
(false, None),
);
}
#[test]
fn workflow_dispatch_result_preserves_non_dropped_reply_errors() {
let error = workflow_dispatch_result(Err(ActorLifecycleError::Destroying.build()))
.expect_err("non-dropped reply errors should be preserved");
let error = rivet_error::RivetError::extract(&error);
assert_eq!(error.group(), "actor");
assert_eq!(error.code(), "destroying");
}
#[test]
fn inspector_error_status_maps_action_timeout_to_408() {
assert_eq!(
inspector_error_status("actor", "action_timed_out"),
StatusCode::REQUEST_TIMEOUT,
);
}
#[test]
fn authorization_bearer_token_accepts_case_insensitive_scheme_and_whitespace() {
let mut headers = http::HeaderMap::new();
headers.insert(
http::header::AUTHORIZATION,
"bearer test-token".parse().unwrap(),
);
assert_eq!(authorization_bearer_token(&headers), Some("test-token"));
let map = HashMap::from([(
http::header::AUTHORIZATION.as_str().to_owned(),
"BEARER\ttest-token".to_owned(),
)]);
assert_eq!(authorization_bearer_token_map(&map), Some("test-token"));
}
#[tokio::test]
async fn action_dispatch_timeout_returns_structured_error() {
let error = with_action_dispatch_timeout(Duration::from_millis(1), async {
tokio::time::sleep(Duration::from_secs(60)).await;
Ok::<Vec<u8>, ActionDispatchError>(Vec::new())
})
.await
.expect_err("timeout should return an action dispatch error");
assert_eq!(error.group, "actor");
assert_eq!(error.code, "action_timed_out");
assert_eq!(error.message, "Action timed out");
}
#[tokio::test]
async fn framework_action_timeout_returns_structured_error() {
let error = with_framework_action_timeout(Duration::from_millis(1), async {
tokio::time::sleep(Duration::from_secs(60)).await;
Ok::<(), anyhow::Error>(())
})
.await
.expect_err("timeout should return a framework error");
let error = rivet_error::RivetError::extract(&error);
assert_eq!(error.group(), "actor");
assert_eq!(error.code(), "action_timed_out");
assert_eq!(error.message(), "Action timed out");
}
#[test]
fn framework_action_error_response_maps_timeout_to_408() {
let response = framework_action_error_response(
HttpResponseEncoding::Json,
ActionDispatchError {
group: "actor".to_owned(),
code: "action_timed_out".to_owned(),
message: "Action timed out".to_owned(),
metadata: None,
actor: None,
},
None,
)
.expect("timeout error response should serialize");
assert_eq!(response.status, StatusCode::REQUEST_TIMEOUT.as_u16());
assert_eq!(
response.body,
Some(
serde_json::to_vec(&json!({
"group": "actor",
"code": "action_timed_out",
"message": "Action timed out",
}))
.expect("json body should encode")
)
);
}
#[test]
fn framework_action_error_response_sanitizes_internal_message() {
let response = framework_action_error_response(
HttpResponseEncoding::Json,
ActionDispatchError {
group: "rivetkit".to_owned(),
code: "internal_error".to_owned(),
message: "plain failure".to_owned(),
metadata: Some(json!({ "error": "plain failure" })),
actor: None,
},
None,
)
.expect("internal error response should serialize");
assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR.as_u16());
assert_eq!(
response.body,
Some(
serde_json::to_vec(&json!({
"group": "rivetkit",
"code": "internal_error",
"message": "An internal error occurred",
}))
.expect("json body should encode")
)
);
}
#[test]
fn message_boundary_error_response_defaults_to_json() {
let response = message_boundary_error_response(
HttpResponseEncoding::Json,
StatusCode::BAD_REQUEST,
IncomingMessageTooLong.build(),
)
.expect("json response should serialize");
assert_eq!(response.status, StatusCode::BAD_REQUEST.as_u16());
assert_eq!(
response.headers.get(http::header::CONTENT_TYPE.as_str()),
Some(&"application/json".to_owned())
);
assert_eq!(
response.body,
Some(
serde_json::to_vec(&json!({
"group": "message",
"code": "incoming_too_long",
"message": "Incoming message too long",
}))
.expect("json body should encode")
)
);
}
#[test]
fn framework_anyhow_error_response_sanitizes_internal_message() {
static TEST_ERROR: RivetErrorSchema = RivetErrorSchema {
group: "rivetkit",
code: "internal_error",
default_message: "plain failure",
meta_type: None,
_macro_marker: MacroMarker { _private: () },
};
let response = framework_anyhow_error_response_with_actor(
HttpResponseEncoding::Json,
TEST_ERROR.build(),
None,
)
.expect("internal error response should serialize");
assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR.as_u16());
assert_eq!(
response.body,
Some(
serde_json::to_vec(&json!({
"group": "rivetkit",
"code": "internal_error",
"message": "An internal error occurred",
}))
.expect("json body should encode")
)
);
}
#[test]
fn request_encoding_reads_cbor_header() {
let mut headers = http::HeaderMap::new();
headers.insert("x-rivet-encoding", "cbor".parse().unwrap());
assert_eq!(request_encoding(&headers), HttpResponseEncoding::Cbor);
}
#[test]
fn message_boundary_error_response_with_actor_sets_body_and_headers() {
let actor = ActorSpecifier::new("actor-http", 9).with_key("user/1");
let response = message_boundary_error_response_with_actor(
HttpResponseEncoding::Json,
StatusCode::BAD_REQUEST,
IncomingMessageTooLong.build(),
Some(&actor),
)
.expect("json response should serialize");
assert_eq!(
response.headers.get("x-rivet-actor"),
Some(&"actor-http".to_owned())
);
assert_eq!(
response.headers.get("x-rivet-actor-generation"),
Some(&"9".to_owned())
);
assert_eq!(
response.headers.get("x-rivet-actor-key"),
Some(&"user/1".to_owned())
);
assert_eq!(
response.body,
Some(
serde_json::to_vec(&json!({
"group": "message",
"code": "incoming_too_long",
"message": "Incoming message too long",
"actor": {
"actorId": "actor-http",
"generation": 9,
"key": "user/1",
},
}))
.expect("json body should encode")
)
);
}
#[test]
fn message_boundary_error_response_serializes_bare_v3() {
let response = message_boundary_error_response(
HttpResponseEncoding::Bare,
StatusCode::BAD_REQUEST,
OutgoingMessageTooLong.build(),
)
.expect("bare response should serialize");
assert_eq!(
response.headers.get(http::header::CONTENT_TYPE.as_str()),
Some(&"application/octet-stream".to_owned())
);
let body = response.body.expect("bare response should include body");
let decoded =
<rivetkit_client_protocol::versioned::HttpResponseError as OwnedVersionedData>::deserialize_with_embedded_version(&body)
.expect("bare error should decode");
assert_eq!(decoded.group, "message");
assert_eq!(decoded.code, "outgoing_too_long");
assert_eq!(decoded.message, "Outgoing message too long");
assert_eq!(decoded.metadata, None);
}