1use axum::extract::{Path, State};
7use axum::response::{IntoResponse, Response};
8use axum::{http::StatusCode, Json};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11
12use mockforge_core::chain_execution::ChainExecutionEngine;
13use mockforge_core::request_chaining::RequestChainRegistry;
14
15#[derive(Clone)]
17pub struct ChainState {
18 registry: Arc<RequestChainRegistry>,
19 engine: Arc<ChainExecutionEngine>,
20}
21
22pub fn create_chain_state(
24 registry: Arc<RequestChainRegistry>,
25 engine: Arc<ChainExecutionEngine>,
26) -> ChainState {
27 ChainState { registry, engine }
28}
29
30#[derive(Debug, Serialize, Deserialize)]
31#[serde(rename_all = "camelCase")]
32pub struct ChainExecutionRequest {
33 pub variables: Option<serde_json::Value>,
34}
35
36#[derive(Debug, Serialize, Deserialize)]
37#[serde(rename_all = "camelCase")]
38pub struct ChainExecutionResponse {
39 pub chain_id: String,
40 pub status: String,
41 pub total_duration_ms: u64,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub request_results: Option<serde_json::Value>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 pub error_message: Option<String>,
46}
47
48#[derive(Debug, Serialize, Deserialize)]
49#[serde(rename_all = "camelCase")]
50pub struct ChainListResponse {
51 pub chains: Vec<ChainSummary>,
52 pub total: usize,
53}
54
55#[derive(Debug, Serialize, Deserialize)]
56#[serde(rename_all = "camelCase")]
57pub struct ChainSummary {
58 pub id: String,
59 pub name: String,
60 pub description: Option<String>,
61 pub tags: Vec<String>,
62 pub enabled: bool,
63 pub link_count: usize,
64}
65
66#[derive(Debug, Serialize, Deserialize)]
67#[serde(rename_all = "camelCase")]
68pub struct ChainCreateRequest {
69 pub definition: String, }
71
72#[derive(Debug, Serialize, Deserialize)]
73#[serde(rename_all = "camelCase")]
74pub struct ChainCreateResponse {
75 pub id: String,
76 pub message: String,
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80#[serde(rename_all = "camelCase")]
81pub struct ChainValidationResponse {
82 pub valid: bool,
83 pub errors: Vec<String>,
84 pub warnings: Vec<String>,
85}
86
87#[derive(Debug, Serialize, Deserialize)]
88#[serde(rename_all = "camelCase")]
89pub struct ChainExecutionHistoryResponse {
90 pub chain_id: String,
91 pub executions: Vec<ChainExecutionRecord>,
92 pub total: usize,
93}
94
95#[derive(Debug, Serialize, Deserialize)]
96#[serde(rename_all = "camelCase")]
97pub struct ChainExecutionRecord {
98 pub executed_at: String,
99 pub status: String,
100 pub total_duration_ms: u64,
101 pub request_count: usize,
102 pub error_message: Option<String>,
103}
104
105pub async fn list_chains(State(state): State<ChainState>) -> impl IntoResponse {
107 let chain_ids = state.registry.list_chains().await;
108 let mut chains = Vec::new();
109
110 for id in chain_ids {
111 if let Some(chain) = state.registry.get_chain(&id).await {
112 chains.push(ChainSummary {
113 id: chain.id.clone(),
114 name: chain.name.clone(),
115 description: chain.description.clone(),
116 tags: chain.tags.clone(),
117 enabled: chain.config.enabled,
118 link_count: chain.links.len(),
119 });
120 }
121 }
122
123 let total = chains.len();
124 Json(ChainListResponse { chains, total })
125}
126
127pub async fn get_chain(Path(chain_id): Path<String>, State(state): State<ChainState>) -> Response {
129 match state.registry.get_chain(&chain_id).await {
130 Some(chain) => Json(chain).into_response(),
131 None => (StatusCode::NOT_FOUND, format!("Chain '{}' not found", chain_id)).into_response(),
132 }
133}
134
135pub async fn create_chain(
137 State(state): State<ChainState>,
138 Json(request): Json<ChainCreateRequest>,
139) -> Response {
140 match state.registry.register_from_yaml(&request.definition).await {
141 Ok(id) => Json(ChainCreateResponse {
142 id: id.clone(),
143 message: format!("Chain '{}' created successfully", id),
144 })
145 .into_response(),
146 Err(e) => {
147 (StatusCode::BAD_REQUEST, format!("Failed to create chain: {}", e)).into_response()
148 }
149 }
150}
151
152pub async fn update_chain(
154 Path(chain_id): Path<String>,
155 State(state): State<ChainState>,
156 Json(request): Json<ChainCreateRequest>,
157) -> Response {
158 if state.registry.remove_chain(&chain_id).await.is_err() {
160 return (StatusCode::NOT_FOUND, format!("Chain '{}' not found", chain_id)).into_response();
161 }
162
163 match state.registry.register_from_yaml(&request.definition).await {
165 Ok(new_id) => {
166 if new_id != chain_id {
167 return (StatusCode::BAD_REQUEST, "Chain ID mismatch in update".to_string())
168 .into_response();
169 }
170 Json(serde_json::json!({
171 "id": new_id,
172 "message": "Chain updated successfully"
173 }))
174 .into_response()
175 }
176 Err(e) => {
177 (StatusCode::BAD_REQUEST, format!("Failed to update chain: {}", e)).into_response()
178 }
179 }
180}
181
182pub async fn delete_chain(
184 Path(chain_id): Path<String>,
185 State(state): State<ChainState>,
186) -> Response {
187 match state.registry.remove_chain(&chain_id).await {
188 Ok(_) => Json(serde_json::json!({
189 "id": chain_id,
190 "message": "Chain deleted successfully"
191 }))
192 .into_response(),
193 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to delete chain: {}", e))
194 .into_response(),
195 }
196}
197
198pub async fn execute_chain(
200 Path(chain_id): Path<String>,
201 State(state): State<ChainState>,
202 Json(request): Json<ChainExecutionRequest>,
203) -> Response {
204 match state.engine.execute_chain(&chain_id, request.variables).await {
205 Ok(result) => Json(ChainExecutionResponse {
206 chain_id: result.chain_id,
207 status: match result.status {
208 mockforge_core::chain_execution::ChainExecutionStatus::Successful => {
209 "successful".to_string()
210 }
211 mockforge_core::chain_execution::ChainExecutionStatus::PartialSuccess => {
212 "partial_success".to_string()
213 }
214 mockforge_core::chain_execution::ChainExecutionStatus::Failed => {
215 "failed".to_string()
216 }
217 },
218 total_duration_ms: result.total_duration_ms,
219 request_results: Some(serde_json::to_value(result.request_results).unwrap_or_default()),
220 error_message: result.error_message,
221 })
222 .into_response(),
223 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to execute chain: {}", e))
224 .into_response(),
225 }
226}
227
228pub async fn validate_chain(
230 Path(chain_id): Path<String>,
231 State(state): State<ChainState>,
232) -> Response {
233 match state.registry.get_chain(&chain_id).await {
234 Some(chain) => {
235 match state.registry.validate_chain(&chain).await {
236 Ok(()) => Json(ChainValidationResponse {
237 valid: true,
238 errors: vec![],
239 warnings: vec![], })
241 .into_response(),
242 Err(e) => Json(ChainValidationResponse {
243 valid: false,
244 errors: vec![e.to_string()],
245 warnings: vec![],
246 })
247 .into_response(),
248 }
249 }
250 None => (StatusCode::NOT_FOUND, format!("Chain '{}' not found", chain_id)).into_response(),
251 }
252}
253
254pub async fn get_chain_history(
256 Path(chain_id): Path<String>,
257 State(state): State<ChainState>,
258) -> Response {
259 if state.registry.get_chain(&chain_id).await.is_none() {
261 return (StatusCode::NOT_FOUND, format!("Chain '{}' not found", chain_id)).into_response();
262 }
263
264 let history = state.engine.get_chain_history(&chain_id).await;
265
266 let executions: Vec<ChainExecutionRecord> = history
267 .into_iter()
268 .map(|record| ChainExecutionRecord {
269 executed_at: record.executed_at,
270 status: match record.result.status {
271 mockforge_core::chain_execution::ChainExecutionStatus::Successful => {
272 "successful".to_string()
273 }
274 mockforge_core::chain_execution::ChainExecutionStatus::PartialSuccess => {
275 "partial_success".to_string()
276 }
277 mockforge_core::chain_execution::ChainExecutionStatus::Failed => {
278 "failed".to_string()
279 }
280 },
281 total_duration_ms: record.result.total_duration_ms,
282 request_count: record.result.request_results.len(),
283 error_message: record.result.error_message,
284 })
285 .collect();
286
287 let total = executions.len();
288
289 Json(ChainExecutionHistoryResponse {
290 chain_id,
291 executions,
292 total,
293 })
294 .into_response()
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use mockforge_core::chain_execution::ChainExecutionEngine;
301 use mockforge_core::request_chaining::{ChainConfig, RequestChainRegistry};
302 use std::sync::Arc;
303
304 #[tokio::test]
305 async fn test_chain_state_creation() {
306 let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
307 let engine = Arc::new(ChainExecutionEngine::new(registry.clone(), ChainConfig::default()));
308 let _state = create_chain_state(registry, engine);
309
310 }
312}