1use crate::{
4 diff::ComparisonResult,
5 har_export::export_to_har,
6 integration_testing::{IntegrationTestGenerator, IntegrationWorkflow, WorkflowSetup},
7 models::RecordedExchange,
8 query::{execute_query, QueryFilter, QueryResult},
9 recorder::Recorder,
10 replay::ReplayEngine,
11 test_generation::{LlmConfig, TestFormat, TestGenerationConfig, TestGenerator},
12};
13use axum::{
14 extract::{Path, Query, State},
15 http::StatusCode,
16 response::{IntoResponse, Json, Response},
17 routing::{delete, get, post},
18 Router,
19};
20use serde::{Deserialize, Serialize};
21use std::sync::Arc;
22use tracing::{debug, error};
23
24#[derive(Clone)]
26pub struct ApiState {
27 pub recorder: Arc<Recorder>,
28}
29
30pub fn create_api_router(recorder: Arc<Recorder>) -> Router {
32 let state = ApiState { recorder };
33
34 Router::new()
35 .route("/api/recorder/requests", get(list_requests))
37 .route("/api/recorder/requests/:id", get(get_request))
38 .route("/api/recorder/requests/:id/response", get(get_response))
39 .route("/api/recorder/search", post(search_requests))
40
41 .route("/api/recorder/export/har", get(export_har))
43
44 .route("/api/recorder/status", get(get_status))
46 .route("/api/recorder/enable", post(enable_recording))
47 .route("/api/recorder/disable", post(disable_recording))
48 .route("/api/recorder/clear", delete(clear_recordings))
49
50 .route("/api/recorder/replay/:id", post(replay_request))
52 .route("/api/recorder/compare/:id", post(compare_responses))
53
54 .route("/api/recorder/stats", get(get_statistics))
56
57 .route("/api/recorder/generate-tests", post(generate_tests))
59
60 .route("/api/recorder/workflows", post(create_workflow))
62 .route("/api/recorder/workflows/:id", get(get_workflow))
63 .route("/api/recorder/workflows/:id/generate", post(generate_integration_test))
64
65 .with_state(state)
66}
67
68async fn list_requests(
70 State(state): State<ApiState>,
71 Query(params): Query<ListParams>,
72) -> Result<Json<QueryResult>, ApiError> {
73 let limit = params.limit.unwrap_or(100);
74 let offset = params.offset.unwrap_or(0);
75
76 let filter = QueryFilter {
77 limit: Some(limit),
78 offset: Some(offset),
79 ..Default::default()
80 };
81
82 let result = execute_query(state.recorder.database(), filter).await?;
83 Ok(Json(result))
84}
85
86async fn get_request(
88 State(state): State<ApiState>,
89 Path(id): Path<String>,
90) -> Result<Json<RecordedExchange>, ApiError> {
91 let exchange = state
92 .recorder
93 .database()
94 .get_exchange(&id)
95 .await?
96 .ok_or_else(|| ApiError::NotFound(format!("Request {} not found", id)))?;
97
98 Ok(Json(exchange))
99}
100
101async fn get_response(
103 State(state): State<ApiState>,
104 Path(id): Path<String>,
105) -> Result<Json<serde_json::Value>, ApiError> {
106 let response = state
107 .recorder
108 .database()
109 .get_response(&id)
110 .await?
111 .ok_or_else(|| ApiError::NotFound(format!("Response for request {} not found", id)))?;
112
113 Ok(Json(serde_json::json!({
114 "request_id": response.request_id,
115 "status_code": response.status_code,
116 "headers": serde_json::from_str::<serde_json::Value>(&response.headers)?,
117 "body": response.body,
118 "body_encoding": response.body_encoding,
119 "size_bytes": response.size_bytes,
120 "timestamp": response.timestamp,
121 })))
122}
123
124async fn search_requests(
126 State(state): State<ApiState>,
127 Json(filter): Json<QueryFilter>,
128) -> Result<Json<QueryResult>, ApiError> {
129 let result = execute_query(state.recorder.database(), filter).await?;
130 Ok(Json(result))
131}
132
133async fn export_har(
135 State(state): State<ApiState>,
136 Query(params): Query<ExportParams>,
137) -> Result<Response, ApiError> {
138 let limit = params.limit.unwrap_or(1000);
139
140 let filter = QueryFilter {
141 limit: Some(limit),
142 protocol: Some(crate::models::Protocol::Http), ..Default::default()
144 };
145
146 let result = execute_query(state.recorder.database(), filter).await?;
147 let har = export_to_har(&result.exchanges)?;
148 let har_json = serde_json::to_string_pretty(&har)?;
149
150 Ok((StatusCode::OK, [("content-type", "application/json")], har_json).into_response())
151}
152
153async fn get_status(State(state): State<ApiState>) -> Json<StatusResponse> {
155 let enabled = state.recorder.is_enabled().await;
156 Json(StatusResponse { enabled })
157}
158
159async fn enable_recording(State(state): State<ApiState>) -> Json<StatusResponse> {
161 state.recorder.enable().await;
162 debug!("Recording enabled via API");
163 Json(StatusResponse { enabled: true })
164}
165
166async fn disable_recording(State(state): State<ApiState>) -> Json<StatusResponse> {
168 state.recorder.disable().await;
169 debug!("Recording disabled via API");
170 Json(StatusResponse { enabled: false })
171}
172
173async fn clear_recordings(State(state): State<ApiState>) -> Result<Json<ClearResponse>, ApiError> {
175 state.recorder.database().clear_all().await?;
176 debug!("All recordings cleared via API");
177 Ok(Json(ClearResponse {
178 message: "All recordings cleared".to_string(),
179 }))
180}
181
182async fn replay_request(
184 State(state): State<ApiState>,
185 Path(id): Path<String>,
186) -> Result<Json<serde_json::Value>, ApiError> {
187 let engine = ReplayEngine::new((**state.recorder.database()).clone());
188 let result = engine.replay_request(&id).await?;
189
190 Ok(Json(serde_json::json!({
191 "request_id": result.request_id,
192 "success": result.success,
193 "message": result.message,
194 "original_status": result.original_status,
195 "replay_status": result.replay_status,
196 })))
197}
198
199async fn compare_responses(
201 State(state): State<ApiState>,
202 Path(id): Path<String>,
203 Json(payload): Json<CompareRequest>,
204) -> Result<Json<ComparisonResult>, ApiError> {
205 let engine = ReplayEngine::new((**state.recorder.database()).clone());
206
207 let result = engine
208 .compare_responses(&id, payload.body.as_bytes(), payload.status_code, &payload.headers)
209 .await?;
210
211 Ok(Json(result))
212}
213
214async fn get_statistics(
216 State(state): State<ApiState>,
217) -> Result<Json<StatisticsResponse>, ApiError> {
218 let db = state.recorder.database();
219 let stats = db.get_statistics().await?;
220
221 Ok(Json(StatisticsResponse {
222 total_requests: stats.total_requests,
223 by_protocol: stats.by_protocol,
224 by_status_code: stats.by_status_code,
225 avg_duration_ms: stats.avg_duration_ms,
226 }))
227}
228
229#[derive(Debug, Deserialize)]
232struct ListParams {
233 limit: Option<i32>,
234 offset: Option<i32>,
235}
236
237#[derive(Debug, Deserialize)]
238struct ExportParams {
239 limit: Option<i32>,
240}
241
242#[derive(Debug, Deserialize)]
243struct CompareRequest {
244 status_code: i32,
245 headers: std::collections::HashMap<String, String>,
246 body: String,
247}
248
249#[derive(Debug, Serialize)]
250struct StatusResponse {
251 enabled: bool,
252}
253
254#[derive(Debug, Serialize)]
255struct ClearResponse {
256 message: String,
257}
258
259#[derive(Debug, Serialize)]
260struct StatisticsResponse {
261 total_requests: i64,
262 by_protocol: std::collections::HashMap<String, i64>,
263 by_status_code: std::collections::HashMap<i32, i64>,
264 avg_duration_ms: Option<f64>,
265}
266
267#[derive(Debug)]
270enum ApiError {
271 Database(sqlx::Error),
272 Serialization(serde_json::Error),
273 NotFound(String),
274 InvalidInput(String),
275 Recorder(crate::RecorderError),
276}
277
278impl From<sqlx::Error> for ApiError {
279 fn from(err: sqlx::Error) -> Self {
280 ApiError::Database(err)
281 }
282}
283
284impl From<serde_json::Error> for ApiError {
285 fn from(err: serde_json::Error) -> Self {
286 ApiError::Serialization(err)
287 }
288}
289
290impl From<crate::RecorderError> for ApiError {
291 fn from(err: crate::RecorderError) -> Self {
292 ApiError::Recorder(err)
293 }
294}
295
296impl IntoResponse for ApiError {
297 fn into_response(self) -> Response {
298 let (status, message) = match self {
299 ApiError::Database(e) => {
300 error!("Database error: {}", e);
301 (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {}", e))
302 }
303 ApiError::Serialization(e) => {
304 error!("Serialization error: {}", e);
305 (StatusCode::INTERNAL_SERVER_ERROR, format!("Serialization error: {}", e))
306 }
307 ApiError::NotFound(msg) => (StatusCode::NOT_FOUND, msg),
308 ApiError::InvalidInput(msg) => (StatusCode::BAD_REQUEST, msg),
309 ApiError::Recorder(e) => {
310 error!("Recorder error: {}", e);
311 (StatusCode::INTERNAL_SERVER_ERROR, format!("Recorder error: {}", e))
312 }
313 };
314
315 (status, Json(serde_json::json!({ "error": message }))).into_response()
316 }
317}
318
319#[derive(Debug, Deserialize)]
321pub struct GenerateTestsRequest {
322 #[serde(default = "default_format")]
324 pub format: String,
325
326 #[serde(flatten)]
328 pub filter: QueryFilter,
329
330 #[serde(default = "default_suite_name")]
332 pub suite_name: String,
333
334 pub base_url: Option<String>,
336
337 #[serde(default)]
339 pub ai_descriptions: bool,
340
341 pub llm_config: Option<LlmConfigRequest>,
343
344 #[serde(default = "default_true")]
346 pub include_assertions: bool,
347
348 #[serde(default = "default_true")]
350 pub validate_body: bool,
351
352 #[serde(default = "default_true")]
354 pub validate_status: bool,
355
356 #[serde(default)]
358 pub validate_headers: bool,
359
360 #[serde(default)]
362 pub validate_timing: bool,
363
364 pub max_duration_ms: Option<u64>,
366}
367
368fn default_format() -> String {
369 "rust_reqwest".to_string()
370}
371
372fn default_suite_name() -> String {
373 "generated_tests".to_string()
374}
375
376fn default_true() -> bool {
377 true
378}
379
380#[derive(Debug, Deserialize)]
382pub struct LlmConfigRequest {
383 pub provider: String,
385 pub api_endpoint: String,
387 pub api_key: Option<String>,
389 pub model: String,
391 #[serde(default = "default_temperature")]
393 pub temperature: f64,
394}
395
396fn default_temperature() -> f64 {
397 0.3
398}
399
400async fn generate_tests(
402 State(state): State<ApiState>,
403 Json(request): Json<GenerateTestsRequest>,
404) -> Result<Json<serde_json::Value>, ApiError> {
405 debug!("Generating tests with format: {}", request.format);
406
407 let test_format = match request.format.as_str() {
409 "rust_reqwest" => TestFormat::RustReqwest,
410 "http_file" => TestFormat::HttpFile,
411 "curl" => TestFormat::Curl,
412 "postman" => TestFormat::Postman,
413 "k6" => TestFormat::K6,
414 "python_pytest" => TestFormat::PythonPytest,
415 "javascript_jest" => TestFormat::JavaScriptJest,
416 "go_test" => TestFormat::GoTest,
417 _ => {
418 return Err(ApiError::NotFound(format!(
419 "Invalid test format: {}. Supported: rust_reqwest, http_file, curl, postman, k6, python_pytest, javascript_jest, go_test",
420 request.format
421 )));
422 }
423 };
424
425 let llm_config = request.llm_config.map(|cfg| LlmConfig {
427 provider: cfg.provider,
428 api_endpoint: cfg.api_endpoint,
429 api_key: cfg.api_key,
430 model: cfg.model,
431 temperature: cfg.temperature,
432 });
433
434 let config = TestGenerationConfig {
436 format: test_format,
437 include_assertions: request.include_assertions,
438 validate_body: request.validate_body,
439 validate_status: request.validate_status,
440 validate_headers: request.validate_headers,
441 validate_timing: request.validate_timing,
442 max_duration_ms: request.max_duration_ms,
443 suite_name: request.suite_name,
444 base_url: request.base_url,
445 ai_descriptions: request.ai_descriptions,
446 llm_config,
447 group_by_endpoint: true,
448 include_setup_teardown: true,
449 generate_fixtures: false,
450 suggest_edge_cases: false,
451 analyze_test_gaps: false,
452 deduplicate_tests: false,
453 optimize_test_order: false,
454 };
455
456 let generator = TestGenerator::from_arc(state.recorder.database().clone(), config);
458
459 let result = generator.generate_from_filter(request.filter).await?;
461
462 Ok(Json(serde_json::json!({
464 "success": true,
465 "metadata": {
466 "suite_name": result.metadata.name,
467 "test_count": result.metadata.test_count,
468 "endpoint_count": result.metadata.endpoint_count,
469 "protocols": result.metadata.protocols,
470 "format": result.metadata.format,
471 "generated_at": result.metadata.generated_at,
472 },
473 "tests": result.tests.iter().map(|t| serde_json::json!({
474 "name": t.name,
475 "description": t.description,
476 "endpoint": t.endpoint,
477 "method": t.method,
478 })).collect::<Vec<_>>(),
479 "test_file": result.test_file,
480 })))
481}
482
483#[derive(Debug, Deserialize)]
487struct CreateWorkflowRequest {
488 workflow: IntegrationWorkflow,
489}
490
491async fn create_workflow(
493 State(_state): State<ApiState>,
494 Json(request): Json<CreateWorkflowRequest>,
495) -> Result<Json<serde_json::Value>, ApiError> {
496 Ok(Json(serde_json::json!({
499 "success": true,
500 "workflow": request.workflow,
501 "message": "Workflow created successfully"
502 })))
503}
504
505async fn get_workflow(
507 State(_state): State<ApiState>,
508 Path(id): Path<String>,
509) -> Result<Json<serde_json::Value>, ApiError> {
510 let workflow = IntegrationWorkflow {
513 id: id.clone(),
514 name: "Sample Workflow".to_string(),
515 description: "A sample integration test workflow".to_string(),
516 steps: vec![],
517 setup: WorkflowSetup::default(),
518 cleanup: vec![],
519 created_at: chrono::Utc::now(),
520 };
521
522 Ok(Json(serde_json::json!({
523 "success": true,
524 "workflow": workflow
525 })))
526}
527
528#[derive(Debug, Deserialize)]
530struct GenerateIntegrationTestRequest {
531 workflow: IntegrationWorkflow,
532 format: String, }
534
535async fn generate_integration_test(
537 State(_state): State<ApiState>,
538 Path(_id): Path<String>,
539 Json(request): Json<GenerateIntegrationTestRequest>,
540) -> Result<Json<serde_json::Value>, ApiError> {
541 let generator = IntegrationTestGenerator::new(request.workflow);
542
543 let test_code = match request.format.as_str() {
544 "rust" => generator.generate_rust_test(),
545 "python" => generator.generate_python_test(),
546 "javascript" | "js" => generator.generate_javascript_test(),
547 _ => return Err(ApiError::InvalidInput(format!("Unsupported format: {}", request.format))),
548 };
549
550 Ok(Json(serde_json::json!({
551 "success": true,
552 "format": request.format,
553 "test_code": test_code,
554 "message": "Integration test generated successfully"
555 })))
556}