1use axum::extract::{Path, State};
7use axum::response::{IntoResponse, Response};
8use axum::routing::{get, post};
9use axum::Router;
10use axum::{http::StatusCode, Json};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13
14use mockforge_core::chain_execution::ChainExecutionEngine;
15use mockforge_core::request_chaining::RequestChainRegistry;
16
17#[derive(Clone)]
19pub struct ChainState {
20 registry: Arc<RequestChainRegistry>,
22 engine: Arc<ChainExecutionEngine>,
24}
25
26pub fn create_chain_state(
32 registry: Arc<RequestChainRegistry>,
33 engine: Arc<ChainExecutionEngine>,
34) -> ChainState {
35 ChainState { registry, engine }
36}
37
38pub fn chains_router(state: ChainState) -> Router {
44 Router::new()
45 .route("/", get(list_chains).post(create_chain))
46 .route("/{chain_id}", get(get_chain).put(update_chain).delete(delete_chain))
47 .route("/{chain_id}/execute", post(execute_chain))
48 .route("/{chain_id}/validate", post(validate_chain))
49 .route("/{chain_id}/history", get(get_chain_history))
50 .with_state(state)
51}
52
53#[derive(Debug, Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56pub struct ChainExecutionRequest {
57 pub variables: Option<serde_json::Value>,
59}
60
61#[derive(Debug, Serialize, Deserialize)]
63#[serde(rename_all = "camelCase")]
64pub struct ChainExecutionResponse {
65 pub chain_id: String,
67 pub status: String,
69 pub total_duration_ms: u64,
71 #[serde(skip_serializing_if = "Option::is_none")]
73 pub request_results: Option<serde_json::Value>,
74 #[serde(skip_serializing_if = "Option::is_none")]
76 pub error_message: Option<String>,
77}
78
79#[derive(Debug, Serialize, Deserialize)]
81#[serde(rename_all = "camelCase")]
82pub struct ChainListResponse {
83 pub chains: Vec<ChainSummary>,
85 pub total: usize,
87}
88
89#[derive(Debug, Serialize, Deserialize)]
91#[serde(rename_all = "camelCase")]
92pub struct ChainSummary {
93 pub id: String,
95 pub name: String,
97 pub description: Option<String>,
99 pub tags: Vec<String>,
101 pub enabled: bool,
103 pub link_count: usize,
105}
106
107#[derive(Debug, Serialize, Deserialize)]
109#[serde(rename_all = "camelCase")]
110pub struct ChainCreateRequest {
111 pub definition: String,
113}
114
115#[derive(Debug, Serialize, Deserialize)]
117#[serde(rename_all = "camelCase")]
118pub struct ChainCreateResponse {
119 pub id: String,
121 pub message: String,
123}
124
125#[derive(Debug, Serialize, Deserialize)]
127#[serde(rename_all = "camelCase")]
128pub struct ChainValidationResponse {
129 pub valid: bool,
131 pub errors: Vec<String>,
133 pub warnings: Vec<String>,
135}
136
137#[derive(Debug, Serialize, Deserialize)]
139#[serde(rename_all = "camelCase")]
140pub struct ChainExecutionHistoryResponse {
141 pub chain_id: String,
143 pub executions: Vec<ChainExecutionRecord>,
145 pub total: usize,
147}
148
149#[derive(Debug, Serialize, Deserialize)]
151#[serde(rename_all = "camelCase")]
152pub struct ChainExecutionRecord {
153 pub executed_at: String,
155 pub status: String,
157 pub total_duration_ms: u64,
159 pub request_count: usize,
161 pub error_message: Option<String>,
163}
164
165pub async fn list_chains(State(state): State<ChainState>) -> impl IntoResponse {
167 let chain_ids = state.registry.list_chains().await;
168 let mut chains = Vec::new();
169
170 for id in chain_ids {
171 if let Some(chain) = state.registry.get_chain(&id).await {
172 chains.push(ChainSummary {
173 id: chain.id.clone(),
174 name: chain.name.clone(),
175 description: chain.description.clone(),
176 tags: chain.tags.clone(),
177 enabled: chain.config.enabled,
178 link_count: chain.links.len(),
179 });
180 }
181 }
182
183 let total = chains.len();
184 Json(ChainListResponse { chains, total })
185}
186
187pub async fn get_chain(Path(chain_id): Path<String>, State(state): State<ChainState>) -> Response {
189 match state.registry.get_chain(&chain_id).await {
190 Some(chain) => Json(chain).into_response(),
191 None => (StatusCode::NOT_FOUND, format!("Chain '{}' not found", chain_id)).into_response(),
192 }
193}
194
195pub async fn create_chain(
197 State(state): State<ChainState>,
198 Json(request): Json<ChainCreateRequest>,
199) -> Response {
200 match state.registry.register_from_yaml(&request.definition).await {
201 Ok(id) => Json(ChainCreateResponse {
202 id: id.clone(),
203 message: format!("Chain '{}' created successfully", id),
204 })
205 .into_response(),
206 Err(e) => {
207 (StatusCode::BAD_REQUEST, format!("Failed to create chain: {}", e)).into_response()
208 }
209 }
210}
211
212pub async fn update_chain(
214 Path(chain_id): Path<String>,
215 State(state): State<ChainState>,
216 Json(request): Json<ChainCreateRequest>,
217) -> Response {
218 if state.registry.remove_chain(&chain_id).await.is_err() {
220 return (StatusCode::NOT_FOUND, format!("Chain '{}' not found", chain_id)).into_response();
221 }
222
223 match state.registry.register_from_yaml(&request.definition).await {
225 Ok(new_id) => {
226 if new_id != chain_id {
227 return (StatusCode::BAD_REQUEST, "Chain ID mismatch in update".to_string())
228 .into_response();
229 }
230 Json(serde_json::json!({
231 "id": new_id,
232 "message": "Chain updated successfully"
233 }))
234 .into_response()
235 }
236 Err(e) => {
237 (StatusCode::BAD_REQUEST, format!("Failed to update chain: {}", e)).into_response()
238 }
239 }
240}
241
242pub async fn delete_chain(
244 Path(chain_id): Path<String>,
245 State(state): State<ChainState>,
246) -> Response {
247 match state.registry.remove_chain(&chain_id).await {
248 Ok(_) => Json(serde_json::json!({
249 "id": chain_id,
250 "message": "Chain deleted successfully"
251 }))
252 .into_response(),
253 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to delete chain: {}", e))
254 .into_response(),
255 }
256}
257
258pub async fn execute_chain(
260 Path(chain_id): Path<String>,
261 State(state): State<ChainState>,
262 Json(request): Json<ChainExecutionRequest>,
263) -> Response {
264 match state.engine.execute_chain(&chain_id, request.variables).await {
265 Ok(result) => Json(ChainExecutionResponse {
266 chain_id: result.chain_id,
267 status: match result.status {
268 mockforge_core::chain_execution::ChainExecutionStatus::Successful => {
269 "successful".to_string()
270 }
271 mockforge_core::chain_execution::ChainExecutionStatus::PartialSuccess => {
272 "partial_success".to_string()
273 }
274 mockforge_core::chain_execution::ChainExecutionStatus::Failed => {
275 "failed".to_string()
276 }
277 },
278 total_duration_ms: result.total_duration_ms,
279 request_results: Some(serde_json::to_value(result.request_results).unwrap_or_default()),
280 error_message: result.error_message,
281 })
282 .into_response(),
283 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to execute chain: {}", e))
284 .into_response(),
285 }
286}
287
288pub async fn validate_chain(
290 Path(chain_id): Path<String>,
291 State(state): State<ChainState>,
292) -> Response {
293 match state.registry.get_chain(&chain_id).await {
294 Some(chain) => {
295 match state.registry.validate_chain(&chain).await {
296 Ok(()) => Json(ChainValidationResponse {
297 valid: true,
298 errors: vec![],
299 warnings: vec![], })
301 .into_response(),
302 Err(e) => Json(ChainValidationResponse {
303 valid: false,
304 errors: vec![e.to_string()],
305 warnings: vec![],
306 })
307 .into_response(),
308 }
309 }
310 None => (StatusCode::NOT_FOUND, format!("Chain '{}' not found", chain_id)).into_response(),
311 }
312}
313
314pub async fn get_chain_history(
316 Path(chain_id): Path<String>,
317 State(state): State<ChainState>,
318) -> Response {
319 if state.registry.get_chain(&chain_id).await.is_none() {
321 return (StatusCode::NOT_FOUND, format!("Chain '{}' not found", chain_id)).into_response();
322 }
323
324 let history = state.engine.get_chain_history(&chain_id).await;
325
326 let executions: Vec<ChainExecutionRecord> = history
327 .into_iter()
328 .map(|record| ChainExecutionRecord {
329 executed_at: record.executed_at,
330 status: match record.result.status {
331 mockforge_core::chain_execution::ChainExecutionStatus::Successful => {
332 "successful".to_string()
333 }
334 mockforge_core::chain_execution::ChainExecutionStatus::PartialSuccess => {
335 "partial_success".to_string()
336 }
337 mockforge_core::chain_execution::ChainExecutionStatus::Failed => {
338 "failed".to_string()
339 }
340 },
341 total_duration_ms: record.result.total_duration_ms,
342 request_count: record.result.request_results.len(),
343 error_message: record.result.error_message,
344 })
345 .collect();
346
347 let total = executions.len();
348
349 Json(ChainExecutionHistoryResponse {
350 chain_id,
351 executions,
352 total,
353 })
354 .into_response()
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use mockforge_core::chain_execution::ChainExecutionEngine;
361 use mockforge_core::request_chaining::{ChainConfig, RequestChainRegistry};
362 use std::sync::Arc;
363
364 #[tokio::test]
365 async fn test_chain_state_creation() {
366 let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
367 let engine = Arc::new(ChainExecutionEngine::new(registry.clone(), ChainConfig::default()));
368 let _state = create_chain_state(registry, engine);
369
370 }
372}