use axum::{
extract::{Path, Query, State},
http::{StatusCode, HeaderMap},
response::{Json, Response},
routing::{delete, get, post},
Router,
middleware,
};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::sync::Arc;
use tower::ServiceBuilder;
use tower_http::{
cors::{Any, CorsLayer},
trace::TraceLayer,
};
use tracing::{info, warn, error};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use ai_agent::service::{AiAgentService, ServiceConfig, TaskRequest, BatchTaskRequest, TaskResponse, BatchTaskResponse, ServiceStatus, MetricsSnapshot, ServiceError};
use ai_agent::config::AgentConfig;
#[derive(Clone)]
struct AppState {
service: Arc<AiAgentService>,
}
#[derive(Deserialize)]
struct TaskQuery {
#[serde(default)]
verbose: bool,
}
#[derive(Deserialize)]
struct PaginationQuery {
#[serde(default = "default_limit")]
limit: usize,
#[serde(default = "default_offset")]
offset: usize,
}
fn default_limit() -> usize { 50 }
fn default_offset() -> usize { 0 }
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(
std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string()),
))
.with(tracing_subscriber::fmt::layer())
.init();
info!("Starting AI Agent Service HTTP Server");
let service_config = load_service_config().await?;
let agent_config = load_agent_config().await?;
let service = Arc::new(AiAgentService::new(service_config.clone(), agent_config).await?);
let app = create_router(service.clone(), service_config.clone());
let bind_addr = std::env::var("BIND_ADDRESS")
.unwrap_or_else(|_| "0.0.0.0:8080".to_string());
let addr: SocketAddr = bind_addr.parse()
.map_err(|_| anyhow::anyhow!("Invalid bind address: {}", bind_addr))?;
info!("Server listening on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
fn create_router(service: Arc<AiAgentService>, config: ServiceConfig) -> Router {
let state = AppState { service };
Router::new()
.route("/health", get(health_check))
.route("/healthz", get(health_check))
.route("/api/v1/status", get(get_service_status))
.route("/api/v1/metrics", get(get_metrics))
.route("/api/v1/tools", get(list_tools))
.route("/api/v1/tasks", post(execute_task))
.route("/api/v1/tasks/batch", post(execute_batch))
.route("/api/v1/tasks/:task_id", get(get_task_status))
.route("/api/v1/tasks/:task_id", delete(cancel_task))
.route("/tasks", post(execute_task_legacy))
.route("/config", get(get_service_status))
.nest_service("/metrics", axum::routing::get(prometheus_metrics_handler))
.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(CorsLayer::permissive())
.layer(middleware::from_fn(request_id_middleware))
)
.with_state(state)
}
async fn health_check() -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(serde_json::json!({
"status": "healthy",
"timestamp": chrono::Utc::now(),
"service": "ai-agent-service",
"version": env!("CARGO_PKG_VERSION")
})))
}
async fn get_service_status(State(state): State<AppState>) -> Result<Json<ServiceStatus>, ServiceError> {
state.service.get_service_status().await.map(Json)
}
async fn get_metrics(State(state): State<AppState>) -> Result<Json<MetricsSnapshot>, ServiceError> {
state.service.get_metrics().await.map(Json)
}
async fn list_tools() -> Result<Json<serde_json::Value>, ServiceError> {
Ok(Json(serde_json::json!({
"tools": [
{
"name": "read_file",
"description": "Read the contents of a file",
"parameters": ["path"]
},
{
"name": "write_file",
"description": "Write content to a file",
"parameters": ["path", "content"]
},
{
"name": "run_command",
"description": "Execute a shell command",
"parameters": ["command", "working_dir"]
},
{
"name": "list_files",
"description": "List files and directories",
"parameters": ["path"]
}
]
})))
}
async fn execute_task(
State(state): State<AppState>,
Json(request): Json<TaskRequest>,
) -> Result<Json<TaskResponse>, ServiceError> {
info!("Executing task: {}", request.task);
state.service.execute_task(request).await.map(Json)
}
async fn execute_task_legacy(
State(state): State<AppState>,
Json(request): serde_json::Value,
) -> Result<Json<serde_json::Value>, ServiceError> {
let task = request.get("task")
.and_then(|v| v.as_str())
.ok_or_else(|| ServiceError {
code: "INVALID_REQUEST".to_string(),
message: "Missing 'task' field in request".to_string(),
details: None,
stack_trace: None,
timestamp: chrono::Utc::now(),
})?;
let task_request = TaskRequest {
task: task.to_string(),
task_id: request.get("task_id").and_then(|v| v.as_str()).map(|s| s.to_string()),
context: None,
priority: None,
metadata: None,
};
let response = state.service.execute_task(task_request).await?;
Ok(Json(serde_json::json!({
"success": response.result.as_ref().map(|r| r.success).unwrap_or(false),
"summary": response.result.as_ref().map(|r| r.summary.clone()).unwrap_or_default(),
"details": response.result.as_ref().and_then(|r| r.details.clone()),
"task_id": response.task_id,
"status": response.status,
"execution_time": response.result.as_ref().map(|r| r.execution_time).unwrap_or(0)
})))
}
async fn execute_batch(
State(state): State<AppState>,
Json(request): Json<BatchTaskRequest>,
) -> Result<Json<BatchTaskResponse>, ServiceError> {
info!("Executing batch with {} tasks", request.tasks.len());
state.service.execute_batch(request).await.map(Json)
}
async fn get_task_status(
State(state): State<AppState>,
Path(task_id): Path<String>,
) -> Result<Json<TaskResponse>, ServiceError> {
state.service.get_task_status(&task_id).await.map(Json)
}
async fn cancel_task(
State(state): State<AppState>,
Path(task_id): Path<String>,
) -> Result<StatusCode, ServiceError> {
state.service.cancel_task(&task_id).await?;
Ok(StatusCode::NO_CONTENT)
}
async fn request_id_middleware<B>(
request: axum::extract::Request<B>,
next: axum::middleware::Next<B>,
) -> axum::response::Response {
let request_id = uuid::Uuid::new_v4().to_string();
let mut response = next.run(request).await;
response.headers_mut().insert("x-request-id", request_id.parse().unwrap());
response
}
async fn prometheus_metrics_handler() -> Result<String, StatusCode> {
Ok("# HELP ai_agent_requests_total Total number of requests
# TYPE ai_agent_requests_total counter
ai_agent_requests_total 0
# HELP ai_agent_request_duration_seconds Request duration
# TYPE ai_agent_request_duration_seconds histogram
ai_agent_request_duration_seconds_bucket{le=\"1.0\"} 0
ai_agent_request_duration_seconds_bucket{le=\"+Inf\"} 0
ai_agent_request_duration_seconds_count 0
ai_agent_request_duration_seconds_sum 0
# HELP ai_agent_active_tasks Number of active tasks
# TYPE ai_agent_active_tasks gauge
ai_agent_active_tasks 0
# HELP ai_agent_completed_tasks_total Total number of completed tasks
# TYPE ai_agent_completed_tasks_total counter
ai_agent_completed_tasks_total 0
# HELP ai_agent_failed_tasks_total Total number of failed tasks
# TYPE ai_agent_failed_tasks_total counter
ai_agent_failed_tasks_total 0".to_string())
}
async fn load_service_config() -> Result<ServiceConfig, anyhow::Error> {
let max_concurrent_tasks = std::env::var("AI_AGENT_MAX_CONCURRENT_TASKS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(10);
let default_task_timeout = std::env::var("AI_AGENT_DEFAULT_TASK_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(300);
let enable_metrics = std::env::var("AI_AGENT_ENABLE_METRICS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(true);
Ok(ServiceConfig {
max_concurrent_tasks,
default_task_timeout,
max_task_timeout: 3600,
enable_metrics,
log_level: std::env::var("AI_AGENT_LOG_LEVEL").unwrap_or_else(|_| "info".to_string()),
cors: ai_agent::service::CorsConfig {
allowed_origins: vec!["*".to_string()],
allowed_methods: vec!["GET".to_string(), "POST".to_string(), "DELETE".to_string()],
allowed_headers: vec!["*".to_string()],
allow_credentials: false,
},
rate_limiting: None,
})
}
async fn load_agent_config() -> Result<AgentConfig, anyhow::Error> {
let config_path = std::env::var("AI_AGENT_CONFIG_FILE")
.unwrap_or_else(|_| "config.toml".to_string());
if std::path::Path::new(&config_path).exists() {
AgentConfig::load_with_fallback(&config_path)
.map_err(|e| anyhow::anyhow!("Failed to load config: {}", e))
} else {
let provider = std::env::var("AI_AGENT_MODEL_PROVIDER")
.unwrap_or_else(|_| "zhipu".to_string());
let model_name = std::env::var("AI_AGENT_MODEL_NAME")
.unwrap_or_else(|_| "glm-4".to_string());
let api_key = std::env::var("AI_AGENT_API_KEY")
.ok_or_else(|| anyhow::anyhow!("AI_AGENT_API_KEY environment variable is required"))?;
let provider_config = match provider.as_str() {
"zhipu" => ai_agent::config::ModelProvider::Zhipu,
"openai" => ai_agent::config::ModelProvider::OpenAI,
"anthropic" => ai_agent::config::ModelProvider::Anthropic,
"local" => ai_agent::config::ModelProvider::Local(
std::env::var("AI_AGENT_LOCAL_ENDPOINT").unwrap_or_else(|_| "http://localhost:8081".to_string())
),
_ => return Err(anyhow::anyhow!("Unsupported model provider: {}", provider)),
};
Ok(AgentConfig {
model: ai_agent::config::ModelConfig {
provider: provider_config,
model_name,
api_key: Some(api_key),
endpoint: std::env::var("AI_AGENT_MODEL_ENDPOINT").ok(),
max_tokens: std::env::var("AI_AGENT_MAX_TOKENS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(4000),
temperature: std::env::var("AI_AGENT_TEMPERATURE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0.7),
},
execution: ai_agent::config::ExecutionConfig {
max_steps: std::env::var("AI_AGENT_MAX_STEPS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(10),
max_retries: std::env::var("AI_AGENT_MAX_RETRIES")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(3),
retry_delay_seconds: std::env::var("AI_AGENT_RETRY_DELAY")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1),
timeout_seconds: std::env::var("AI_AGENT_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(300),
},
tools: ai_agent::config::ToolsConfig {
enable_file_operations: std::env::var("AI_AGENT_ENABLE_FILE_OPS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(true),
enable_command_execution: std::env::var("AI_AGENT_ENABLE_COMMANDS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(true),
working_directory: std::env::var("AI_AGENT_WORK_DIR").ok(),
allowed_paths: std::env::var("AI_AGENT_ALLOWED_PATHS")
.ok()
.map(|s| s.split(',').map(|p| p.trim().to_string()).collect()),
forbidden_commands: std::env::var("AI_AGENT_FORBIDDEN_COMMANDS")
.ok()
.map(|s| s.split(',').map(|c| c.trim().to_string()).collect()),
},
})
}
}