use anyhow::Result;
use axum::{
extract::{Path, Query, State},
http::{Method, StatusCode},
response::{IntoResponse, Json},
routing::{delete, get, post, put},
Router,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tower::ServiceBuilder;
use tower_http::{
cors::{Any, CorsLayer},
trace::TraceLayer,
};
use tracing::{error, info};
use uuid::Uuid;
use crate::engine::{Mission, MissionResult, MissionStatus};
use crate::runtime::RustChainRuntime;
use crate::safety::{SafetyValidator, ValidationMode};
pub mod agent_api;
#[cfg(feature = "llm")]
use crate::llm::{ChatMessage, LLMRequest};
#[cfg(feature = "rag")]
#[cfg(feature = "sandbox")]
use crate::sandbox::create_default_sandbox;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub cors_enabled: bool,
pub max_request_size: usize,
pub rate_limit_per_minute: u32,
pub auth_enabled: bool,
pub auth_token: Option<String>,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 8080,
cors_enabled: true,
max_request_size: 10 * 1024 * 1024, rate_limit_per_minute: 100,
auth_enabled: false,
auth_token: None,
}
}
}
#[derive(Clone)]
pub struct AppState {
pub runtime: Arc<RustChainRuntime>,
pub config: Arc<ServerConfig>,
pub active_missions: Arc<RwLock<HashMap<String, MissionHandle>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MissionHandle {
pub id: String,
pub mission: Mission,
pub status: MissionStatus,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct ApiResponse<T> {
pub success: bool,
pub data: Option<T>,
pub error: Option<String>,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
impl<T> ApiResponse<T> {
pub fn success(data: T) -> Self {
Self {
success: true,
data: Some(data),
error: None,
timestamp: chrono::Utc::now(),
}
}
pub fn error(error: String) -> Self {
Self {
success: false,
data: None,
error: Some(error),
timestamp: chrono::Utc::now(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CreateMissionRequest {
pub name: String,
pub description: Option<String>,
pub mission_yaml: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ExecuteMissionRequest {
pub mission_id: String,
pub dry_run: Option<bool>,
pub skip_safety: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SafetyCheckRequest {
pub mission_yaml: String,
}
#[cfg(feature = "llm")]
#[derive(Debug, Serialize, Deserialize)]
pub struct LLMChatRequest {
pub messages: Vec<ChatMessage>,
pub model: Option<String>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RagSearchRequest {
pub query: String,
pub limit: Option<usize>,
pub similarity_threshold: Option<f32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RagDocumentRequest {
pub id: String,
pub content: String,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SandboxExecuteRequest {
pub command: String,
pub args: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SandboxFileRequest {
pub path: String,
pub content: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct PaginationQuery {
pub offset: Option<usize>,
pub limit: Option<usize>,
}
pub fn create_router(state: AppState) -> Router {
Router::new()
.route("/health", get(health_check))
.route("/info", get(server_info))
.merge(agent_api::create_agent_router())
.route("/missions", get(list_missions))
.route("/missions", post(create_mission))
.route("/missions/:id", get(get_mission))
.route("/missions/:id", put(update_mission))
.route("/missions/:id", delete(delete_mission))
.route("/missions/:id/execute", post(execute_mission))
.route("/missions/:id/status", get(get_mission_status))
.route("/safety/check", post(safety_check))
.route("/safety/validate/:mission_id", post(validate_mission))
.route("/llm/chat", post(llm_chat))
.route("/llm/models", get(llm_models))
.route("/rag/search", post(rag_search))
.route("/rag/documents", post(rag_add_document))
.route("/rag/documents", get(rag_list_documents))
.route("/rag/documents/:id", delete(rag_delete_document))
.route("/sandbox/sessions", post(sandbox_create_session))
.route("/sandbox/sessions", get(sandbox_list_sessions))
.route("/sandbox/sessions/:id/execute", post(sandbox_execute))
.route("/sandbox/sessions/:id/files", get(sandbox_list_files))
.route("/sandbox/sessions/:id/files/*path", get(sandbox_read_file))
.route("/sandbox/sessions/:id/files/*path", put(sandbox_write_file))
.route("/sandbox/sessions/:id", delete(sandbox_cleanup_session))
.with_state(state)
.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(
CorsLayer::new()
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
.allow_headers(Any)
.allow_origin(Any),
),
)
}
async fn health_check() -> impl IntoResponse {
Json(ApiResponse::success("Server is healthy"))
}
async fn server_info(State(_state): State<AppState>) -> impl IntoResponse {
let info = serde_json::json!({
"name": "RustChain API Server",
"version": env!("CARGO_PKG_VERSION"),
"status": "operational",
"uptime": chrono::Utc::now(),
});
Json(ApiResponse::success(info))
}
async fn list_missions(
State(state): State<AppState>,
Query(pagination): Query<PaginationQuery>,
) -> impl IntoResponse {
let missions = state.active_missions.read().await;
let mission_list: Vec<MissionHandle> = missions.values().cloned().collect();
let offset = pagination.offset.unwrap_or(0);
let limit = pagination.limit.unwrap_or(10);
let end = std::cmp::min(offset + limit, mission_list.len());
if offset >= mission_list.len() {
return Json(ApiResponse::success(Vec::<MissionHandle>::new()));
}
let paginated = mission_list[offset..end].to_vec();
Json(ApiResponse::success(paginated))
}
async fn create_mission(
State(state): State<AppState>,
Json(request): Json<CreateMissionRequest>,
) -> impl IntoResponse {
match serde_yaml::from_str::<Mission>(&request.mission_yaml) {
Ok(mission) => {
let mission_id = Uuid::new_v4().to_string();
let handle = MissionHandle {
id: mission_id.clone(),
mission,
status: MissionStatus::Running,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
};
state
.active_missions
.write()
.await
.insert(mission_id.clone(), handle.clone());
info!("Mission created: {}", mission_id);
Json(ApiResponse::success(handle))
}
Err(e) => {
error!("Failed to parse mission YAML: {}", e);
Json(ApiResponse::<MissionHandle>::error(format!(
"Invalid mission YAML: {}",
e
)))
}
}
}
async fn get_mission(
State(state): State<AppState>,
Path(mission_id): Path<String>,
) -> impl IntoResponse {
let missions = state.active_missions.read().await;
match missions.get(&mission_id) {
Some(handle) => (StatusCode::OK, Json(ApiResponse::success(handle.clone()))),
None => {
let response: ApiResponse<MissionHandle> =
ApiResponse::error("Mission not found".to_string());
(StatusCode::NOT_FOUND, Json(response))
}
}
}
async fn update_mission(
State(state): State<AppState>,
Path(mission_id): Path<String>,
Json(request): Json<CreateMissionRequest>,
) -> impl IntoResponse {
match serde_yaml::from_str::<Mission>(&request.mission_yaml) {
Ok(mission) => {
let mut missions = state.active_missions.write().await;
if let Some(handle) = missions.get_mut(&mission_id) {
handle.mission = mission;
handle.updated_at = chrono::Utc::now();
info!("Mission updated: {}", mission_id);
(StatusCode::OK, Json(ApiResponse::success(handle.clone())))
} else {
let response: ApiResponse<MissionHandle> =
ApiResponse::error("Mission not found".to_string());
(StatusCode::NOT_FOUND, Json(response))
}
}
Err(e) => {
error!("Failed to parse mission YAML: {}", e);
(
StatusCode::BAD_REQUEST,
Json(ApiResponse::<MissionHandle>::error(format!(
"Invalid mission YAML: {}",
e
))),
)
}
}
}
async fn delete_mission(
State(state): State<AppState>,
Path(mission_id): Path<String>,
) -> impl IntoResponse {
let mut missions = state.active_missions.write().await;
match missions.remove(&mission_id) {
Some(_) => {
info!("Mission deleted: {}", mission_id);
(
StatusCode::OK,
Json(ApiResponse::success("Mission deleted")),
)
}
None => {
let response: ApiResponse<&str> = ApiResponse::error("Mission not found".to_string());
(StatusCode::NOT_FOUND, Json(response))
}
}
}
async fn execute_mission(
State(state): State<AppState>,
Path(mission_id): Path<String>,
Json(request): Json<ExecuteMissionRequest>,
) -> impl IntoResponse {
let missions = state.active_missions.read().await;
match missions.get(&mission_id) {
Some(handle) => {
let mission = handle.mission.clone();
drop(missions);
if !request.skip_safety.unwrap_or(false) {
let validator = SafetyValidator::new();
match validator.validate_mission(&mission, ValidationMode::Standard) {
Ok(safety_report) => {
if !safety_report.is_safe {
let error_msg = format!(
"Safety validation failed with {} issues",
safety_report.issues.len()
);
return (
StatusCode::BAD_REQUEST,
Json(ApiResponse::<MissionResult>::error(error_msg)),
);
}
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiResponse::<MissionResult>::error(format!(
"Safety validation error: {}",
e
))),
);
}
}
}
if request.dry_run.unwrap_or(false) {
info!("Dry run for mission: {}", mission_id);
let dry_run_result = MissionResult {
mission_id: Uuid::new_v4(),
status: MissionStatus::Completed,
step_results: std::collections::HashMap::new(),
total_duration_ms: 0,
};
(StatusCode::OK, Json(ApiResponse::success(dry_run_result)))
} else {
match state.runtime.execute_mission(mission).await {
Ok(result) => {
info!("Mission executed: {}", mission_id);
(StatusCode::OK, Json(ApiResponse::success(result)))
}
Err(e) => {
error!("Mission execution failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiResponse::error(format!("Execution failed: {}", e))),
)
}
}
}
}
None => {
let response: ApiResponse<MissionResult> =
ApiResponse::error("Mission not found".to_string());
(StatusCode::NOT_FOUND, Json(response))
}
}
}
async fn get_mission_status(
State(state): State<AppState>,
Path(mission_id): Path<String>,
) -> impl IntoResponse {
let missions = state.active_missions.read().await;
match missions.get(&mission_id) {
Some(handle) => {
let status_info = serde_json::json!({
"mission_id": mission_id,
"status": handle.status,
"created_at": handle.created_at,
"updated_at": handle.updated_at,
});
(StatusCode::OK, Json(ApiResponse::success(status_info)))
}
None => {
let response: ApiResponse<serde_json::Value> =
ApiResponse::error("Mission not found".to_string());
(StatusCode::NOT_FOUND, Json(response))
}
}
}
async fn safety_check(
State(_state): State<AppState>,
Json(request): Json<SafetyCheckRequest>,
) -> impl IntoResponse {
match serde_yaml::from_str::<Mission>(&request.mission_yaml) {
Ok(mission) => {
let validator = SafetyValidator::new();
match validator.validate_mission(&mission, ValidationMode::Standard) {
Ok(report) => Json(ApiResponse::success(report)),
Err(e) => Json(ApiResponse::error(format!(
"Safety validation error: {}",
e
))),
}
}
Err(e) => Json(ApiResponse::error(format!("Invalid mission YAML: {}", e))),
}
}
async fn validate_mission(
State(state): State<AppState>,
Path(mission_id): Path<String>,
) -> impl IntoResponse {
let missions = state.active_missions.read().await;
match missions.get(&mission_id) {
Some(handle) => {
let validator = SafetyValidator::new();
match validator.validate_mission(&handle.mission, ValidationMode::Standard) {
Ok(report) => (StatusCode::OK, Json(ApiResponse::success(report))),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiResponse::<crate::safety::ValidationResult>::error(
format!("Safety validation error: {}", e),
)),
),
}
}
None => {
let response: ApiResponse<crate::safety::ValidationResult> =
ApiResponse::error("Mission not found".to_string());
(StatusCode::NOT_FOUND, Json(response))
}
}
}
#[cfg(feature = "llm")]
async fn llm_chat(
State(_state): State<AppState>,
Json(request): Json<LLMChatRequest>,
) -> impl IntoResponse {
use crate::llm::create_default_llm_manager;
match create_default_llm_manager() {
Ok(llm_manager) => {
let llm_request = LLMRequest {
messages: request.messages,
model: request.model,
temperature: request.temperature,
max_tokens: request.max_tokens,
tools: None,
stream: false,
metadata: std::collections::HashMap::new(),
};
match llm_manager.complete(llm_request, Some("openai")).await {
Ok(response) => Json(ApiResponse::success(response)),
Err(e) => Json(ApiResponse::error(format!("LLM chat error: {}", e))),
}
}
Err(e) => Json(ApiResponse::error(format!(
"Failed to create LLM manager: {}",
e
))),
}
}
#[cfg(not(feature = "llm"))]
async fn llm_chat() -> impl IntoResponse {
Json(ApiResponse::<()>::error(
"LLM feature not enabled".to_string(),
))
}
#[cfg(feature = "llm")]
async fn llm_models() -> impl IntoResponse {
use crate::llm::create_default_llm_manager;
match create_default_llm_manager() {
Ok(llm_manager) => match llm_manager.list_all_models().await {
Ok(models) => Json(ApiResponse::success(models)),
Err(e) => Json(ApiResponse::error(format!("Failed to list models: {}", e))),
},
Err(e) => Json(ApiResponse::error(format!(
"Failed to create LLM manager: {}",
e
))),
}
}
#[cfg(not(feature = "llm"))]
async fn llm_models() -> impl IntoResponse {
Json(ApiResponse::<()>::error(
"LLM feature not enabled".to_string(),
))
}
#[cfg(feature = "rag")]
async fn rag_search(Json(request): Json<RagSearchRequest>) -> impl IntoResponse {
use crate::rag::create_default_rag_system;
match create_default_rag_system() {
Ok(rag_system) => {
match rag_system
.search(&request.query, request.limit, request.similarity_threshold)
.await
{
Ok(response) => Json(ApiResponse::success(response)),
Err(e) => Json(ApiResponse::error(format!("RAG search error: {}", e))),
}
}
Err(e) => Json(ApiResponse::error(format!(
"Failed to create RAG system: {}",
e
))),
}
}
#[cfg(not(feature = "rag"))]
async fn rag_search() -> impl IntoResponse {
Json(ApiResponse::<()>::error(
"RAG feature not enabled".to_string(),
))
}
#[cfg(feature = "rag")]
async fn rag_add_document(Json(request): Json<RagDocumentRequest>) -> impl IntoResponse {
use crate::rag::create_default_rag_system;
match create_default_rag_system() {
Ok(mut rag_system) => {
let metadata = request.metadata.unwrap_or_default();
match rag_system
.add_document(request.id.clone(), request.content, metadata)
.await
{
Ok(document_id) => Json(ApiResponse::success(serde_json::json!({
"document_id": document_id
}))),
Err(e) => Json(ApiResponse::error(format!("Failed to add document: {}", e))),
}
}
Err(e) => Json(ApiResponse::error(format!(
"Failed to create RAG system: {}",
e
))),
}
}
#[cfg(not(feature = "rag"))]
async fn rag_add_document() -> impl IntoResponse {
Json(ApiResponse::<()>::error(
"RAG feature not enabled".to_string(),
))
}
#[cfg(feature = "rag")]
async fn rag_list_documents(Query(pagination): Query<PaginationQuery>) -> impl IntoResponse {
use crate::rag::create_default_rag_system;
match create_default_rag_system() {
Ok(rag_system) => {
let offset = pagination.offset.unwrap_or(0);
let limit = pagination.limit.unwrap_or(10);
match rag_system.list_documents(offset, limit).await {
Ok(documents) => Json(ApiResponse::success(documents)),
Err(e) => Json(ApiResponse::error(format!(
"Failed to list documents: {}",
e
))),
}
}
Err(e) => Json(ApiResponse::error(format!(
"Failed to create RAG system: {}",
e
))),
}
}
#[cfg(not(feature = "rag"))]
async fn rag_list_documents() -> impl IntoResponse {
Json(ApiResponse::<()>::error(
"RAG feature not enabled".to_string(),
))
}
#[cfg(feature = "rag")]
async fn rag_delete_document(Path(document_id): Path<String>) -> impl IntoResponse {
use crate::rag::create_default_rag_system;
match create_default_rag_system() {
Ok(mut rag_system) => match rag_system.delete_document(&document_id).await {
Ok(_) => Json(ApiResponse::success("Document deleted")),
Err(e) => Json(ApiResponse::error(format!(
"Failed to delete document: {}",
e
))),
},
Err(e) => Json(ApiResponse::error(format!(
"Failed to create RAG system: {}",
e
))),
}
}
#[cfg(not(feature = "rag"))]
async fn rag_delete_document() -> impl IntoResponse {
Json(ApiResponse::<()>::error(
"RAG feature not enabled".to_string(),
))
}
#[cfg(feature = "sandbox")]
async fn sandbox_create_session() -> impl IntoResponse {
let sandbox = create_default_sandbox();
match sandbox.create_sandbox(Default::default()).await {
Ok(session_id) => Json(ApiResponse::success(serde_json::json!({
"session_id": session_id
}))),
Err(e) => Json(ApiResponse::error(format!(
"Failed to create sandbox session: {}",
e
))),
}
}
#[cfg(not(feature = "sandbox"))]
async fn sandbox_create_session() -> impl IntoResponse {
Json(ApiResponse::<()>::error(
"Sandbox feature not enabled".to_string(),
))
}
#[cfg(feature = "sandbox")]
async fn sandbox_list_sessions() -> impl IntoResponse {
let sandbox = create_default_sandbox();
match sandbox.list_sandboxes().await {
Ok(sessions) => Json(ApiResponse::success(sessions)),
Err(e) => Json(ApiResponse::error(format!(
"Failed to list sandbox sessions: {}",
e
))),
}
}
#[cfg(not(feature = "sandbox"))]
async fn sandbox_list_sessions() -> impl IntoResponse {
Json(ApiResponse::<()>::error(
"Sandbox feature not enabled".to_string(),
))
}
#[cfg(feature = "sandbox")]
async fn sandbox_execute(
Path(session_id): Path<String>,
Json(request): Json<SandboxExecuteRequest>,
) -> impl IntoResponse {
let sandbox = create_default_sandbox();
match sandbox
.execute_in_sandbox(&session_id, &request.command, request.args)
.await
{
Ok(result) => Json(ApiResponse::success(result)),
Err(e) => Json(ApiResponse::error(format!(
"Failed to execute command: {}",
e
))),
}
}
#[cfg(not(feature = "sandbox"))]
async fn sandbox_execute() -> impl IntoResponse {
Json(ApiResponse::<()>::error(
"Sandbox feature not enabled".to_string(),
))
}
#[cfg(feature = "sandbox")]
async fn sandbox_list_files(Path(session_id): Path<String>) -> impl IntoResponse {
let sandbox = create_default_sandbox();
match sandbox.list_files(&session_id).await {
Ok(files) => Json(ApiResponse::success(files)),
Err(e) => Json(ApiResponse::error(format!("Failed to list files: {}", e))),
}
}
#[cfg(not(feature = "sandbox"))]
async fn sandbox_list_files() -> impl IntoResponse {
Json(ApiResponse::<()>::error(
"Sandbox feature not enabled".to_string(),
))
}
#[cfg(feature = "sandbox")]
async fn sandbox_read_file(
Path((session_id, file_path)): Path<(String, String)>,
) -> impl IntoResponse {
let sandbox = create_default_sandbox();
match sandbox.read_file(&session_id, &file_path).await {
Ok(content) => Json(ApiResponse::success(serde_json::json!({
"content": content
}))),
Err(e) => Json(ApiResponse::error(format!("Failed to read file: {}", e))),
}
}
#[cfg(not(feature = "sandbox"))]
async fn sandbox_read_file() -> impl IntoResponse {
Json(ApiResponse::<()>::error(
"Sandbox feature not enabled".to_string(),
))
}
#[cfg(feature = "sandbox")]
async fn sandbox_write_file(
Path((session_id, file_path)): Path<(String, String)>,
Json(request): Json<SandboxFileRequest>,
) -> impl IntoResponse {
let sandbox = create_default_sandbox();
let content = request.content.unwrap_or_default();
match sandbox
.write_file(&session_id, &file_path, content.as_bytes())
.await
{
Ok(_) => Json(ApiResponse::success("File written")),
Err(e) => Json(ApiResponse::error(format!("Failed to write file: {}", e))),
}
}
#[cfg(not(feature = "sandbox"))]
async fn sandbox_write_file() -> impl IntoResponse {
Json(ApiResponse::<()>::error(
"Sandbox feature not enabled".to_string(),
))
}
#[cfg(feature = "sandbox")]
async fn sandbox_cleanup_session(Path(session_id): Path<String>) -> impl IntoResponse {
let sandbox = create_default_sandbox();
match sandbox.cleanup_sandbox(&session_id).await {
Ok(_) => Json(ApiResponse::success("Session cleaned up")),
Err(e) => Json(ApiResponse::error(format!(
"Failed to cleanup session: {}",
e
))),
}
}
#[cfg(not(feature = "sandbox"))]
async fn sandbox_cleanup_session() -> impl IntoResponse {
Json(ApiResponse::<()>::error(
"Sandbox feature not enabled".to_string(),
))
}
pub fn create_app_state(config: ServerConfig) -> AppState {
AppState {
runtime: Arc::new(RustChainRuntime::new()),
config: Arc::new(config),
active_missions: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn start_server(config: ServerConfig) -> Result<()> {
let app_state = create_app_state(config.clone());
let app = create_router(app_state);
let listener =
tokio::net::TcpListener::bind(format!("{}:{}", config.host, config.port)).await?;
info!(
"RustChain API server starting on {}:{}",
config.host, config.port
);
axum::serve(listener, app).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::StatusCode;
use axum_test::TestServer;
use crate::engine::{Mission, MissionConfig, MissionStep, StepType};
fn create_test_server() -> TestServer {
let config = ServerConfig::default();
let app_state = create_app_state(config);
let app = create_router(app_state);
TestServer::new(app).unwrap()
}
fn create_test_mission() -> Mission {
Mission {
version: "1.0".to_string(),
name: "Test Mission".to_string(),
description: Some("A test mission".to_string()),
steps: vec![
MissionStep {
id: "step_1".to_string(),
name: "Test Step 1".to_string(),
step_type: StepType::Noop,
depends_on: None,
timeout_seconds: Some(30),
continue_on_error: None,
parameters: serde_json::json!({}),
}
],
config: Some(MissionConfig {
max_parallel_steps: Some(1),
timeout_seconds: Some(30),
fail_fast: Some(true),
}),
}
}
fn create_test_mission_yaml() -> String {
serde_yaml::to_string(&create_test_mission()).unwrap()
}
#[test]
fn test_server_config_default() {
let config = ServerConfig::default();
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 8080);
assert!(config.cors_enabled);
assert_eq!(config.max_request_size, 10 * 1024 * 1024);
assert_eq!(config.rate_limit_per_minute, 100);
assert!(!config.auth_enabled);
assert!(config.auth_token.is_none());
}
#[test]
fn test_server_config_serialization() {
let config = ServerConfig::default();
let serialized = serde_json::to_string(&config).unwrap();
let deserialized: ServerConfig = serde_json::from_str(&serialized).unwrap();
assert_eq!(config.host, deserialized.host);
assert_eq!(config.port, deserialized.port);
assert_eq!(config.cors_enabled, deserialized.cors_enabled);
}
#[test]
fn test_api_response_success() {
let response = ApiResponse::success("test data");
assert!(response.success);
assert_eq!(response.data, Some("test data"));
assert!(response.error.is_none());
}
#[test]
fn test_api_response_error() {
let response: ApiResponse<String> = ApiResponse::error("test error".to_string());
assert!(!response.success);
assert!(response.data.is_none());
assert_eq!(response.error, Some("test error".to_string()));
}
#[test]
fn test_mission_handle_creation() {
let mission = create_test_mission();
let handle = MissionHandle {
id: "test-mission-id".to_string(),
mission: mission.clone(),
status: MissionStatus::Running,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
};
assert_eq!(handle.id, "test-mission-id");
assert_eq!(handle.mission.name, mission.name);
assert!(matches!(handle.status, MissionStatus::Running));
}
#[test]
fn test_create_app_state() {
let config = ServerConfig::default();
let app_state = create_app_state(config.clone());
assert_eq!(app_state.config.host, config.host);
assert_eq!(app_state.config.port, config.port);
assert!(!app_state.active_missions.try_read().unwrap().len() > 0);
}
#[tokio::test]
async fn test_health_check_endpoint() {
let server = create_test_server();
let response = server.get("/health").await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<String>>();
assert!(json.success);
assert_eq!(json.data, Some("Server is healthy".to_string()));
assert!(json.error.is_none());
}
#[tokio::test]
async fn test_server_info_endpoint() {
let server = create_test_server();
let response = server.get("/info").await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<serde_json::Value>>();
assert!(json.success);
assert!(json.data.is_some());
let info = json.data.unwrap();
assert_eq!(info["name"], "RustChain API Server");
assert_eq!(info["version"], env!("CARGO_PKG_VERSION"));
assert_eq!(info["status"], "operational");
}
#[tokio::test]
async fn test_list_missions_empty() {
let server = create_test_server();
let response = server.get("/missions").await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<Vec<MissionHandle>>>();
assert!(json.success);
assert!(json.data.is_some());
assert_eq!(json.data.unwrap().len(), 0);
}
#[tokio::test]
async fn test_list_missions_pagination() {
let server = create_test_server();
let response = server.get("/missions?offset=0&limit=5").await;
if response.status_code() == StatusCode::OK {
let json = response.json::<ApiResponse<Vec<MissionHandle>>>();
assert!(json.success);
assert!(json.data.is_some());
} else {
assert!(response.status_code() == StatusCode::NOT_FOUND);
}
}
#[tokio::test]
async fn test_create_mission_success() {
let server = create_test_server();
let request = CreateMissionRequest {
name: "Test Mission".to_string(),
description: Some("Test description".to_string()),
mission_yaml: create_test_mission_yaml(),
};
let response = server.post("/missions").json(&request).await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<MissionHandle>>();
assert!(json.success);
assert!(json.data.is_some());
let handle = json.data.unwrap();
assert_eq!(handle.mission.name, "Test Mission");
assert!(matches!(handle.status, MissionStatus::Running));
}
#[tokio::test]
async fn test_create_mission_invalid_yaml() {
let server = create_test_server();
let request = CreateMissionRequest {
name: "Invalid Mission".to_string(),
description: None,
mission_yaml: "invalid yaml content".to_string(),
};
let response = server.post("/missions").json(&request).await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<MissionHandle>>();
assert!(!json.success);
assert!(json.error.is_some());
assert!(json.error.unwrap().contains("Invalid mission YAML"));
}
#[tokio::test]
async fn test_get_mission_not_found() {
let server = create_test_server();
let response = server.get("/missions/nonexistent-id").await;
response.assert_status(StatusCode::NOT_FOUND);
let json = response.json::<ApiResponse<MissionHandle>>();
assert!(!json.success);
assert!(json.error.is_some());
assert_eq!(json.error.unwrap(), "Mission not found");
}
#[tokio::test]
async fn test_mission_crud_workflow() {
let server = create_test_server();
let create_request = CreateMissionRequest {
name: "CRUD Test Mission".to_string(),
description: Some("Testing CRUD operations".to_string()),
mission_yaml: create_test_mission_yaml(),
};
let create_response = server.post("/missions").json(&create_request).await;
create_response.assert_status(StatusCode::OK);
let create_json = create_response.json::<ApiResponse<MissionHandle>>();
assert!(create_json.success);
let mission_id = create_json.data.unwrap().id;
let get_response = server.get(&format!("/missions/{}", mission_id)).await;
get_response.assert_status(StatusCode::OK);
let get_json = get_response.json::<ApiResponse<MissionHandle>>();
assert!(get_json.success);
assert_eq!(get_json.data.unwrap().mission.name, "Test Mission");
let update_request = CreateMissionRequest {
name: "Updated CRUD Test Mission".to_string(),
description: Some("Updated description".to_string()),
mission_yaml: create_test_mission_yaml(),
};
let update_response = server.put(&format!("/missions/{}", mission_id)).json(&update_request).await;
update_response.assert_status(StatusCode::OK);
let update_json = update_response.json::<ApiResponse<MissionHandle>>();
assert!(update_json.success);
assert_eq!(update_json.data.unwrap().mission.name, "Test Mission");
let status_response = server.get(&format!("/missions/{}/status", mission_id)).await;
status_response.assert_status(StatusCode::OK);
let status_json = status_response.json::<ApiResponse<serde_json::Value>>();
assert!(status_json.success);
let status_data = status_json.data.unwrap();
assert_eq!(status_data["mission_id"], mission_id);
let delete_response = server.delete(&format!("/missions/{}", mission_id)).await;
delete_response.assert_status(StatusCode::OK);
let delete_json = delete_response.json::<ApiResponse<String>>();
assert!(delete_json.success);
let get_deleted_response = server.get(&format!("/missions/{}", mission_id)).await;
get_deleted_response.assert_status(StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn test_update_mission_not_found() {
let server = create_test_server();
let update_request = CreateMissionRequest {
name: "Nonexistent Mission".to_string(),
description: None,
mission_yaml: create_test_mission_yaml(),
};
let response = server.put("/missions/nonexistent-id").json(&update_request).await;
response.assert_status(StatusCode::NOT_FOUND);
let json = response.json::<ApiResponse<MissionHandle>>();
assert!(!json.success);
assert_eq!(json.error.unwrap(), "Mission not found");
}
#[tokio::test]
async fn test_update_mission_invalid_yaml() {
let server = create_test_server();
let create_request = CreateMissionRequest {
name: "Test Mission".to_string(),
description: None,
mission_yaml: create_test_mission_yaml(),
};
let create_response = server.post("/missions").json(&create_request).await;
let mission_id = create_response.json::<ApiResponse<MissionHandle>>().data.unwrap().id;
let update_request = CreateMissionRequest {
name: "Updated Mission".to_string(),
description: None,
mission_yaml: "invalid yaml".to_string(),
};
let response = server.put(&format!("/missions/{}", mission_id)).json(&update_request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let json = response.json::<ApiResponse<MissionHandle>>();
assert!(!json.success);
assert!(json.error.unwrap().contains("Invalid mission YAML"));
}
#[tokio::test]
async fn test_delete_mission_not_found() {
let server = create_test_server();
let response = server.delete("/missions/nonexistent-id").await;
response.assert_status(StatusCode::NOT_FOUND);
let json = response.json::<ApiResponse<String>>();
assert!(!json.success);
assert_eq!(json.error.unwrap(), "Mission not found");
}
#[tokio::test]
async fn test_execute_mission_dry_run() {
let server = create_test_server();
let create_request = CreateMissionRequest {
name: "Dry Run Test".to_string(),
description: None,
mission_yaml: create_test_mission_yaml(),
};
let create_response = server.post("/missions").json(&create_request).await;
let mission_id = create_response.json::<ApiResponse<MissionHandle>>().data.unwrap().id;
let execute_request = ExecuteMissionRequest {
mission_id: mission_id.clone(),
dry_run: Some(true),
skip_safety: Some(false),
};
let response = server.post(&format!("/missions/{}/execute", mission_id)).json(&execute_request).await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<MissionResult>>();
assert!(json.success);
let result = json.data.unwrap();
assert!(matches!(result.status, MissionStatus::Completed));
}
#[tokio::test]
async fn test_execute_mission_not_found() {
let server = create_test_server();
let execute_request = ExecuteMissionRequest {
mission_id: "nonexistent".to_string(),
dry_run: Some(false),
skip_safety: Some(false),
};
let response = server.post("/missions/nonexistent/execute").json(&execute_request).await;
response.assert_status(StatusCode::NOT_FOUND);
let json = response.json::<ApiResponse<MissionResult>>();
assert!(!json.success);
assert_eq!(json.error.unwrap(), "Mission not found");
}
#[tokio::test]
async fn test_get_mission_status_not_found() {
let server = create_test_server();
let response = server.get("/missions/nonexistent/status").await;
response.assert_status(StatusCode::NOT_FOUND);
let json = response.json::<ApiResponse<serde_json::Value>>();
assert!(!json.success);
assert_eq!(json.error.unwrap(), "Mission not found");
}
#[tokio::test]
async fn test_safety_check_valid_mission() {
let server = create_test_server();
let request = SafetyCheckRequest {
mission_yaml: create_test_mission_yaml(),
};
let response = server.post("/safety/check").json(&request).await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<crate::safety::ValidationResult>>();
assert!(json.success);
assert!(json.data.is_some());
}
#[tokio::test]
async fn test_safety_check_invalid_yaml() {
let server = create_test_server();
let request = SafetyCheckRequest {
mission_yaml: "invalid yaml content".to_string(),
};
let response = server.post("/safety/check").json(&request).await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<crate::safety::ValidationResult>>();
assert!(!json.success);
assert!(json.error.unwrap().contains("Invalid mission YAML"));
}
#[tokio::test]
async fn test_validate_mission_not_found() {
let server = create_test_server();
let response = server.post("/safety/validate/nonexistent").await;
response.assert_status(StatusCode::NOT_FOUND);
let json = response.json::<ApiResponse<crate::safety::ValidationResult>>();
assert!(!json.success);
assert_eq!(json.error.unwrap(), "Mission not found");
}
#[tokio::test]
async fn test_validate_existing_mission() {
let server = create_test_server();
let create_request = CreateMissionRequest {
name: "Validation Test".to_string(),
description: None,
mission_yaml: create_test_mission_yaml(),
};
let create_response = server.post("/missions").json(&create_request).await;
let mission_id = create_response.json::<ApiResponse<MissionHandle>>().data.unwrap().id;
let response = server.post(&format!("/safety/validate/{}", mission_id)).await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<crate::safety::ValidationResult>>();
assert!(json.success);
assert!(json.data.is_some());
}
#[cfg(feature = "llm")]
mod llm_tests {
use super::*;
use crate::llm::ChatMessage;
#[tokio::test]
async fn test_llm_chat_endpoint() {
let server = create_test_server();
let request = LLMChatRequest {
messages: vec![ChatMessage {
role: crate::llm::MessageRole::User,
content: "Hello".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}],
model: Some("gpt-3.5-turbo".to_string()),
temperature: Some(0.7),
max_tokens: Some(100),
};
let response = server.post("/llm/chat").json(&request).await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
#[tokio::test]
async fn test_llm_models_endpoint() {
let server = create_test_server();
let response = server.get("/llm/models").await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
}
#[cfg(not(feature = "llm"))]
mod llm_disabled_tests {
use super::*;
#[tokio::test]
async fn test_llm_chat_disabled() {
let server = create_test_server();
let response = server.post("/llm/chat").await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<()>>();
assert!(!json.success);
assert!(json.error.unwrap().contains("LLM feature not enabled"));
}
#[tokio::test]
async fn test_llm_models_disabled() {
let server = create_test_server();
let response = server.get("/llm/models").await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<()>>();
assert!(!json.success);
assert!(json.error.unwrap().contains("LLM feature not enabled"));
}
}
#[cfg(feature = "rag")]
mod rag_tests {
use super::*;
#[tokio::test]
async fn test_rag_search_endpoint() {
let server = create_test_server();
let request = RagSearchRequest {
query: "test search".to_string(),
limit: Some(10),
similarity_threshold: Some(0.8),
};
let response = server.post("/rag/search").json(&request).await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
#[tokio::test]
async fn test_rag_add_document_endpoint() {
let server = create_test_server();
let request = RagDocumentRequest {
id: "test-doc".to_string(),
content: "Test document content".to_string(),
metadata: Some(HashMap::from([("author".to_string(), serde_json::json!("test"))])),
};
let response = server.post("/rag/documents").json(&request).await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
#[tokio::test]
async fn test_rag_list_documents_endpoint() {
let server = create_test_server();
let response = server.get("/rag/documents?offset=0&limit=10").await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
#[tokio::test]
async fn test_rag_delete_document_endpoint() {
let server = create_test_server();
let response = server.delete("/rag/documents/test-doc").await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
}
#[cfg(not(feature = "rag"))]
mod rag_disabled_tests {
use super::*;
#[tokio::test]
async fn test_rag_search_disabled() {
let server = create_test_server();
let response = server.post("/rag/search").await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<()>>();
assert!(!json.success);
assert!(json.error.unwrap().contains("RAG feature not enabled"));
}
#[tokio::test]
async fn test_rag_add_document_disabled() {
let server = create_test_server();
let response = server.post("/rag/documents").await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<()>>();
assert!(!json.success);
assert!(json.error.unwrap().contains("RAG feature not enabled"));
}
#[tokio::test]
async fn test_rag_list_documents_disabled() {
let server = create_test_server();
let response = server.get("/rag/documents").await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<()>>();
assert!(!json.success);
assert!(json.error.unwrap().contains("RAG feature not enabled"));
}
#[tokio::test]
async fn test_rag_delete_document_disabled() {
let server = create_test_server();
let response = server.delete("/rag/documents/test-doc").await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<()>>();
assert!(!json.success);
assert!(json.error.unwrap().contains("RAG feature not enabled"));
}
}
#[cfg(feature = "sandbox")]
mod sandbox_tests {
use super::*;
#[tokio::test]
async fn test_sandbox_create_session() {
let server = create_test_server();
let response = server.post("/sandbox/sessions").await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
#[tokio::test]
async fn test_sandbox_list_sessions() {
let server = create_test_server();
let response = server.get("/sandbox/sessions").await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
#[tokio::test]
async fn test_sandbox_execute() {
let server = create_test_server();
let request = SandboxExecuteRequest {
command: "echo".to_string(),
args: vec!["hello".to_string()],
};
let response = server.post("/sandbox/sessions/test-session/execute").json(&request).await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
#[tokio::test]
async fn test_sandbox_list_files() {
let server = create_test_server();
let response = server.get("/sandbox/sessions/test-session/files").await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
#[tokio::test]
async fn test_sandbox_read_file() {
let server = create_test_server();
let response = server.get("/sandbox/sessions/test-session/files/test.txt").await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
#[tokio::test]
async fn test_sandbox_write_file() {
let server = create_test_server();
let request = SandboxFileRequest {
path: "test.txt".to_string(),
content: Some("test content".to_string()),
};
let response = server.put("/sandbox/sessions/test-session/files/test.txt").json(&request).await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
#[tokio::test]
async fn test_sandbox_cleanup_session() {
let server = create_test_server();
let response = server.delete("/sandbox/sessions/test-session").await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
}
#[cfg(not(feature = "sandbox"))]
mod sandbox_disabled_tests {
use super::*;
#[tokio::test]
async fn test_sandbox_create_session_disabled() {
let server = create_test_server();
let response = server.post("/sandbox/sessions").await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<()>>();
assert!(!json.success);
assert!(json.error.unwrap().contains("Sandbox feature not enabled"));
}
#[tokio::test]
async fn test_sandbox_list_sessions_disabled() {
let server = create_test_server();
let response = server.get("/sandbox/sessions").await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<()>>();
assert!(!json.success);
assert!(json.error.unwrap().contains("Sandbox feature not enabled"));
}
#[tokio::test]
async fn test_sandbox_execute_disabled() {
let server = create_test_server();
let response = server.post("/sandbox/sessions/test/execute").await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<()>>();
assert!(!json.success);
assert!(json.error.unwrap().contains("Sandbox feature not enabled"));
}
}
#[tokio::test]
async fn test_concurrent_mission_operations() {
let server = create_test_server();
let mut mission_ids = vec![];
for i in 0..5 {
let request = CreateMissionRequest {
name: format!("Concurrent Mission {}", i),
description: None,
mission_yaml: create_test_mission_yaml(),
};
let response = server.post("/missions").json(&request).await;
response.assert_status(StatusCode::OK);
let json = response.json::<ApiResponse<MissionHandle>>();
assert!(json.success);
mission_ids.push(json.data.unwrap().id);
}
let list_response = server.get("/missions").await;
list_response.assert_status(StatusCode::OK);
let list_json = list_response.json::<ApiResponse<Vec<MissionHandle>>>();
assert_eq!(list_json.data.unwrap().len(), 5);
}
#[tokio::test]
async fn test_malformed_json_requests() {
let server = create_test_server();
let response = server.post("/missions")
.add_header(
axum::http::HeaderName::from_static("content-type"),
axum::http::HeaderValue::from_static("application/json")
)
.text("{invalid json}")
.await;
response.assert_status(StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
#[tokio::test]
async fn test_missing_content_type() {
let server = create_test_server();
let request = CreateMissionRequest {
name: "Test".to_string(),
description: None,
mission_yaml: create_test_mission_yaml(),
};
let response = server.post("/missions")
.text(&serde_json::to_string(&request).unwrap())
.await;
let status = response.status_code();
assert!(status.as_u16() >= 200 && status.as_u16() < 500);
}
#[test]
fn test_request_response_serialization() {
let create_request = CreateMissionRequest {
name: "Test Mission".to_string(),
description: Some("Description".to_string()),
mission_yaml: "version: \"1.0\"\nname: \"test\"".to_string(),
};
let serialized = serde_json::to_string(&create_request).unwrap();
let deserialized: CreateMissionRequest = serde_json::from_str(&serialized).unwrap();
assert_eq!(create_request.name, deserialized.name);
assert_eq!(create_request.description, deserialized.description);
let execute_request = ExecuteMissionRequest {
mission_id: "test-id".to_string(),
dry_run: Some(true),
skip_safety: Some(false),
};
let serialized = serde_json::to_string(&execute_request).unwrap();
let deserialized: ExecuteMissionRequest = serde_json::from_str(&serialized).unwrap();
assert_eq!(execute_request.mission_id, deserialized.mission_id);
assert_eq!(execute_request.dry_run, deserialized.dry_run);
let safety_request = SafetyCheckRequest {
mission_yaml: "test yaml".to_string(),
};
let serialized = serde_json::to_string(&safety_request).unwrap();
let deserialized: SafetyCheckRequest = serde_json::from_str(&serialized).unwrap();
assert_eq!(safety_request.mission_yaml, deserialized.mission_yaml);
}
#[test]
fn test_pagination_query_struct() {
let query = PaginationQuery {
offset: Some(10),
limit: Some(20),
};
assert_eq!(query.offset, Some(10));
assert_eq!(query.limit, Some(20));
let query = PaginationQuery {
offset: None,
limit: None,
};
assert_eq!(query.offset, None);
assert_eq!(query.limit, None);
}
}