use rivetkit_client::{Client, ClientConfig, EncodingKind, GetOrCreateOptions, TransportKind};
use fs_extra;
use portpicker;
use serde_json::json;
use std::process::{Child, Command};
use std::time::Duration;
use tempfile;
use tokio::time::sleep;
use tracing::{error, info};
struct MockServer {
child: Child,
_temp_dir: tempfile::TempDir,
}
impl MockServer {
async fn start(port: u16) -> Self {
let current_dir = std::env::current_dir().expect("Failed to get current directory");
let repo_root = current_dir
.ancestors()
.find(|p| p.join("package.json").exists())
.expect("Failed to find repo root");
let status = Command::new("yarn")
.args(["build", "-F", "rivetkit"])
.current_dir(&repo_root)
.status()
.expect("Failed to build rivetkit");
if !status.success() {
panic!("Failed to build rivetkit");
}
let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
let temp_path = temp_dir.path();
println!("Created temp directory at: {}", temp_path.display());
let vendor_dir = temp_path.join("vendor");
std::fs::create_dir_all(&vendor_dir).expect("Failed to create vendor directory");
let packages = [
("rivetkit", repo_root.join("packages/rivetkit")),
("nodejs", repo_root.join("packages/platforms/nodejs")),
("memory", repo_root.join("packages/drivers/memory")),
("file-system", repo_root.join("packages/drivers/file-system")),
];
for (name, path) in packages.iter() {
let output_path = vendor_dir.join(format!("rivetkit-{}.tgz", name));
println!(
"Packing {} from {} to {}",
name,
path.display(),
output_path.display()
);
let status = Command::new("yarn")
.args(["pack", "--out", output_path.to_str().unwrap()])
.current_dir(path)
.status()
.expect(&format!("Failed to pack {}", name));
if !status.success() {
panic!("Failed to pack {}", name);
}
}
let counter_dir = repo_root.join("examples/counter");
let options = fs_extra::dir::CopyOptions::new();
fs_extra::dir::copy(&counter_dir, temp_path, &options)
.expect("Failed to copy counter example");
let server_dir = temp_path.join("counter");
let server_script_path = server_dir.join("run.ts");
let server_script = r#"
import { app } from "./actors/app.ts";
import { serve } from "@rivetkit/nodejs";
serve(app, { port: PORT, mode: "memory" });
"#
.replace("PORT", &port.to_string());
std::fs::write(&server_script_path, server_script).expect("Failed to write server script");
let package_json_path = server_dir.join("package.json");
let package_json = format!(
r#"{{
"name": "rivetkit-rust-test",
"packageManager": "yarn@4.2.2",
"private": true,
"type": "module",
"dependencies": {{
"rivetkit": "file:{}",
"@rivetkit/nodejs": "file:{}",
"@rivetkit/memory": "file:{}",
"@rivetkit/file-system": "file:{}"
}},
"devDependencies": {{
"tsx": "^3.12.7"
}}
}}"#,
vendor_dir.join("rivetkit-rivetkit.tgz").display(),
vendor_dir.join("rivetkit-nodejs.tgz").display(),
vendor_dir.join("rivetkit-memory.tgz").display(),
vendor_dir.join("rivetkit-file-system.tgz").display()
);
std::fs::write(&package_json_path, package_json).expect("Failed to write package.json");
let yarnrc_path = server_dir.join(".yarnrc.yml");
let yarnrc_content = "nodeLinker: node-modules\n";
std::fs::write(&yarnrc_path, yarnrc_content).expect("Failed to write .yarnrc.yml");
let status = Command::new("yarn")
.current_dir(&server_dir)
.status()
.expect("Failed to install dependencies");
if !status.success() {
panic!("Failed to install dependencies");
}
let child = Command::new("npx")
.args(["tsx", "run.ts"])
.current_dir(&server_dir)
.spawn()
.expect("Failed to spawn server process");
Self {
child,
_temp_dir: temp_dir,
}
}
}
impl Drop for MockServer {
fn drop(&mut self) {
if let Err(e) = self.child.kill() {
error!("Failed to kill server: {}", e);
}
info!("Mock server terminated");
}
}
#[tokio::test]
async fn e2e() {
let subscriber = tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.finish();
let _guard = tracing::subscriber::set_default(subscriber);
let port = portpicker::pick_unused_port().expect("Failed to pick an unused port");
info!("Using port {}", port);
let endpoint = format!("http://127.0.0.1:{}", port);
let _server = MockServer::start(port).await;
info!("Waiting for server to start...");
sleep(Duration::from_secs(2)).await;
info!("Creating client to endpoint: {}", endpoint);
let client = Client::new(
ClientConfig::new(endpoint.as_str())
.transport(TransportKind::WebSocket)
.encoding(EncodingKind::Cbor),
);
let counter = client.get_or_create("counter", [].into(), GetOrCreateOptions::default())
.unwrap();
let conn = counter.connect();
conn.on_event("newCount", |x| {
info!("Received newCount event: {:?}", x);
}).await;
let out = counter.action("increment", vec![json!(1)]).await.unwrap();
info!("Action 1: {:?}", out);
let out = conn.action("increment", vec![json!(1)]).await.unwrap();
info!("Action 2: {:?}", out);
client.disconnect();
}